From d591b140c8039ceecb589e4d3e9cf67881d16bc1 Mon Sep 17 00:00:00 2001 From: Tim O'Farrell Date: Mon, 16 Mar 2026 05:19:31 -0600 Subject: [PATCH 1/5] feat: Add configurable sandbox reuse with grouping strategies (#11922) Co-authored-by: openhands --- .../100_add_sandbox_grouping_strategy.py | 33 ++++ enterprise/storage/org.py | 1 + enterprise/storage/saas_settings_store.py | 3 + enterprise/storage/user.py | 1 + enterprise/storage/user_settings.py | 1 + frontend/__tests__/utils/get-git-path.test.ts | 22 ++- frontend/src/hooks/query/use-settings.ts | 3 + .../query/use-unified-get-git-changes.ts | 2 +- .../src/hooks/query/use-unified-git-diff.ts | 2 +- frontend/src/i18n/declaration.ts | 6 + frontend/src/i18n/translation.json | 96 +++++++++ frontend/src/routes/app-settings.tsx | 50 +++++ frontend/src/services/settings.ts | 1 + frontend/src/types/settings.ts | 12 ++ frontend/src/utils/feature-flags.ts | 2 + frontend/src/utils/get-git-path.ts | 5 +- .../app_conversation_models.py | 4 + .../live_status_app_conversation_service.py | 185 ++++++++++++++++-- .../server/routes/manage_conversations.py | 34 +++- openhands/storage/data_models/settings.py | 18 ++ ...st_live_status_app_conversation_service.py | 131 +++++++++++-- .../server/data_models/test_conversation.py | 3 + 22 files changed, 569 insertions(+), 46 deletions(-) create mode 100644 enterprise/migrations/versions/100_add_sandbox_grouping_strategy.py diff --git a/enterprise/migrations/versions/100_add_sandbox_grouping_strategy.py b/enterprise/migrations/versions/100_add_sandbox_grouping_strategy.py new file mode 100644 index 0000000000..e58f6fbcf6 --- /dev/null +++ b/enterprise/migrations/versions/100_add_sandbox_grouping_strategy.py @@ -0,0 +1,33 @@ +"""Add sandbox_grouping_strategy column to user, org, and user_settings tables. + +Revision ID: 100 +Revises: 099 +Create Date: 2025-03-12 +""" + +import sqlalchemy as sa +from alembic import op + +revision = '100' +down_revision = '099' + + +def upgrade() -> None: + op.add_column( + 'user', + sa.Column('sandbox_grouping_strategy', sa.String, nullable=True), + ) + op.add_column( + 'org', + sa.Column('sandbox_grouping_strategy', sa.String, nullable=True), + ) + op.add_column( + 'user_settings', + sa.Column('sandbox_grouping_strategy', sa.String, nullable=True), + ) + + +def downgrade() -> None: + op.drop_column('user_settings', 'sandbox_grouping_strategy') + op.drop_column('org', 'sandbox_grouping_strategy') + op.drop_column('user', 'sandbox_grouping_strategy') diff --git a/enterprise/storage/org.py b/enterprise/storage/org.py index b0ec98b0a2..3b0b898fd1 100644 --- a/enterprise/storage/org.py +++ b/enterprise/storage/org.py @@ -47,6 +47,7 @@ class Org(Base): # type: ignore conversation_expiration = Column(Integer, nullable=True) condenser_max_size = Column(Integer, nullable=True) byor_export_enabled = Column(Boolean, nullable=False, default=False) + sandbox_grouping_strategy = Column(String, nullable=True) # Relationships org_members = relationship('OrgMember', back_populates='org') diff --git a/enterprise/storage/saas_settings_store.py b/enterprise/storage/saas_settings_store.py index d062ff6e7e..b2fbdac2bd 100644 --- a/enterprise/storage/saas_settings_store.py +++ b/enterprise/storage/saas_settings_store.py @@ -117,6 +117,9 @@ class SaasSettingsStore(SettingsStore): kwargs['llm_base_url'] = org_member.llm_base_url if org.v1_enabled is None: kwargs['v1_enabled'] = True + # Apply default if sandbox_grouping_strategy is None in the database + if kwargs.get('sandbox_grouping_strategy') is None: + kwargs.pop('sandbox_grouping_strategy', None) settings = Settings(**kwargs) return settings diff --git a/enterprise/storage/user.py b/enterprise/storage/user.py index adedf85366..2df86a7039 100644 --- a/enterprise/storage/user.py +++ b/enterprise/storage/user.py @@ -33,6 +33,7 @@ class User(Base): # type: ignore email_verified = Column(Boolean, nullable=True) git_user_name = Column(String, nullable=True) git_user_email = Column(String, nullable=True) + sandbox_grouping_strategy = Column(String, nullable=True) # Relationships role = relationship('Role', back_populates='users') diff --git a/enterprise/storage/user_settings.py b/enterprise/storage/user_settings.py index 96ccc9653e..3e62c3e930 100644 --- a/enterprise/storage/user_settings.py +++ b/enterprise/storage/user_settings.py @@ -27,6 +27,7 @@ class UserSettings(Base): # type: ignore ) sandbox_base_container_image = Column(String, nullable=True) sandbox_runtime_container_image = Column(String, nullable=True) + sandbox_grouping_strategy = Column(String, nullable=True) user_version = Column(Integer, nullable=False, default=0) accepted_tos = Column(DateTime, nullable=True) mcp_config = Column(JSON, nullable=True) diff --git a/frontend/__tests__/utils/get-git-path.test.ts b/frontend/__tests__/utils/get-git-path.test.ts index d1507228fc..2adfc232d4 100644 --- a/frontend/__tests__/utils/get-git-path.test.ts +++ b/frontend/__tests__/utils/get-git-path.test.ts @@ -2,27 +2,29 @@ import { describe, it, expect } from "vitest"; import { getGitPath } from "#/utils/get-git-path"; describe("getGitPath", () => { - it("should return /workspace/project when no repository is selected", () => { - expect(getGitPath(null)).toBe("/workspace/project"); - expect(getGitPath(undefined)).toBe("/workspace/project"); + const conversationId = "abc123"; + + it("should return /workspace/project/{conversationId} when no repository is selected", () => { + expect(getGitPath(conversationId, null)).toBe(`/workspace/project/${conversationId}`); + expect(getGitPath(conversationId, undefined)).toBe(`/workspace/project/${conversationId}`); }); it("should handle standard owner/repo format (GitHub)", () => { - expect(getGitPath("OpenHands/OpenHands")).toBe("/workspace/project/OpenHands"); - expect(getGitPath("facebook/react")).toBe("/workspace/project/react"); + expect(getGitPath(conversationId, "OpenHands/OpenHands")).toBe(`/workspace/project/${conversationId}/OpenHands`); + expect(getGitPath(conversationId, "facebook/react")).toBe(`/workspace/project/${conversationId}/react`); }); it("should handle nested group paths (GitLab)", () => { - expect(getGitPath("modernhealth/frontend-guild/pan")).toBe("/workspace/project/pan"); - expect(getGitPath("group/subgroup/repo")).toBe("/workspace/project/repo"); - expect(getGitPath("a/b/c/d/repo")).toBe("/workspace/project/repo"); + expect(getGitPath(conversationId, "modernhealth/frontend-guild/pan")).toBe(`/workspace/project/${conversationId}/pan`); + expect(getGitPath(conversationId, "group/subgroup/repo")).toBe(`/workspace/project/${conversationId}/repo`); + expect(getGitPath(conversationId, "a/b/c/d/repo")).toBe(`/workspace/project/${conversationId}/repo`); }); it("should handle single segment paths", () => { - expect(getGitPath("repo")).toBe("/workspace/project/repo"); + expect(getGitPath(conversationId, "repo")).toBe(`/workspace/project/${conversationId}/repo`); }); it("should handle empty string", () => { - expect(getGitPath("")).toBe("/workspace/project"); + expect(getGitPath(conversationId, "")).toBe(`/workspace/project/${conversationId}`); }); }); diff --git a/frontend/src/hooks/query/use-settings.ts b/frontend/src/hooks/query/use-settings.ts index ce01e4f69b..6c6d766b69 100644 --- a/frontend/src/hooks/query/use-settings.ts +++ b/frontend/src/hooks/query/use-settings.ts @@ -18,6 +18,9 @@ const getSettingsQueryFn = async (): Promise => { git_user_email: settings.git_user_email || DEFAULT_SETTINGS.git_user_email, is_new_user: false, v1_enabled: settings.v1_enabled ?? DEFAULT_SETTINGS.v1_enabled, + sandbox_grouping_strategy: + settings.sandbox_grouping_strategy ?? + DEFAULT_SETTINGS.sandbox_grouping_strategy, }; }; diff --git a/frontend/src/hooks/query/use-unified-get-git-changes.ts b/frontend/src/hooks/query/use-unified-get-git-changes.ts index 6b0856031c..70bc5f451f 100644 --- a/frontend/src/hooks/query/use-unified-get-git-changes.ts +++ b/frontend/src/hooks/query/use-unified-get-git-changes.ts @@ -27,7 +27,7 @@ export const useUnifiedGetGitChanges = () => { // Calculate git path based on selected repository const gitPath = React.useMemo( - () => getGitPath(selectedRepository), + () => getGitPath(conversationId, selectedRepository), [selectedRepository], ); diff --git a/frontend/src/hooks/query/use-unified-git-diff.ts b/frontend/src/hooks/query/use-unified-git-diff.ts index 33fedb497b..26bca16fce 100644 --- a/frontend/src/hooks/query/use-unified-git-diff.ts +++ b/frontend/src/hooks/query/use-unified-git-diff.ts @@ -32,7 +32,7 @@ export const useUnifiedGitDiff = (config: UseUnifiedGitDiffConfig) => { const absoluteFilePath = React.useMemo(() => { if (!isV1Conversation) return config.filePath; - const gitPath = getGitPath(selectedRepository); + const gitPath = getGitPath(conversationId, selectedRepository); return `${gitPath}/${config.filePath}`; }, [isV1Conversation, selectedRepository, config.filePath]); diff --git a/frontend/src/i18n/declaration.ts b/frontend/src/i18n/declaration.ts index 10e9d885fd..648143fc2e 100644 --- a/frontend/src/i18n/declaration.ts +++ b/frontend/src/i18n/declaration.ts @@ -175,6 +175,12 @@ export enum I18nKey { SETTINGS$MAX_BUDGET_PER_CONVERSATION = "SETTINGS$MAX_BUDGET_PER_CONVERSATION", SETTINGS$PROACTIVE_CONVERSATION_STARTERS = "SETTINGS$PROACTIVE_CONVERSATION_STARTERS", SETTINGS$SOLVABILITY_ANALYSIS = "SETTINGS$SOLVABILITY_ANALYSIS", + SETTINGS$SANDBOX_GROUPING_STRATEGY = "SETTINGS$SANDBOX_GROUPING_STRATEGY", + SETTINGS$SANDBOX_GROUPING_NO_GROUPING = "SETTINGS$SANDBOX_GROUPING_NO_GROUPING", + SETTINGS$SANDBOX_GROUPING_GROUP_BY_NEWEST = "SETTINGS$SANDBOX_GROUPING_GROUP_BY_NEWEST", + SETTINGS$SANDBOX_GROUPING_LEAST_RECENTLY_USED = "SETTINGS$SANDBOX_GROUPING_LEAST_RECENTLY_USED", + SETTINGS$SANDBOX_GROUPING_FEWEST_CONVERSATIONS = "SETTINGS$SANDBOX_GROUPING_FEWEST_CONVERSATIONS", + SETTINGS$SANDBOX_GROUPING_ADD_TO_ANY = "SETTINGS$SANDBOX_GROUPING_ADD_TO_ANY", SETTINGS$SEARCH_API_KEY = "SETTINGS$SEARCH_API_KEY", SETTINGS$SEARCH_API_KEY_OPTIONAL = "SETTINGS$SEARCH_API_KEY_OPTIONAL", SETTINGS$SEARCH_API_KEY_INSTRUCTIONS = "SETTINGS$SEARCH_API_KEY_INSTRUCTIONS", diff --git a/frontend/src/i18n/translation.json b/frontend/src/i18n/translation.json index abeba30110..d3f91ceec7 100644 --- a/frontend/src/i18n/translation.json +++ b/frontend/src/i18n/translation.json @@ -2799,6 +2799,102 @@ "tr": "Çözünürlük Analizini Etkinleştir", "uk": "Увімкнути аналіз розв'язності" }, + "SETTINGS$SANDBOX_GROUPING_STRATEGY": { + "en": "Sandbox Grouping Strategy", + "ja": "サンドボックスグループ化戦略", + "zh-CN": "沙盒分组策略", + "zh-TW": "沙盒分組策略", + "ko-KR": "샌드박스 그룹화 전략", + "de": "Sandbox-Gruppierungsstrategie", + "no": "Sandkasse-grupperingsstrategi", + "it": "Strategia di raggruppamento sandbox", + "pt": "Estratégia de agrupamento de sandbox", + "es": "Estrategia de agrupación de sandbox", + "ar": "استراتيجية تجميع صندوق الرمل", + "fr": "Stratégie de regroupement sandbox", + "tr": "Sandbox Gruplama Stratejisi", + "uk": "Стратегія групування пісочниці" + }, + "SETTINGS$SANDBOX_GROUPING_NO_GROUPING": { + "en": "No Grouping (new sandbox per conversation)", + "ja": "グループ化なし (会話ごとに新しいサンドボックス)", + "zh-CN": "不分组 (每个对话使用新沙盒)", + "zh-TW": "不分組 (每個對話使用新沙盒)", + "ko-KR": "그룹화 없음 (대화마다 새 샌드박스)", + "de": "Keine Gruppierung (neue Sandbox pro Gespräch)", + "no": "Ingen gruppering (ny sandkasse per samtale)", + "it": "Nessun raggruppamento (nuova sandbox per conversazione)", + "pt": "Sem agrupamento (novo sandbox por conversa)", + "es": "Sin agrupación (nuevo sandbox por conversación)", + "ar": "بدون تجميع (صندوق رمل جديد لكل محادثة)", + "fr": "Pas de regroupement (nouveau sandbox par conversation)", + "tr": "Gruplama Yok (konuşma başına yeni sandbox)", + "uk": "Без групування (нова пісочниця для кожної розмови)" + }, + "SETTINGS$SANDBOX_GROUPING_GROUP_BY_NEWEST": { + "en": "Group by Newest (add to most recent sandbox)", + "ja": "最新でグループ化 (最新のサンドボックスに追加)", + "zh-CN": "按最新分组 (添加到最近的沙盒)", + "zh-TW": "按最新分組 (添加到最近的沙盒)", + "ko-KR": "최신으로 그룹화 (가장 최근 샌드박스에 추가)", + "de": "Nach neuester gruppieren (zur neuesten Sandbox hinzufügen)", + "no": "Grupper etter nyeste (legg til i nyeste sandkasse)", + "it": "Raggruppa per più recente (aggiungi alla sandbox più recente)", + "pt": "Agrupar por mais recente (adicionar ao sandbox mais recente)", + "es": "Agrupar por más reciente (agregar al sandbox más reciente)", + "ar": "التجميع حسب الأحدث (إضافة إلى أحدث صندوق رمل)", + "fr": "Regrouper par le plus récent (ajouter au sandbox le plus récent)", + "tr": "En Yeniye Göre Grupla (en yeni sandbox'a ekle)", + "uk": "Групувати за найновішим (додати до найновішої пісочниці)" + }, + "SETTINGS$SANDBOX_GROUPING_LEAST_RECENTLY_USED": { + "en": "Least Recently Used (add to oldest sandbox)", + "ja": "最も古い (最も古いサンドボックスに追加)", + "zh-CN": "最近最少使用 (添加到最旧的沙盒)", + "zh-TW": "最近最少使用 (添加到最舊的沙盒)", + "ko-KR": "가장 오래된 것 (가장 오래된 샌드박스에 추가)", + "de": "Am ältesten (zur ältesten Sandbox hinzufügen)", + "no": "Eldst (legg til i eldste sandkasse)", + "it": "Meno usato di recente (aggiungi alla sandbox più vecchia)", + "pt": "Menos usado recentemente (adicionar ao sandbox mais antigo)", + "es": "Menos usado recientemente (agregar al sandbox más antiguo)", + "ar": "الأقل استخدامًا مؤخرًا (إضافة إلى أقدم صندوق رمل)", + "fr": "Le moins récemment utilisé (ajouter au sandbox le plus ancien)", + "tr": "En Az Kullanılan (en eski sandbox'a ekle)", + "uk": "Найменш нещодавно використана (додати до найстаршої пісочниці)" + }, + "SETTINGS$SANDBOX_GROUPING_FEWEST_CONVERSATIONS": { + "en": "Fewest Conversations (add to least busy sandbox)", + "ja": "会話数が最少 (最も空いているサンドボックスに追加)", + "zh-CN": "最少对话 (添加到最空闲的沙盒)", + "zh-TW": "最少對話 (添加到最空閒的沙盒)", + "ko-KR": "대화 수가 가장 적은 (가장 한가한 샌드박스에 추가)", + "de": "Wenigste Gespräche (zur am wenigsten beschäftigten Sandbox hinzufügen)", + "no": "Færrest samtaler (legg til i minst opptatt sandkasse)", + "it": "Meno conversazioni (aggiungi alla sandbox meno occupata)", + "pt": "Menos conversas (adicionar ao sandbox menos ocupado)", + "es": "Menos conversaciones (agregar al sandbox menos ocupado)", + "ar": "أقل محادثات (إضافة إلى صندوق الرمل الأقل انشغالاً)", + "fr": "Moins de conversations (ajouter au sandbox le moins occupé)", + "tr": "En Az Konuşma (en az meşgul sandbox'a ekle)", + "uk": "Найменше розмов (додати до найменш зайнятої пісочниці)" + }, + "SETTINGS$SANDBOX_GROUPING_ADD_TO_ANY": { + "en": "Add to Any (use first available sandbox)", + "ja": "任意に追加 (最初に利用可能なサンドボックスを使用)", + "zh-CN": "添加到任意 (使用第一个可用的沙盒)", + "zh-TW": "添加到任意 (使用第一個可用的沙盒)", + "ko-KR": "아무 곳에나 추가 (첫 번째 사용 가능한 샌드박스 사용)", + "de": "Zu beliebig hinzufügen (erste verfügbare Sandbox verwenden)", + "no": "Legg til i hvilken som helst (bruk første tilgjengelige sandkasse)", + "it": "Aggiungi a qualsiasi (usa la prima sandbox disponibile)", + "pt": "Adicionar a qualquer (usar o primeiro sandbox disponível)", + "es": "Agregar a cualquiera (usar el primer sandbox disponible)", + "ar": "إضافة إلى أي (استخدام أول صندوق رمل متاح)", + "fr": "Ajouter à n'importe lequel (utiliser le premier sandbox disponible)", + "tr": "Herhangi Birine Ekle (ilk uygun sandbox'ı kullan)", + "uk": "Додати до будь-якої (використовувати першу доступну пісочницю)" + }, "SETTINGS$SEARCH_API_KEY": { "en": "Search API Key (Tavily)", "ja": "検索APIキー (Tavily)", diff --git a/frontend/src/routes/app-settings.tsx b/frontend/src/routes/app-settings.tsx index 8226488468..43753fbec7 100644 --- a/frontend/src/routes/app-settings.tsx +++ b/frontend/src/routes/app-settings.tsx @@ -8,6 +8,7 @@ import { DEFAULT_SETTINGS } from "#/services/settings"; import { BrandButton } from "#/components/features/settings/brand-button"; import { SettingsSwitch } from "#/components/features/settings/settings-switch"; import { SettingsInput } from "#/components/features/settings/settings-input"; +import { SettingsDropdownInput } from "#/components/features/settings/settings-dropdown-input"; import { I18nKey } from "#/i18n/declaration"; import { LanguageInput } from "#/components/features/settings/app-settings/language-input"; import { handleCaptureConsent } from "#/utils/handle-capture-consent"; @@ -19,6 +20,11 @@ import { retrieveAxiosErrorMessage } from "#/utils/retrieve-axios-error-message" import { AppSettingsInputsSkeleton } from "#/components/features/settings/app-settings/app-settings-inputs-skeleton"; import { useConfig } from "#/hooks/query/use-config"; import { parseMaxBudgetPerTask } from "#/utils/settings-utils"; +import { + SandboxGroupingStrategy, + SandboxGroupingStrategyOptions, +} from "#/types/settings"; +import { ENABLE_SANDBOX_GROUPING } from "#/utils/feature-flags"; import { createPermissionGuard } from "#/utils/org/permission-guard"; export const clientLoader = createPermissionGuard( @@ -49,6 +55,12 @@ function AppSettingsScreen() { solvabilityAnalysisSwitchHasChanged, setSolvabilityAnalysisSwitchHasChanged, ] = React.useState(false); + const [ + sandboxGroupingStrategyHasChanged, + setSandboxGroupingStrategyHasChanged, + ] = React.useState(false); + const [selectedSandboxGroupingStrategy, setSelectedSandboxGroupingStrategy] = + React.useState(null); const [maxBudgetPerTaskHasChanged, setMaxBudgetPerTaskHasChanged] = React.useState(false); const [gitUserNameHasChanged, setGitUserNameHasChanged] = @@ -75,6 +87,11 @@ function AppSettingsScreen() { const enableSolvabilityAnalysis = formData.get("enable-solvability-analysis-switch")?.toString() === "on"; + const sandboxGroupingStrategy = + selectedSandboxGroupingStrategy || + settings?.sandbox_grouping_strategy || + DEFAULT_SETTINGS.sandbox_grouping_strategy; + const maxBudgetPerTaskValue = formData .get("max-budget-per-task-input") ?.toString(); @@ -94,6 +111,7 @@ function AppSettingsScreen() { enable_sound_notifications: enableSoundNotifications, enable_proactive_conversation_starters: enableProactiveConversations, enable_solvability_analysis: enableSolvabilityAnalysis, + sandbox_grouping_strategy: sandboxGroupingStrategy, max_budget_per_task: maxBudgetPerTask, git_user_name: gitUserName, git_user_email: gitUserEmail, @@ -112,6 +130,8 @@ function AppSettingsScreen() { setAnalyticsSwitchHasChanged(false); setSoundNotificationsSwitchHasChanged(false); setProactiveConversationsSwitchHasChanged(false); + setSandboxGroupingStrategyHasChanged(false); + setSelectedSandboxGroupingStrategy(null); setMaxBudgetPerTaskHasChanged(false); setGitUserNameHasChanged(false); setGitUserEmailHasChanged(false); @@ -159,6 +179,15 @@ function AppSettingsScreen() { ); }; + const handleSandboxGroupingStrategyChange = (key: React.Key | null) => { + const newStrategy = key?.toString() as SandboxGroupingStrategy | undefined; + setSelectedSandboxGroupingStrategy(newStrategy || null); + const currentStrategy = + settings?.sandbox_grouping_strategy || + DEFAULT_SETTINGS.sandbox_grouping_strategy; + setSandboxGroupingStrategyHasChanged(newStrategy !== currentStrategy); + }; + const checkIfMaxBudgetPerTaskHasChanged = (value: string) => { const newValue = parseMaxBudgetPerTask(value); const currentValue = settings?.max_budget_per_task; @@ -181,6 +210,7 @@ function AppSettingsScreen() { !soundNotificationsSwitchHasChanged && !proactiveConversationsSwitchHasChanged && !solvabilityAnalysisSwitchHasChanged && + !sandboxGroupingStrategyHasChanged && !maxBudgetPerTaskHasChanged && !gitUserNameHasChanged && !gitUserEmailHasChanged; @@ -244,6 +274,26 @@ function AppSettingsScreen() { )} + {ENABLE_SANDBOX_GROUPING() && ( + ({ + key, + label: t(`SETTINGS$SANDBOX_GROUPING_${key}` as I18nKey), + }))} + selectedKey={ + selectedSandboxGroupingStrategy || + settings.sandbox_grouping_strategy || + DEFAULT_SETTINGS.sandbox_grouping_strategy + } + isClearable={false} + onSelectionChange={handleSandboxGroupingStrategyChange} + wrapperClassName="w-full max-w-[680px]" + /> + )} + {!settings?.v1_enabled && ( loadFeatureFlag("VSCODE_IN_NEW_TAB"); export const ENABLE_TRAJECTORY_REPLAY = () => loadFeatureFlag("TRAJECTORY_REPLAY"); export const ENABLE_ONBOARDING = () => loadFeatureFlag("ENABLE_ONBOARDING"); +export const ENABLE_SANDBOX_GROUPING = () => + loadFeatureFlag("SANDBOX_GROUPING"); diff --git a/frontend/src/utils/get-git-path.ts b/frontend/src/utils/get-git-path.ts index 15c8ff947e..39292b819f 100644 --- a/frontend/src/utils/get-git-path.ts +++ b/frontend/src/utils/get-git-path.ts @@ -7,10 +7,11 @@ * @returns The git path to use */ export function getGitPath( + conversationId: string, selectedRepository: string | null | undefined, ): string { if (!selectedRepository) { - return "/workspace/project"; + return `/workspace/project/${conversationId}`; } // Extract the repository name from the path @@ -18,5 +19,5 @@ export function getGitPath( const parts = selectedRepository.split("/"); const repoName = parts[parts.length - 1]; - return `/workspace/project/${repoName}`; + return `/workspace/project/${conversationId}/${repoName}`; } diff --git a/openhands/app_server/app_conversation/app_conversation_models.py b/openhands/app_server/app_conversation/app_conversation_models.py index a30b40e56c..b7a4cc4dce 100644 --- a/openhands/app_server/app_conversation/app_conversation_models.py +++ b/openhands/app_server/app_conversation/app_conversation_models.py @@ -16,6 +16,10 @@ from openhands.sdk.conversation.state import ConversationExecutionStatus from openhands.sdk.llm import MetricsSnapshot from openhands.sdk.plugin import PluginSource from openhands.storage.data_models.conversation_metadata import ConversationTrigger +from openhands.storage.data_models.settings import SandboxGroupingStrategy + +# Re-export SandboxGroupingStrategy for backward compatibility +__all__ = ['SandboxGroupingStrategy'] class AgentType(Enum): diff --git a/openhands/app_server/app_conversation/live_status_app_conversation_service.py b/openhands/app_server/app_conversation/live_status_app_conversation_service.py index 94b5740329..7bb59fedbf 100644 --- a/openhands/app_server/app_conversation/live_status_app_conversation_service.py +++ b/openhands/app_server/app_conversation/live_status_app_conversation_service.py @@ -88,6 +88,7 @@ from openhands.sdk.utils.paging import page_iterator from openhands.sdk.workspace.remote.async_remote_workspace import AsyncRemoteWorkspace from openhands.server.types import AppMode from openhands.storage.data_models.conversation_metadata import ConversationTrigger +from openhands.storage.data_models.settings import SandboxGroupingStrategy from openhands.tools.preset.default import ( get_default_tools, ) @@ -128,6 +129,7 @@ class LiveStatusAppConversationService(AppConversationServiceBase): jwt_service: JwtService sandbox_startup_timeout: int sandbox_startup_poll_frequency: int + max_num_conversations_per_sandbox: int httpx_client: httpx.AsyncClient web_url: str | None openhands_provider_base_url: str | None @@ -135,6 +137,11 @@ class LiveStatusAppConversationService(AppConversationServiceBase): app_mode: str | None = None tavily_api_key: str | None = None + async def _get_sandbox_grouping_strategy(self) -> SandboxGroupingStrategy: + """Get the sandbox grouping strategy from user settings.""" + user_info = await self.user_context.get_user_info() + return user_info.sandbox_grouping_strategy + async def search_app_conversations( self, title__contains: str | None = None, @@ -255,11 +262,20 @@ class LiveStatusAppConversationService(AppConversationServiceBase): ) assert sandbox_spec is not None + # Set up conversation id + conversation_id = request.conversation_id or uuid4() + + # Setup working dir based on grouping + working_dir = sandbox_spec.working_dir + sandbox_grouping_strategy = await self._get_sandbox_grouping_strategy() + if sandbox_grouping_strategy != SandboxGroupingStrategy.NO_GROUPING: + working_dir = f'{working_dir}/{conversation_id.hex}' + # Run setup scripts remote_workspace = AsyncRemoteWorkspace( host=agent_server_url, api_key=sandbox.session_api_key, - working_dir=sandbox_spec.working_dir, + working_dir=working_dir, ) async for updated_task in self.run_setup_scripts( task, sandbox, remote_workspace, agent_server_url @@ -270,13 +286,13 @@ class LiveStatusAppConversationService(AppConversationServiceBase): start_conversation_request = ( await self._build_start_conversation_request_for_user( sandbox, + conversation_id, request.initial_message, request.system_message_suffix, request.git_provider, - sandbox_spec.working_dir, + working_dir, request.agent_type, request.llm_model, - request.conversation_id, remote_workspace=remote_workspace, selected_repository=request.selected_repository, plugins=request.plugins, @@ -495,21 +511,157 @@ class LiveStatusAppConversationService(AppConversationServiceBase): result[stored_conversation.sandbox_id].append(stored_conversation.id) return result + async def _find_running_sandbox_for_user(self) -> SandboxInfo | None: + """Find a running sandbox for the current user based on the grouping strategy. + + Returns: + SandboxInfo if a running sandbox is found, None otherwise. + """ + try: + user_id = await self.user_context.get_user_id() + sandbox_grouping_strategy = await self._get_sandbox_grouping_strategy() + + # If no grouping, return None to force creation of a new sandbox + if sandbox_grouping_strategy == SandboxGroupingStrategy.NO_GROUPING: + return None + + # Collect all running sandboxes for this user + running_sandboxes = [] + page_id = None + while True: + page = await self.sandbox_service.search_sandboxes( + page_id=page_id, limit=100 + ) + + for sandbox in page.items: + if ( + sandbox.status == SandboxStatus.RUNNING + and sandbox.created_by_user_id == user_id + ): + running_sandboxes.append(sandbox) + + if page.next_page_id is None: + break + page_id = page.next_page_id + + if not running_sandboxes: + return None + + # Apply the grouping strategy + return await self._select_sandbox_by_strategy( + running_sandboxes, sandbox_grouping_strategy + ) + + except Exception as e: + _logger.warning( + f'Error finding running sandbox for user: {e}', exc_info=True + ) + return None + + async def _select_sandbox_by_strategy( + self, + running_sandboxes: list[SandboxInfo], + sandbox_grouping_strategy: SandboxGroupingStrategy, + ) -> SandboxInfo | None: + """Select a sandbox from the list based on the configured grouping strategy. + + Args: + running_sandboxes: List of running sandboxes for the user + sandbox_grouping_strategy: The strategy to use for selection + + Returns: + Selected sandbox based on the strategy, or None if no sandbox is available + (e.g., all sandboxes have reached max_num_conversations_per_sandbox) + """ + # Get conversation counts for filtering by max_num_conversations_per_sandbox + sandbox_conversation_counts = await self._get_conversation_counts_by_sandbox( + [s.id for s in running_sandboxes] + ) + + # Filter out sandboxes that have reached the max number of conversations + available_sandboxes = [ + s + for s in running_sandboxes + if sandbox_conversation_counts.get(s.id, 0) + < self.max_num_conversations_per_sandbox + ] + + if not available_sandboxes: + # All sandboxes have reached the max - need to create a new one + return None + + if sandbox_grouping_strategy == SandboxGroupingStrategy.ADD_TO_ANY: + # Return the first available sandbox + return available_sandboxes[0] + + elif sandbox_grouping_strategy == SandboxGroupingStrategy.GROUP_BY_NEWEST: + # Return the most recently created sandbox + return max(available_sandboxes, key=lambda s: s.created_at) + + elif sandbox_grouping_strategy == SandboxGroupingStrategy.LEAST_RECENTLY_USED: + # Return the least recently created sandbox (oldest) + return min(available_sandboxes, key=lambda s: s.created_at) + + elif sandbox_grouping_strategy == SandboxGroupingStrategy.FEWEST_CONVERSATIONS: + # Return the one with fewest conversations + return min( + available_sandboxes, + key=lambda s: sandbox_conversation_counts.get(s.id, 0), + ) + + else: + # Default fallback - return first sandbox + return available_sandboxes[0] + + async def _get_conversation_counts_by_sandbox( + self, sandbox_ids: list[str] + ) -> dict[str, int]: + """Get the count of conversations for each sandbox. + + Args: + sandbox_ids: List of sandbox IDs to count conversations for + + Returns: + Dictionary mapping sandbox_id to conversation count + """ + try: + # Query count for each sandbox individually + # This is efficient since there are at most ~8 running sandboxes per user + counts: dict[str, int] = {} + for sandbox_id in sandbox_ids: + count = await self.app_conversation_info_service.count_app_conversation_info( + sandbox_id__eq=sandbox_id + ) + counts[sandbox_id] = count + return counts + except Exception as e: + _logger.warning( + f'Error counting conversations by sandbox: {e}', exc_info=True + ) + # Return empty counts on error - will default to first sandbox + return {} + async def _wait_for_sandbox_start( self, task: AppConversationStartTask ) -> AsyncGenerator[AppConversationStartTask, None]: """Wait for sandbox to start and return info.""" # Get or create the sandbox if not task.request.sandbox_id: - # Convert conversation_id to hex string if present - sandbox_id_str = ( - task.request.conversation_id.hex - if task.request.conversation_id is not None - else None - ) - sandbox = await self.sandbox_service.start_sandbox( - sandbox_id=sandbox_id_str - ) + # First try to find a running sandbox for the current user + sandbox = await self._find_running_sandbox_for_user() + if sandbox is None: + # No running sandbox found, start a new one + + # Convert conversation_id to hex string if present + sandbox_id_str = ( + task.request.conversation_id.hex + if task.request.conversation_id is not None + else None + ) + + sandbox = await self.sandbox_service.start_sandbox( + sandbox_id=sandbox_id_str + ) task.sandbox_id = sandbox.id else: sandbox_info = await self.sandbox_service.get_sandbox( @@ -1133,7 +1285,7 @@ class LiveStatusAppConversationService(AppConversationServiceBase): async def _finalize_conversation_request( self, agent: Agent, - conversation_id: UUID | None, + conversation_id: UUID, user: UserInfo, workspace: LocalWorkspace, initial_message: SendMessageRequest | None, @@ -1211,13 +1363,13 @@ class LiveStatusAppConversationService(AppConversationServiceBase): async def _build_start_conversation_request_for_user( self, sandbox: SandboxInfo, + conversation_id: UUID, initial_message: SendMessageRequest | None, system_message_suffix: str | None, git_provider: ProviderType | None, working_dir: str, agent_type: AgentType = AgentType.DEFAULT, llm_model: str | None = None, - conversation_id: UUID | None = None, remote_workspace: AsyncRemoteWorkspace | None = None, selected_repository: str | None = None, plugins: list[PluginSpec] | None = None, @@ -1614,6 +1766,10 @@ class LiveStatusAppConversationServiceInjector(AppConversationServiceInjector): sandbox_startup_poll_frequency: int = Field( default=2, description='The frequency to poll for sandbox readiness' ) + max_num_conversations_per_sandbox: int = Field( + default=20, + description='The maximum number of conversations allowed per sandbox', + ) init_git_in_empty_workspace: bool = Field( default=True, description='Whether to initialize a git repo when the workspace is empty', @@ -1705,6 +1861,7 @@ class LiveStatusAppConversationServiceInjector(AppConversationServiceInjector): jwt_service=jwt_service, sandbox_startup_timeout=self.sandbox_startup_timeout, sandbox_startup_poll_frequency=self.sandbox_startup_poll_frequency, + max_num_conversations_per_sandbox=self.max_num_conversations_per_sandbox, httpx_client=httpx_client, web_url=web_url, openhands_provider_base_url=config.openhands_provider_base_url, diff --git a/openhands/server/routes/manage_conversations.py b/openhands/server/routes/manage_conversations.py index 547ca6e252..fa73aa4d52 100644 --- a/openhands/server/routes/manage_conversations.py +++ b/openhands/server/routes/manage_conversations.py @@ -41,7 +41,7 @@ from openhands.app_server.config import ( depends_httpx_client, depends_sandbox_service, ) -from openhands.app_server.sandbox.sandbox_models import SandboxStatus +from openhands.app_server.sandbox.sandbox_models import AGENT_SERVER, SandboxStatus from openhands.app_server.sandbox.sandbox_service import SandboxService from openhands.app_server.services.db_session_injector import set_db_session_keep_open from openhands.app_server.services.httpx_client_injector import ( @@ -614,7 +614,7 @@ async def _try_delete_v1_conversation( # Delete the sandbox in the background asyncio.create_task( - _delete_sandbox_and_close_connections( + _finalize_delete_and_close_connections( sandbox_service, app_conversation_info.sandbox_id, db_session, @@ -628,14 +628,18 @@ async def _try_delete_v1_conversation( return result -async def _delete_sandbox_and_close_connections( +async def _finalize_delete_and_close_connections( sandbox_service: SandboxService, sandbox_id: str, db_session: AsyncSession, httpx_client: httpx.AsyncClient, ): try: - await sandbox_service.delete_sandbox(sandbox_id) + num_conversations_in_sandbox = await _get_num_conversations_in_sandbox( + sandbox_service, sandbox_id, httpx_client + ) + if num_conversations_in_sandbox == 0: + await sandbox_service.delete_sandbox(sandbox_id) await db_session.commit() finally: await asyncio.gather( @@ -646,6 +650,28 @@ async def _delete_sandbox_and_close_connections( ) +async def _get_num_conversations_in_sandbox( + sandbox_service: SandboxService, + sandbox_id: str, + httpx_client: httpx.AsyncClient, +) -> int: + try: + sandbox = await sandbox_service.get_sandbox(sandbox_id) + if not sandbox or not sandbox.exposed_urls: + return 0 + agent_server_url = next( + u for u in sandbox.exposed_urls if u.name == AGENT_SERVER + ) + response = await httpx_client.get( + f'{agent_server_url.url}/api/conversations/count', + headers={'X-Session-API-Key': sandbox.session_api_key}, + ) + result = int(response.content) + return result + except Exception: + return 0 + + async def _delete_v0_conversation(conversation_id: str, user_id: str | None) -> bool: """Delete a V0 conversation using the legacy logic.""" conversation_store = await ConversationStoreImpl.get_instance(config, user_id) diff --git a/openhands/storage/data_models/settings.py b/openhands/storage/data_models/settings.py index 1600acd3ad..a27c4e1b20 100644 --- a/openhands/storage/data_models/settings.py +++ b/openhands/storage/data_models/settings.py @@ -1,5 +1,6 @@ from __future__ import annotations +from enum import Enum from typing import Annotated from pydantic import ( @@ -19,6 +20,20 @@ from openhands.core.config.utils import load_openhands_config from openhands.storage.data_models.secrets import Secrets +class SandboxGroupingStrategy(str, Enum): + """Strategy for grouping conversations within sandboxes.""" + + NO_GROUPING = 'NO_GROUPING' # Default - each conversation gets its own sandbox + GROUP_BY_NEWEST = 'GROUP_BY_NEWEST' # Add to the most recently created sandbox + LEAST_RECENTLY_USED = ( + 'LEAST_RECENTLY_USED' # Add to the least recently used sandbox + ) + FEWEST_CONVERSATIONS = ( + 'FEWEST_CONVERSATIONS' # Add to sandbox with fewest conversations + ) + ADD_TO_ANY = 'ADD_TO_ANY' # Add to any available sandbox (first found) + + class Settings(BaseModel): """Persisted settings for OpenHands sessions""" @@ -54,6 +69,9 @@ class Settings(BaseModel): git_user_name: str | None = None git_user_email: str | None = None v1_enabled: bool = True + sandbox_grouping_strategy: SandboxGroupingStrategy = ( + SandboxGroupingStrategy.NO_GROUPING + ) model_config = ConfigDict( validate_assignment=True, diff --git a/tests/unit/app_server/test_live_status_app_conversation_service.py b/tests/unit/app_server/test_live_status_app_conversation_service.py index ad9b4edb46..cf32cfaf05 100644 --- a/tests/unit/app_server/test_live_status_app_conversation_service.py +++ b/tests/unit/app_server/test_live_status_app_conversation_service.py @@ -6,7 +6,7 @@ import os import zipfile from datetime import datetime from unittest.mock import AsyncMock, Mock, patch -from uuid import UUID, uuid4 +from uuid import uuid4 import pytest from pydantic import SecretStr @@ -29,6 +29,7 @@ from openhands.app_server.sandbox.sandbox_models import ( AGENT_SERVER, ExposedUrl, SandboxInfo, + SandboxPage, SandboxStatus, ) from openhands.app_server.sandbox.sandbox_spec_models import SandboxSpecInfo @@ -42,6 +43,7 @@ from openhands.sdk.workspace import LocalWorkspace from openhands.sdk.workspace.remote.async_remote_workspace import AsyncRemoteWorkspace from openhands.server.types import AppMode from openhands.storage.data_models.conversation_metadata import ConversationTrigger +from openhands.storage.data_models.settings import SandboxGroupingStrategy # Env var used by openhands SDK LLM to skip context-window validation (e.g. for gpt-4 in tests) _ALLOW_SHORT_CONTEXT_WINDOWS = 'ALLOW_SHORT_CONTEXT_WINDOWS' @@ -92,6 +94,7 @@ class TestLiveStatusAppConversationService: jwt_service=self.mock_jwt_service, sandbox_startup_timeout=30, sandbox_startup_poll_frequency=1, + max_num_conversations_per_sandbox=20, httpx_client=self.mock_httpx_client, web_url='https://test.example.com', openhands_provider_base_url='https://provider.example.com', @@ -105,6 +108,8 @@ class TestLiveStatusAppConversationService: self.mock_user.llm_model = 'gpt-4' self.mock_user.llm_base_url = 'https://api.openai.com/v1' self.mock_user.llm_api_key = 'test_api_key' + # Use ADD_TO_ANY for tests to maintain old behavior + self.mock_user.sandbox_grouping_strategy = SandboxGroupingStrategy.ADD_TO_ANY self.mock_user.confirmation_mode = False self.mock_user.search_api_key = None # Default to None self.mock_user.condenser_max_size = None # Default to None @@ -1091,11 +1096,12 @@ class TestLiveStatusAppConversationService: workspace = LocalWorkspace(working_dir='/test') secrets = {'test': StaticSecret(value='secret')} + test_conversation_id = uuid4() # Act result = await self.service._finalize_conversation_request( mock_agent, - None, + test_conversation_id, self.mock_user, workspace, None, @@ -1108,7 +1114,7 @@ class TestLiveStatusAppConversationService: # Assert assert isinstance(result, StartConversationRequest) - assert isinstance(result.conversation_id, UUID) + assert result.conversation_id == test_conversation_id @pytest.mark.asyncio async def test_finalize_conversation_request_skills_loading_fails(self): @@ -1179,13 +1185,13 @@ class TestLiveStatusAppConversationService: # Act result = await self.service._build_start_conversation_request_for_user( sandbox=self.mock_sandbox, + conversation_id=uuid4(), initial_message=None, system_message_suffix='Test suffix', git_provider=ProviderType.GITHUB, working_dir='/test/dir', agent_type=AgentType.DEFAULT, llm_model='gpt-4', - conversation_id=None, remote_workspace=None, selected_repository='test/repo', ) @@ -1215,6 +1221,98 @@ class TestLiveStatusAppConversationService: self.service._finalize_conversation_request.assert_called_once() @pytest.mark.asyncio + async def test_find_running_sandbox_for_user_found(self): + """Test _find_running_sandbox_for_user when a running sandbox is found.""" + # Arrange + user_id = 'test_user_123' + self.mock_user_context.get_user_id.return_value = user_id + + # Create mock sandboxes + running_sandbox = Mock(spec=SandboxInfo) + running_sandbox.id = 'sandbox_1' + running_sandbox.status = SandboxStatus.RUNNING + running_sandbox.created_by_user_id = user_id + + other_user_sandbox = Mock(spec=SandboxInfo) + other_user_sandbox.id = 'sandbox_2' + other_user_sandbox.status = SandboxStatus.RUNNING + other_user_sandbox.created_by_user_id = 'other_user' + + paused_sandbox = Mock(spec=SandboxInfo) + paused_sandbox.id = 'sandbox_3' + paused_sandbox.status = SandboxStatus.PAUSED + paused_sandbox.created_by_user_id = user_id + + # Mock sandbox service search + mock_page = Mock(spec=SandboxPage) + mock_page.items = [other_user_sandbox, running_sandbox, paused_sandbox] + mock_page.next_page_id = None + self.mock_sandbox_service.search_sandboxes = AsyncMock(return_value=mock_page) + + # Act + result = await self.service._find_running_sandbox_for_user() + + # Assert + assert result == running_sandbox + self.mock_user_context.get_user_id.assert_called_once() + self.mock_sandbox_service.search_sandboxes.assert_called_once_with( + page_id=None, limit=100 + ) + + @pytest.mark.asyncio + async def test_find_running_sandbox_for_user_not_found(self): + """Test _find_running_sandbox_for_user when no running sandbox is found.""" + # Arrange + user_id = 'test_user_123' + self.mock_user_context.get_user_id.return_value = user_id + + # Create mock sandboxes (none running for this user) + other_user_sandbox = Mock(spec=SandboxInfo) + other_user_sandbox.id = 'sandbox_1' + other_user_sandbox.status = SandboxStatus.RUNNING + other_user_sandbox.created_by_user_id = 'other_user' + + paused_sandbox = Mock(spec=SandboxInfo) + paused_sandbox.id = 'sandbox_2' + paused_sandbox.status = SandboxStatus.PAUSED + paused_sandbox.created_by_user_id = user_id + + # Mock sandbox service search + mock_page = Mock(spec=SandboxPage) + mock_page.items = [other_user_sandbox, paused_sandbox] + mock_page.next_page_id = None + self.mock_sandbox_service.search_sandboxes = AsyncMock(return_value=mock_page) + + # Act + result = await self.service._find_running_sandbox_for_user() + + # Assert + assert result is None + self.mock_user_context.get_user_id.assert_called_once() + self.mock_sandbox_service.search_sandboxes.assert_called_once_with( + page_id=None, limit=100 + ) + + @pytest.mark.asyncio + async def test_find_running_sandbox_for_user_exception_handling(self): + """Test _find_running_sandbox_for_user handles exceptions gracefully.""" + # Arrange + self.mock_user_context.get_user_id.side_effect = Exception('User context error') + + # Act + with patch( + 'openhands.app_server.app_conversation.live_status_app_conversation_service._logger' + ) as mock_logger: + result = await self.service._find_running_sandbox_for_user() + + # Assert + assert result is None + mock_logger.warning.assert_called_once() + assert ( + 'Error finding running sandbox for user' + in mock_logger.warning.call_args[0][0] + ) + async def test_export_conversation_success(self): """Test successful download of conversation trajectory.""" # Arrange @@ -2052,6 +2150,7 @@ class TestLiveStatusAppConversationService: await self.service._build_start_conversation_request_for_user( sandbox=self.mock_sandbox, + conversation_id=uuid4(), initial_message=None, system_message_suffix=None, git_provider=None, @@ -2088,6 +2187,7 @@ class TestLiveStatusAppConversationService: await self.service._build_start_conversation_request_for_user( sandbox=self.mock_sandbox, + conversation_id=uuid4(), initial_message=None, system_message_suffix=None, git_provider=None, @@ -2243,6 +2343,7 @@ class TestPluginHandling: jwt_service=self.mock_jwt_service, sandbox_startup_timeout=30, sandbox_startup_poll_frequency=1, + max_num_conversations_per_sandbox=20, httpx_client=self.mock_httpx_client, web_url='https://test.example.com', openhands_provider_base_url='https://provider.example.com', @@ -2726,11 +2827,12 @@ class TestPluginHandling: # Act await self.service._build_start_conversation_request_for_user( - self.mock_sandbox, - None, - None, - None, - '/workspace', + sandbox=self.mock_sandbox, + conversation_id=uuid4(), + initial_message=None, + system_message_suffix=None, + git_provider=None, + working_dir='/workspace', plugins=plugins, ) @@ -2754,11 +2856,12 @@ class TestPluginHandling: # Act await self.service._build_start_conversation_request_for_user( - self.mock_sandbox, - None, - None, - None, - '/workspace', + sandbox=self.mock_sandbox, + conversation_id=uuid4(), + initial_message=None, + system_message_suffix=None, + git_provider=None, + working_dir='/workspace', ) # Assert diff --git a/tests/unit/server/data_models/test_conversation.py b/tests/unit/server/data_models/test_conversation.py index fc305d170e..7fa64ab12a 100644 --- a/tests/unit/server/data_models/test_conversation.py +++ b/tests/unit/server/data_models/test_conversation.py @@ -2189,6 +2189,7 @@ async def test_delete_v1_conversation_with_sub_conversations(): jwt_service=MagicMock(), sandbox_startup_timeout=120, sandbox_startup_poll_frequency=2, + max_num_conversations_per_sandbox=20, httpx_client=mock_httpx_client, web_url=None, openhands_provider_base_url=None, @@ -2312,6 +2313,7 @@ async def test_delete_v1_conversation_with_no_sub_conversations(): jwt_service=MagicMock(), sandbox_startup_timeout=120, sandbox_startup_poll_frequency=2, + max_num_conversations_per_sandbox=20, httpx_client=mock_httpx_client, web_url=None, openhands_provider_base_url=None, @@ -2465,6 +2467,7 @@ async def test_delete_v1_conversation_sub_conversation_deletion_error(): jwt_service=MagicMock(), sandbox_startup_timeout=120, sandbox_startup_poll_frequency=2, + max_num_conversations_per_sandbox=20, httpx_client=mock_httpx_client, web_url=None, openhands_provider_base_url=None, From aec95ecf3b39f54de1637b83e5949feb9493e994 Mon Sep 17 00:00:00 2001 From: Tim O'Farrell Date: Mon, 16 Mar 2026 05:20:10 -0600 Subject: [PATCH 2/5] feat(frontend): update stop sandbox dialog to display conversations in sandbox (#13388) Co-authored-by: openhands --- .../conversation/conversation-name.test.tsx | 4 +- .../v1-conversation-service.api.ts | 23 +++++ .../v1-conversation-service.types.ts | 5 ++ .../conversation-panel/confirm-stop-modal.tsx | 60 +++++++++++++ .../conversation-panel/conversation-panel.tsx | 7 ++ .../conversation/conversation-name.tsx | 1 + .../query/use-conversations-in-sandbox.ts | 15 ++++ frontend/src/i18n/translation.json | 84 +++++++++---------- 8 files changed, 155 insertions(+), 44 deletions(-) create mode 100644 frontend/src/hooks/query/use-conversations-in-sandbox.ts diff --git a/frontend/__tests__/components/features/conversation/conversation-name.test.tsx b/frontend/__tests__/components/features/conversation/conversation-name.test.tsx index 45716775cc..3152fde699 100644 --- a/frontend/__tests__/components/features/conversation/conversation-name.test.tsx +++ b/frontend/__tests__/components/features/conversation/conversation-name.test.tsx @@ -72,7 +72,7 @@ vi.mock("react-i18next", async () => { CONVERSATION$SHOW_SKILLS: "Show Skills", BUTTON$DISPLAY_COST: "Display Cost", COMMON$CLOSE_CONVERSATION_STOP_RUNTIME: - "Close Conversation (Stop Runtime)", + "Close Conversation (Stop Sandbox)", COMMON$DELETE_CONVERSATION: "Delete Conversation", CONVERSATION$SHARE_PUBLICLY: "Share Publicly", CONVERSATION$LINK_COPIED: "Link copied to clipboard", @@ -565,7 +565,7 @@ describe("ConversationNameContextMenu", () => { "Delete Conversation", ); expect(screen.getByTestId("stop-button")).toHaveTextContent( - "Close Conversation (Stop Runtime)", + "Close Conversation (Stop Sandbox)", ); expect(screen.getByTestId("display-cost-button")).toHaveTextContent( "Display Cost", diff --git a/frontend/src/api/conversation-service/v1-conversation-service.api.ts b/frontend/src/api/conversation-service/v1-conversation-service.api.ts index 17cbb24cdf..30fdeb9369 100644 --- a/frontend/src/api/conversation-service/v1-conversation-service.api.ts +++ b/frontend/src/api/conversation-service/v1-conversation-service.api.ts @@ -12,6 +12,7 @@ import type { V1AppConversationStartTask, V1AppConversationStartTaskPage, V1AppConversation, + V1AppConversationPage, GetSkillsResponse, V1RuntimeConversationInfo, } from "./v1-conversation-service.types"; @@ -424,6 +425,28 @@ class V1ConversationService { }); return data; } + + /** + * Search for V1 conversations by sandbox ID + * + * @param sandboxId The sandbox ID to filter by + * @param limit Maximum number of results (default: 100) + * @returns Array of conversations in the specified sandbox + */ + static async searchConversationsBySandboxId( + sandboxId: string, + limit: number = 100, + ): Promise { + const params = new URLSearchParams(); + params.append("sandbox_id__eq", sandboxId); + params.append("limit", limit.toString()); + + const { data } = await openHands.get( + `/api/v1/app-conversations/search?${params.toString()}`, + ); + + return data.items; + } } export default V1ConversationService; diff --git a/frontend/src/api/conversation-service/v1-conversation-service.types.ts b/frontend/src/api/conversation-service/v1-conversation-service.types.ts index fb59623372..b437e17bf1 100644 --- a/frontend/src/api/conversation-service/v1-conversation-service.types.ts +++ b/frontend/src/api/conversation-service/v1-conversation-service.types.ts @@ -119,6 +119,11 @@ export interface V1AppConversation { public?: boolean; } +export interface V1AppConversationPage { + items: V1AppConversation[]; + next_page_id: string | null; +} + export interface Skill { name: string; type: "repo" | "knowledge" | "agentskills"; diff --git a/frontend/src/components/features/conversation-panel/confirm-stop-modal.tsx b/frontend/src/components/features/conversation-panel/confirm-stop-modal.tsx index d841211ace..acf30f4b09 100644 --- a/frontend/src/components/features/conversation-panel/confirm-stop-modal.tsx +++ b/frontend/src/components/features/conversation-panel/confirm-stop-modal.tsx @@ -7,17 +7,71 @@ import { ModalBackdrop } from "#/components/shared/modals/modal-backdrop"; import { ModalBody } from "#/components/shared/modals/modal-body"; import { BrandButton } from "../settings/brand-button"; import { I18nKey } from "#/i18n/declaration"; +import { useConversationsInSandbox } from "#/hooks/query/use-conversations-in-sandbox"; interface ConfirmStopModalProps { onConfirm: () => void; onCancel: () => void; + sandboxId: string | null; +} + +function ConversationsList({ + conversations, + isLoading, + isError, + t, +}: { + conversations: { id: string; title: string | null }[] | undefined; + isLoading: boolean; + isError: boolean; + t: (key: string) => string; +}) { + if (isLoading) { + return ( +
+ {t(I18nKey.HOME$LOADING)} +
+ ); + } + + if (isError) { + return ( +
+ {t(I18nKey.COMMON$ERROR)} +
+ ); + } + + if (conversations && conversations.length > 0) { + return ( +
    + {conversations.map((conv) => ( +
  • {conv.title || conv.id}
  • + ))} +
+ ); + } + + return null; } export function ConfirmStopModal({ onConfirm, onCancel, + sandboxId, }: ConfirmStopModalProps) { const { t } = useTranslation(); + const { + data: conversations, + isLoading, + isError, + } = useConversationsInSandbox(sandboxId); return ( @@ -29,6 +83,12 @@ export function ConfirmStopModal({ +
(null); const [selectedConversationVersion, setSelectedConversationVersion] = React.useState<"V0" | "V1" | undefined>(undefined); + const [selectedSandboxId, setSelectedSandboxId] = React.useState< + string | null + >(null); const [openContextMenuId, setOpenContextMenuId] = React.useState< string | null >(null); @@ -85,10 +88,12 @@ export function ConversationPanel({ onClose }: ConversationPanelProps) { const handleStopConversation = ( conversationId: string, version?: "V0" | "V1", + sandboxId?: string | null, ) => { setConfirmStopModalVisible(true); setSelectedConversationId(conversationId); setSelectedConversationVersion(version); + setSelectedSandboxId(sandboxId ?? null); }; const handleConversationTitleChange = async ( @@ -185,6 +190,7 @@ export function ConversationPanel({ onClose }: ConversationPanelProps) { handleStopConversation( project.conversation_id, project.conversation_version, + project.sandbox_id, ) } onChangeTitle={(title) => @@ -238,6 +244,7 @@ export function ConversationPanel({ onClose }: ConversationPanelProps) { setConfirmStopModalVisible(false); }} onCancel={() => setConfirmStopModalVisible(false)} + sandboxId={selectedSandboxId} /> )} diff --git a/frontend/src/components/features/conversation/conversation-name.tsx b/frontend/src/components/features/conversation/conversation-name.tsx index b7a26aad30..664c583839 100644 --- a/frontend/src/components/features/conversation/conversation-name.tsx +++ b/frontend/src/components/features/conversation/conversation-name.tsx @@ -233,6 +233,7 @@ export function ConversationName() { setConfirmStopModalVisible(false)} + sandboxId={conversation?.sandbox_id ?? null} /> )} diff --git a/frontend/src/hooks/query/use-conversations-in-sandbox.ts b/frontend/src/hooks/query/use-conversations-in-sandbox.ts new file mode 100644 index 0000000000..f41edb7a54 --- /dev/null +++ b/frontend/src/hooks/query/use-conversations-in-sandbox.ts @@ -0,0 +1,15 @@ +import { useQuery } from "@tanstack/react-query"; +import V1ConversationService from "#/api/conversation-service/v1-conversation-service.api"; + +export const useConversationsInSandbox = (sandboxId: string | null) => + useQuery({ + queryKey: ["conversations", "sandbox", sandboxId], + queryFn: () => + sandboxId + ? V1ConversationService.searchConversationsBySandboxId(sandboxId) + : Promise.resolve([]), + enabled: !!sandboxId, + staleTime: 0, // Always consider data stale for confirmation dialogs + gcTime: 1000 * 60, // 1 minute + refetchOnMount: true, // Always fetch fresh data when modal opens + }); diff --git a/frontend/src/i18n/translation.json b/frontend/src/i18n/translation.json index d3f91ceec7..3437fde4a6 100644 --- a/frontend/src/i18n/translation.json +++ b/frontend/src/i18n/translation.json @@ -5856,36 +5856,36 @@ "uk": "Ви впевнені, що хочете призупинити цю розмову?" }, "CONVERSATION$CONFIRM_CLOSE_CONVERSATION": { - "en": "Confirm Close Conversation", - "ja": "会話終了の確認", - "zh-CN": "确认关闭对话", - "zh-TW": "確認關閉對話", - "ko-KR": "대화 종료 확인", - "no": "Bekreft avslutt samtale", - "it": "Conferma chiusura conversazione", - "pt": "Confirmar encerrar conversa", - "es": "Confirmar cerrar conversación", - "ar": "تأكيد إغلاق المحادثة", - "fr": "Confirmer la fermeture de la conversation", - "tr": "Konuşmayı Kapatmayı Onayla", - "de": "Gespräch schließen bestätigen", - "uk": "Підтвердити закриття розмови" + "en": "Confirm Stop Sandbox", + "ja": "サンドボックス停止の確認", + "zh-CN": "确认停止沙盒", + "zh-TW": "確認停止沙盒", + "ko-KR": "샌드박스 중지 확인", + "no": "Bekreft stopp sandkasse", + "it": "Conferma arresto sandbox", + "pt": "Confirmar parar sandbox", + "es": "Confirmar detener sandbox", + "ar": "تأكيد إيقاف صندوق الحماية", + "fr": "Confirmer l'arrêt du sandbox", + "tr": "Sandbox'ı Durdurmayı Onayla", + "de": "Sandbox-Stopp bestätigen", + "uk": "Підтвердити зупинку пісочниці" }, "CONVERSATION$CLOSE_CONVERSATION_WARNING": { - "en": "Are you sure you want to close this conversation and stop the runtime?", - "ja": "この会話を終了してランタイムを停止してもよろしいですか?", - "zh-CN": "您确定要关闭此对话并停止运行时吗?", - "zh-TW": "您確定要關閉此對話並停止執行時嗎?", - "ko-KR": "이 대화를 종료하고 런타임을 중지하시겠습니까?", - "no": "Er du sikker på at du vil avslutte denne samtalen og stoppe kjøretiden?", - "it": "Sei sicuro di voler chiudere questa conversazione e fermare il runtime?", - "pt": "Tem certeza de que deseja encerrar esta conversa e parar o runtime?", - "es": "¿Está seguro de que desea cerrar esta conversación y detener el runtime?", - "ar": "هل أنت متأكد أنك تريد إغلاق هذه المحادثة وإيقاف وقت التشغيل؟", - "fr": "Êtes-vous sûr de vouloir fermer cette conversation et arrêter le runtime ?", - "tr": "Bu konuşmayı kapatmak ve çalışma zamanını durdurmak istediğinizden emin misiniz?", - "de": "Sind Sie sicher, dass Sie dieses Gespräch schließen und die Laufzeit stoppen möchten?", - "uk": "Ви впевнені, що хочете закрити цю розмову та зупинити час виконання?" + "en": "This will stop the sandbox, and pause the following conversations:", + "ja": "サンドボックスを停止し、以下の会話を一時停止します:", + "zh-CN": "这将停止沙盒,并暂停以下对话:", + "zh-TW": "這將停止沙盒,並暫停以下對話:", + "ko-KR": "샌드박스를 중지하고 다음 대화를 일시 중지합니다:", + "no": "Dette vil stoppe sandkassen og pause følgende samtaler:", + "it": "Questo fermerà la sandbox e metterà in pausa le seguenti conversazioni:", + "pt": "Isso irá parar o sandbox e pausar as seguintes conversas:", + "es": "Esto detendrá el sandbox y pausará las siguientes conversaciones:", + "ar": "سيؤدي هذا إلى إيقاف صندوق الحماية وإيقاف المحادثات التالية مؤقتًا:", + "fr": "Cela arrêtera le sandbox et mettra en pause les conversations suivantes :", + "tr": "Bu, sandbox'ı durduracak ve aşağıdaki konuşmaları duraklatacaktır:", + "de": "Dies wird die Sandbox stoppen und die folgenden Gespräche pausieren:", + "uk": "Це зупинить пісочницю та призупинить наступні розмови:" }, "CONVERSATION$STOP_WARNING": { "en": "Are you sure you want to pause this conversation?", @@ -14964,20 +14964,20 @@ "uk": "Натисніть тут" }, "COMMON$CLOSE_CONVERSATION_STOP_RUNTIME": { - "en": "Close Conversation (Stop Runtime)", - "ja": "会話を閉じる(ランタイム停止)", - "zh-CN": "关闭对话(停止运行时)", - "zh-TW": "關閉對話(停止執行時)", - "ko-KR": "대화 닫기(런타임 중지)", - "no": "Lukk samtale (stopp kjøring)", - "it": "Chiudi conversazione (Interrompi runtime)", - "pt": "Fechar conversa (Parar execução)", - "es": "Cerrar conversación (Detener ejecución)", - "ar": "إغلاق المحادثة (إيقاف وقت التشغيل)", - "fr": "Fermer la conversation (Arrêter l'exécution)", - "tr": "Konuşmayı Kapat (Çalışma Zamanını Durdur)", - "de": "Gespräch schließen (Laufzeit beenden)", - "uk": "Закрити розмову (зупинити виконання)" + "en": "Close Conversation (Stop Sandbox)", + "ja": "会話を閉じる(サンドボックス停止)", + "zh-CN": "关闭对话(停止沙盒)", + "zh-TW": "關閉對話(停止沙盒)", + "ko-KR": "대화 닫기(샌드박스 중지)", + "no": "Lukk samtale (stopp sandkasse)", + "it": "Chiudi conversazione (Interrompi sandbox)", + "pt": "Fechar conversa (Parar sandbox)", + "es": "Cerrar conversación (Detener sandbox)", + "ar": "إغلاق المحادثة (إيقاف صندوق الحماية)", + "fr": "Fermer la conversation (Arrêter le sandbox)", + "tr": "Konuşmayı Kapat (Sandbox'ı Durdur)", + "de": "Gespräch schließen (Sandbox beenden)", + "uk": "Закрити розмову (зупинити пісочницю)" }, "COMMON$CODE": { "en": "Code", From 238cab4d08ebb176bf972e37ef55fa40615dbf82 Mon Sep 17 00:00:00 2001 From: Hiep Le <69354317+hieptl@users.noreply.github.com> Date: Mon, 16 Mar 2026 22:25:44 +0700 Subject: [PATCH 3/5] fix(frontend): prevent chat message loss during websocket disconnections or page refresh (#13380) --- .../101_add_pending_messages_table.py | 39 ++ .../utils/saas_pending_message_injector.py | 172 +++++ .../components/interactive-chat-box.test.tsx | 4 +- .../conversation-local-storage.test.ts | 227 +++++++ .../conversation-websocket-handler.test.tsx | 241 ++++++- .../hooks/use-draft-persistence.test.tsx | 594 ++++++++++++++++++ .../hooks/use-handle-plan-click.test.tsx | 3 + .../pending-message-service.api.ts | 40 ++ .../pending-message-service.types.ts | 22 + .../features/chat/chat-interface.tsx | 10 +- .../features/chat/custom-chat-input.tsx | 2 + .../features/chat/interactive-chat-box.tsx | 3 +- .../conversation-websocket-context.tsx | 50 +- .../src/hooks/chat/use-chat-input-logic.ts | 11 + .../src/hooks/chat/use-draft-persistence.ts | 179 ++++++ frontend/src/hooks/use-send-message.ts | 21 +- .../src/utils/conversation-local-storage.ts | 4 + .../live_status_app_conversation_service.py | 99 +++ .../app_lifespan/alembic/versions/007.py | 39 ++ openhands/app_server/config.py | 26 + .../app_server/pending_messages/__init__.py | 21 + .../pending_message_models.py | 32 + .../pending_message_router.py | 104 +++ .../pending_message_service.py | 200 ++++++ openhands/app_server/v1_router.py | 4 + ...st_live_status_app_conversation_service.py | 4 + .../app_server/test_pending_message_router.py | 227 +++++++ .../test_pending_message_service.py | 309 +++++++++ .../server/data_models/test_conversation.py | 3 + 29 files changed, 2668 insertions(+), 22 deletions(-) create mode 100644 enterprise/migrations/versions/101_add_pending_messages_table.py create mode 100644 enterprise/server/utils/saas_pending_message_injector.py create mode 100644 frontend/__tests__/hooks/use-draft-persistence.test.tsx create mode 100644 frontend/src/api/pending-message-service/pending-message-service.api.ts create mode 100644 frontend/src/api/pending-message-service/pending-message-service.types.ts create mode 100644 frontend/src/hooks/chat/use-draft-persistence.ts create mode 100644 openhands/app_server/app_lifespan/alembic/versions/007.py create mode 100644 openhands/app_server/pending_messages/__init__.py create mode 100644 openhands/app_server/pending_messages/pending_message_models.py create mode 100644 openhands/app_server/pending_messages/pending_message_router.py create mode 100644 openhands/app_server/pending_messages/pending_message_service.py create mode 100644 tests/unit/app_server/test_pending_message_router.py create mode 100644 tests/unit/app_server/test_pending_message_service.py diff --git a/enterprise/migrations/versions/101_add_pending_messages_table.py b/enterprise/migrations/versions/101_add_pending_messages_table.py new file mode 100644 index 0000000000..cbe97a955b --- /dev/null +++ b/enterprise/migrations/versions/101_add_pending_messages_table.py @@ -0,0 +1,39 @@ +"""Add pending_messages table for server-side message queuing + +Revision ID: 101 +Revises: 100 +Create Date: 2025-03-15 00:00:00.000000 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = '101' +down_revision: Union[str, None] = '100' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Create pending_messages table for storing messages before conversation is ready. + + Messages are stored temporarily until the conversation becomes ready, then + delivered and deleted regardless of success or failure. + """ + op.create_table( + 'pending_messages', + sa.Column('id', sa.String(), primary_key=True), + sa.Column('conversation_id', sa.String(), nullable=False, index=True), + sa.Column('role', sa.String(20), nullable=False, server_default='user'), + sa.Column('content', sa.JSON, nullable=False), + sa.Column('created_at', sa.DateTime(timezone=True), nullable=False), + ) + + +def downgrade() -> None: + """Remove pending_messages table.""" + op.drop_table('pending_messages') diff --git a/enterprise/server/utils/saas_pending_message_injector.py b/enterprise/server/utils/saas_pending_message_injector.py new file mode 100644 index 0000000000..fa47152801 --- /dev/null +++ b/enterprise/server/utils/saas_pending_message_injector.py @@ -0,0 +1,172 @@ +"""Enterprise injector for PendingMessageService with SAAS filtering.""" + +from typing import AsyncGenerator +from uuid import UUID + +from fastapi import Request +from sqlalchemy import select +from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas +from storage.user import User + +from openhands.agent_server.models import ImageContent, TextContent +from openhands.app_server.errors import AuthError +from openhands.app_server.pending_messages.pending_message_models import ( + PendingMessageResponse, +) +from openhands.app_server.pending_messages.pending_message_service import ( + PendingMessageService, + PendingMessageServiceInjector, + SQLPendingMessageService, +) +from openhands.app_server.services.injector import InjectorState +from openhands.app_server.user.specifiy_user_context import ADMIN +from openhands.app_server.user.user_context import UserContext + + +class SaasSQLPendingMessageService(SQLPendingMessageService): + """Extended SQLPendingMessageService with user and organization-based filtering. + + This enterprise version ensures that: + - Users can only queue messages for conversations they own + - Organization isolation is enforced for multi-tenant deployments + """ + + def __init__(self, db_session, user_context: UserContext): + super().__init__(db_session=db_session) + self.user_context = user_context + + async def _get_current_user(self) -> User | None: + """Get the current user using the existing db_session. + + Returns: + User object or None if no user_id is available + """ + user_id_str = await self.user_context.get_user_id() + if not user_id_str: + return None + + user_id_uuid = UUID(user_id_str) + result = await self.db_session.execute( + select(User).where(User.id == user_id_uuid) + ) + return result.scalars().first() + + async def _validate_conversation_ownership(self, conversation_id: str) -> None: + """Validate that the current user owns the conversation. + + This ensures multi-tenant isolation by checking: + - The conversation belongs to the current user + - The conversation belongs to the user's current organization + + Args: + conversation_id: The conversation ID to validate (can be task-id or UUID) + + Raises: + AuthError: If user doesn't own the conversation or authentication fails + """ + # For internal operations (e.g., processing pending messages during startup) + # we need a mode that bypasses filtering. The ADMIN context enables this. + if self.user_context == ADMIN: + return + + user_id_str = await self.user_context.get_user_id() + if not user_id_str: + raise AuthError('User authentication required') + + user_id_uuid = UUID(user_id_str) + + # Check conversation ownership via SAAS metadata + query = select(StoredConversationMetadataSaas).where( + StoredConversationMetadataSaas.conversation_id == conversation_id + ) + result = await self.db_session.execute(query) + saas_metadata = result.scalar_one_or_none() + + # If no SAAS metadata exists, the conversation might be a new task-id + # that hasn't been linked to a conversation yet. Allow access in this case + # as the message will be validated when the conversation is created. + if saas_metadata is None: + return + + # Verify user ownership + if saas_metadata.user_id != user_id_uuid: + raise AuthError('You do not have access to this conversation') + + # Verify organization ownership if applicable + user = await self._get_current_user() + if user and user.current_org_id is not None: + if saas_metadata.org_id != user.current_org_id: + raise AuthError('Conversation belongs to a different organization') + + async def add_message( + self, + conversation_id: str, + content: list[TextContent | ImageContent], + role: str = 'user', + ) -> PendingMessageResponse: + """Queue a message with ownership validation. + + Args: + conversation_id: The conversation ID to queue the message for + content: Message content + role: Message role (default: 'user') + + Returns: + PendingMessageResponse with the queued message info + + Raises: + AuthError: If user doesn't own the conversation + """ + await self._validate_conversation_ownership(conversation_id) + return await super().add_message(conversation_id, content, role) + + async def get_pending_messages(self, conversation_id: str): + """Get pending messages with ownership validation. + + Args: + conversation_id: The conversation ID to get messages for + + Returns: + List of pending messages + + Raises: + AuthError: If user doesn't own the conversation + """ + await self._validate_conversation_ownership(conversation_id) + return await super().get_pending_messages(conversation_id) + + async def count_pending_messages(self, conversation_id: str) -> int: + """Count pending messages with ownership validation. + + Args: + conversation_id: The conversation ID to count messages for + + Returns: + Number of pending messages + + Raises: + AuthError: If user doesn't own the conversation + """ + await self._validate_conversation_ownership(conversation_id) + return await super().count_pending_messages(conversation_id) + + +class SaasPendingMessageServiceInjector(PendingMessageServiceInjector): + """Enterprise injector for PendingMessageService with SAAS filtering.""" + + async def inject( + self, state: InjectorState, request: Request | None = None + ) -> AsyncGenerator[PendingMessageService, None]: + from openhands.app_server.config import ( + get_db_session, + get_user_context, + ) + + async with ( + get_user_context(state, request) as user_context, + get_db_session(state, request) as db_session, + ): + service = SaasSQLPendingMessageService( + db_session=db_session, user_context=user_context + ) + yield service diff --git a/frontend/__tests__/components/interactive-chat-box.test.tsx b/frontend/__tests__/components/interactive-chat-box.test.tsx index cb164123c1..bafa673731 100644 --- a/frontend/__tests__/components/interactive-chat-box.test.tsx +++ b/frontend/__tests__/components/interactive-chat-box.test.tsx @@ -198,9 +198,9 @@ describe("InteractiveChatBox", () => { expect(onSubmitMock).toHaveBeenCalledWith("Hello, world!", [], []); }); - it("should disable the submit button when agent is loading", async () => { + it("should disable the submit button when awaiting user confirmation", async () => { const user = userEvent.setup(); - mockStores(AgentState.LOADING); + mockStores(AgentState.AWAITING_USER_CONFIRMATION); renderInteractiveChatBox({ onSubmit: onSubmitMock, diff --git a/frontend/__tests__/conversation-local-storage.test.ts b/frontend/__tests__/conversation-local-storage.test.ts index a99e5fc005..33e9e12a7e 100644 --- a/frontend/__tests__/conversation-local-storage.test.ts +++ b/frontend/__tests__/conversation-local-storage.test.ts @@ -229,4 +229,231 @@ describe("conversation localStorage utilities", () => { expect(parsed.subConversationTaskId).toBeNull(); }); }); + + describe("draftMessage persistence", () => { + describe("getConversationState", () => { + it("returns default draftMessage as null when no state exists", () => { + // Arrange + const conversationId = "conv-draft-1"; + + // Act + const state = getConversationState(conversationId); + + // Assert + expect(state.draftMessage).toBeNull(); + }); + + it("retrieves draftMessage from localStorage when it exists", () => { + // Arrange + const conversationId = "conv-draft-2"; + const draftText = "This is my saved draft message"; + const consolidatedKey = `${LOCAL_STORAGE_KEYS.CONVERSATION_STATE}-${conversationId}`; + + localStorage.setItem( + consolidatedKey, + JSON.stringify({ + draftMessage: draftText, + }), + ); + + // Act + const state = getConversationState(conversationId); + + // Assert + expect(state.draftMessage).toBe(draftText); + }); + + it("returns null draftMessage for task conversation IDs (not persisted)", () => { + // Arrange + const taskId = "task-uuid-123"; + const consolidatedKey = `${LOCAL_STORAGE_KEYS.CONVERSATION_STATE}-${taskId}`; + + // Even if somehow there's data in localStorage for a task ID + localStorage.setItem( + consolidatedKey, + JSON.stringify({ + draftMessage: "Should not be returned", + }), + ); + + // Act + const state = getConversationState(taskId); + + // Assert - should return default state, not the stored value + expect(state.draftMessage).toBeNull(); + }); + }); + + describe("setConversationState", () => { + it("persists draftMessage to localStorage", () => { + // Arrange + const conversationId = "conv-draft-3"; + const draftText = "New draft message to save"; + const consolidatedKey = `${LOCAL_STORAGE_KEYS.CONVERSATION_STATE}-${conversationId}`; + + // Act + setConversationState(conversationId, { + draftMessage: draftText, + }); + + // Assert + const stored = localStorage.getItem(consolidatedKey); + expect(stored).not.toBeNull(); + const parsed = JSON.parse(stored!); + expect(parsed.draftMessage).toBe(draftText); + }); + + it("does not persist draftMessage for task conversation IDs", () => { + // Arrange + const taskId = "task-draft-xyz"; + const consolidatedKey = `${LOCAL_STORAGE_KEYS.CONVERSATION_STATE}-${taskId}`; + + // Act + setConversationState(taskId, { + draftMessage: "Draft for task ID", + }); + + // Assert - nothing should be stored + expect(localStorage.getItem(consolidatedKey)).toBeNull(); + }); + + it("merges draftMessage with existing state without overwriting other fields", () => { + // Arrange + const conversationId = "conv-draft-4"; + const consolidatedKey = `${LOCAL_STORAGE_KEYS.CONVERSATION_STATE}-${conversationId}`; + + localStorage.setItem( + consolidatedKey, + JSON.stringify({ + selectedTab: "terminal", + rightPanelShown: false, + unpinnedTabs: ["tab-1", "tab-2"], + conversationMode: "plan", + subConversationTaskId: "task-123", + }), + ); + + // Act + setConversationState(conversationId, { + draftMessage: "Updated draft", + }); + + // Assert + const stored = localStorage.getItem(consolidatedKey); + const parsed = JSON.parse(stored!); + + expect(parsed.draftMessage).toBe("Updated draft"); + expect(parsed.selectedTab).toBe("terminal"); + expect(parsed.rightPanelShown).toBe(false); + expect(parsed.unpinnedTabs).toEqual(["tab-1", "tab-2"]); + expect(parsed.conversationMode).toBe("plan"); + expect(parsed.subConversationTaskId).toBe("task-123"); + }); + + it("clears draftMessage when set to null", () => { + // Arrange + const conversationId = "conv-draft-5"; + const consolidatedKey = `${LOCAL_STORAGE_KEYS.CONVERSATION_STATE}-${conversationId}`; + + localStorage.setItem( + consolidatedKey, + JSON.stringify({ + draftMessage: "Existing draft", + }), + ); + + // Act + setConversationState(conversationId, { + draftMessage: null, + }); + + // Assert + const stored = localStorage.getItem(consolidatedKey); + const parsed = JSON.parse(stored!); + expect(parsed.draftMessage).toBeNull(); + }); + + it("clears draftMessage when set to empty string (stored as empty string)", () => { + // Arrange + const conversationId = "conv-draft-6"; + const consolidatedKey = `${LOCAL_STORAGE_KEYS.CONVERSATION_STATE}-${conversationId}`; + + localStorage.setItem( + consolidatedKey, + JSON.stringify({ + draftMessage: "Existing draft", + }), + ); + + // Act + setConversationState(conversationId, { + draftMessage: "", + }); + + // Assert + const stored = localStorage.getItem(consolidatedKey); + const parsed = JSON.parse(stored!); + expect(parsed.draftMessage).toBe(""); + }); + }); + + describe("conversation-specific draft isolation", () => { + it("stores drafts separately for different conversations", () => { + // Arrange + const convA = "conv-A"; + const convB = "conv-B"; + const draftA = "Draft for conversation A"; + const draftB = "Draft for conversation B"; + + // Act + setConversationState(convA, { draftMessage: draftA }); + setConversationState(convB, { draftMessage: draftB }); + + // Assert + const stateA = getConversationState(convA); + const stateB = getConversationState(convB); + + expect(stateA.draftMessage).toBe(draftA); + expect(stateB.draftMessage).toBe(draftB); + }); + + it("updating one conversation draft does not affect another", () => { + // Arrange + const convA = "conv-isolated-A"; + const convB = "conv-isolated-B"; + + setConversationState(convA, { draftMessage: "Original draft A" }); + setConversationState(convB, { draftMessage: "Original draft B" }); + + // Act - update only conversation A + setConversationState(convA, { draftMessage: "Updated draft A" }); + + // Assert - conversation B should be unchanged + const stateA = getConversationState(convA); + const stateB = getConversationState(convB); + + expect(stateA.draftMessage).toBe("Updated draft A"); + expect(stateB.draftMessage).toBe("Original draft B"); + }); + + it("clearing one conversation draft does not affect another", () => { + // Arrange + const convA = "conv-clear-A"; + const convB = "conv-clear-B"; + + setConversationState(convA, { draftMessage: "Draft A" }); + setConversationState(convB, { draftMessage: "Draft B" }); + + // Act - clear draft for conversation A + setConversationState(convA, { draftMessage: null }); + + // Assert + const stateA = getConversationState(convA); + const stateB = getConversationState(convB); + + expect(stateA.draftMessage).toBeNull(); + expect(stateB.draftMessage).toBe("Draft B"); + }); + }); + }); }); diff --git a/frontend/__tests__/conversation-websocket-handler.test.tsx b/frontend/__tests__/conversation-websocket-handler.test.tsx index 284aaee287..393d6f68f0 100644 --- a/frontend/__tests__/conversation-websocket-handler.test.tsx +++ b/frontend/__tests__/conversation-websocket-handler.test.tsx @@ -1,3 +1,4 @@ +import React from "react"; import { describe, it, @@ -8,7 +9,7 @@ import { afterEach, vi, } from "vitest"; -import { screen, waitFor, render, cleanup } from "@testing-library/react"; +import { screen, waitFor, render, cleanup, act } from "@testing-library/react"; import { QueryClient, QueryClientProvider } from "@tanstack/react-query"; import { http, HttpResponse } from "msw"; import { MemoryRouter, Route, Routes } from "react-router"; @@ -682,8 +683,242 @@ describe("Conversation WebSocket Handler", () => { // 7. Message Sending Tests describe("Message Sending", () => { - it.todo("should send user actions through WebSocket when connected"); - it.todo("should handle send attempts when disconnected"); + it("should send user actions through WebSocket when connected", async () => { + // Arrange + const conversationId = "test-conversation-send"; + let receivedMessage: unknown = null; + + // Set up MSW to capture sent messages + mswServer.use( + wsLink.addEventListener("connection", ({ client, server }) => { + server.connect(); + + // Capture messages sent from client + client.addEventListener("message", (event) => { + receivedMessage = JSON.parse(event.data as string); + }); + }), + ); + + // Create ref to store sendMessage function + let sendMessageFn: typeof useConversationWebSocket extends () => infer R + ? R extends { sendMessage: infer S } + ? S + : null + : null = null; + + function TestComponent() { + const context = useConversationWebSocket(); + + React.useEffect(() => { + if (context?.sendMessage) { + sendMessageFn = context.sendMessage; + } + }, [context?.sendMessage]); + + return ( +
+
+ {context?.connectionState || "NOT_AVAILABLE"} +
+
+ ); + } + + // Act + renderWithWebSocketContext( + , + conversationId, + `http://localhost:3000/api/conversations/${conversationId}`, + ); + + // Wait for connection + await waitFor(() => { + expect(screen.getByTestId("connection-state")).toHaveTextContent( + "OPEN", + ); + }); + + // Send a message + await waitFor(() => { + expect(sendMessageFn).not.toBeNull(); + }); + + await act(async () => { + await sendMessageFn!({ + role: "user", + content: [{ type: "text", text: "Hello from test" }], + }); + }); + + // Assert - message should have been received by mock server + await waitFor(() => { + expect(receivedMessage).toEqual({ + role: "user", + content: [{ type: "text", text: "Hello from test" }], + }); + }); + }); + + it("should not throw error when sendMessage is called with WebSocket connected", async () => { + // This test verifies that sendMessage doesn't throw an error + // when the WebSocket is connected. + const conversationId = "test-conversation-no-throw"; + let sendError: Error | null = null; + + // Set up MSW to connect and receive messages + mswServer.use( + wsLink.addEventListener("connection", ({ server }) => { + server.connect(); + }), + ); + + // Create ref to store sendMessage function + let sendMessageFn: typeof useConversationWebSocket extends () => infer R + ? R extends { sendMessage: infer S } + ? S + : null + : null = null; + + function TestComponent() { + const context = useConversationWebSocket(); + + React.useEffect(() => { + if (context?.sendMessage) { + sendMessageFn = context.sendMessage; + } + }, [context?.sendMessage]); + + return ( +
+
+ {context?.connectionState || "NOT_AVAILABLE"} +
+
+ ); + } + + // Act + renderWithWebSocketContext( + , + conversationId, + `http://localhost:3000/api/conversations/${conversationId}`, + ); + + // Wait for connection + await waitFor(() => { + expect(screen.getByTestId("connection-state")).toHaveTextContent( + "OPEN", + ); + }); + + // Wait for the context to be available + await waitFor(() => { + expect(sendMessageFn).not.toBeNull(); + }); + + // Try to send a message + await act(async () => { + try { + await sendMessageFn!({ + role: "user", + content: [{ type: "text", text: "Test message" }], + }); + } catch (error) { + sendError = error as Error; + } + }); + + // Assert - should NOT throw an error + expect(sendError).toBeNull(); + }); + + it("should send multiple messages through WebSocket in order", async () => { + // Arrange + const conversationId = "test-conversation-multi"; + const receivedMessages: unknown[] = []; + + // Set up MSW to capture sent messages + mswServer.use( + wsLink.addEventListener("connection", ({ client, server }) => { + server.connect(); + + // Capture messages sent from client + client.addEventListener("message", (event) => { + receivedMessages.push(JSON.parse(event.data as string)); + }); + }), + ); + + // Create ref to store sendMessage function + let sendMessageFn: typeof useConversationWebSocket extends () => infer R + ? R extends { sendMessage: infer S } + ? S + : null + : null = null; + + function TestComponent() { + const context = useConversationWebSocket(); + + React.useEffect(() => { + if (context?.sendMessage) { + sendMessageFn = context.sendMessage; + } + }, [context?.sendMessage]); + + return ( +
+
+ {context?.connectionState || "NOT_AVAILABLE"} +
+
+ ); + } + + // Act + renderWithWebSocketContext( + , + conversationId, + `http://localhost:3000/api/conversations/${conversationId}`, + ); + + // Wait for connection + await waitFor(() => { + expect(screen.getByTestId("connection-state")).toHaveTextContent( + "OPEN", + ); + }); + + await waitFor(() => { + expect(sendMessageFn).not.toBeNull(); + }); + + // Send multiple messages + await act(async () => { + await sendMessageFn!({ + role: "user", + content: [{ type: "text", text: "Message 1" }], + }); + await sendMessageFn!({ + role: "user", + content: [{ type: "text", text: "Message 2" }], + }); + }); + + // Assert - both messages should have been received in order + await waitFor(() => { + expect(receivedMessages.length).toBe(2); + }); + + expect(receivedMessages[0]).toEqual({ + role: "user", + content: [{ type: "text", text: "Message 1" }], + }); + expect(receivedMessages[1]).toEqual({ + role: "user", + content: [{ type: "text", text: "Message 2" }], + }); + }); }); // 8. History Loading State Tests diff --git a/frontend/__tests__/hooks/use-draft-persistence.test.tsx b/frontend/__tests__/hooks/use-draft-persistence.test.tsx new file mode 100644 index 0000000000..0734470324 --- /dev/null +++ b/frontend/__tests__/hooks/use-draft-persistence.test.tsx @@ -0,0 +1,594 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; +import { renderHook, act } from "@testing-library/react"; +import { useDraftPersistence } from "#/hooks/chat/use-draft-persistence"; +import * as conversationLocalStorage from "#/utils/conversation-local-storage"; + +// Mock the entire module +vi.mock("#/utils/conversation-local-storage", () => ({ + useConversationLocalStorageState: vi.fn(), + getConversationState: vi.fn(), + setConversationState: vi.fn(), +})); + +// Mock the getTextContent utility +vi.mock("#/components/features/chat/utils/chat-input.utils", () => ({ + getTextContent: vi.fn((el: HTMLDivElement | null) => el?.textContent || ""), +})); + +describe("useDraftPersistence", () => { + let mockSetDraftMessage: (message: string | null) => void; + + // Create a mock ref to contentEditable div + const createMockChatInputRef = (initialContent = "") => { + const div = document.createElement("div"); + div.setAttribute("contenteditable", "true"); + div.textContent = initialContent; + return { current: div }; + }; + + beforeEach(() => { + vi.clearAllMocks(); + vi.useFakeTimers(); + localStorage.clear(); + + mockSetDraftMessage = vi.fn<(message: string | null) => void>(); + + // Default mock for useConversationLocalStorageState + vi.mocked(conversationLocalStorage.useConversationLocalStorageState).mockReturnValue({ + state: { + selectedTab: "editor", + rightPanelShown: true, + unpinnedTabs: [], + conversationMode: "code", + subConversationTaskId: null, + draftMessage: null, + }, + setSelectedTab: vi.fn(), + setRightPanelShown: vi.fn(), + setUnpinnedTabs: vi.fn(), + setConversationMode: vi.fn(), + setDraftMessage: mockSetDraftMessage, + }); + + // Default mock for getConversationState + vi.mocked(conversationLocalStorage.getConversationState).mockReturnValue({ + selectedTab: "editor", + rightPanelShown: true, + unpinnedTabs: [], + conversationMode: "code", + subConversationTaskId: null, + draftMessage: null, + }); + }); + + afterEach(() => { + vi.useRealTimers(); + vi.clearAllMocks(); + }); + + describe("draft restoration on mount", () => { + it("restores draft from localStorage when mounting with existing draft", () => { + // Arrange + const conversationId = "conv-restore-1"; + const savedDraft = "Previously saved draft message"; + const chatInputRef = createMockChatInputRef(); + + vi.mocked(conversationLocalStorage.getConversationState).mockReturnValue({ + selectedTab: "editor", + rightPanelShown: true, + unpinnedTabs: [], + conversationMode: "code", + subConversationTaskId: null, + draftMessage: savedDraft, + }); + + // Act + renderHook(() => useDraftPersistence(conversationId, chatInputRef)); + + // Assert - draft should be restored to the DOM element + expect(chatInputRef.current?.textContent).toBe(savedDraft); + }); + + it("clears input on mount then restores draft if exists", () => { + // Arrange + const conversationId = "conv-restore-2"; + const existingContent = "Stale content from previous conversation"; + const savedDraft = "Saved draft"; + const chatInputRef = createMockChatInputRef(existingContent); + + vi.mocked(conversationLocalStorage.getConversationState).mockReturnValue({ + selectedTab: "editor", + rightPanelShown: true, + unpinnedTabs: [], + conversationMode: "code", + subConversationTaskId: null, + draftMessage: savedDraft, + }); + + // Act + renderHook(() => useDraftPersistence(conversationId, chatInputRef)); + + // Assert - input cleared then draft restored + expect(chatInputRef.current?.textContent).toBe(savedDraft); + }); + + it("clears input when no draft exists for conversation", () => { + // Arrange + const conversationId = "conv-no-draft"; + const chatInputRef = createMockChatInputRef("Some stale content"); + + vi.mocked(conversationLocalStorage.getConversationState).mockReturnValue({ + selectedTab: "editor", + rightPanelShown: true, + unpinnedTabs: [], + conversationMode: "code", + subConversationTaskId: null, + draftMessage: null, + }); + + // Act + renderHook(() => useDraftPersistence(conversationId, chatInputRef)); + + // Assert - content should be cleared since there's no draft + expect(chatInputRef.current?.textContent).toBe(""); + }); + }); + + describe("debounced saving", () => { + it("saves draft after debounce period", () => { + // Arrange + const conversationId = "conv-debounce-1"; + const chatInputRef = createMockChatInputRef(); + + const { result } = renderHook(() => + useDraftPersistence(conversationId, chatInputRef), + ); + + // Act - simulate user typing + chatInputRef.current!.textContent = "New draft content"; + act(() => { + result.current.saveDraft(); + }); + + // Assert - should not save immediately + expect(mockSetDraftMessage).not.toHaveBeenCalled(); + + // Fast forward past debounce period (500ms) + act(() => { + vi.advanceTimersByTime(500); + }); + + // Assert - should save after debounce + expect(mockSetDraftMessage).toHaveBeenCalledWith("New draft content"); + }); + + it("cancels pending save when new input arrives before debounce", () => { + // Arrange + const conversationId = "conv-debounce-2"; + const chatInputRef = createMockChatInputRef(); + + const { result } = renderHook(() => + useDraftPersistence(conversationId, chatInputRef), + ); + + // Act - first input + chatInputRef.current!.textContent = "First"; + act(() => { + result.current.saveDraft(); + }); + + // Wait 200ms (less than debounce) + act(() => { + vi.advanceTimersByTime(200); + }); + + // Second input before debounce completes + chatInputRef.current!.textContent = "First Second"; + act(() => { + result.current.saveDraft(); + }); + + // Complete the second debounce + act(() => { + vi.advanceTimersByTime(500); + }); + + // Assert - should only save the final value once + expect(mockSetDraftMessage).toHaveBeenCalledTimes(1); + expect(mockSetDraftMessage).toHaveBeenCalledWith("First Second"); + }); + + it("does not save if content matches existing draft", () => { + // Arrange + const conversationId = "conv-no-change"; + const existingDraft = "Existing draft"; + const chatInputRef = createMockChatInputRef(existingDraft); + + vi.mocked(conversationLocalStorage.useConversationLocalStorageState).mockReturnValue({ + state: { + selectedTab: "editor", + rightPanelShown: true, + unpinnedTabs: [], + conversationMode: "code", + subConversationTaskId: null, + draftMessage: existingDraft, + }, + setSelectedTab: vi.fn(), + setRightPanelShown: vi.fn(), + setUnpinnedTabs: vi.fn(), + setConversationMode: vi.fn(), + setDraftMessage: mockSetDraftMessage, + }); + + vi.mocked(conversationLocalStorage.getConversationState).mockReturnValue({ + selectedTab: "editor", + rightPanelShown: true, + unpinnedTabs: [], + conversationMode: "code", + subConversationTaskId: null, + draftMessage: existingDraft, + }); + + const { result } = renderHook(() => + useDraftPersistence(conversationId, chatInputRef), + ); + + // Act - try to save same content + act(() => { + result.current.saveDraft(); + }); + + act(() => { + vi.advanceTimersByTime(500); + }); + + // Assert - should not save since content is the same + expect(mockSetDraftMessage).not.toHaveBeenCalled(); + }); + }); + + describe("clearDraft", () => { + it("clears the draft from localStorage", () => { + // Arrange + const conversationId = "conv-clear-1"; + const chatInputRef = createMockChatInputRef("Some content"); + + const { result } = renderHook(() => + useDraftPersistence(conversationId, chatInputRef), + ); + + // Act + act(() => { + result.current.clearDraft(); + }); + + // Assert + expect(mockSetDraftMessage).toHaveBeenCalledWith(null); + }); + + it("cancels any pending debounced save when clearing", () => { + // Arrange + const conversationId = "conv-clear-2"; + const chatInputRef = createMockChatInputRef(); + + const { result } = renderHook(() => + useDraftPersistence(conversationId, chatInputRef), + ); + + // Start a save + chatInputRef.current!.textContent = "Pending draft"; + act(() => { + result.current.saveDraft(); + }); + + // Clear before debounce completes + act(() => { + vi.advanceTimersByTime(200); + result.current.clearDraft(); + }); + + // Complete the original debounce period + act(() => { + vi.advanceTimersByTime(500); + }); + + // Assert - only the clear should have been called (the pending save should be cancelled) + expect(mockSetDraftMessage).toHaveBeenCalledTimes(1); + expect(mockSetDraftMessage).toHaveBeenCalledWith(null); + }); + }); + + describe("conversation switching", () => { + it("clears input when switching to a new conversation without a draft", () => { + // Arrange + const chatInputRef = createMockChatInputRef("Draft from conv A"); + + // First conversation has a draft + vi.mocked(conversationLocalStorage.getConversationState) + .mockReturnValueOnce({ + selectedTab: "editor", + rightPanelShown: true, + unpinnedTabs: [], + conversationMode: "code", + subConversationTaskId: null, + draftMessage: "Draft from conv A", + }) + .mockReturnValue({ + selectedTab: "editor", + rightPanelShown: true, + unpinnedTabs: [], + conversationMode: "code", + subConversationTaskId: null, + draftMessage: null, + }); + + const { rerender } = renderHook( + ({ conversationId }) => + useDraftPersistence(conversationId, chatInputRef), + { initialProps: { conversationId: "conv-A" } }, + ); + + // Act - switch to conversation B + rerender({ conversationId: "conv-B" }); + + // Assert - input should be cleared (no draft for conv-B) + expect(chatInputRef.current?.textContent).toBe(""); + }); + + it("restores draft when switching to a conversation with an existing draft", () => { + // Arrange + const chatInputRef = createMockChatInputRef(); + const draftForConvB = "Saved draft for conversation B"; + + vi.mocked(conversationLocalStorage.getConversationState) + .mockReturnValueOnce({ + selectedTab: "editor", + rightPanelShown: true, + unpinnedTabs: [], + conversationMode: "code", + subConversationTaskId: null, + draftMessage: null, + }) + .mockReturnValue({ + selectedTab: "editor", + rightPanelShown: true, + unpinnedTabs: [], + conversationMode: "code", + subConversationTaskId: null, + draftMessage: draftForConvB, + }); + + const { rerender } = renderHook( + ({ conversationId }) => + useDraftPersistence(conversationId, chatInputRef), + { initialProps: { conversationId: "conv-A" } }, + ); + + // Act - switch to conversation B + rerender({ conversationId: "conv-B" }); + + // Assert - draft for conv-B should be restored + expect(chatInputRef.current?.textContent).toBe(draftForConvB); + }); + + it("cancels pending save when switching conversations", () => { + // Arrange + const chatInputRef = createMockChatInputRef(); + + const { result, rerender } = renderHook( + ({ conversationId }) => + useDraftPersistence(conversationId, chatInputRef), + { initialProps: { conversationId: "conv-A" } }, + ); + + // Start typing in conv-A + chatInputRef.current!.textContent = "Draft for conv-A"; + act(() => { + result.current.saveDraft(); + }); + + // Switch conversation before debounce completes + act(() => { + vi.advanceTimersByTime(200); + }); + rerender({ conversationId: "conv-B" }); + + // Complete the debounce period + act(() => { + vi.advanceTimersByTime(500); + }); + + // Assert - the save should NOT have happened because conversation changed + expect(mockSetDraftMessage).not.toHaveBeenCalled(); + }); + }); + + describe("task ID to real conversation ID transition", () => { + it("transfers draft from task ID to real conversation ID during transition", () => { + // Arrange + const chatInputRef = createMockChatInputRef("Draft typed during init"); + + vi.mocked(conversationLocalStorage.getConversationState).mockReturnValue({ + selectedTab: "editor", + rightPanelShown: true, + unpinnedTabs: [], + conversationMode: "code", + subConversationTaskId: null, + draftMessage: null, + }); + + const { rerender } = renderHook( + ({ conversationId }) => + useDraftPersistence(conversationId, chatInputRef), + { initialProps: { conversationId: "task-abc-123" } }, + ); + + // Simulate user typing during task initialization + chatInputRef.current!.textContent = "Draft typed during init"; + + // Act - transition to real conversation ID + rerender({ conversationId: "conv-real-123" }); + + // Assert - draft should be saved to the new real conversation ID + expect(conversationLocalStorage.setConversationState).toHaveBeenCalledWith( + "conv-real-123", + { draftMessage: "Draft typed during init" }, + ); + + // And the draft should remain visible in the input + expect(chatInputRef.current?.textContent).toBe("Draft typed during init"); + }); + + it("does not transfer empty draft during task-to-real transition", () => { + // Arrange + const chatInputRef = createMockChatInputRef(""); + + vi.mocked(conversationLocalStorage.getConversationState).mockReturnValue({ + selectedTab: "editor", + rightPanelShown: true, + unpinnedTabs: [], + conversationMode: "code", + subConversationTaskId: null, + draftMessage: null, + }); + + const { rerender } = renderHook( + ({ conversationId }) => + useDraftPersistence(conversationId, chatInputRef), + { initialProps: { conversationId: "task-abc-123" } }, + ); + + // Act - transition to real conversation ID with empty input + rerender({ conversationId: "conv-real-123" }); + + // Assert - no draft should be saved (input is cleared, checked by hook) + // The setConversationState should not be called with draftMessage + expect(conversationLocalStorage.setConversationState).not.toHaveBeenCalled(); + }); + + it("does not transfer draft for non-task ID transitions", () => { + // Arrange + const chatInputRef = createMockChatInputRef("Some draft"); + + vi.mocked(conversationLocalStorage.getConversationState).mockReturnValue({ + selectedTab: "editor", + rightPanelShown: true, + unpinnedTabs: [], + conversationMode: "code", + subConversationTaskId: null, + draftMessage: null, + }); + + const { rerender } = renderHook( + ({ conversationId }) => + useDraftPersistence(conversationId, chatInputRef), + { initialProps: { conversationId: "conv-A" } }, + ); + + // Act - normal conversation switch (not task-to-real) + rerender({ conversationId: "conv-B" }); + + // Assert - should not use setConversationState directly + // (the normal path uses setDraftMessage from the hook) + expect(conversationLocalStorage.setConversationState).not.toHaveBeenCalled(); + }); + }); + + describe("hasDraft and isRestored state", () => { + it("returns hasDraft true when draft exists in hook state", () => { + // Arrange + const conversationId = "conv-has-draft"; + const chatInputRef = createMockChatInputRef(); + + vi.mocked(conversationLocalStorage.useConversationLocalStorageState).mockReturnValue({ + state: { + selectedTab: "editor", + rightPanelShown: true, + unpinnedTabs: [], + conversationMode: "code", + subConversationTaskId: null, + draftMessage: "Existing draft", + }, + setSelectedTab: vi.fn(), + setRightPanelShown: vi.fn(), + setUnpinnedTabs: vi.fn(), + setConversationMode: vi.fn(), + setDraftMessage: mockSetDraftMessage, + }); + + // Act + const { result } = renderHook(() => + useDraftPersistence(conversationId, chatInputRef), + ); + + // Assert + expect(result.current.hasDraft).toBe(true); + }); + + it("returns hasDraft false when no draft exists", () => { + // Arrange + const conversationId = "conv-no-draft"; + const chatInputRef = createMockChatInputRef(); + + // Act + const { result } = renderHook(() => + useDraftPersistence(conversationId, chatInputRef), + ); + + // Assert + expect(result.current.hasDraft).toBe(false); + }); + + it("sets isRestored to true after restoration completes", () => { + // Arrange + const conversationId = "conv-restored"; + const chatInputRef = createMockChatInputRef(); + + vi.mocked(conversationLocalStorage.getConversationState).mockReturnValue({ + selectedTab: "editor", + rightPanelShown: true, + unpinnedTabs: [], + conversationMode: "code", + subConversationTaskId: null, + draftMessage: "Draft to restore", + }); + + // Act + const { result } = renderHook(() => + useDraftPersistence(conversationId, chatInputRef), + ); + + // Assert + expect(result.current.isRestored).toBe(true); + }); + }); + + describe("cleanup on unmount", () => { + it("clears pending timeout on unmount", () => { + // Arrange + const conversationId = "conv-unmount"; + const chatInputRef = createMockChatInputRef(); + + const { result, unmount } = renderHook(() => + useDraftPersistence(conversationId, chatInputRef), + ); + + // Start a save + chatInputRef.current!.textContent = "Draft"; + act(() => { + result.current.saveDraft(); + }); + + // Unmount before debounce completes + unmount(); + + // Complete the debounce period + act(() => { + vi.advanceTimersByTime(500); + }); + + // Assert - save should not have been called after unmount + expect(mockSetDraftMessage).not.toHaveBeenCalled(); + }); + }); +}); diff --git a/frontend/__tests__/hooks/use-handle-plan-click.test.tsx b/frontend/__tests__/hooks/use-handle-plan-click.test.tsx index 067a208c81..fdaa4c06aa 100644 --- a/frontend/__tests__/hooks/use-handle-plan-click.test.tsx +++ b/frontend/__tests__/hooks/use-handle-plan-click.test.tsx @@ -88,6 +88,7 @@ describe("useHandlePlanClick", () => { unpinnedTabs: [], subConversationTaskId: null, conversationMode: "code", + draftMessage: null, }); }); @@ -117,6 +118,7 @@ describe("useHandlePlanClick", () => { unpinnedTabs: [], subConversationTaskId: storedTaskId, conversationMode: "code", + draftMessage: null, }); renderHook(() => useHandlePlanClick()); @@ -155,6 +157,7 @@ describe("useHandlePlanClick", () => { unpinnedTabs: [], subConversationTaskId: storedTaskId, conversationMode: "code", + draftMessage: null, }); renderHook(() => useHandlePlanClick()); diff --git a/frontend/src/api/pending-message-service/pending-message-service.api.ts b/frontend/src/api/pending-message-service/pending-message-service.api.ts new file mode 100644 index 0000000000..8c7ef73a8a --- /dev/null +++ b/frontend/src/api/pending-message-service/pending-message-service.api.ts @@ -0,0 +1,40 @@ +/** + * Pending Message Service + * + * This service handles server-side message queuing for V1 conversations. + * Messages can be queued when the WebSocket is not connected and will be + * delivered automatically when the conversation becomes ready. + */ + +import { openHands } from "../open-hands-axios"; +import type { + PendingMessageResponse, + QueuePendingMessageRequest, +} from "./pending-message-service.types"; + +class PendingMessageService { + /** + * Queue a message for delivery when conversation becomes ready. + * + * This endpoint allows users to submit messages even when the conversation's + * WebSocket connection is not yet established. Messages are stored server-side + * and delivered automatically when the conversation transitions to READY status. + * + * @param conversationId The conversation ID (can be task ID before conversation is ready) + * @param message The message to queue + * @returns PendingMessageResponse with the message ID and queue position + * @throws Error if too many pending messages (limit: 10 per conversation) + */ + static async queueMessage( + conversationId: string, + message: QueuePendingMessageRequest, + ): Promise { + const { data } = await openHands.post( + `/api/v1/conversations/${conversationId}/pending-messages`, + message, + ); + return data; + } +} + +export default PendingMessageService; diff --git a/frontend/src/api/pending-message-service/pending-message-service.types.ts b/frontend/src/api/pending-message-service/pending-message-service.types.ts new file mode 100644 index 0000000000..cf7b8dbf0e --- /dev/null +++ b/frontend/src/api/pending-message-service/pending-message-service.types.ts @@ -0,0 +1,22 @@ +/** + * Types for the pending message service + */ + +import type { V1MessageContent } from "../conversation-service/v1-conversation-service.types"; + +/** + * Response when queueing a pending message + */ +export interface PendingMessageResponse { + id: string; + queued: boolean; + position: number; +} + +/** + * Request to queue a pending message + */ +export interface QueuePendingMessageRequest { + role?: "user"; + content: V1MessageContent[]; +} diff --git a/frontend/src/components/features/chat/chat-interface.tsx b/frontend/src/components/features/chat/chat-interface.tsx index 85a9435678..43218149ae 100644 --- a/frontend/src/components/features/chat/chat-interface.tsx +++ b/frontend/src/components/features/chat/chat-interface.tsx @@ -190,8 +190,14 @@ export function ChatInterface() { const prompt = uploadedFiles.length > 0 ? `${content}\n\n${filePrompt}` : content; - send(createChatMessage(prompt, imageUrls, uploadedFiles, timestamp)); - setOptimisticUserMessage(content); + const result = await send( + createChatMessage(prompt, imageUrls, uploadedFiles, timestamp), + ); + // Only show optimistic UI if message was sent immediately via WebSocket + // If queued for later delivery, the message will appear when actually delivered + if (!result.queued) { + setOptimisticUserMessage(content); + } setMessageToSend(""); }; diff --git a/frontend/src/components/features/chat/custom-chat-input.tsx b/frontend/src/components/features/chat/custom-chat-input.tsx index 5fd92fdcd6..26a0f74ca9 100644 --- a/frontend/src/components/features/chat/custom-chat-input.tsx +++ b/frontend/src/components/features/chat/custom-chat-input.tsx @@ -60,6 +60,7 @@ export function CustomChatInput({ messageToSend, checkIsContentEmpty, clearEmptyContentHandler, + saveDraft, } = useChatInputLogic(); const { @@ -158,6 +159,7 @@ export function CustomChatInput({ onInput={() => { handleInput(); updateSlashMenu(); + saveDraft(); }} onPaste={handlePaste} onKeyDown={(e) => { diff --git a/frontend/src/components/features/chat/interactive-chat-box.tsx b/frontend/src/components/features/chat/interactive-chat-box.tsx index a2f1df8348..74818d1d6c 100644 --- a/frontend/src/components/features/chat/interactive-chat-box.tsx +++ b/frontend/src/components/features/chat/interactive-chat-box.tsx @@ -142,8 +142,9 @@ export function InteractiveChatBox({ onSubmit }: InteractiveChatBoxProps) { handleSubmit(suggestion); }; + // Allow users to submit messages during LOADING state - they will be + // queued server-side and delivered when the conversation becomes ready const isDisabled = - curAgentState === AgentState.LOADING || curAgentState === AgentState.AWAITING_USER_CONFIRMATION || isTaskPolling(subConversationTaskStatus); diff --git a/frontend/src/contexts/conversation-websocket-context.tsx b/frontend/src/contexts/conversation-websocket-context.tsx index 572ab4fd75..86863734b9 100644 --- a/frontend/src/contexts/conversation-websocket-context.tsx +++ b/frontend/src/contexts/conversation-websocket-context.tsx @@ -40,6 +40,7 @@ import type { V1SendMessageRequest, } from "#/api/conversation-service/v1-conversation-service.types"; import EventService from "#/api/event-service/event-service.api"; +import PendingMessageService from "#/api/pending-message-service/pending-message-service.api"; import { useConversationStore } from "#/stores/conversation-store"; import { isBudgetOrCreditError, trackError } from "#/utils/error-handler"; import { useTracking } from "#/hooks/use-tracking"; @@ -47,6 +48,7 @@ import { useReadConversationFile } from "#/hooks/mutation/use-read-conversation- import useMetricsStore from "#/stores/metrics-store"; import { I18nKey } from "#/i18n/declaration"; import { useConversationHistory } from "#/hooks/query/use-conversation-history"; +import { setConversationState } from "#/utils/conversation-local-storage"; // eslint-disable-next-line @typescript-eslint/naming-convention export type V1_WebSocketConnectionState = @@ -55,9 +57,13 @@ export type V1_WebSocketConnectionState = | "CLOSED" | "CLOSING"; +interface SendMessageResult { + queued: boolean; // true if message was queued for later delivery, false if sent immediately +} + interface ConversationWebSocketContextType { connectionState: V1_WebSocketConnectionState; - sendMessage: (message: V1SendMessageRequest) => Promise; + sendMessage: (message: V1SendMessageRequest) => Promise; isLoadingHistory: boolean; } @@ -397,6 +403,10 @@ export function ConversationWebSocketProvider({ // Clear optimistic user message when a user message is confirmed if (isUserMessageEvent(event)) { removeOptimisticUserMessage(); + // Clear draft from localStorage - message was successfully delivered + if (conversationId) { + setConversationState(conversationId, { draftMessage: null }); + } } // Handle cache invalidation for ActionEvent @@ -556,6 +566,11 @@ export function ConversationWebSocketProvider({ // Clear optimistic user message when a user message is confirmed if (isUserMessageEvent(event)) { removeOptimisticUserMessage(); + // Clear draft from localStorage - message was successfully delivered + // Use main conversationId since user types in main conversation input + if (conversationId) { + setConversationState(conversationId, { draftMessage: null }); + } } // Handle cache invalidation for ActionEvent @@ -810,21 +825,44 @@ export function ConversationWebSocketProvider({ ); // V1 send message function via WebSocket + // Falls back to REST API queue when WebSocket is not connected const sendMessage = useCallback( - async (message: V1SendMessageRequest) => { + async (message: V1SendMessageRequest): Promise => { const currentMode = useConversationStore.getState().conversationMode; const currentSocket = currentMode === "plan" ? planningAgentSocket : mainSocket; if (!currentSocket || currentSocket.readyState !== WebSocket.OPEN) { - const error = "WebSocket is not connected"; - setErrorMessage(error); - throw new Error(error); + // WebSocket not connected - queue message via REST API + // Message will be delivered automatically when conversation becomes ready + if (!conversationId) { + const error = new Error("No conversation ID available"); + setErrorMessage(error.message); + throw error; + } + + try { + await PendingMessageService.queueMessage(conversationId, { + role: "user", + content: message.content, + }); + // Message queued successfully - it will be delivered when ready + // Return queued: true so caller knows not to show optimistic UI + return { queued: true }; + } catch (error) { + const errorMessage = + error instanceof Error + ? error.message + : "Failed to queue message for delivery"; + setErrorMessage(errorMessage); + throw error; + } } try { // Send message through WebSocket as JSON currentSocket.send(JSON.stringify(message)); + return { queued: false }; } catch (error) { const errorMessage = error instanceof Error ? error.message : "Failed to send message"; @@ -832,7 +870,7 @@ export function ConversationWebSocketProvider({ throw error; } }, - [mainSocket, planningAgentSocket, setErrorMessage], + [mainSocket, planningAgentSocket, setErrorMessage, conversationId], ); // Track main socket state changes diff --git a/frontend/src/hooks/chat/use-chat-input-logic.ts b/frontend/src/hooks/chat/use-chat-input-logic.ts index 21dc682fc9..47a6fafacb 100644 --- a/frontend/src/hooks/chat/use-chat-input-logic.ts +++ b/frontend/src/hooks/chat/use-chat-input-logic.ts @@ -5,12 +5,15 @@ import { getTextContent, } from "#/components/features/chat/utils/chat-input.utils"; import { useConversationStore } from "#/stores/conversation-store"; +import { useConversationId } from "#/hooks/use-conversation-id"; +import { useDraftPersistence } from "./use-draft-persistence"; /** * Hook for managing chat input content logic */ export const useChatInputLogic = () => { const chatInputRef = useRef(null); + const { conversationId } = useConversationId(); const { messageToSend, @@ -19,6 +22,12 @@ export const useChatInputLogic = () => { setIsRightPanelShown, } = useConversationStore(); + // Draft persistence - saves to localStorage, restores on mount + const { saveDraft, clearDraft } = useDraftPersistence( + conversationId, + chatInputRef, + ); + // Save current input value when drawer state changes useEffect(() => { if (chatInputRef.current) { @@ -51,5 +60,7 @@ export const useChatInputLogic = () => { checkIsContentEmpty, clearEmptyContentHandler, getCurrentMessage, + saveDraft, + clearDraft, }; }; diff --git a/frontend/src/hooks/chat/use-draft-persistence.ts b/frontend/src/hooks/chat/use-draft-persistence.ts new file mode 100644 index 0000000000..fd958030b1 --- /dev/null +++ b/frontend/src/hooks/chat/use-draft-persistence.ts @@ -0,0 +1,179 @@ +import { useEffect, useRef, useCallback, useState } from "react"; +import { + useConversationLocalStorageState, + getConversationState, + setConversationState, +} from "#/utils/conversation-local-storage"; +import { getTextContent } from "#/components/features/chat/utils/chat-input.utils"; + +/** + * Check if a conversation ID is a temporary task ID. + * Task IDs have the format "task-{uuid}" and are used during V1 conversation initialization. + */ +const isTaskId = (id: string): boolean => id.startsWith("task-"); + +const DRAFT_SAVE_DEBOUNCE_MS = 500; + +/** + * Hook for persisting draft messages to localStorage. + * Handles debounced saving on input, restoration on mount, and clearing on confirmed delivery. + */ +export const useDraftPersistence = ( + conversationId: string, + chatInputRef: React.RefObject, +) => { + const { state, setDraftMessage } = + useConversationLocalStorageState(conversationId); + const saveTimeoutRef = useRef | null>(null); + const hasRestoredRef = useRef(false); + const [isRestored, setIsRestored] = useState(false); + + // Track current conversationId to prevent saving draft to wrong conversation + const currentConversationIdRef = useRef(conversationId); + // Track if this is the first mount to handle initial cleanup + const isFirstMountRef = useRef(true); + + // IMPORTANT: This effect must run FIRST when conversation changes. + // It handles three concerns: + // 1. Cleanup: Cancel pending saves from previous conversation + // 2. Task-to-real transition: Preserve draft typed during initialization + // 3. DOM reset: Clear stale content before restoration effect runs + useEffect(() => { + const previousConversationId = currentConversationIdRef.current; + const isInitialMount = isFirstMountRef.current; + currentConversationIdRef.current = conversationId; + isFirstMountRef.current = false; + + // --- 1. Cancel pending saves from previous conversation --- + // Prevents draft from being saved to wrong conversation if user switched quickly + if (saveTimeoutRef.current) { + clearTimeout(saveTimeoutRef.current); + saveTimeoutRef.current = null; + } + + const element = chatInputRef.current; + + // --- 2. Handle task-to-real ID transition (preserve draft during initialization) --- + // When a new V1 conversation initializes, it starts with a temporary "task-xxx" ID + // that transitions to a real conversation ID once ready. Task IDs don't persist + // to localStorage, so any draft typed during this phase would be lost. + // We detect this transition and transfer the draft to the new real ID. + if (!isInitialMount && previousConversationId !== conversationId) { + const wasTaskId = isTaskId(previousConversationId); + const isNowRealId = !isTaskId(conversationId); + + if (wasTaskId && isNowRealId && element) { + const currentText = getTextContent(element).trim(); + if (currentText) { + // Transfer draft to the new (real) conversation ID + setConversationState(conversationId, { draftMessage: currentText }); + // Keep draft visible in DOM and mark as restored to prevent overwrite + hasRestoredRef.current = true; + setIsRestored(true); + return; // Skip normal cleanup - draft is already in correct state + } + } + } + + // --- 3. Clear stale DOM content (will be restored by next effect if draft exists) --- + // This prevents stale drafts from appearing in new conversations due to: + // - Browser form restoration on back/forward navigation + // - React DOM recycling between conversation switches + // The restoration effect will then populate with the correct saved draft + if (element) { + element.textContent = ""; + } + + // Reset restoration flag so the restoration effect will run for new conversation + hasRestoredRef.current = false; + setIsRestored(false); + }, [conversationId, chatInputRef]); + + // Restore draft from localStorage - reads directly to avoid state sync timing issues + useEffect(() => { + if (hasRestoredRef.current) { + return; + } + + const element = chatInputRef.current; + if (!element) { + return; + } + + // Read directly from localStorage to avoid stale state from useConversationLocalStorageState + // The hook's state may not have synced yet after conversationId change + const { draftMessage } = getConversationState(conversationId); + + // Only restore if there's a saved draft and the input is empty + if (draftMessage && getTextContent(element).trim() === "") { + element.textContent = draftMessage; + // Move cursor to end + const selection = window.getSelection(); + const range = document.createRange(); + range.selectNodeContents(element); + range.collapse(false); + selection?.removeAllRanges(); + selection?.addRange(range); + } + + hasRestoredRef.current = true; + setIsRestored(true); + }, [chatInputRef, conversationId]); + + // Debounced save function - called from onInput handler + const saveDraft = useCallback(() => { + // Clear any pending save + if (saveTimeoutRef.current) { + clearTimeout(saveTimeoutRef.current); + } + + // Capture the conversationId at the time of input + const capturedConversationId = conversationId; + + saveTimeoutRef.current = setTimeout(() => { + // Verify we're still on the same conversation before saving + // This prevents saving draft to wrong conversation if user switched quickly + if (capturedConversationId !== currentConversationIdRef.current) { + return; + } + + const element = chatInputRef.current; + if (!element) { + return; + } + + const text = getTextContent(element).trim(); + // Only save if content has changed + if (text !== (state.draftMessage || "")) { + setDraftMessage(text || null); + } + }, DRAFT_SAVE_DEBOUNCE_MS); + }, [chatInputRef, state.draftMessage, setDraftMessage, conversationId]); + + // Clear draft - called after message delivery is confirmed + const clearDraft = useCallback(() => { + // Cancel any pending save + if (saveTimeoutRef.current) { + clearTimeout(saveTimeoutRef.current); + saveTimeoutRef.current = null; + } + setDraftMessage(null); + }, [setDraftMessage]); + + // Cleanup timeout on unmount + useEffect( + () => () => { + if (saveTimeoutRef.current) { + clearTimeout(saveTimeoutRef.current); + } + }, + [], + ); + + return { + saveDraft, + clearDraft, + isRestored, + hasDraft: !!state.draftMessage, + }; +}; diff --git a/frontend/src/hooks/use-send-message.ts b/frontend/src/hooks/use-send-message.ts index 4da5eafc2e..3f641521e6 100644 --- a/frontend/src/hooks/use-send-message.ts +++ b/frontend/src/hooks/use-send-message.ts @@ -5,6 +5,10 @@ import { useConversationWebSocket } from "#/contexts/conversation-websocket-cont import { useConversationId } from "#/hooks/use-conversation-id"; import { V1MessageContent } from "#/api/conversation-service/v1-conversation-service.types"; +interface SendResult { + queued: boolean; // true if message was queued for later delivery +} + /** * Unified hook for sending messages that works with both V0 and V1 conversations * - For V0 conversations: Uses Socket.IO WebSocket via useWsClient @@ -26,7 +30,7 @@ export function useSendMessage() { conversation?.conversation_version === "V1"; const send = useCallback( - async (event: Record) => { + async (event: Record): Promise => { if (isV1Conversation && v1Context) { // V1: Convert V0 event format to V1 message format const { action, args } = event as { @@ -57,19 +61,20 @@ export function useSendMessage() { } // Send via V1 WebSocket context (uses correct host/port) - await v1Context.sendMessage({ + const result = await v1Context.sendMessage({ role: "user", content, }); - } else { - // For non-message events, fall back to V0 send - // (e.g., agent state changes, other control events) - v0Send(event); + return result; } - } else { - // V0: Use Socket.IO + // For non-message events, fall back to V0 send + // (e.g., agent state changes, other control events) v0Send(event); + return { queued: false }; } + // V0: Use Socket.IO + v0Send(event); + return { queued: false }; }, [isV1Conversation, v1Context, v0Send, conversationId], ); diff --git a/frontend/src/utils/conversation-local-storage.ts b/frontend/src/utils/conversation-local-storage.ts index de16da9f55..4beb800b88 100644 --- a/frontend/src/utils/conversation-local-storage.ts +++ b/frontend/src/utils/conversation-local-storage.ts @@ -23,6 +23,7 @@ export interface ConversationState { unpinnedTabs: string[]; conversationMode: ConversationMode; subConversationTaskId: string | null; + draftMessage: string | null; } const DEFAULT_CONVERSATION_STATE: ConversationState = { @@ -31,6 +32,7 @@ const DEFAULT_CONVERSATION_STATE: ConversationState = { unpinnedTabs: [], conversationMode: "code", subConversationTaskId: null, + draftMessage: null, }; /** @@ -121,6 +123,7 @@ export function useConversationLocalStorageState(conversationId: string): { setRightPanelShown: (shown: boolean) => void; setUnpinnedTabs: (tabs: string[]) => void; setConversationMode: (mode: ConversationMode) => void; + setDraftMessage: (message: string | null) => void; } { const [state, setState] = useState(() => getConversationState(conversationId), @@ -178,5 +181,6 @@ export function useConversationLocalStorageState(conversationId: string): { setRightPanelShown: (shown) => updateState({ rightPanelShown: shown }), setUnpinnedTabs: (tabs) => updateState({ unpinnedTabs: tabs }), setConversationMode: (mode) => updateState({ conversationMode: mode }), + setDraftMessage: (message) => updateState({ draftMessage: message }), }; } diff --git a/openhands/app_server/app_conversation/live_status_app_conversation_service.py b/openhands/app_server/app_conversation/live_status_app_conversation_service.py index 7bb59fedbf..fe07f205c1 100644 --- a/openhands/app_server/app_conversation/live_status_app_conversation_service.py +++ b/openhands/app_server/app_conversation/live_status_app_conversation_service.py @@ -59,6 +59,9 @@ from openhands.app_server.event_callback.event_callback_service import ( from openhands.app_server.event_callback.set_title_callback_processor import ( SetTitleCallbackProcessor, ) +from openhands.app_server.pending_messages.pending_message_service import ( + PendingMessageService, +) from openhands.app_server.sandbox.docker_sandbox_service import DockerSandboxService from openhands.app_server.sandbox.sandbox_models import ( AGENT_SERVER, @@ -127,6 +130,7 @@ class LiveStatusAppConversationService(AppConversationServiceBase): sandbox_service: SandboxService sandbox_spec_service: SandboxSpecService jwt_service: JwtService + pending_message_service: PendingMessageService sandbox_startup_timeout: int sandbox_startup_poll_frequency: int max_num_conversations_per_sandbox: int @@ -373,6 +377,15 @@ class LiveStatusAppConversationService(AppConversationServiceBase): task.app_conversation_id = info.id yield task + # Process any pending messages queued while waiting for conversation + if sandbox.session_api_key: + await self._process_pending_messages( + task_id=task.id, + conversation_id=info.id, + agent_server_url=agent_server_url, + session_api_key=sandbox.session_api_key, + ) + except Exception as exc: _logger.exception('Error starting conversation', stack_info=True) task.status = AppConversationStartTaskStatus.ERROR @@ -1424,6 +1437,89 @@ class LiveStatusAppConversationService(AppConversationServiceBase): plugins=plugins, ) + async def _process_pending_messages( + self, + task_id: UUID, + conversation_id: UUID, + agent_server_url: str, + session_api_key: str, + ) -> None: + """Process pending messages queued before conversation was ready. + + Messages are delivered concurrently to the agent server. After processing, + all messages are deleted from the database regardless of success or failure. + + Args: + task_id: The start task ID (may have been used as conversation_id initially) + conversation_id: The real conversation ID + agent_server_url: URL of the agent server + session_api_key: API key for authenticating with agent server + """ + # Convert UUIDs to strings for the pending message service + # The frontend uses task-{uuid.hex} format (no hyphens), matching OpenHandsUUID serialization + task_id_str = f'task-{task_id.hex}' + # conversation_id uses standard format (with hyphens) for agent server API compatibility + conversation_id_str = str(conversation_id) + + _logger.info(f'task_id={task_id_str} conversation_id={conversation_id_str}') + + # First, update any messages that were queued with the task_id + updated_count = await self.pending_message_service.update_conversation_id( + old_conversation_id=task_id_str, + new_conversation_id=conversation_id_str, + ) + _logger.info(f'updated_count={updated_count} ') + if updated_count > 0: + _logger.info( + f'Updated {updated_count} pending messages from task_id={task_id_str} ' + f'to conversation_id={conversation_id_str}' + ) + + # Get all pending messages for this conversation + pending_messages = await self.pending_message_service.get_pending_messages( + conversation_id_str + ) + + if not pending_messages: + return + + _logger.info( + f'Processing {len(pending_messages)} pending messages for ' + f'conversation {conversation_id_str}' + ) + + # Process messages sequentially to preserve order + for msg in pending_messages: + try: + # Serialize content objects to JSON-compatible dicts + content_json = [item.model_dump() for item in msg.content] + # Use the events endpoint which handles message sending + response = await self.httpx_client.post( + f'{agent_server_url}/api/conversations/{conversation_id_str}/events', + json={ + 'role': msg.role, + 'content': content_json, + 'run': True, + }, + headers={'X-Session-API-Key': session_api_key}, + timeout=30.0, + ) + response.raise_for_status() + _logger.debug(f'Delivered pending message {msg.id}') + except Exception as e: + _logger.warning(f'Failed to deliver pending message {msg.id}: {e}') + + # Delete all pending messages after processing (regardless of success/failure) + deleted_count = ( + await self.pending_message_service.delete_messages_for_conversation( + conversation_id_str + ) + ) + _logger.info( + f'Finished processing pending messages for conversation {conversation_id_str}. ' + f'Deleted {deleted_count} messages.' + ) + async def update_agent_server_conversation_title( self, conversation_id: str, @@ -1796,6 +1892,7 @@ class LiveStatusAppConversationServiceInjector(AppConversationServiceInjector): get_global_config, get_httpx_client, get_jwt_service, + get_pending_message_service, get_sandbox_service, get_sandbox_spec_service, get_user_context, @@ -1815,6 +1912,7 @@ class LiveStatusAppConversationServiceInjector(AppConversationServiceInjector): get_event_service(state, request) as event_service, get_jwt_service(state, request) as jwt_service, get_httpx_client(state, request) as httpx_client, + get_pending_message_service(state, request) as pending_message_service, ): access_token_hard_timeout = None if self.access_token_hard_timeout: @@ -1859,6 +1957,7 @@ class LiveStatusAppConversationServiceInjector(AppConversationServiceInjector): event_callback_service=event_callback_service, event_service=event_service, jwt_service=jwt_service, + pending_message_service=pending_message_service, sandbox_startup_timeout=self.sandbox_startup_timeout, sandbox_startup_poll_frequency=self.sandbox_startup_poll_frequency, max_num_conversations_per_sandbox=self.max_num_conversations_per_sandbox, diff --git a/openhands/app_server/app_lifespan/alembic/versions/007.py b/openhands/app_server/app_lifespan/alembic/versions/007.py new file mode 100644 index 0000000000..ef0b34b2eb --- /dev/null +++ b/openhands/app_server/app_lifespan/alembic/versions/007.py @@ -0,0 +1,39 @@ +"""Add pending_messages table for server-side message queuing + +Revision ID: 007 +Revises: 006 +Create Date: 2025-03-15 00:00:00.000000 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = '007' +down_revision: Union[str, None] = '006' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Create pending_messages table for storing messages before conversation is ready. + + Messages are stored temporarily until the conversation becomes ready, then + delivered and deleted regardless of success or failure. + """ + op.create_table( + 'pending_messages', + sa.Column('id', sa.String(), primary_key=True), + sa.Column('conversation_id', sa.String(), nullable=False, index=True), + sa.Column('role', sa.String(20), nullable=False, server_default='user'), + sa.Column('content', sa.JSON, nullable=False), + sa.Column('created_at', sa.DateTime(timezone=True), nullable=False), + ) + + +def downgrade() -> None: + """Remove pending_messages table.""" + op.drop_table('pending_messages') diff --git a/openhands/app_server/config.py b/openhands/app_server/config.py index 8c0ded6d2e..4b7f78e389 100644 --- a/openhands/app_server/config.py +++ b/openhands/app_server/config.py @@ -33,6 +33,10 @@ from openhands.app_server.event_callback.event_callback_service import ( EventCallbackService, EventCallbackServiceInjector, ) +from openhands.app_server.pending_messages.pending_message_service import ( + PendingMessageService, + PendingMessageServiceInjector, +) from openhands.app_server.sandbox.sandbox_service import ( SandboxService, SandboxServiceInjector, @@ -114,6 +118,7 @@ class AppServerConfig(OpenHandsModel): app_conversation_info: AppConversationInfoServiceInjector | None = None app_conversation_start_task: AppConversationStartTaskServiceInjector | None = None app_conversation: AppConversationServiceInjector | None = None + pending_message: PendingMessageServiceInjector | None = None user: UserContextInjector | None = None jwt: JwtServiceInjector | None = None httpx: HttpxClientInjector = Field(default_factory=HttpxClientInjector) @@ -280,6 +285,13 @@ def config_from_env() -> AppServerConfig: tavily_api_key=tavily_api_key ) + if config.pending_message is None: + from openhands.app_server.pending_messages.pending_message_service import ( + SQLPendingMessageServiceInjector, + ) + + config.pending_message = SQLPendingMessageServiceInjector() + if config.user is None: config.user = AuthUserContextInjector() @@ -358,6 +370,14 @@ def get_app_conversation_service( return injector.context(state, request) +def get_pending_message_service( + state: InjectorState, request: Request | None = None +) -> AsyncContextManager[PendingMessageService]: + injector = get_global_config().pending_message + assert injector is not None + return injector.context(state, request) + + def get_user_context( state: InjectorState, request: Request | None = None ) -> AsyncContextManager[UserContext]: @@ -433,6 +453,12 @@ def depends_app_conversation_service(): return Depends(injector.depends) +def depends_pending_message_service(): + injector = get_global_config().pending_message + assert injector is not None + return Depends(injector.depends) + + def depends_user_context(): injector = get_global_config().user assert injector is not None diff --git a/openhands/app_server/pending_messages/__init__.py b/openhands/app_server/pending_messages/__init__.py new file mode 100644 index 0000000000..5aa37fc675 --- /dev/null +++ b/openhands/app_server/pending_messages/__init__.py @@ -0,0 +1,21 @@ +"""Pending messages module for server-side message queuing.""" + +from openhands.app_server.pending_messages.pending_message_models import ( + PendingMessage, + PendingMessageResponse, +) +from openhands.app_server.pending_messages.pending_message_service import ( + PendingMessageService, + PendingMessageServiceInjector, + SQLPendingMessageService, + SQLPendingMessageServiceInjector, +) + +__all__ = [ + 'PendingMessage', + 'PendingMessageResponse', + 'PendingMessageService', + 'PendingMessageServiceInjector', + 'SQLPendingMessageService', + 'SQLPendingMessageServiceInjector', +] diff --git a/openhands/app_server/pending_messages/pending_message_models.py b/openhands/app_server/pending_messages/pending_message_models.py new file mode 100644 index 0000000000..9e0062b185 --- /dev/null +++ b/openhands/app_server/pending_messages/pending_message_models.py @@ -0,0 +1,32 @@ +"""Models for pending message queue functionality.""" + +from datetime import datetime +from uuid import uuid4 + +from pydantic import BaseModel, Field + +from openhands.agent_server.models import ImageContent, TextContent +from openhands.agent_server.utils import utc_now + + +class PendingMessage(BaseModel): + """A message queued for delivery when conversation becomes ready. + + Pending messages are stored in the database and delivered to the agent_server + when the conversation transitions to READY status. Messages are deleted after + processing, regardless of success or failure. + """ + + id: str = Field(default_factory=lambda: str(uuid4())) + conversation_id: str # Can be task-{uuid} or real conversation UUID + role: str = 'user' + content: list[TextContent | ImageContent] + created_at: datetime = Field(default_factory=utc_now) + + +class PendingMessageResponse(BaseModel): + """Response when queueing a pending message.""" + + id: str + queued: bool + position: int = Field(description='Position in the queue (1-based)') diff --git a/openhands/app_server/pending_messages/pending_message_router.py b/openhands/app_server/pending_messages/pending_message_router.py new file mode 100644 index 0000000000..7c78e2d6eb --- /dev/null +++ b/openhands/app_server/pending_messages/pending_message_router.py @@ -0,0 +1,104 @@ +"""REST API router for pending messages.""" + +import logging + +from fastapi import APIRouter, HTTPException, Request, status +from pydantic import TypeAdapter, ValidationError + +from openhands.agent_server.models import ImageContent, TextContent +from openhands.app_server.config import depends_pending_message_service +from openhands.app_server.pending_messages.pending_message_models import ( + PendingMessageResponse, +) +from openhands.app_server.pending_messages.pending_message_service import ( + PendingMessageService, +) +from openhands.server.dependencies import get_dependencies + +logger = logging.getLogger(__name__) + +# Type adapter for validating content from request +_content_type_adapter = TypeAdapter(list[TextContent | ImageContent]) + +# Create router with authentication dependencies +router = APIRouter( + prefix='/conversations/{conversation_id}/pending-messages', + tags=['Pending Messages'], + dependencies=get_dependencies(), +) + +# Create dependency at module level +pending_message_service_dependency = depends_pending_message_service() + + +@router.post( + '', response_model=PendingMessageResponse, status_code=status.HTTP_201_CREATED +) +async def queue_pending_message( + conversation_id: str, + request: Request, + pending_service: PendingMessageService = pending_message_service_dependency, +) -> PendingMessageResponse: + """Queue a message for delivery when conversation becomes ready. + + This endpoint allows users to submit messages even when the conversation's + WebSocket connection is not yet established. Messages are stored server-side + and delivered automatically when the conversation transitions to READY status. + + Args: + conversation_id: The conversation ID (can be task ID before conversation is ready) + request: The FastAPI request containing message content + + Returns: + PendingMessageResponse with the message ID and queue position + + Raises: + HTTPException 400: If the request body is invalid + HTTPException 429: If too many pending messages are queued (limit: 10) + """ + try: + body = await request.json() + except Exception: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail='Invalid request body', + ) + + raw_content = body.get('content') + role = body.get('role', 'user') + + if not raw_content or not isinstance(raw_content, list): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail='content must be a non-empty list', + ) + + # Validate and parse content into typed objects + try: + content = _content_type_adapter.validate_python(raw_content) + except ValidationError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f'Invalid content format: {e}', + ) + + # Rate limit: max 10 pending messages per conversation + pending_count = await pending_service.count_pending_messages(conversation_id) + if pending_count >= 10: + raise HTTPException( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + detail='Too many pending messages. Maximum 10 messages per conversation.', + ) + + response = await pending_service.add_message( + conversation_id=conversation_id, + content=content, + role=role, + ) + + logger.info( + f'Queued pending message {response.id} for conversation {conversation_id} ' + f'(position: {response.position})' + ) + + return response diff --git a/openhands/app_server/pending_messages/pending_message_service.py b/openhands/app_server/pending_messages/pending_message_service.py new file mode 100644 index 0000000000..44d426c409 --- /dev/null +++ b/openhands/app_server/pending_messages/pending_message_service.py @@ -0,0 +1,200 @@ +"""Service for managing pending messages in SQL database.""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import AsyncGenerator + +from fastapi import Request +from pydantic import TypeAdapter +from sqlalchemy import JSON, Column, String, func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from openhands.agent_server.models import ImageContent, TextContent +from openhands.app_server.pending_messages.pending_message_models import ( + PendingMessage, + PendingMessageResponse, +) +from openhands.app_server.services.injector import Injector, InjectorState +from openhands.app_server.utils.sql_utils import Base, UtcDateTime +from openhands.sdk.utils.models import DiscriminatedUnionMixin + +# Type adapter for deserializing content from JSON +_content_type_adapter = TypeAdapter(list[TextContent | ImageContent]) + + +class StoredPendingMessage(Base): # type: ignore + """SQLAlchemy model for pending messages.""" + + __tablename__ = 'pending_messages' + id = Column(String, primary_key=True) + conversation_id = Column(String, nullable=False, index=True) + role = Column(String(20), nullable=False, default='user') + content = Column(JSON, nullable=False) + created_at = Column(UtcDateTime, server_default=func.now(), index=True) + + +class PendingMessageService(ABC): + """Abstract service for managing pending messages.""" + + @abstractmethod + async def add_message( + self, + conversation_id: str, + content: list[TextContent | ImageContent], + role: str = 'user', + ) -> PendingMessageResponse: + """Queue a message for delivery when conversation becomes ready.""" + + @abstractmethod + async def get_pending_messages(self, conversation_id: str) -> list[PendingMessage]: + """Get all pending messages for a conversation, ordered by created_at.""" + + @abstractmethod + async def count_pending_messages(self, conversation_id: str) -> int: + """Count pending messages for a conversation.""" + + @abstractmethod + async def delete_messages_for_conversation(self, conversation_id: str) -> int: + """Delete all pending messages for a conversation, returning count deleted.""" + + @abstractmethod + async def update_conversation_id( + self, old_conversation_id: str, new_conversation_id: str + ) -> int: + """Update conversation_id when task-id transitions to real conversation-id. + + Returns the number of messages updated. + """ + + +@dataclass +class SQLPendingMessageService(PendingMessageService): + """SQL implementation of PendingMessageService.""" + + db_session: AsyncSession + + async def add_message( + self, + conversation_id: str, + content: list[TextContent | ImageContent], + role: str = 'user', + ) -> PendingMessageResponse: + """Queue a message for delivery when conversation becomes ready.""" + # Create the pending message + pending_message = PendingMessage( + conversation_id=conversation_id, + role=role, + content=content, + ) + + # Count existing pending messages for position + count_stmt = select(func.count()).where( + StoredPendingMessage.conversation_id == conversation_id + ) + result = await self.db_session.execute(count_stmt) + position = result.scalar() or 0 + + # Serialize content to JSON-compatible format for storage + content_json = [item.model_dump() for item in content] + + # Store in database + stored_message = StoredPendingMessage( + id=str(pending_message.id), + conversation_id=conversation_id, + role=role, + content=content_json, + created_at=pending_message.created_at, + ) + self.db_session.add(stored_message) + await self.db_session.commit() + + return PendingMessageResponse( + id=pending_message.id, + queued=True, + position=position + 1, + ) + + async def get_pending_messages(self, conversation_id: str) -> list[PendingMessage]: + """Get all pending messages for a conversation, ordered by created_at.""" + stmt = ( + select(StoredPendingMessage) + .where(StoredPendingMessage.conversation_id == conversation_id) + .order_by(StoredPendingMessage.created_at.asc()) + ) + result = await self.db_session.execute(stmt) + stored_messages = result.scalars().all() + + return [ + PendingMessage( + id=msg.id, + conversation_id=msg.conversation_id, + role=msg.role, + content=_content_type_adapter.validate_python(msg.content), + created_at=msg.created_at, + ) + for msg in stored_messages + ] + + async def count_pending_messages(self, conversation_id: str) -> int: + """Count pending messages for a conversation.""" + count_stmt = select(func.count()).where( + StoredPendingMessage.conversation_id == conversation_id + ) + result = await self.db_session.execute(count_stmt) + return result.scalar() or 0 + + async def delete_messages_for_conversation(self, conversation_id: str) -> int: + """Delete all pending messages for a conversation, returning count deleted.""" + stmt = select(StoredPendingMessage).where( + StoredPendingMessage.conversation_id == conversation_id + ) + result = await self.db_session.execute(stmt) + stored_messages = result.scalars().all() + + count = len(stored_messages) + for msg in stored_messages: + await self.db_session.delete(msg) + + if count > 0: + await self.db_session.commit() + + return count + + async def update_conversation_id( + self, old_conversation_id: str, new_conversation_id: str + ) -> int: + """Update conversation_id when task-id transitions to real conversation-id.""" + stmt = select(StoredPendingMessage).where( + StoredPendingMessage.conversation_id == old_conversation_id + ) + result = await self.db_session.execute(stmt) + stored_messages = result.scalars().all() + + count = len(stored_messages) + for msg in stored_messages: + msg.conversation_id = new_conversation_id + + if count > 0: + await self.db_session.commit() + + return count + + +class PendingMessageServiceInjector( + DiscriminatedUnionMixin, Injector[PendingMessageService], ABC +): + """Abstract injector for PendingMessageService.""" + + pass + + +class SQLPendingMessageServiceInjector(PendingMessageServiceInjector): + """SQL-based injector for PendingMessageService.""" + + async def inject( + self, state: InjectorState, request: Request | None = None + ) -> AsyncGenerator[PendingMessageService, None]: + from openhands.app_server.config import get_db_session + + async with get_db_session(state) as db_session: + yield SQLPendingMessageService(db_session=db_session) diff --git a/openhands/app_server/v1_router.py b/openhands/app_server/v1_router.py index 2a21c06abd..81823b481c 100644 --- a/openhands/app_server/v1_router.py +++ b/openhands/app_server/v1_router.py @@ -5,6 +5,9 @@ from openhands.app_server.event import event_router from openhands.app_server.event_callback import ( webhook_router, ) +from openhands.app_server.pending_messages.pending_message_router import ( + router as pending_message_router, +) from openhands.app_server.sandbox import sandbox_router, sandbox_spec_router from openhands.app_server.user import user_router from openhands.app_server.web_client import web_client_router @@ -13,6 +16,7 @@ from openhands.app_server.web_client import web_client_router router = APIRouter(prefix='/api/v1') router.include_router(event_router.router) router.include_router(app_conversation_router.router) +router.include_router(pending_message_router) router.include_router(sandbox_router.router) router.include_router(sandbox_spec_router.router) router.include_router(user_router.router) diff --git a/tests/unit/app_server/test_live_status_app_conversation_service.py b/tests/unit/app_server/test_live_status_app_conversation_service.py index cf32cfaf05..fcb251797f 100644 --- a/tests/unit/app_server/test_live_status_app_conversation_service.py +++ b/tests/unit/app_server/test_live_status_app_conversation_service.py @@ -80,6 +80,7 @@ class TestLiveStatusAppConversationService: self.mock_event_callback_service = Mock() self.mock_event_service = Mock() self.mock_httpx_client = Mock() + self.mock_pending_message_service = Mock() # Create service instance self.service = LiveStatusAppConversationService( @@ -92,6 +93,7 @@ class TestLiveStatusAppConversationService: sandbox_service=self.mock_sandbox_service, sandbox_spec_service=self.mock_sandbox_spec_service, jwt_service=self.mock_jwt_service, + pending_message_service=self.mock_pending_message_service, sandbox_startup_timeout=30, sandbox_startup_poll_frequency=1, max_num_conversations_per_sandbox=20, @@ -2329,6 +2331,7 @@ class TestPluginHandling: self.mock_event_callback_service = Mock() self.mock_event_service = Mock() self.mock_httpx_client = Mock() + self.mock_pending_message_service = Mock() # Create service instance self.service = LiveStatusAppConversationService( @@ -2341,6 +2344,7 @@ class TestPluginHandling: sandbox_service=self.mock_sandbox_service, sandbox_spec_service=self.mock_sandbox_spec_service, jwt_service=self.mock_jwt_service, + pending_message_service=self.mock_pending_message_service, sandbox_startup_timeout=30, sandbox_startup_poll_frequency=1, max_num_conversations_per_sandbox=20, diff --git a/tests/unit/app_server/test_pending_message_router.py b/tests/unit/app_server/test_pending_message_router.py new file mode 100644 index 0000000000..92dbe2c4a4 --- /dev/null +++ b/tests/unit/app_server/test_pending_message_router.py @@ -0,0 +1,227 @@ +"""Unit tests for the pending_message_router endpoints. + +This module tests the queue_pending_message endpoint, +focusing on request validation and rate limiting. +""" + +from unittest.mock import AsyncMock, MagicMock +from uuid import uuid4 + +import pytest +from fastapi import HTTPException, status + +from openhands.agent_server.models import TextContent +from openhands.app_server.pending_messages.pending_message_models import ( + PendingMessageResponse, +) +from openhands.app_server.pending_messages.pending_message_router import ( + queue_pending_message, +) + + +def _make_mock_service( + add_message_return=None, + count_pending_messages_return=0, +): + """Create a mock PendingMessageService for testing.""" + service = MagicMock() + service.add_message = AsyncMock(return_value=add_message_return) + service.count_pending_messages = AsyncMock( + return_value=count_pending_messages_return + ) + return service + + +def _make_mock_request(body: dict): + """Create a mock FastAPI Request with given JSON body.""" + request = MagicMock() + request.json = AsyncMock(return_value=body) + return request + + +@pytest.mark.asyncio +class TestQueuePendingMessage: + """Test suite for queue_pending_message endpoint.""" + + async def test_queues_message_successfully(self): + """Test that a valid message is queued successfully.""" + # Arrange + conversation_id = f'task-{uuid4().hex}' + raw_content = [{'type': 'text', 'text': 'Hello, world!'}] + expected_response = PendingMessageResponse( + id=str(uuid4()), + queued=True, + position=1, + ) + mock_service = _make_mock_service( + add_message_return=expected_response, + count_pending_messages_return=0, + ) + mock_request = _make_mock_request({'content': raw_content, 'role': 'user'}) + + # Act + result = await queue_pending_message( + conversation_id=conversation_id, + request=mock_request, + pending_service=mock_service, + ) + + # Assert + assert result == expected_response + mock_service.add_message.assert_called_once() + call_kwargs = mock_service.add_message.call_args.kwargs + assert call_kwargs['conversation_id'] == conversation_id + assert call_kwargs['role'] == 'user' + # Content should be parsed into typed objects + assert len(call_kwargs['content']) == 1 + assert isinstance(call_kwargs['content'][0], TextContent) + assert call_kwargs['content'][0].text == 'Hello, world!' + + async def test_uses_default_role_when_not_provided(self): + """Test that 'user' role is used by default.""" + # Arrange + conversation_id = f'task-{uuid4().hex}' + raw_content = [{'type': 'text', 'text': 'Test message'}] + expected_response = PendingMessageResponse( + id=str(uuid4()), + queued=True, + position=1, + ) + mock_service = _make_mock_service( + add_message_return=expected_response, + count_pending_messages_return=0, + ) + mock_request = _make_mock_request({'content': raw_content}) + + # Act + await queue_pending_message( + conversation_id=conversation_id, + request=mock_request, + pending_service=mock_service, + ) + + # Assert + mock_service.add_message.assert_called_once() + call_kwargs = mock_service.add_message.call_args.kwargs + assert call_kwargs['conversation_id'] == conversation_id + assert call_kwargs['role'] == 'user' + assert isinstance(call_kwargs['content'][0], TextContent) + + async def test_returns_400_for_invalid_json_body(self): + """Test that invalid JSON body returns 400 Bad Request.""" + # Arrange + conversation_id = f'task-{uuid4().hex}' + mock_service = _make_mock_service() + mock_request = MagicMock() + mock_request.json = AsyncMock(side_effect=Exception('Invalid JSON')) + + # Act & Assert + with pytest.raises(HTTPException) as exc_info: + await queue_pending_message( + conversation_id=conversation_id, + request=mock_request, + pending_service=mock_service, + ) + + assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST + assert 'Invalid request body' in exc_info.value.detail + + async def test_returns_400_when_content_is_missing(self): + """Test that missing content returns 400 Bad Request.""" + # Arrange + conversation_id = f'task-{uuid4().hex}' + mock_service = _make_mock_service() + mock_request = _make_mock_request({'role': 'user'}) + + # Act & Assert + with pytest.raises(HTTPException) as exc_info: + await queue_pending_message( + conversation_id=conversation_id, + request=mock_request, + pending_service=mock_service, + ) + + assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST + assert 'content must be a non-empty list' in exc_info.value.detail + + async def test_returns_400_when_content_is_not_a_list(self): + """Test that non-list content returns 400 Bad Request.""" + # Arrange + conversation_id = f'task-{uuid4().hex}' + mock_service = _make_mock_service() + mock_request = _make_mock_request({'content': 'not a list'}) + + # Act & Assert + with pytest.raises(HTTPException) as exc_info: + await queue_pending_message( + conversation_id=conversation_id, + request=mock_request, + pending_service=mock_service, + ) + + assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST + assert 'content must be a non-empty list' in exc_info.value.detail + + async def test_returns_400_when_content_is_empty_list(self): + """Test that empty list content returns 400 Bad Request.""" + # Arrange + conversation_id = f'task-{uuid4().hex}' + mock_service = _make_mock_service() + mock_request = _make_mock_request({'content': []}) + + # Act & Assert + with pytest.raises(HTTPException) as exc_info: + await queue_pending_message( + conversation_id=conversation_id, + request=mock_request, + pending_service=mock_service, + ) + + assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST + assert 'content must be a non-empty list' in exc_info.value.detail + + async def test_returns_429_when_rate_limit_exceeded(self): + """Test that exceeding rate limit returns 429 Too Many Requests.""" + # Arrange + conversation_id = f'task-{uuid4().hex}' + raw_content = [{'type': 'text', 'text': 'Test message'}] + mock_service = _make_mock_service(count_pending_messages_return=10) + mock_request = _make_mock_request({'content': raw_content}) + + # Act & Assert + with pytest.raises(HTTPException) as exc_info: + await queue_pending_message( + conversation_id=conversation_id, + request=mock_request, + pending_service=mock_service, + ) + + assert exc_info.value.status_code == status.HTTP_429_TOO_MANY_REQUESTS + assert 'Maximum 10 messages' in exc_info.value.detail + + async def test_allows_up_to_10_messages(self): + """Test that 9 existing messages still allows adding one more.""" + # Arrange + conversation_id = f'task-{uuid4().hex}' + raw_content = [{'type': 'text', 'text': 'Test message'}] + expected_response = PendingMessageResponse( + id=str(uuid4()), + queued=True, + position=10, + ) + mock_service = _make_mock_service( + add_message_return=expected_response, + count_pending_messages_return=9, + ) + mock_request = _make_mock_request({'content': raw_content}) + + # Act + result = await queue_pending_message( + conversation_id=conversation_id, + request=mock_request, + pending_service=mock_service, + ) + + # Assert + assert result == expected_response + mock_service.add_message.assert_called_once() diff --git a/tests/unit/app_server/test_pending_message_service.py b/tests/unit/app_server/test_pending_message_service.py new file mode 100644 index 0000000000..869aae05d0 --- /dev/null +++ b/tests/unit/app_server/test_pending_message_service.py @@ -0,0 +1,309 @@ +"""Tests for SQLPendingMessageService. + +This module tests the SQL implementation of PendingMessageService, +covering message queuing, retrieval, counting, deletion, and +conversation_id updates using SQLite as a mock database. +""" + +from typing import AsyncGenerator +from uuid import uuid4 + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.pool import StaticPool + +from openhands.agent_server.models import TextContent +from openhands.app_server.pending_messages.pending_message_models import ( + PendingMessageResponse, +) +from openhands.app_server.pending_messages.pending_message_service import ( + SQLPendingMessageService, +) +from openhands.app_server.utils.sql_utils import Base + + +@pytest.fixture +async def async_engine(): + """Create an async SQLite engine for testing.""" + engine = create_async_engine( + 'sqlite+aiosqlite:///:memory:', + poolclass=StaticPool, + connect_args={'check_same_thread': False}, + echo=False, + ) + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + yield engine + + await engine.dispose() + + +@pytest.fixture +async def async_session(async_engine) -> AsyncGenerator[AsyncSession, None]: + """Create an async session for testing.""" + async_session_maker = async_sessionmaker( + async_engine, class_=AsyncSession, expire_on_commit=False + ) + + async with async_session_maker() as db_session: + yield db_session + + +@pytest.fixture +def service(async_session) -> SQLPendingMessageService: + """Create a SQLPendingMessageService instance for testing.""" + return SQLPendingMessageService(db_session=async_session) + + +@pytest.fixture +def sample_content() -> list[TextContent]: + """Create sample message content for testing.""" + return [TextContent(text='Hello, this is a test message')] + + +class TestSQLPendingMessageService: + """Test suite for SQLPendingMessageService.""" + + @pytest.mark.asyncio + async def test_add_message_creates_message_with_correct_data( + self, + service: SQLPendingMessageService, + sample_content: list[TextContent], + ): + """Test that add_message creates a message with the expected fields.""" + # Arrange + conversation_id = f'task-{uuid4().hex}' + + # Act + response = await service.add_message( + conversation_id=conversation_id, + content=sample_content, + role='user', + ) + + # Assert + assert isinstance(response, PendingMessageResponse) + assert response.queued is True + assert response.id is not None + + # Verify the message was stored correctly + messages = await service.get_pending_messages(conversation_id) + assert len(messages) == 1 + assert messages[0].conversation_id == conversation_id + assert len(messages[0].content) == 1 + assert isinstance(messages[0].content[0], TextContent) + assert messages[0].content[0].text == sample_content[0].text + assert messages[0].role == 'user' + assert messages[0].created_at is not None + + @pytest.mark.asyncio + async def test_add_message_returns_correct_queue_position( + self, + service: SQLPendingMessageService, + sample_content: list[TextContent], + ): + """Test that queue position increments correctly for each message.""" + # Arrange + conversation_id = f'task-{uuid4().hex}' + + # Act - Add three messages + response1 = await service.add_message(conversation_id, sample_content) + response2 = await service.add_message(conversation_id, sample_content) + response3 = await service.add_message(conversation_id, sample_content) + + # Assert + assert response1.position == 1 + assert response2.position == 2 + assert response3.position == 3 + + @pytest.mark.asyncio + async def test_get_pending_messages_returns_messages_ordered_by_created_at( + self, + service: SQLPendingMessageService, + ): + """Test that messages are returned in chronological order.""" + # Arrange + conversation_id = f'task-{uuid4().hex}' + contents = [ + [TextContent(text='First message')], + [TextContent(text='Second message')], + [TextContent(text='Third message')], + ] + + for content in contents: + await service.add_message(conversation_id, content) + + # Act + messages = await service.get_pending_messages(conversation_id) + + # Assert + assert len(messages) == 3 + assert messages[0].content[0].text == 'First message' + assert messages[1].content[0].text == 'Second message' + assert messages[2].content[0].text == 'Third message' + + @pytest.mark.asyncio + async def test_get_pending_messages_returns_empty_list_when_none_exist( + self, + service: SQLPendingMessageService, + ): + """Test that an empty list is returned for a conversation with no messages.""" + # Arrange + conversation_id = f'task-{uuid4().hex}' + + # Act + messages = await service.get_pending_messages(conversation_id) + + # Assert + assert messages == [] + + @pytest.mark.asyncio + async def test_count_pending_messages_returns_correct_count( + self, + service: SQLPendingMessageService, + sample_content: list[TextContent], + ): + """Test that count_pending_messages returns the correct number.""" + # Arrange + conversation_id = f'task-{uuid4().hex}' + other_conversation_id = f'task-{uuid4().hex}' + + # Add 3 messages to first conversation + for _ in range(3): + await service.add_message(conversation_id, sample_content) + + # Add 2 messages to second conversation + for _ in range(2): + await service.add_message(other_conversation_id, sample_content) + + # Act + count1 = await service.count_pending_messages(conversation_id) + count2 = await service.count_pending_messages(other_conversation_id) + count_empty = await service.count_pending_messages('nonexistent') + + # Assert + assert count1 == 3 + assert count2 == 2 + assert count_empty == 0 + + @pytest.mark.asyncio + async def test_delete_messages_for_conversation_removes_all_messages( + self, + service: SQLPendingMessageService, + sample_content: list[TextContent], + ): + """Test that delete_messages_for_conversation removes all messages and returns count.""" + # Arrange + conversation_id = f'task-{uuid4().hex}' + other_conversation_id = f'task-{uuid4().hex}' + + # Add messages to both conversations + for _ in range(3): + await service.add_message(conversation_id, sample_content) + await service.add_message(other_conversation_id, sample_content) + + # Act + deleted_count = await service.delete_messages_for_conversation(conversation_id) + + # Assert + assert deleted_count == 3 + assert await service.count_pending_messages(conversation_id) == 0 + # Other conversation should be unaffected + assert await service.count_pending_messages(other_conversation_id) == 1 + + @pytest.mark.asyncio + async def test_delete_messages_for_conversation_returns_zero_when_none_exist( + self, + service: SQLPendingMessageService, + ): + """Test that deleting from nonexistent conversation returns 0.""" + # Arrange + conversation_id = f'task-{uuid4().hex}' + + # Act + deleted_count = await service.delete_messages_for_conversation(conversation_id) + + # Assert + assert deleted_count == 0 + + @pytest.mark.asyncio + async def test_update_conversation_id_updates_all_matching_messages( + self, + service: SQLPendingMessageService, + sample_content: list[TextContent], + ): + """Test that update_conversation_id updates all messages with the old ID.""" + # Arrange + old_conversation_id = f'task-{uuid4().hex}' + new_conversation_id = str(uuid4()) + unrelated_conversation_id = f'task-{uuid4().hex}' + + # Add messages to old conversation + for _ in range(3): + await service.add_message(old_conversation_id, sample_content) + + # Add message to unrelated conversation + await service.add_message(unrelated_conversation_id, sample_content) + + # Act + updated_count = await service.update_conversation_id( + old_conversation_id, new_conversation_id + ) + + # Assert + assert updated_count == 3 + + # Verify old conversation has no messages + assert await service.count_pending_messages(old_conversation_id) == 0 + + # Verify new conversation has all messages + messages = await service.get_pending_messages(new_conversation_id) + assert len(messages) == 3 + for msg in messages: + assert msg.conversation_id == new_conversation_id + + # Verify unrelated conversation is unchanged + assert await service.count_pending_messages(unrelated_conversation_id) == 1 + + @pytest.mark.asyncio + async def test_update_conversation_id_returns_zero_when_no_match( + self, + service: SQLPendingMessageService, + ): + """Test that updating nonexistent conversation_id returns 0.""" + # Arrange + old_conversation_id = f'task-{uuid4().hex}' + new_conversation_id = str(uuid4()) + + # Act + updated_count = await service.update_conversation_id( + old_conversation_id, new_conversation_id + ) + + # Assert + assert updated_count == 0 + + @pytest.mark.asyncio + async def test_messages_are_isolated_between_conversations( + self, + service: SQLPendingMessageService, + ): + """Test that operations on one conversation don't affect others.""" + # Arrange + conv1 = f'task-{uuid4().hex}' + conv2 = f'task-{uuid4().hex}' + + await service.add_message(conv1, [TextContent(text='Conv1 msg')]) + await service.add_message(conv2, [TextContent(text='Conv2 msg')]) + + # Act + messages1 = await service.get_pending_messages(conv1) + messages2 = await service.get_pending_messages(conv2) + + # Assert + assert len(messages1) == 1 + assert len(messages2) == 1 + assert messages1[0].content[0].text == 'Conv1 msg' + assert messages2[0].content[0].text == 'Conv2 msg' diff --git a/tests/unit/server/data_models/test_conversation.py b/tests/unit/server/data_models/test_conversation.py index 7fa64ab12a..99dbdfaacc 100644 --- a/tests/unit/server/data_models/test_conversation.py +++ b/tests/unit/server/data_models/test_conversation.py @@ -2187,6 +2187,7 @@ async def test_delete_v1_conversation_with_sub_conversations(): sandbox_service=mock_sandbox_service, sandbox_spec_service=MagicMock(), jwt_service=MagicMock(), + pending_message_service=MagicMock(), sandbox_startup_timeout=120, sandbox_startup_poll_frequency=2, max_num_conversations_per_sandbox=20, @@ -2311,6 +2312,7 @@ async def test_delete_v1_conversation_with_no_sub_conversations(): sandbox_service=mock_sandbox_service, sandbox_spec_service=MagicMock(), jwt_service=MagicMock(), + pending_message_service=MagicMock(), sandbox_startup_timeout=120, sandbox_startup_poll_frequency=2, max_num_conversations_per_sandbox=20, @@ -2465,6 +2467,7 @@ async def test_delete_v1_conversation_sub_conversation_deletion_error(): sandbox_service=mock_sandbox_service, sandbox_spec_service=MagicMock(), jwt_service=MagicMock(), + pending_message_service=MagicMock(), sandbox_startup_timeout=120, sandbox_startup_poll_frequency=2, max_num_conversations_per_sandbox=20, From a0e777503ee846f3d7c601c9076ce9f80ca4582d Mon Sep 17 00:00:00 2001 From: HeyItsChloe <54480367+HeyItsChloe@users.noreply.github.com> Date: Mon, 16 Mar 2026 09:22:23 -0700 Subject: [PATCH 4/5] fix(frontend): prevent auto sandbox resume behavior (#13133) Co-authored-by: openhands --- .../hooks/use-sandbox-recovery.test.tsx | 577 ++++++++++++++++++ .../hooks/use-visibility-recovery.test.ts | 286 +++++++++ .../conversation-websocket-context.tsx | 22 +- .../contexts/websocket-provider-wrapper.tsx | 35 +- frontend/src/hooks/use-sandbox-recovery.ts | 138 +++++ frontend/src/hooks/use-visibility-change.ts | 64 ++ frontend/src/hooks/use-websocket-recovery.ts | 110 ---- frontend/src/routes/conversation.tsx | 63 +- 8 files changed, 1091 insertions(+), 204 deletions(-) create mode 100644 frontend/__tests__/hooks/use-sandbox-recovery.test.tsx create mode 100644 frontend/__tests__/hooks/use-visibility-recovery.test.ts create mode 100644 frontend/src/hooks/use-sandbox-recovery.ts create mode 100644 frontend/src/hooks/use-visibility-change.ts delete mode 100644 frontend/src/hooks/use-websocket-recovery.ts diff --git a/frontend/__tests__/hooks/use-sandbox-recovery.test.tsx b/frontend/__tests__/hooks/use-sandbox-recovery.test.tsx new file mode 100644 index 0000000000..638fe21788 --- /dev/null +++ b/frontend/__tests__/hooks/use-sandbox-recovery.test.tsx @@ -0,0 +1,577 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; +import { renderHook, act, waitFor } from "@testing-library/react"; +import { QueryClient, QueryClientProvider } from "@tanstack/react-query"; +import React from "react"; +import { useSandboxRecovery } from "#/hooks/use-sandbox-recovery"; +import { useUnifiedResumeConversationSandbox } from "#/hooks/mutation/use-unified-start-conversation"; +import * as customToastHandlers from "#/utils/custom-toast-handlers"; + +vi.mock("react-i18next", () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})); + +vi.mock("#/hooks/use-user-providers", () => ({ + useUserProviders: () => ({ + providers: [{ provider: "github", token: "test-token" }], + }), +})); + +vi.mock("#/utils/custom-toast-handlers"); +vi.mock("#/hooks/mutation/use-unified-start-conversation"); + +describe("useSandboxRecovery", () => { + let mockMutate: ReturnType; + + const createWrapper = () => { + const queryClient = new QueryClient({ + defaultOptions: { + queries: { retry: false }, + mutations: { retry: false }, + }, + }); + + return ({ children }: { children: React.ReactNode }) => ( + {children} + ); + }; + + beforeEach(() => { + vi.clearAllMocks(); + + mockMutate = vi.fn(); + + vi.mocked(useUnifiedResumeConversationSandbox).mockReturnValue({ + mutate: mockMutate, + mutateAsync: vi.fn(), + isPending: false, + isSuccess: false, + isError: false, + isIdle: true, + data: undefined, + error: null, + reset: vi.fn(), + status: "idle", + variables: undefined, + failureCount: 0, + failureReason: null, + submittedAt: 0, + context: undefined, + } as unknown as ReturnType); + + // Reset document.visibilityState + Object.defineProperty(document, "visibilityState", { + value: "visible", + writable: true, + }); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + describe("initial load recovery", () => { + it("should call resumeSandbox on initial load when conversation is STOPPED", () => { + renderHook( + () => + useSandboxRecovery({ + conversationId: "conv-123", + conversationStatus: "STOPPED", + }), + { wrapper: createWrapper() }, + ); + + expect(mockMutate).toHaveBeenCalledTimes(1); + expect(mockMutate).toHaveBeenCalledWith( + { + conversationId: "conv-123", + providers: [{ provider: "github", token: "test-token" }], + }, + expect.objectContaining({ + onSuccess: expect.any(Function), + onError: expect.any(Function), + }), + ); + }); + + it("should NOT call resumeSandbox on initial load when conversation is RUNNING", () => { + renderHook( + () => + useSandboxRecovery({ + conversationId: "conv-123", + conversationStatus: "RUNNING", + }), + { wrapper: createWrapper() }, + ); + + expect(mockMutate).not.toHaveBeenCalled(); + }); + + it("should NOT call resumeSandbox when conversationId is undefined", () => { + renderHook( + () => + useSandboxRecovery({ + conversationId: undefined, + conversationStatus: "STOPPED", + }), + { wrapper: createWrapper() }, + ); + + expect(mockMutate).not.toHaveBeenCalled(); + }); + + it("should NOT call resumeSandbox when conversationStatus is undefined", () => { + renderHook( + () => + useSandboxRecovery({ + conversationId: "conv-123", + conversationStatus: undefined, + }), + { wrapper: createWrapper() }, + ); + + expect(mockMutate).not.toHaveBeenCalled(); + }); + + it("should only call resumeSandbox once per conversation on initial load", () => { + const { rerender } = renderHook( + () => + useSandboxRecovery({ + conversationId: "conv-123", + conversationStatus: "STOPPED", + }), + { wrapper: createWrapper() }, + ); + + expect(mockMutate).toHaveBeenCalledTimes(1); + + // Rerender with same props - should not trigger again + rerender(); + + expect(mockMutate).toHaveBeenCalledTimes(1); + }); + + it("should call resumeSandbox for a new conversation after navigating", async () => { + const { rerender } = renderHook( + ({ conversationId }) => + useSandboxRecovery({ + conversationId, + conversationStatus: "STOPPED", + }), + { + wrapper: createWrapper(), + initialProps: { conversationId: "conv-123" }, + }, + ); + + expect(mockMutate).toHaveBeenCalledTimes(1); + expect(mockMutate).toHaveBeenLastCalledWith( + expect.objectContaining({ conversationId: "conv-123" }), + expect.any(Object), + ); + + // Navigate to a different conversation + rerender({ conversationId: "conv-456" }); + + await waitFor(() => { + expect(mockMutate).toHaveBeenCalledTimes(2); + }); + + expect(mockMutate).toHaveBeenLastCalledWith( + expect.objectContaining({ conversationId: "conv-456" }), + expect.any(Object), + ); + }); + }); + + describe("tab focus recovery", () => { + it("should call resumeSandbox when tab becomes visible and refetch returns STOPPED", async () => { + // Start with tab hidden + Object.defineProperty(document, "visibilityState", { + value: "hidden", + writable: true, + }); + + const mockRefetch = vi.fn().mockResolvedValue({ + data: { status: "STOPPED" }, + }); + + renderHook( + () => + useSandboxRecovery({ + conversationId: "conv-123", + conversationStatus: "RUNNING", // Cached status is RUNNING + refetchConversation: mockRefetch, + }), + { wrapper: createWrapper() }, + ); + + // No initial recovery for RUNNING + expect(mockMutate).not.toHaveBeenCalled(); + + // Simulate tab becoming visible + Object.defineProperty(document, "visibilityState", { + value: "visible", + writable: true, + }); + + await act(async () => { + document.dispatchEvent(new Event("visibilitychange")); + }); + + // Refetch should be called to get fresh status + expect(mockRefetch).toHaveBeenCalledTimes(1); + // Recovery should trigger because fresh status is STOPPED + expect(mockMutate).toHaveBeenCalledTimes(1); + }); + + it("should NOT call resumeSandbox when tab becomes visible and refetch returns RUNNING", async () => { + const mockRefetch = vi.fn().mockResolvedValue({ + data: { status: "RUNNING" }, + }); + + renderHook( + () => + useSandboxRecovery({ + conversationId: "conv-123", + conversationStatus: "RUNNING", + refetchConversation: mockRefetch, + }), + { wrapper: createWrapper() }, + ); + + // No initial recovery for RUNNING + expect(mockMutate).not.toHaveBeenCalled(); + + // Simulate tab becoming visible + await act(async () => { + document.dispatchEvent(new Event("visibilitychange")); + }); + + // Refetch was called but status is still RUNNING + expect(mockRefetch).toHaveBeenCalledTimes(1); + expect(mockMutate).not.toHaveBeenCalled(); + }); + + it("should NOT call resumeSandbox when tab becomes visible but refetchConversation is not provided", async () => { + renderHook( + () => + useSandboxRecovery({ + conversationId: "conv-123", + conversationStatus: "STOPPED", + // No refetchConversation provided + }), + { wrapper: createWrapper() }, + ); + + // Initial load triggers recovery + expect(mockMutate).toHaveBeenCalledTimes(1); + mockMutate.mockClear(); + + // Simulate tab becoming visible + await act(async () => { + document.dispatchEvent(new Event("visibilitychange")); + }); + + // No recovery on tab focus without refetchConversation + expect(mockMutate).not.toHaveBeenCalled(); + }); + + it("should NOT call resumeSandbox when tab becomes hidden", async () => { + const mockRefetch = vi.fn().mockResolvedValue({ + data: { status: "STOPPED" }, + }); + + renderHook( + () => + useSandboxRecovery({ + conversationId: "conv-123", + conversationStatus: "STOPPED", + refetchConversation: mockRefetch, + }), + { wrapper: createWrapper() }, + ); + + // Initial load triggers recovery + expect(mockMutate).toHaveBeenCalledTimes(1); + mockMutate.mockClear(); + + // Simulate tab becoming hidden + Object.defineProperty(document, "visibilityState", { + value: "hidden", + writable: true, + }); + + await act(async () => { + document.dispatchEvent(new Event("visibilitychange")); + }); + + // Refetch should NOT be called when tab is hidden + expect(mockRefetch).not.toHaveBeenCalled(); + expect(mockMutate).not.toHaveBeenCalled(); + }); + + it("should clean up visibility event listener on unmount", () => { + const addEventListenerSpy = vi.spyOn(document, "addEventListener"); + const removeEventListenerSpy = vi.spyOn(document, "removeEventListener"); + + const { unmount } = renderHook( + () => + useSandboxRecovery({ + conversationId: "conv-123", + conversationStatus: "STOPPED", + }), + { wrapper: createWrapper() }, + ); + + expect(addEventListenerSpy).toHaveBeenCalledWith( + "visibilitychange", + expect.any(Function), + ); + + unmount(); + + expect(removeEventListenerSpy).toHaveBeenCalledWith( + "visibilitychange", + expect.any(Function), + ); + }); + + it("should NOT call resumeSandbox when tab becomes visible while isPending is true", async () => { + vi.mocked(useUnifiedResumeConversationSandbox).mockReturnValue({ + mutate: mockMutate, + mutateAsync: vi.fn(), + isPending: true, + isSuccess: false, + isError: false, + isIdle: false, + data: undefined, + error: null, + reset: vi.fn(), + status: "pending", + variables: undefined, + failureCount: 0, + failureReason: null, + submittedAt: 0, + context: undefined, + } as unknown as ReturnType); + + const mockRefetch = vi.fn().mockResolvedValue({ + data: { status: "STOPPED" }, + }); + + renderHook( + () => + useSandboxRecovery({ + conversationId: "conv-123", + conversationStatus: "RUNNING", + refetchConversation: mockRefetch, + }), + { wrapper: createWrapper() }, + ); + + // Simulate tab becoming visible + await act(async () => { + document.dispatchEvent(new Event("visibilitychange")); + }); + + // Refetch will be called when isPending is true + expect(mockRefetch).toHaveBeenCalledTimes(1); + // resumeSandbox should NOT be called + expect(mockMutate).not.toHaveBeenCalled(); + }); + + it("should handle refetch errors gracefully without crashing", async () => { + const consoleErrorSpy = vi.spyOn(console, "error").mockImplementation(() => {}); + + const mockRefetch = vi.fn().mockRejectedValue(new Error("Network error")); + + renderHook( + () => + useSandboxRecovery({ + conversationId: "conv-123", + conversationStatus: "RUNNING", + refetchConversation: mockRefetch, + }), + { wrapper: createWrapper() }, + ); + + // Simulate tab becoming visible + await act(async () => { + document.dispatchEvent(new Event("visibilitychange")); + }); + + // Refetch was called + expect(mockRefetch).toHaveBeenCalledTimes(1); + // Error was logged + expect(consoleErrorSpy).toHaveBeenCalledWith( + "Failed to refetch conversation on visibility change:", + expect.any(Error), + ); + // No recovery attempt was made (due to error) + expect(mockMutate).not.toHaveBeenCalled(); + + consoleErrorSpy.mockRestore(); + }); + }); + + describe("recovery callbacks", () => { + it("should return isResuming=false when no recovery is in progress", () => { + const { result } = renderHook( + () => + useSandboxRecovery({ + conversationId: "conv-123", + conversationStatus: "RUNNING", + }), + { wrapper: createWrapper() }, + ); + + expect(result.current.isResuming).toBe(false); + }); + + it("should return isResuming=true when mutation is pending", () => { + vi.mocked(useUnifiedResumeConversationSandbox).mockReturnValue({ + mutate: mockMutate, + mutateAsync: vi.fn(), + isPending: true, + isSuccess: false, + isError: false, + isIdle: false, + data: undefined, + error: null, + reset: vi.fn(), + status: "pending", + variables: undefined, + failureCount: 0, + failureReason: null, + submittedAt: 0, + context: undefined, + } as unknown as ReturnType); + + const { result } = renderHook( + () => + useSandboxRecovery({ + conversationId: "conv-123", + conversationStatus: "STOPPED", + }), + { wrapper: createWrapper() }, + ); + + expect(result.current.isResuming).toBe(true); + }); + + it("should call onSuccess callback when recovery succeeds", () => { + const onSuccess = vi.fn(); + + renderHook( + () => + useSandboxRecovery({ + conversationId: "conv-123", + conversationStatus: "STOPPED", + onSuccess, + }), + { wrapper: createWrapper() }, + ); + + // Get the onSuccess callback passed to mutate + const mutateCall = mockMutate.mock.calls[0]; + const options = mutateCall[1]; + + // Simulate successful mutation + act(() => { + options.onSuccess(); + }); + + expect(onSuccess).toHaveBeenCalledTimes(1); + }); + + it("should call onError callback and display toast when recovery fails", () => { + const onError = vi.fn(); + const testError = new Error("Resume failed"); + + renderHook( + () => + useSandboxRecovery({ + conversationId: "conv-123", + conversationStatus: "STOPPED", + onError, + }), + { wrapper: createWrapper() }, + ); + + // Get the onError callback passed to mutate + const mutateCall = mockMutate.mock.calls[0]; + const options = mutateCall[1]; + + // Simulate failed mutation + act(() => { + options.onError(testError); + }); + + expect(onError).toHaveBeenCalledTimes(1); + expect(onError).toHaveBeenCalledWith(testError); + expect(vi.mocked(customToastHandlers.displayErrorToast)).toHaveBeenCalled(); + }); + + it("should NOT call resumeSandbox when isPending is true", () => { + vi.mocked(useUnifiedResumeConversationSandbox).mockReturnValue({ + mutate: mockMutate, + mutateAsync: vi.fn(), + isPending: true, + isSuccess: false, + isError: false, + isIdle: false, + data: undefined, + error: null, + reset: vi.fn(), + status: "pending", + variables: undefined, + failureCount: 0, + failureReason: null, + submittedAt: 0, + context: undefined, + } as unknown as ReturnType); + + renderHook( + () => + useSandboxRecovery({ + conversationId: "conv-123", + conversationStatus: "STOPPED", + }), + { wrapper: createWrapper() }, + ); + + // Should not call mutate because isPending is true + expect(mockMutate).not.toHaveBeenCalled(); + }); + }); + + describe("WebSocket disconnect (negative test)", () => { + it("should NOT have any mechanism to auto-resume on WebSocket disconnect", () => { + // This test documents the intended behavior: the hook does NOT + // listen for WebSocket disconnects. Recovery only happens on: + // 1. Initial page load (STOPPED status) + // 2. Tab focus (visibilitychange event) + // + // There is intentionally NO onDisconnect handler or WebSocket listener. + + const { result } = renderHook( + () => + useSandboxRecovery({ + conversationId: "conv-123", + conversationStatus: "RUNNING", + }), + { wrapper: createWrapper() }, + ); + + // The hook should only expose isResuming - no disconnect-related functionality + expect(result.current).toEqual({ + isResuming: expect.any(Boolean), + }); + + // No calls should have been made for RUNNING status + expect(mockMutate).not.toHaveBeenCalled(); + }); + }); +}); diff --git a/frontend/__tests__/hooks/use-visibility-recovery.test.ts b/frontend/__tests__/hooks/use-visibility-recovery.test.ts new file mode 100644 index 0000000000..301d910fa2 --- /dev/null +++ b/frontend/__tests__/hooks/use-visibility-recovery.test.ts @@ -0,0 +1,286 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; +import { renderHook, act } from "@testing-library/react"; +import { useVisibilityChange } from "#/hooks/use-visibility-change"; + +describe("useVisibilityChange", () => { + beforeEach(() => { + // Reset document.visibilityState to visible + Object.defineProperty(document, "visibilityState", { + value: "visible", + writable: true, + configurable: true, + }); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + describe("initial state", () => { + it("should return isVisible=true when document is visible", () => { + Object.defineProperty(document, "visibilityState", { + value: "visible", + writable: true, + }); + + const { result } = renderHook(() => useVisibilityChange()); + + expect(result.current.isVisible).toBe(true); + }); + + it("should return isVisible=false when document is hidden", () => { + Object.defineProperty(document, "visibilityState", { + value: "hidden", + writable: true, + }); + + const { result } = renderHook(() => useVisibilityChange()); + + expect(result.current.isVisible).toBe(false); + }); + }); + + describe("visibility change events", () => { + it("should update isVisible when visibility changes to hidden", () => { + const { result } = renderHook(() => useVisibilityChange()); + + expect(result.current.isVisible).toBe(true); + + // Simulate tab becoming hidden + Object.defineProperty(document, "visibilityState", { + value: "hidden", + writable: true, + }); + + act(() => { + document.dispatchEvent(new Event("visibilitychange")); + }); + + expect(result.current.isVisible).toBe(false); + }); + + it("should update isVisible when visibility changes to visible", () => { + Object.defineProperty(document, "visibilityState", { + value: "hidden", + writable: true, + }); + + const { result } = renderHook(() => useVisibilityChange()); + + expect(result.current.isVisible).toBe(false); + + // Simulate tab becoming visible + Object.defineProperty(document, "visibilityState", { + value: "visible", + writable: true, + }); + + act(() => { + document.dispatchEvent(new Event("visibilitychange")); + }); + + expect(result.current.isVisible).toBe(true); + }); + }); + + describe("callbacks", () => { + it("should call onVisibilityChange with the new state", () => { + const onVisibilityChange = vi.fn(); + + renderHook(() => useVisibilityChange({ onVisibilityChange })); + + // Simulate tab becoming hidden + Object.defineProperty(document, "visibilityState", { + value: "hidden", + writable: true, + }); + + act(() => { + document.dispatchEvent(new Event("visibilitychange")); + }); + + expect(onVisibilityChange).toHaveBeenCalledWith("hidden"); + + // Simulate tab becoming visible + Object.defineProperty(document, "visibilityState", { + value: "visible", + writable: true, + }); + + act(() => { + document.dispatchEvent(new Event("visibilitychange")); + }); + + expect(onVisibilityChange).toHaveBeenCalledWith("visible"); + }); + + it("should call onVisible only when tab becomes visible", () => { + const onVisible = vi.fn(); + const onHidden = vi.fn(); + + renderHook(() => useVisibilityChange({ onVisible, onHidden })); + + // Simulate tab becoming hidden + Object.defineProperty(document, "visibilityState", { + value: "hidden", + writable: true, + }); + + act(() => { + document.dispatchEvent(new Event("visibilitychange")); + }); + + expect(onVisible).not.toHaveBeenCalled(); + expect(onHidden).toHaveBeenCalledTimes(1); + + // Simulate tab becoming visible + Object.defineProperty(document, "visibilityState", { + value: "visible", + writable: true, + }); + + act(() => { + document.dispatchEvent(new Event("visibilitychange")); + }); + + expect(onVisible).toHaveBeenCalledTimes(1); + expect(onHidden).toHaveBeenCalledTimes(1); + }); + + it("should call onHidden only when tab becomes hidden", () => { + const onHidden = vi.fn(); + + renderHook(() => useVisibilityChange({ onHidden })); + + // Simulate tab becoming hidden + Object.defineProperty(document, "visibilityState", { + value: "hidden", + writable: true, + }); + + act(() => { + document.dispatchEvent(new Event("visibilitychange")); + }); + + expect(onHidden).toHaveBeenCalledTimes(1); + + // Simulate tab becoming visible (should not call onHidden) + Object.defineProperty(document, "visibilityState", { + value: "visible", + writable: true, + }); + + act(() => { + document.dispatchEvent(new Event("visibilitychange")); + }); + + expect(onHidden).toHaveBeenCalledTimes(1); + }); + }); + + describe("enabled option", () => { + it("should not listen for events when enabled=false", () => { + const onVisible = vi.fn(); + + renderHook(() => useVisibilityChange({ onVisible, enabled: false })); + + // Simulate tab becoming visible + Object.defineProperty(document, "visibilityState", { + value: "visible", + writable: true, + }); + + act(() => { + document.dispatchEvent(new Event("visibilitychange")); + }); + + expect(onVisible).not.toHaveBeenCalled(); + }); + + it("should start listening when enabled changes from false to true", () => { + const onVisible = vi.fn(); + + const { rerender } = renderHook( + ({ enabled }) => useVisibilityChange({ onVisible, enabled }), + { initialProps: { enabled: false } }, + ); + + // Simulate event while disabled + act(() => { + document.dispatchEvent(new Event("visibilitychange")); + }); + + expect(onVisible).not.toHaveBeenCalled(); + + // Enable the hook + rerender({ enabled: true }); + + // Now events should be captured + act(() => { + document.dispatchEvent(new Event("visibilitychange")); + }); + + expect(onVisible).toHaveBeenCalledTimes(1); + }); + }); + + describe("cleanup", () => { + it("should remove event listener on unmount", () => { + const addEventListenerSpy = vi.spyOn(document, "addEventListener"); + const removeEventListenerSpy = vi.spyOn(document, "removeEventListener"); + + const { unmount } = renderHook(() => useVisibilityChange()); + + expect(addEventListenerSpy).toHaveBeenCalledWith( + "visibilitychange", + expect.any(Function), + ); + + unmount(); + + expect(removeEventListenerSpy).toHaveBeenCalledWith( + "visibilitychange", + expect.any(Function), + ); + }); + + it("should remove event listener when enabled changes to false", () => { + const removeEventListenerSpy = vi.spyOn(document, "removeEventListener"); + + const { rerender } = renderHook( + ({ enabled }) => useVisibilityChange({ enabled }), + { initialProps: { enabled: true } }, + ); + + rerender({ enabled: false }); + + expect(removeEventListenerSpy).toHaveBeenCalledWith( + "visibilitychange", + expect.any(Function), + ); + }); + }); + + describe("callback stability", () => { + it("should handle callback updates without missing events", () => { + const onVisible1 = vi.fn(); + const onVisible2 = vi.fn(); + + const { rerender } = renderHook( + ({ onVisible }) => useVisibilityChange({ onVisible }), + { initialProps: { onVisible: onVisible1 } }, + ); + + // Update callback + rerender({ onVisible: onVisible2 }); + + // Simulate visibility change + act(() => { + document.dispatchEvent(new Event("visibilitychange")); + }); + + expect(onVisible1).not.toHaveBeenCalled(); + expect(onVisible2).toHaveBeenCalledTimes(1); + }); + }); +}); diff --git a/frontend/src/contexts/conversation-websocket-context.tsx b/frontend/src/contexts/conversation-websocket-context.tsx index 86863734b9..33c7169a77 100644 --- a/frontend/src/contexts/conversation-websocket-context.tsx +++ b/frontend/src/contexts/conversation-websocket-context.tsx @@ -78,7 +78,6 @@ export function ConversationWebSocketProvider({ sessionApiKey, subConversations, subConversationIds, - onDisconnect, }: { children: React.ReactNode; conversationId?: string; @@ -86,7 +85,6 @@ export function ConversationWebSocketProvider({ sessionApiKey?: string | null; subConversations?: V1AppConversation[]; subConversationIds?: string[]; - onDisconnect?: () => void; }) { // Separate connection state tracking for each WebSocket const [mainConnectionState, setMainConnectionState] = @@ -714,13 +712,10 @@ export function ConversationWebSocketProvider({ } } }, - onClose: (event: CloseEvent) => { + onClose: () => { setMainConnectionState("CLOSED"); - // Trigger silent recovery on unexpected disconnect - // Do NOT show error message - recovery happens automatically - if (event.code !== 1000 && hasConnectedRefMain.current) { - onDisconnect?.(); - } + // Recovery is handled by useSandboxRecovery on tab focus/page refresh + // No error message needed - silent recovery provides better UX }, onError: () => { setMainConnectionState("CLOSED"); @@ -738,7 +733,6 @@ export function ConversationWebSocketProvider({ sessionApiKey, conversationId, conversationUrl, - onDisconnect, ]); // Separate WebSocket options for planning agent connection @@ -785,13 +779,10 @@ export function ConversationWebSocketProvider({ } } }, - onClose: (event: CloseEvent) => { + onClose: () => { setPlanningConnectionState("CLOSED"); - // Trigger silent recovery on unexpected disconnect - // Do NOT show error message - recovery happens automatically - if (event.code !== 1000 && hasConnectedRefPlanning.current) { - onDisconnect?.(); - } + // Recovery is handled by useSandboxRecovery on tab focus/page refresh + // No error message needed - silent recovery provides better UX }, onError: () => { setPlanningConnectionState("CLOSED"); @@ -808,7 +799,6 @@ export function ConversationWebSocketProvider({ removeErrorMessage, sessionApiKey, subConversations, - onDisconnect, ]); // Only attempt WebSocket connection when we have a valid URL diff --git a/frontend/src/contexts/websocket-provider-wrapper.tsx b/frontend/src/contexts/websocket-provider-wrapper.tsx index 3aa21e4113..e484bd0571 100644 --- a/frontend/src/contexts/websocket-provider-wrapper.tsx +++ b/frontend/src/contexts/websocket-provider-wrapper.tsx @@ -3,7 +3,8 @@ import { WsClientProvider } from "#/context/ws-client-provider"; import { ConversationWebSocketProvider } from "#/contexts/conversation-websocket-context"; import { useActiveConversation } from "#/hooks/query/use-active-conversation"; import { useSubConversations } from "#/hooks/query/use-sub-conversations"; -import { useWebSocketRecovery } from "#/hooks/use-websocket-recovery"; +import { useSandboxRecovery } from "#/hooks/use-sandbox-recovery"; +import { isTaskConversationId } from "#/utils/conversation-local-storage"; interface WebSocketProviderWrapperProps { children: React.ReactNode; @@ -18,18 +19,6 @@ interface WebSocketProviderWrapperProps { * @param version - 0 for old WsClientProvider, 1 for new ConversationWebSocketProvider * @param conversationId - The conversation ID to pass to the provider * @param children - The child components to wrap - * - * @example - * // Use the old v0 provider - * - * - * - * - * @example - * // Use the new v1 provider - * - * - * */ export function WebSocketProviderWrapper({ children, @@ -37,7 +26,11 @@ export function WebSocketProviderWrapper({ version, }: WebSocketProviderWrapperProps) { // Get conversation data for V1 provider - const { data: conversation } = useActiveConversation(); + const { + data: conversation, + refetch: refetchConversation, + isFetched, + } = useActiveConversation(); // Get sub-conversation data for V1 provider const { data: subConversations } = useSubConversations( conversation?.sub_conversation_ids ?? [], @@ -48,9 +41,15 @@ export function WebSocketProviderWrapper({ (subConversation) => subConversation !== null, ); - // Silent recovery for V1 WebSocket disconnections - const { reconnectKey, handleDisconnect } = - useWebSocketRecovery(conversationId); + const isConversationReady = + !isTaskConversationId(conversationId) && isFetched && !!conversation; + // Recovery for V1 conversations - handles page refresh and tab focus + // Does NOT resume on WebSocket disconnect (server pauses after 20 min inactivity) + useSandboxRecovery({ + conversationId, + conversationStatus: conversation?.status, + refetchConversation: isConversationReady ? refetchConversation : undefined, + }); if (version === 0) { return ( @@ -63,13 +62,11 @@ export function WebSocketProviderWrapper({ if (version === 1) { return ( {children} diff --git a/frontend/src/hooks/use-sandbox-recovery.ts b/frontend/src/hooks/use-sandbox-recovery.ts new file mode 100644 index 0000000000..78804f6706 --- /dev/null +++ b/frontend/src/hooks/use-sandbox-recovery.ts @@ -0,0 +1,138 @@ +import React from "react"; +import { useTranslation } from "react-i18next"; +import { useUnifiedResumeConversationSandbox } from "./mutation/use-unified-start-conversation"; +import { useUserProviders } from "./use-user-providers"; +import { useVisibilityChange } from "./use-visibility-change"; +import { displayErrorToast } from "#/utils/custom-toast-handlers"; +import { I18nKey } from "#/i18n/declaration"; +import type { ConversationStatus } from "#/types/conversation-status"; +import type { Conversation } from "#/api/open-hands.types"; + +interface UseSandboxRecoveryOptions { + conversationId: string | undefined; + conversationStatus: ConversationStatus | undefined; + /** Function to refetch the conversation data - used to get fresh status on tab focus */ + refetchConversation?: () => Promise<{ + data: Conversation | null | undefined; + }>; + onSuccess?: () => void; + onError?: (error: Error) => void; +} + +/** + * Hook that handles sandbox recovery based on user intent. + * + * Recovery triggers: + * - Page refresh: Resumes the sandbox on initial load if it was paused/stopped + * - Tab gains focus: Resumes the sandbox if it was paused/stopped + * + * What does NOT trigger recovery: + * - WebSocket disconnect: Does NOT automatically resume the sandbox + * (The server pauses sandboxes after 20 minutes of inactivity, + * and sandboxes should only be resumed when the user explicitly shows intent) + * + * @param options.conversationId - The conversation ID to recover + * @param options.conversationStatus - The current conversation status + * @param options.refetchConversation - Function to refetch conversation data on tab focus + * @param options.onSuccess - Callback when recovery succeeds + * @param options.onError - Callback when recovery fails + * @returns isResuming - Whether a recovery is in progress + */ +export function useSandboxRecovery({ + conversationId, + conversationStatus, + refetchConversation, + onSuccess, + onError, +}: UseSandboxRecoveryOptions) { + const { t } = useTranslation(); + const { providers } = useUserProviders(); + const { mutate: resumeSandbox, isPending: isResuming } = + useUnifiedResumeConversationSandbox(); + + // Track which conversation ID we've already processed for initial load recovery + const processedConversationIdRef = React.useRef(null); + + const attemptRecovery = React.useCallback( + (statusOverride?: ConversationStatus) => { + const status = statusOverride ?? conversationStatus; + /** + * Only recover if sandbox is paused (status === STOPPED) and not already resuming + * + * Note: ConversationStatus uses different terminology than SandboxStatus: + * - SandboxStatus.PAUSED → ConversationStatus.STOPPED : the runtime is not running but may be restarted + * - SandboxStatus.MISSING → ConversationStatus.ARCHIVED : the runtime is not running and will not restart due to deleted files. + */ + if (!conversationId || status !== "STOPPED" || isResuming) { + return; + } + + resumeSandbox( + { conversationId, providers }, + { + onSuccess: () => { + onSuccess?.(); + }, + onError: (error) => { + displayErrorToast( + t(I18nKey.CONVERSATION$FAILED_TO_START_WITH_ERROR, { + error: error.message, + }), + ); + onError?.(error); + }, + }, + ); + }, + [ + conversationId, + conversationStatus, + isResuming, + providers, + resumeSandbox, + onSuccess, + onError, + t, + ], + ); + + // Handle page refresh (initial load) and conversation navigation + React.useEffect(() => { + if (!conversationId || !conversationStatus) return; + + // Only attempt recovery once per conversation (handles both initial load and navigation) + if (processedConversationIdRef.current === conversationId) return; + + processedConversationIdRef.current = conversationId; + + if (conversationStatus === "STOPPED") { + attemptRecovery(); + } + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [conversationId, conversationStatus]); + + const handleVisible = React.useCallback(async () => { + // Skip if no conversation or refetch function + if (!conversationId || !refetchConversation) return; + + try { + // Refetch to get fresh status - cached status may be stale if sandbox was paused while tab was inactive + const { data } = await refetchConversation(); + attemptRecovery(data?.status); + } catch (error) { + // eslint-disable-next-line no-console + console.error( + "Failed to refetch conversation on visibility change:", + error, + ); + } + }, [conversationId, refetchConversation, isResuming, attemptRecovery]); + + // Handle tab focus (visibility change) - refetch conversation status and resume if needed + useVisibilityChange({ + enabled: !!conversationId, + onVisible: handleVisible, + }); + + return { isResuming }; +} diff --git a/frontend/src/hooks/use-visibility-change.ts b/frontend/src/hooks/use-visibility-change.ts new file mode 100644 index 0000000000..1ee929cf6c --- /dev/null +++ b/frontend/src/hooks/use-visibility-change.ts @@ -0,0 +1,64 @@ +import React from "react"; + +type VisibilityState = "visible" | "hidden"; + +interface UseVisibilityChangeOptions { + /** Callback fired when visibility changes to the specified state */ + onVisibilityChange?: (state: VisibilityState) => void; + /** Callback fired only when tab becomes visible */ + onVisible?: () => void; + /** Callback fired only when tab becomes hidden */ + onHidden?: () => void; + /** Whether to listen for visibility changes (default: true) */ + enabled?: boolean; +} + +/** + * Hook that listens for browser tab visibility changes. + * + * Useful for: + * - Resuming operations when user returns to the tab + * - Pausing expensive operations when tab is hidden + * - Tracking user engagement + * + * @param options.onVisibilityChange - Callback with the new visibility state + * @param options.onVisible - Callback fired only when tab becomes visible + * @param options.onHidden - Callback fired only when tab becomes hidden + * @param options.enabled - Whether to listen for changes (default: true) + * @returns isVisible - Current visibility state of the tab + */ +export function useVisibilityChange({ + onVisibilityChange, + onVisible, + onHidden, + enabled = true, +}: UseVisibilityChangeOptions = {}) { + const [isVisible, setIsVisible] = React.useState( + () => document.visibilityState === "visible", + ); + + React.useEffect(() => { + if (!enabled) return undefined; + + const handleVisibilityChange = () => { + const state = document.visibilityState as VisibilityState; + setIsVisible(state === "visible"); + + onVisibilityChange?.(state); + + if (state === "visible") { + onVisible?.(); + } else { + onHidden?.(); + } + }; + + document.addEventListener("visibilitychange", handleVisibilityChange); + + return () => { + document.removeEventListener("visibilitychange", handleVisibilityChange); + }; + }, [enabled, onVisibilityChange, onVisible, onHidden]); + + return { isVisible }; +} diff --git a/frontend/src/hooks/use-websocket-recovery.ts b/frontend/src/hooks/use-websocket-recovery.ts deleted file mode 100644 index d15358d12e..0000000000 --- a/frontend/src/hooks/use-websocket-recovery.ts +++ /dev/null @@ -1,110 +0,0 @@ -import React from "react"; -import { useQueryClient } from "@tanstack/react-query"; -import { useUnifiedResumeConversationSandbox } from "#/hooks/mutation/use-unified-start-conversation"; -import { useUserProviders } from "#/hooks/use-user-providers"; -import { useErrorMessageStore } from "#/stores/error-message-store"; -import { I18nKey } from "#/i18n/declaration"; - -const MAX_RECOVERY_ATTEMPTS = 3; -const RECOVERY_COOLDOWN_MS = 5000; -const RECOVERY_SETTLED_DELAY_MS = 2000; - -/** - * Hook that handles silent WebSocket recovery by resuming the sandbox - * when a WebSocket disconnection is detected. - * - * @param conversationId - The conversation ID to recover - * @returns reconnectKey - Key to force provider remount (resets connection state) - * @returns handleDisconnect - Callback to trigger recovery on WebSocket disconnect - */ -export function useWebSocketRecovery(conversationId: string) { - // Recovery state (refs to avoid re-renders) - const recoveryAttemptsRef = React.useRef(0); - const recoveryInProgressRef = React.useRef(false); - const lastRecoveryAttemptRef = React.useRef(null); - - // Key to force remount of provider after recovery (resets connection state to "CONNECTING") - const [reconnectKey, setReconnectKey] = React.useState(0); - - const queryClient = useQueryClient(); - const { mutate: resumeConversation } = useUnifiedResumeConversationSandbox(); - const { providers } = useUserProviders(); - const setErrorMessage = useErrorMessageStore( - (state) => state.setErrorMessage, - ); - - // Reset recovery state when conversation changes - React.useEffect(() => { - recoveryAttemptsRef.current = 0; - recoveryInProgressRef.current = false; - lastRecoveryAttemptRef.current = null; - }, [conversationId]); - - // Silent recovery callback - resumes sandbox when WebSocket disconnects - const handleDisconnect = React.useCallback(() => { - // Prevent concurrent recovery attempts - if (recoveryInProgressRef.current) return; - - // Check cooldown - const now = Date.now(); - if ( - lastRecoveryAttemptRef.current && - now - lastRecoveryAttemptRef.current < RECOVERY_COOLDOWN_MS - ) { - return; - } - - // Check max attempts - notify user when recovery is exhausted - if (recoveryAttemptsRef.current >= MAX_RECOVERY_ATTEMPTS) { - setErrorMessage(I18nKey.STATUS$CONNECTION_LOST); - return; - } - - // Start silent recovery - recoveryInProgressRef.current = true; - lastRecoveryAttemptRef.current = now; - recoveryAttemptsRef.current += 1; - - resumeConversation( - { conversationId, providers }, - { - onSuccess: async () => { - // Invalidate and wait for refetch to complete before remounting - // This ensures the provider remounts with fresh data (url: null during startup) - await queryClient.invalidateQueries({ - queryKey: ["user", "conversation", conversationId], - }); - - // Force remount to reset connection state to "CONNECTING" - setReconnectKey((k) => k + 1); - - // Reset recovery state on success - recoveryAttemptsRef.current = 0; - recoveryInProgressRef.current = false; - lastRecoveryAttemptRef.current = null; - }, - onError: () => { - // If this was the last attempt, show error to user - if (recoveryAttemptsRef.current >= MAX_RECOVERY_ATTEMPTS) { - setErrorMessage(I18nKey.STATUS$CONNECTION_LOST); - } - // recoveryInProgressRef will be reset by onSettled - }, - onSettled: () => { - // Allow next attempt after a delay (covers both success and error) - setTimeout(() => { - recoveryInProgressRef.current = false; - }, RECOVERY_SETTLED_DELAY_MS); - }, - }, - ); - }, [ - conversationId, - providers, - resumeConversation, - queryClient, - setErrorMessage, - ]); - - return { reconnectKey, handleDisconnect }; -} diff --git a/frontend/src/routes/conversation.tsx b/frontend/src/routes/conversation.tsx index 3063c2d89c..c12ce948c7 100644 --- a/frontend/src/routes/conversation.tsx +++ b/frontend/src/routes/conversation.tsx @@ -18,7 +18,6 @@ import { useTaskPolling } from "#/hooks/query/use-task-polling"; import { displayErrorToast } from "#/utils/custom-toast-handlers"; import { useIsAuthed } from "#/hooks/query/use-is-authed"; import { ConversationSubscriptionsProvider } from "#/context/conversation-subscriptions-provider"; -import { useUserProviders } from "#/hooks/use-user-providers"; import { ConversationMain } from "#/components/features/conversation/conversation-main/conversation-main"; import { ConversationNameWithStatus } from "#/components/features/conversation/conversation-name-with-status"; @@ -26,7 +25,6 @@ import { ConversationNameWithStatus } from "#/components/features/conversation/c import { ConversationTabs } from "#/components/features/conversation/conversation-tabs/conversation-tabs"; import { WebSocketProviderWrapper } from "#/contexts/websocket-provider-wrapper"; import { useErrorMessageStore } from "#/stores/error-message-store"; -import { useUnifiedResumeConversationSandbox } from "#/hooks/mutation/use-unified-start-conversation"; import { I18nKey } from "#/i18n/declaration"; import { useEventStore } from "#/stores/use-event-store"; @@ -39,11 +37,8 @@ function AppContent() { // Handle both task IDs (task-{uuid}) and regular conversation IDs const { isTask, taskStatus, taskDetail } = useTaskPolling(); - const { data: conversation, isFetched, refetch } = useActiveConversation(); - const { mutate: startConversation, isPending: isStarting } = - useUnifiedResumeConversationSandbox(); + const { data: conversation, isFetched } = useActiveConversation(); const { data: isAuthed } = useIsAuthed(); - const { providers } = useUserProviders(); const { resetConversationState } = useConversationStore(); const navigate = useNavigate(); const clearTerminal = useCommandStore((state) => state.clearTerminal); @@ -54,9 +49,6 @@ function AppContent() { (state) => state.removeErrorMessage, ); - // Track which conversation ID we've auto-started to prevent auto-restart after manual stop - const processedConversationId = React.useRef(null); - // Fetch batch feedback data when conversation is loaded useBatchFeedback(); @@ -67,12 +59,6 @@ function AppContent() { setCurrentAgentState(AgentState.LOADING); removeErrorMessage(); clearEvents(); - - // Reset tracking ONLY if we're navigating to a DIFFERENT conversation - // Don't reset on StrictMode remounts (conversationId is the same) - if (processedConversationId.current !== conversationId) { - processedConversationId.current = null; - } }, [ conversationId, clearTerminal, @@ -91,7 +77,8 @@ function AppContent() { } }, [isTask, taskStatus, taskDetail, t]); - // 3. Auto-start Effect - handles conversation not found and auto-starting STOPPED conversations + // 3. Handle conversation not found + // NOTE: Resuming STOPPED conversations is handled by useSandboxRecovery in WebSocketProviderWrapper React.useEffect(() => { // Wait for data to be fetched if (!isFetched || !isAuthed) return; @@ -100,50 +87,8 @@ function AppContent() { if (!conversation) { displayErrorToast(t(I18nKey.CONVERSATION$NOT_EXIST_OR_NO_PERMISSION)); navigate("/"); - return; } - - const currentConversationId = conversation.conversation_id; - const currentStatus = conversation.status; - - // Skip if we've already processed this conversation - if (processedConversationId.current === currentConversationId) { - return; - } - - // Mark as processed immediately to prevent duplicate calls - processedConversationId.current = currentConversationId; - - // Auto-start STOPPED conversations on initial load only - if (currentStatus === "STOPPED" && !isStarting) { - startConversation( - { conversationId: currentConversationId, providers }, - { - onError: (error) => { - displayErrorToast( - t(I18nKey.CONVERSATION$FAILED_TO_START_WITH_ERROR, { - error: error.message, - }), - ); - refetch(); - }, - }, - ); - } - // NOTE: conversation?.status is intentionally NOT in dependencies - // We only want to run when conversation ID changes, not when status changes - // This prevents duplicate calls when stale cache data is replaced with fresh data - }, [ - conversation?.conversation_id, - isFetched, - isAuthed, - isStarting, - providers, - startConversation, - navigate, - refetch, - t, - ]); + }, [conversation, isFetched, isAuthed, navigate, t]); const isV0Conversation = conversation?.conversation_version === "V0"; From 00daaa41d327435f61a7e93e70debfa8d75f8650 Mon Sep 17 00:00:00 2001 From: Xingyao Wang Date: Tue, 17 Mar 2026 00:55:23 +0800 Subject: [PATCH 5/5] feat: Load workspace hooks for V1 conversations and add hooks viewer UI (#12773) Co-authored-by: openhands Co-authored-by: enyst Co-authored-by: Alona King --- .../context-menu/tools-context-menu.test.tsx | 1 + .../conversation-panel/hooks-modal.test.tsx | 207 +++++++++ .../v1-conversation-service.api.ts | 13 + .../v1-conversation-service.types.ts | 21 + .../hook-execution-event-message.tsx | 1 + .../chat/event-message-components/index.ts | 1 + .../features/controls/tools-context-menu.tsx | 19 + .../components/features/controls/tools.tsx | 12 + .../conversation-panel/hook-event-item.tsx | 75 ++++ .../hook-matcher-content.tsx | 61 +++ .../conversation-panel/hooks-empty-state.tsx | 21 + .../hooks-loading-state.tsx | 7 + .../conversation-panel/hooks-modal-header.tsx | 45 ++ .../conversation-panel/hooks-modal.tsx | 102 +++++ .../conversation-name-context-menu.tsx | 18 +- .../conversation/conversation-name.tsx | 11 + .../shared/hook-execution-event-message.tsx | 152 +++++++ .../should-render-event.ts | 6 + .../hook-execution-event-message.tsx | 1 + .../v1/chat/event-message-components/index.ts | 1 + .../src/components/v1/chat/event-message.tsx | 7 + .../src/hooks/query/use-conversation-hooks.ts | 36 ++ .../use-conversation-name-context-menu.ts | 16 + frontend/src/i18n/declaration.ts | 24 ++ frontend/src/i18n/translation.json | 392 +++++++++++++++++- frontend/src/types/core/base.ts | 2 +- frontend/src/types/v1/core/base/common.ts | 2 +- .../v1/core/events/hook-execution-event.ts | 100 +++++ frontend/src/types/v1/core/events/index.ts | 1 + frontend/src/types/v1/core/openhands-event.ts | 3 + frontend/src/types/v1/type-guards.ts | 12 +- .../app_conversation_models.py | 29 ++ .../app_conversation_router.py | 343 +++++++++++---- .../app_conversation/hook_loader.py | 148 +++++++ .../live_status_app_conversation_service.py | 74 ++++ .../test_app_conversation_hooks_endpoint.py | 293 +++++++++++++ ...st_live_status_app_conversation_service.py | 279 +++++++++++++ 37 files changed, 2452 insertions(+), 84 deletions(-) create mode 100644 frontend/__tests__/components/features/conversation-panel/hooks-modal.test.tsx create mode 100644 frontend/src/components/features/chat/event-message-components/hook-execution-event-message.tsx create mode 100644 frontend/src/components/features/conversation-panel/hook-event-item.tsx create mode 100644 frontend/src/components/features/conversation-panel/hook-matcher-content.tsx create mode 100644 frontend/src/components/features/conversation-panel/hooks-empty-state.tsx create mode 100644 frontend/src/components/features/conversation-panel/hooks-loading-state.tsx create mode 100644 frontend/src/components/features/conversation-panel/hooks-modal-header.tsx create mode 100644 frontend/src/components/features/conversation-panel/hooks-modal.tsx create mode 100644 frontend/src/components/shared/hook-execution-event-message.tsx create mode 100644 frontend/src/components/v1/chat/event-message-components/hook-execution-event-message.tsx create mode 100644 frontend/src/hooks/query/use-conversation-hooks.ts create mode 100644 frontend/src/types/v1/core/events/hook-execution-event.ts create mode 100644 openhands/app_server/app_conversation/hook_loader.py create mode 100644 tests/unit/app_server/test_app_conversation_hooks_endpoint.py diff --git a/frontend/__tests__/components/context-menu/tools-context-menu.test.tsx b/frontend/__tests__/components/context-menu/tools-context-menu.test.tsx index 3e4f4b90b1..7377febdf5 100644 --- a/frontend/__tests__/components/context-menu/tools-context-menu.test.tsx +++ b/frontend/__tests__/components/context-menu/tools-context-menu.test.tsx @@ -44,6 +44,7 @@ describe("SystemMessage UI Rendering", () => { {}} onShowSkills={() => {}} + onShowHooks={() => {}} onShowAgentTools={() => {}} />, ); diff --git a/frontend/__tests__/components/features/conversation-panel/hooks-modal.test.tsx b/frontend/__tests__/components/features/conversation-panel/hooks-modal.test.tsx new file mode 100644 index 0000000000..7cb788068d --- /dev/null +++ b/frontend/__tests__/components/features/conversation-panel/hooks-modal.test.tsx @@ -0,0 +1,207 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; +import { render, screen, within } from "@testing-library/react"; +import userEvent from "@testing-library/user-event"; +import React from "react"; +import { HookEventItem } from "#/components/features/conversation-panel/hook-event-item"; +import { HooksEmptyState } from "#/components/features/conversation-panel/hooks-empty-state"; +import { HooksLoadingState } from "#/components/features/conversation-panel/hooks-loading-state"; +import { HooksModalHeader } from "#/components/features/conversation-panel/hooks-modal-header"; +import { HookEvent } from "#/api/conversation-service/v1-conversation-service.types"; + +// Mock react-i18next +vi.mock("react-i18next", async () => { + const actual = await vi.importActual("react-i18next"); + return { + ...actual, + useTranslation: () => ({ + t: (key: string, params?: Record) => { + const translations: Record = { + HOOKS_MODAL$TITLE: "Available Hooks", + HOOKS_MODAL$HOOK_COUNT: `${params?.count ?? 0} hooks`, + HOOKS_MODAL$EVENT_PRE_TOOL_USE: "Pre Tool Use", + HOOKS_MODAL$EVENT_POST_TOOL_USE: "Post Tool Use", + HOOKS_MODAL$EVENT_USER_PROMPT_SUBMIT: "User Prompt Submit", + HOOKS_MODAL$EVENT_SESSION_START: "Session Start", + HOOKS_MODAL$EVENT_SESSION_END: "Session End", + HOOKS_MODAL$EVENT_STOP: "Stop", + HOOKS_MODAL$MATCHER: "Matcher", + HOOKS_MODAL$COMMANDS: "Commands", + HOOKS_MODAL$TYPE: `Type: ${params?.type ?? ""}`, + HOOKS_MODAL$TIMEOUT: `Timeout: ${params?.timeout ?? 0}s`, + HOOKS_MODAL$ASYNC: "Async", + COMMON$FETCH_ERROR: "Failed to fetch data", + CONVERSATION$NO_HOOKS: "No hooks configured", + BUTTON$REFRESH: "Refresh", + }; + return translations[key] || key; + }, + i18n: { + changeLanguage: () => new Promise(() => {}), + }, + }), + }; +}); + +describe("HooksLoadingState", () => { + it("should render loading spinner", () => { + render(); + const spinner = document.querySelector(".animate-spin"); + expect(spinner).toBeInTheDocument(); + }); +}); + +describe("HooksEmptyState", () => { + it("should render no hooks message when not error", () => { + render(); + expect(screen.getByText("No hooks configured")).toBeInTheDocument(); + }); + + it("should render error message when isError is true", () => { + render(); + expect(screen.getByText("Failed to fetch data")).toBeInTheDocument(); + }); +}); + +describe("HooksModalHeader", () => { + const defaultProps = { + isAgentReady: true, + isLoading: false, + isRefetching: false, + onRefresh: vi.fn(), + }; + + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("should render title", () => { + render(); + expect(screen.getByText("Available Hooks")).toBeInTheDocument(); + }); + + it("should render refresh button when agent is ready", () => { + render(); + expect(screen.getByTestId("refresh-hooks")).toBeInTheDocument(); + }); + + it("should not render refresh button when agent is not ready", () => { + render(); + expect(screen.queryByTestId("refresh-hooks")).not.toBeInTheDocument(); + }); + + it("should call onRefresh when refresh button is clicked", async () => { + const user = userEvent.setup(); + const onRefresh = vi.fn(); + render(); + + await user.click(screen.getByTestId("refresh-hooks")); + expect(onRefresh).toHaveBeenCalledTimes(1); + }); + + it("should disable refresh button when loading", () => { + render(); + expect(screen.getByTestId("refresh-hooks")).toBeDisabled(); + }); + + it("should disable refresh button when refetching", () => { + render(); + expect(screen.getByTestId("refresh-hooks")).toBeDisabled(); + }); +}); + +describe("HookEventItem", () => { + const mockHookEvent: HookEvent = { + event_type: "stop", + matchers: [ + { + matcher: "*", + hooks: [ + { + type: "command", + command: ".openhands/hooks/on_stop.sh", + timeout: 30, + async: true, + }, + ], + }, + ], + }; + + const defaultProps = { + hookEvent: mockHookEvent, + isExpanded: false, + onToggle: vi.fn(), + }; + + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("should render event type label using i18n", () => { + render(); + expect(screen.getByText("Stop")).toBeInTheDocument(); + }); + + it("should render hook count", () => { + render(); + expect(screen.getByText("1 hooks")).toBeInTheDocument(); + }); + + it("should call onToggle when clicked", async () => { + const user = userEvent.setup(); + const onToggle = vi.fn(); + render(); + + await user.click(screen.getByRole("button")); + expect(onToggle).toHaveBeenCalledWith("stop"); + }); + + it("should show collapsed state by default", () => { + render(); + // Matcher content should not be visible when collapsed + expect(screen.queryByText("*")).not.toBeInTheDocument(); + }); + + it("should show expanded state with matcher content", () => { + render(); + // Matcher content should be visible when expanded + expect(screen.getByText("*")).toBeInTheDocument(); + }); + + it("should render async badge for async hooks", () => { + render(); + expect(screen.getByText("Async")).toBeInTheDocument(); + }); + + it("should render different event types with correct i18n labels", () => { + const eventTypes = [ + { type: "pre_tool_use", label: "Pre Tool Use" }, + { type: "post_tool_use", label: "Post Tool Use" }, + { type: "user_prompt_submit", label: "User Prompt Submit" }, + { type: "session_start", label: "Session Start" }, + { type: "session_end", label: "Session End" }, + { type: "stop", label: "Stop" }, + ]; + + eventTypes.forEach(({ type, label }) => { + const { unmount } = render( + , + ); + expect(screen.getByText(label)).toBeInTheDocument(); + unmount(); + }); + }); + + it("should fallback to event_type when no i18n key exists", () => { + render( + , + ); + expect(screen.getByText("unknown_event")).toBeInTheDocument(); + }); +}); diff --git a/frontend/src/api/conversation-service/v1-conversation-service.api.ts b/frontend/src/api/conversation-service/v1-conversation-service.api.ts index 30fdeb9369..a0e99abe0f 100644 --- a/frontend/src/api/conversation-service/v1-conversation-service.api.ts +++ b/frontend/src/api/conversation-service/v1-conversation-service.api.ts @@ -14,6 +14,7 @@ import type { V1AppConversation, V1AppConversationPage, GetSkillsResponse, + GetHooksResponse, V1RuntimeConversationInfo, } from "./v1-conversation-service.types"; @@ -400,6 +401,18 @@ class V1ConversationService { return data; } + /** + * Get all hooks associated with a V1 conversation + * @param conversationId The conversation ID + * @returns The available hooks associated with the conversation + */ + static async getHooks(conversationId: string): Promise { + const { data } = await openHands.get( + `/api/v1/app-conversations/${conversationId}/hooks`, + ); + return data; + } + /** * Get conversation info directly from the runtime for a V1 conversation * Uses the custom runtime URL from the conversation diff --git a/frontend/src/api/conversation-service/v1-conversation-service.types.ts b/frontend/src/api/conversation-service/v1-conversation-service.types.ts index b437e17bf1..50e904d553 100644 --- a/frontend/src/api/conversation-service/v1-conversation-service.types.ts +++ b/frontend/src/api/conversation-service/v1-conversation-service.types.ts @@ -135,6 +135,27 @@ export interface GetSkillsResponse { skills: Skill[]; } +export interface HookDefinition { + type: string; // 'command' or 'prompt' + command: string; + timeout: number; + async?: boolean; +} + +export interface HookMatcher { + matcher: string; // Pattern: '*', exact match, or regex + hooks: HookDefinition[]; +} + +export interface HookEvent { + event_type: string; // e.g., 'stop', 'pre_tool_use', 'post_tool_use' + matchers: HookMatcher[]; +} + +export interface GetHooksResponse { + hooks: HookEvent[]; +} + // Runtime conversation types (from agent server) export interface V1RuntimeConversationStats { usage_to_metrics: Record; diff --git a/frontend/src/components/features/chat/event-message-components/hook-execution-event-message.tsx b/frontend/src/components/features/chat/event-message-components/hook-execution-event-message.tsx new file mode 100644 index 0000000000..4bec021ab7 --- /dev/null +++ b/frontend/src/components/features/chat/event-message-components/hook-execution-event-message.tsx @@ -0,0 +1 @@ +export { HookExecutionEventMessage } from "#/components/shared/hook-execution-event-message"; diff --git a/frontend/src/components/features/chat/event-message-components/index.ts b/frontend/src/components/features/chat/event-message-components/index.ts index c2db5a5f5a..d439405e5d 100644 --- a/frontend/src/components/features/chat/event-message-components/index.ts +++ b/frontend/src/components/features/chat/event-message-components/index.ts @@ -8,3 +8,4 @@ export { ObservationPairEventMessage } from "./observation-pair-event-message"; export { GenericEventMessageWrapper } from "./generic-event-message-wrapper"; export { MicroagentStatusWrapper } from "./microagent-status-wrapper"; export { LikertScaleWrapper } from "./likert-scale-wrapper"; +export { HookExecutionEventMessage } from "./hook-execution-event-message"; diff --git a/frontend/src/components/features/controls/tools-context-menu.tsx b/frontend/src/components/features/controls/tools-context-menu.tsx index 2089f95111..31d61105fe 100644 --- a/frontend/src/components/features/controls/tools-context-menu.tsx +++ b/frontend/src/components/features/controls/tools-context-menu.tsx @@ -27,15 +27,19 @@ const contextMenuListItemClassName = cn( interface ToolsContextMenuProps { onClose: () => void; onShowSkills: (event: React.MouseEvent) => void; + onShowHooks: (event: React.MouseEvent) => void; onShowAgentTools: (event: React.MouseEvent) => void; shouldShowAgentTools?: boolean; + shouldShowHooks?: boolean; } export function ToolsContextMenu({ onClose, onShowSkills, + onShowHooks, onShowAgentTools, shouldShowAgentTools = true, + shouldShowHooks = false, }: ToolsContextMenuProps) { const { t } = useTranslation(); const { data: conversation } = useActiveConversation(); @@ -141,6 +145,21 @@ export function ToolsContextMenu({ /> + {/* Show Hooks - Only show for V1 conversations */} + {shouldShowHooks && ( + + } + text={t(I18nKey.CONVERSATION$SHOW_HOOKS)} + className={CONTEXT_MENU_ICON_TEXT_CLASSNAME} + /> + + )} + {/* Show Agent Tools and Metadata - Only show if system message is available */} {shouldShowAgentTools && ( setContextMenuOpen(false)} onShowSkills={handleShowSkills} + onShowHooks={handleShowHooks} onShowAgentTools={handleShowAgentTools} shouldShowAgentTools={shouldShowAgentTools} + shouldShowHooks={shouldShowHooks} /> )} @@ -68,6 +75,11 @@ export function Tools() { {skillsModalVisible && ( setSkillsModalVisible(false)} /> )} + + {/* Hooks Modal */} + {hooksModalVisible && ( + setHooksModalVisible(false)} /> + )}
); } diff --git a/frontend/src/components/features/conversation-panel/hook-event-item.tsx b/frontend/src/components/features/conversation-panel/hook-event-item.tsx new file mode 100644 index 0000000000..add99f891b --- /dev/null +++ b/frontend/src/components/features/conversation-panel/hook-event-item.tsx @@ -0,0 +1,75 @@ +import { useTranslation } from "react-i18next"; +import { ChevronDown, ChevronRight } from "lucide-react"; +import { Typography } from "#/ui/typography"; +import { HookEvent } from "#/api/conversation-service/v1-conversation-service.types"; +import { HookMatcherContent } from "./hook-matcher-content"; +import { I18nKey } from "#/i18n/declaration"; + +interface HookEventItemProps { + hookEvent: HookEvent; + isExpanded: boolean; + onToggle: (eventType: string) => void; +} + +const EVENT_TYPE_I18N_KEYS: Record = { + pre_tool_use: I18nKey.HOOKS_MODAL$EVENT_PRE_TOOL_USE, + post_tool_use: I18nKey.HOOKS_MODAL$EVENT_POST_TOOL_USE, + user_prompt_submit: I18nKey.HOOKS_MODAL$EVENT_USER_PROMPT_SUBMIT, + session_start: I18nKey.HOOKS_MODAL$EVENT_SESSION_START, + session_end: I18nKey.HOOKS_MODAL$EVENT_SESSION_END, + stop: I18nKey.HOOKS_MODAL$EVENT_STOP, +}; + +export function HookEventItem({ + hookEvent, + isExpanded, + onToggle, +}: HookEventItemProps) { + const { t } = useTranslation(); + const i18nKey = EVENT_TYPE_I18N_KEYS[hookEvent.event_type]; + const eventTypeLabel = i18nKey ? t(i18nKey) : hookEvent.event_type; + + const totalHooks = hookEvent.matchers.reduce( + (sum, matcher) => sum + matcher.hooks.length, + 0, + ); + + return ( +
+ + + {isExpanded && ( +
+ {hookEvent.matchers.map((matcher, index) => ( + + ))} +
+ )} +
+ ); +} diff --git a/frontend/src/components/features/conversation-panel/hook-matcher-content.tsx b/frontend/src/components/features/conversation-panel/hook-matcher-content.tsx new file mode 100644 index 0000000000..587653502f --- /dev/null +++ b/frontend/src/components/features/conversation-panel/hook-matcher-content.tsx @@ -0,0 +1,61 @@ +import { useTranslation } from "react-i18next"; +import { I18nKey } from "#/i18n/declaration"; +import { Typography } from "#/ui/typography"; +import { Pre } from "#/ui/pre"; +import { HookMatcher } from "#/api/conversation-service/v1-conversation-service.types"; + +interface HookMatcherContentProps { + matcher: HookMatcher; +} + +export function HookMatcherContent({ matcher }: HookMatcherContentProps) { + const { t } = useTranslation(); + + return ( +
+
+ + {t(I18nKey.HOOKS_MODAL$MATCHER)} + + + {matcher.matcher} + +
+ +
+ + {t(I18nKey.HOOKS_MODAL$COMMANDS)} + + {matcher.hooks.map((hook, index) => ( +
+
+              {hook.command}
+            
+
+ {t(I18nKey.HOOKS_MODAL$TYPE, { type: hook.type })} + + {t(I18nKey.HOOKS_MODAL$TIMEOUT, { timeout: hook.timeout })} + + {hook.async ? ( + + {t(I18nKey.HOOKS_MODAL$ASYNC)} + + ) : null} +
+
+ ))} +
+
+ ); +} diff --git a/frontend/src/components/features/conversation-panel/hooks-empty-state.tsx b/frontend/src/components/features/conversation-panel/hooks-empty-state.tsx new file mode 100644 index 0000000000..626561b52f --- /dev/null +++ b/frontend/src/components/features/conversation-panel/hooks-empty-state.tsx @@ -0,0 +1,21 @@ +import { useTranslation } from "react-i18next"; +import { I18nKey } from "#/i18n/declaration"; +import { Typography } from "#/ui/typography"; + +interface HooksEmptyStateProps { + isError: boolean; +} + +export function HooksEmptyState({ isError }: HooksEmptyStateProps) { + const { t } = useTranslation(); + + return ( +
+ + {isError + ? t(I18nKey.COMMON$FETCH_ERROR) + : t(I18nKey.CONVERSATION$NO_HOOKS)} + +
+ ); +} diff --git a/frontend/src/components/features/conversation-panel/hooks-loading-state.tsx b/frontend/src/components/features/conversation-panel/hooks-loading-state.tsx new file mode 100644 index 0000000000..2a915a3677 --- /dev/null +++ b/frontend/src/components/features/conversation-panel/hooks-loading-state.tsx @@ -0,0 +1,7 @@ +export function HooksLoadingState() { + return ( +
+
+
+ ); +} diff --git a/frontend/src/components/features/conversation-panel/hooks-modal-header.tsx b/frontend/src/components/features/conversation-panel/hooks-modal-header.tsx new file mode 100644 index 0000000000..ab65fa7386 --- /dev/null +++ b/frontend/src/components/features/conversation-panel/hooks-modal-header.tsx @@ -0,0 +1,45 @@ +import { useTranslation } from "react-i18next"; +import { RefreshCw } from "lucide-react"; +import { BaseModalTitle } from "#/components/shared/modals/confirmation-modals/base-modal"; +import { I18nKey } from "#/i18n/declaration"; +import { BrandButton } from "../settings/brand-button"; + +interface HooksModalHeaderProps { + isAgentReady: boolean; + isLoading: boolean; + isRefetching: boolean; + onRefresh: () => void; +} + +export function HooksModalHeader({ + isAgentReady, + isLoading, + isRefetching, + onRefresh, +}: HooksModalHeaderProps) { + const { t } = useTranslation(); + + return ( +
+
+ + {isAgentReady && ( + + + {t(I18nKey.BUTTON$REFRESH)} + + )} +
+
+ ); +} diff --git a/frontend/src/components/features/conversation-panel/hooks-modal.tsx b/frontend/src/components/features/conversation-panel/hooks-modal.tsx new file mode 100644 index 0000000000..6f2677ffb1 --- /dev/null +++ b/frontend/src/components/features/conversation-panel/hooks-modal.tsx @@ -0,0 +1,102 @@ +import { useState } from "react"; +import { useTranslation } from "react-i18next"; +import { ModalBackdrop } from "#/components/shared/modals/modal-backdrop"; +import { ModalBody } from "#/components/shared/modals/modal-body"; +import { I18nKey } from "#/i18n/declaration"; +import { useConversationHooks } from "#/hooks/query/use-conversation-hooks"; +import { AgentState } from "#/types/agent-state"; +import { Typography } from "#/ui/typography"; +import { HooksModalHeader } from "./hooks-modal-header"; +import { HooksLoadingState } from "./hooks-loading-state"; +import { HooksEmptyState } from "./hooks-empty-state"; +import { HookEventItem } from "./hook-event-item"; +import { useAgentState } from "#/hooks/use-agent-state"; + +interface HooksModalProps { + onClose: () => void; +} + +export function HooksModal({ onClose }: HooksModalProps) { + const { t } = useTranslation(); + const { curAgentState } = useAgentState(); + const [expandedEvents, setExpandedEvents] = useState>( + {}, + ); + const { + data: hooks, + isLoading, + isError, + refetch, + isRefetching, + } = useConversationHooks(); + + const toggleEvent = (eventType: string) => { + setExpandedEvents((prev) => ({ + ...prev, + [eventType]: !prev[eventType], + })); + }; + + const isAgentReady = ![AgentState.LOADING, AgentState.INIT].includes( + curAgentState, + ); + + return ( + + + + + {isAgentReady && ( + + {t(I18nKey.HOOKS_MODAL$WARNING)} + + )} + +
+ {!isAgentReady && ( +
+ + {t(I18nKey.DIFF_VIEWER$WAITING_FOR_RUNTIME)} + +
+ )} + + {isLoading && } + + {!isLoading && + isAgentReady && + (isError || !hooks || hooks.length === 0) && ( + + )} + + {!isLoading && isAgentReady && hooks && hooks.length > 0 && ( +
+ {hooks.map((hookEvent) => { + const isExpanded = + expandedEvents[hookEvent.event_type] || false; + + return ( + + ); + })} +
+ )} +
+
+
+ ); +} diff --git a/frontend/src/components/features/conversation/conversation-name-context-menu.tsx b/frontend/src/components/features/conversation/conversation-name-context-menu.tsx index 1d1a7bb789..83d80eff39 100644 --- a/frontend/src/components/features/conversation/conversation-name-context-menu.tsx +++ b/frontend/src/components/features/conversation/conversation-name-context-menu.tsx @@ -35,6 +35,7 @@ interface ConversationNameContextMenuProps { onDisplayCost?: (event: React.MouseEvent) => void; onShowAgentTools?: (event: React.MouseEvent) => void; onShowSkills?: (event: React.MouseEvent) => void; + onShowHooks?: (event: React.MouseEvent) => void; onExportConversation?: (event: React.MouseEvent) => void; onDownloadViaVSCode?: (event: React.MouseEvent) => void; onTogglePublic?: (event: React.MouseEvent) => void; @@ -52,6 +53,7 @@ export function ConversationNameContextMenu({ onDisplayCost, onShowAgentTools, onShowSkills, + onShowHooks, onExportConversation, onDownloadViaVSCode, onTogglePublic, @@ -77,7 +79,7 @@ export function ConversationNameContextMenu({ const hasDownload = Boolean(onDownloadViaVSCode || onDownloadConversation); const hasExport = Boolean(onExportConversation); - const hasTools = Boolean(onShowAgentTools || onShowSkills); + const hasTools = Boolean(onShowAgentTools || onShowSkills || onShowHooks); const hasInfo = Boolean(onDisplayCost); const hasControl = Boolean(onStop || onDelete); @@ -119,6 +121,20 @@ export function ConversationNameContextMenu({ )} + {onShowHooks && ( + + } + text={t(I18nKey.CONVERSATION$SHOW_HOOKS)} + className={CONTEXT_MENU_ICON_TEXT_CLASSNAME} + /> + + )} + {onShowAgentTools && ( setSkillsModalVisible(false)} /> )} + {/* Hooks Modal */} + {hooksModalVisible && ( + setHooksModalVisible(false)} /> + )} + {/* Confirm Delete Modal */} {confirmDeleteModalVisible && ( 80) { + return `${command.slice(0, 77)}...`; + } + return command; +} + +function getStatusText(blocked: boolean, success: boolean): string { + if (blocked) return "blocked"; + if (success) return "ok"; + return "failed"; +} + +function getStatusClassName(blocked: boolean, success: boolean): string { + if (blocked) return "bg-amber-900/50 text-amber-300"; + if (success) return "bg-green-900/50 text-green-300"; + return "bg-red-900/50 text-red-300"; +} + +export function HookExecutionEventMessage({ + event, +}: HookExecutionEventMessageProps) { + const { t } = useTranslation(); + + if (!isHookExecutionEvent(event)) { + return null; + } + + const icon = getHookIcon(event.hook_event_type, event.blocked); + const statusText = getStatusText(event.blocked, event.success); + const statusClassName = getStatusClassName(event.blocked, event.success); + + // Determine the overall success indicator for GenericEventMessage. + // When blocked, suppress the success indicator entirely — the amber "blocked" + // badge in the title is the authoritative status signal. + const getSuccessStatus = (): "success" | "error" | undefined => { + if (event.blocked) return undefined; + return event.success ? "success" : "error"; + }; + const successStatus = getSuccessStatus(); + + const title = ( + + {icon} {t("HOOK$HOOK_LABEL")}: {event.hook_event_type} + {event.tool_name && ( + ({event.tool_name}) + )} + + {statusText} + + + ); + + const details = ( +
+
+ {t("HOOK$COMMAND")}:{" "} + + {formatHookCommand(event.hook_command)} + +
+ + {event.exit_code !== null && ( +
+ {t("HOOK$EXIT_CODE")}:{" "} + {event.exit_code} +
+ )} + + {event.blocked && event.reason && ( +
+ {t("HOOK$BLOCKED_REASON")}:{" "} + {event.reason} +
+ )} + + {event.additional_context && ( +
+ {t("HOOK$CONTEXT")}:{" "} + {event.additional_context} +
+ )} + + {event.error && ( +
+ {t("HOOK$ERROR")}:{" "} + {event.error} +
+ )} + + {event.stdout && ( +
+ {t("HOOK$OUTPUT")}: +
+            {event.stdout}
+          
+
+ )} + + {event.stderr && ( +
+ {t("HOOK$STDERR")}: +
+            {event.stderr}
+          
+
+ )} +
+ ); + + return ( + + ); +} diff --git a/frontend/src/components/v1/chat/event-content-helpers/should-render-event.ts b/frontend/src/components/v1/chat/event-content-helpers/should-render-event.ts index b20bedc4a9..91b206f840 100644 --- a/frontend/src/components/v1/chat/event-content-helpers/should-render-event.ts +++ b/frontend/src/components/v1/chat/event-content-helpers/should-render-event.ts @@ -5,6 +5,7 @@ import { isMessageEvent, isAgentErrorEvent, isConversationStateUpdateEvent, + isHookExecutionEvent, } from "#/types/v1/type-guards"; export const shouldRenderEvent = (event: OpenHandsEvent) => { @@ -50,6 +51,11 @@ export const shouldRenderEvent = (event: OpenHandsEvent) => { return true; } + // Render hook execution events + if (isHookExecutionEvent(event)) { + return true; + } + // Don't render any other event types (system events, etc.) return false; }; diff --git a/frontend/src/components/v1/chat/event-message-components/hook-execution-event-message.tsx b/frontend/src/components/v1/chat/event-message-components/hook-execution-event-message.tsx new file mode 100644 index 0000000000..4bec021ab7 --- /dev/null +++ b/frontend/src/components/v1/chat/event-message-components/hook-execution-event-message.tsx @@ -0,0 +1 @@ +export { HookExecutionEventMessage } from "#/components/shared/hook-execution-event-message"; diff --git a/frontend/src/components/v1/chat/event-message-components/index.ts b/frontend/src/components/v1/chat/event-message-components/index.ts index 3672255101..52df9ab096 100644 --- a/frontend/src/components/v1/chat/event-message-components/index.ts +++ b/frontend/src/components/v1/chat/event-message-components/index.ts @@ -4,3 +4,4 @@ export { ErrorEventMessage } from "./error-event-message"; export { FinishEventMessage } from "./finish-event-message"; export { GenericEventMessageWrapper } from "./generic-event-message-wrapper"; export { ThoughtEventMessage } from "./thought-event-message"; +export { HookExecutionEventMessage } from "./hook-execution-event-message"; diff --git a/frontend/src/components/v1/chat/event-message.tsx b/frontend/src/components/v1/chat/event-message.tsx index fa6299e57a..57b543bc8e 100644 --- a/frontend/src/components/v1/chat/event-message.tsx +++ b/frontend/src/components/v1/chat/event-message.tsx @@ -7,6 +7,7 @@ import { isAgentErrorEvent, isUserMessageEvent, isPlanningFileEditorObservationEvent, + isHookExecutionEvent, } from "#/types/v1/type-guards"; import { MicroagentStatus } from "#/types/microagent-status"; import { useConfig } from "#/hooks/query/use-config"; @@ -21,6 +22,7 @@ import { FinishEventMessage, GenericEventMessageWrapper, ThoughtEventMessage, + HookExecutionEventMessage, } from "./event-message-components"; import { createSkillReadyEvent } from "./event-content-helpers/create-skill-ready-event"; import { PlanPreview } from "../../features/chat/plan-preview"; @@ -188,6 +190,11 @@ export function EventMessage({ return ; } + // Hook execution events + if (isHookExecutionEvent(event)) { + return ; + } + // Finish actions if (isActionEvent(event) && event.action.kind === "FinishAction") { return ( diff --git a/frontend/src/hooks/query/use-conversation-hooks.ts b/frontend/src/hooks/query/use-conversation-hooks.ts new file mode 100644 index 0000000000..7f18fab0ba --- /dev/null +++ b/frontend/src/hooks/query/use-conversation-hooks.ts @@ -0,0 +1,36 @@ +import { useQuery } from "@tanstack/react-query"; +import V1ConversationService from "#/api/conversation-service/v1-conversation-service.api"; +import { useConversationId } from "../use-conversation-id"; +import { AgentState } from "#/types/agent-state"; +import { useAgentState } from "#/hooks/use-agent-state"; +import { useSettings } from "./use-settings"; + +export const useConversationHooks = () => { + const { conversationId } = useConversationId(); + const { curAgentState } = useAgentState(); + const { data: settings } = useSettings(); + + return useQuery({ + queryKey: ["conversation", conversationId, "hooks", settings?.v1_enabled], + queryFn: async () => { + if (!conversationId) { + throw new Error("No conversation ID provided"); + } + + // Hooks are only available for V1 conversations + if (!settings?.v1_enabled) { + return []; + } + + const data = await V1ConversationService.getHooks(conversationId); + return data.hooks; + }, + enabled: + !!conversationId && + !!settings?.v1_enabled && + curAgentState !== AgentState.LOADING && + curAgentState !== AgentState.INIT, + staleTime: 1000 * 60 * 5, // 5 minutes + gcTime: 1000 * 60 * 15, // 15 minutes + }); +}; diff --git a/frontend/src/hooks/use-conversation-name-context-menu.ts b/frontend/src/hooks/use-conversation-name-context-menu.ts index d20348ef8f..ed49f81b59 100644 --- a/frontend/src/hooks/use-conversation-name-context-menu.ts +++ b/frontend/src/hooks/use-conversation-name-context-menu.ts @@ -53,6 +53,7 @@ export function useConversationNameContextMenu({ const [metricsModalVisible, setMetricsModalVisible] = React.useState(false); const [systemModalVisible, setSystemModalVisible] = React.useState(false); const [skillsModalVisible, setSkillsModalVisible] = React.useState(false); + const [hooksModalVisible, setHooksModalVisible] = React.useState(false); const [confirmDeleteModalVisible, setConfirmDeleteModalVisible] = React.useState(false); const [confirmStopModalVisible, setConfirmStopModalVisible] = @@ -187,6 +188,12 @@ export function useConversationNameContextMenu({ onContextMenuToggle?.(false); }; + const handleShowHooks = (event: React.MouseEvent) => { + event.stopPropagation(); + setHooksModalVisible(true); + onContextMenuToggle?.(false); + }; + const handleTogglePublic = (event: React.MouseEvent) => { event.preventDefault(); event.stopPropagation(); @@ -233,6 +240,7 @@ export function useConversationNameContextMenu({ handleDisplayCost, handleShowAgentTools, handleShowSkills, + handleShowHooks, handleTogglePublic, handleCopyShareLink, shareUrl, @@ -246,6 +254,8 @@ export function useConversationNameContextMenu({ setSystemModalVisible, skillsModalVisible, setSkillsModalVisible, + hooksModalVisible, + setHooksModalVisible, confirmDeleteModalVisible, setConfirmDeleteModalVisible, confirmStopModalVisible, @@ -267,5 +277,11 @@ export function useConversationNameContextMenu({ shouldShowDisplayCost: showOptions, shouldShowAgentTools: Boolean(showOptions && systemMessage), shouldShowSkills: Boolean(showOptions && conversationId), + shouldShowHooks: Boolean( + showOptions && + conversationId && + conversation?.conversation_version === "V1" && + conversationStatus === "RUNNING", + ), }; } diff --git a/frontend/src/i18n/declaration.ts b/frontend/src/i18n/declaration.ts index 648143fc2e..8aef562320 100644 --- a/frontend/src/i18n/declaration.ts +++ b/frontend/src/i18n/declaration.ts @@ -683,6 +683,8 @@ export enum I18nKey { TOS$ERROR_ACCEPTING = "TOS$ERROR_ACCEPTING", TIPS$CUSTOMIZE_MICROAGENT = "TIPS$CUSTOMIZE_MICROAGENT", CONVERSATION$NO_SKILLS = "CONVERSATION$NO_SKILLS", + CONVERSATION$NO_HOOKS = "CONVERSATION$NO_HOOKS", + CONVERSATION$SHOW_HOOKS = "CONVERSATION$SHOW_HOOKS", CONVERSATION$FAILED_TO_FETCH_MICROAGENTS = "CONVERSATION$FAILED_TO_FETCH_MICROAGENTS", MICROAGENTS_MODAL$TITLE = "MICROAGENTS_MODAL$TITLE", SKILLS_MODAL$WARNING = "SKILLS_MODAL$WARNING", @@ -1078,6 +1080,28 @@ export enum I18nKey { CONVERSATION$NO_HISTORY_AVAILABLE = "CONVERSATION$NO_HISTORY_AVAILABLE", CONVERSATION$SHARED_CONVERSATION = "CONVERSATION$SHARED_CONVERSATION", CONVERSATION$LINK_COPIED = "CONVERSATION$LINK_COPIED", + HOOKS_MODAL$TITLE = "HOOKS_MODAL$TITLE", + HOOKS_MODAL$WARNING = "HOOKS_MODAL$WARNING", + HOOKS_MODAL$MATCHER = "HOOKS_MODAL$MATCHER", + HOOKS_MODAL$COMMANDS = "HOOKS_MODAL$COMMANDS", + HOOKS_MODAL$HOOK_COUNT = "HOOKS_MODAL$HOOK_COUNT", + HOOKS_MODAL$TYPE = "HOOKS_MODAL$TYPE", + HOOKS_MODAL$TIMEOUT = "HOOKS_MODAL$TIMEOUT", + HOOKS_MODAL$ASYNC = "HOOKS_MODAL$ASYNC", + HOOKS_MODAL$EVENT_PRE_TOOL_USE = "HOOKS_MODAL$EVENT_PRE_TOOL_USE", + HOOKS_MODAL$EVENT_POST_TOOL_USE = "HOOKS_MODAL$EVENT_POST_TOOL_USE", + HOOKS_MODAL$EVENT_USER_PROMPT_SUBMIT = "HOOKS_MODAL$EVENT_USER_PROMPT_SUBMIT", + HOOKS_MODAL$EVENT_SESSION_START = "HOOKS_MODAL$EVENT_SESSION_START", + HOOKS_MODAL$EVENT_SESSION_END = "HOOKS_MODAL$EVENT_SESSION_END", + HOOKS_MODAL$EVENT_STOP = "HOOKS_MODAL$EVENT_STOP", + HOOK$HOOK_LABEL = "HOOK$HOOK_LABEL", + HOOK$COMMAND = "HOOK$COMMAND", + HOOK$EXIT_CODE = "HOOK$EXIT_CODE", + HOOK$BLOCKED_REASON = "HOOK$BLOCKED_REASON", + HOOK$CONTEXT = "HOOK$CONTEXT", + HOOK$ERROR = "HOOK$ERROR", + HOOK$OUTPUT = "HOOK$OUTPUT", + HOOK$STDERR = "HOOK$STDERR", COMMON$TYPE_EMAIL_AND_PRESS_SPACE = "COMMON$TYPE_EMAIL_AND_PRESS_SPACE", ORG$INVITE_ORG_MEMBERS = "ORG$INVITE_ORG_MEMBERS", ORG$MANAGE_ORGANIZATION = "ORG$MANAGE_ORGANIZATION", diff --git a/frontend/src/i18n/translation.json b/frontend/src/i18n/translation.json index 3437fde4a6..5556f891ca 100644 --- a/frontend/src/i18n/translation.json +++ b/frontend/src/i18n/translation.json @@ -7359,7 +7359,7 @@ "es": "Actualmente no hay un plan para este repositorio", "tr": "Şu anda bu depo için bir plan yok" }, - "SIDEBAR$NAVIGATION_LABEL": { + "SIDEBAR$NAVIGATION_LABEL": { "en": "Sidebar navigation", "zh-CN": "侧边栏导航", "zh-TW": "側邊欄導航", @@ -9327,7 +9327,6 @@ "de": "Abonnement kündigen", "uk": "Скасувати підписку" }, - "PAYMENT$SUBSCRIPTION_CANCELLED": { "en": "Subscription cancelled successfully", "ja": "サブスクリプションが正常にキャンセルされました", @@ -9344,7 +9343,6 @@ "de": "Abonnement erfolgreich gekündigt", "uk": "Підписку успішно скасовано" }, - "PAYMENT$NEXT_BILLING_DATE": { "en": "Next billing date: {{date}}", "ja": "次回請求日: {{date}}", @@ -10529,7 +10527,7 @@ "de": "klicken Sie hier für Anweisungen", "uk": "натисніть тут, щоб отримати інструкції" }, - "BITBUCKET_DATA_CENTER$TOKEN_LABEL": { + "BITBUCKET_DATA_CENTER$TOKEN_LABEL": { "en": "Bitbucket Data Center Token", "ja": "Bitbucket Data Centerトークン", "zh-CN": "Bitbucket Data Center令牌", @@ -10929,6 +10927,38 @@ "tr": "Bu sohbet için kullanılabilir yetenek bulunamadı.", "uk": "У цій розмові не знайдено доступних навичок." }, + "CONVERSATION$NO_HOOKS": { + "en": "No hooks configured for this conversation.", + "ja": "この会話にはフックが設定されていません。", + "zh-CN": "此会话未配置钩子。", + "zh-TW": "此對話未配置鉤子。", + "ko-KR": "이 대화에 구성된 훅이 없습니다.", + "no": "Ingen kroker konfigurert for denne samtalen.", + "ar": "لم يتم تكوين أي خطافات لهذه المحادثة.", + "de": "Keine Hooks für diese Unterhaltung konfiguriert.", + "fr": "Aucun hook configuré pour cette conversation.", + "it": "Nessun hook configurato per questa conversazione.", + "pt": "Nenhum hook configurado para esta conversa.", + "es": "No hay hooks configurados para esta conversación.", + "tr": "Bu sohbet için yapılandırılmış kanca yok.", + "uk": "Для цієї розмови не налаштовано хуків." + }, + "CONVERSATION$SHOW_HOOKS": { + "en": "Show Available Hooks", + "ja": "利用可能なフックを表示", + "zh-CN": "显示可用钩子", + "zh-TW": "顯示可用鉤子", + "ko-KR": "사용 가능한 훅 표시", + "no": "Vis tilgjengelige kroker", + "ar": "عرض الخطافات المتاحة", + "de": "Verfügbare Hooks anzeigen", + "fr": "Afficher les hooks disponibles", + "it": "Mostra hook disponibili", + "pt": "Mostrar hooks disponíveis", + "es": "Mostrar hooks disponibles", + "tr": "Kullanılabilir kancaları göster", + "uk": "Показати доступні хуки" + }, "CONVERSATION$FAILED_TO_FETCH_MICROAGENTS": { "en": "Failed to fetch available microagents", "ja": "利用可能なマイクロエージェントの取得に失敗しました", @@ -11777,7 +11807,6 @@ "tr": "Git sağlayıcısını bağla", "uk": "Підключити постачальник Git" }, - "TASKS$NO_GIT_PROVIDERS_DESCRIPTION": { "en": "Connect a Git provider to see suggested tasks from your repositories.", "ja": "Gitプロバイダーを接続して、リポジトリからの提案タスクを表示します。", @@ -11794,7 +11823,6 @@ "tr": "Depolarınızdan önerilen görevleri görmek için bir Git sağlayıcısı bağlayın.", "uk": "Підключіть постачальник Git, щоб бачити запропоновані завдання з ваших репозиторіїв." }, - "TASKS$NO_GIT_PROVIDERS_CTA": { "en": "Go to Integrations", "ja": "統合へ移動", @@ -17251,6 +17279,358 @@ "de": "Link in die Zwischenablage kopiert", "uk": "Посилання скопійовано в буфер обміну" }, + "HOOKS_MODAL$TITLE": { + "en": "Available Hooks", + "ja": "利用可能なフック", + "zh-CN": "可用钩子", + "zh-TW": "可用鉤子", + "ko-KR": "사용 가능한 훅", + "no": "Tilgjengelige kroker", + "ar": "الخطافات المتاحة", + "de": "Verfügbare Hooks", + "fr": "Hooks disponibles", + "it": "Hook disponibili", + "pt": "Hooks disponíveis", + "es": "Hooks disponibles", + "tr": "Kullanılabilir kancalar", + "uk": "Доступні хуки" + }, + "HOOKS_MODAL$WARNING": { + "en": "Hooks are loaded from your workspace. This view refreshes on demand and may differ from the hooks that were active when the conversation started. Stop and restart the conversation to apply changes.", + "ja": "フックはワークスペースから読み込まれます。この表示は要求時にワークスペースから再読み込みするため、会話開始時に有効だったフックと異なる場合があります。変更を適用するには会話を停止して再開してください。", + "zh-CN": "Hooks 从工作区读取。本视图会在请求时从工作区刷新,因此可能与会话启动时生效的 hooks 不一致。要应用更改,请停止并重新开始会话。", + "zh-TW": "Hooks 從工作區讀取。本視圖會在請求時從工作區重新整理,因此可能與會話啟動時生效的 hooks 不一致。要套用變更,請停止並重新開始會話。", + "ko-KR": "훅은 작업공간에서 로드됩니다. 이 화면은 요청 시 작업공간에서 다시 읽어 오므로 대화 시작 시 적용된 훅과 다를 수 있습니다. 변경을 적용하려면 대화를 중지한 뒤 다시 시작하세요.", + "no": "Hooks lastes fra arbeidsområdet. Denne visningen leser filen på nytt ved forespørsel og kan derfor avvike fra hookene som var aktive da samtalen startet. Stopp og start samtalen på nytt for å ta i bruk endringer.", + "ar": "يتم تحميل الخطافات من مساحة العمل. تقوم هذه الشاشة بإعادة قراءة الملف عند الطلب وقد تختلف عن الخطافات التي كانت فعّالة عند بدء المحادثة. لتطبيق التغييرات، أوقف المحادثة وأعد تشغيلها.", + "de": "Hooks werden aus dem Workspace geladen. Diese Ansicht liest die Datei bei Bedarf neu ein und kann daher von den Hooks abweichen, die beim Start der Unterhaltung aktiv waren. Stoppen und starten Sie die Unterhaltung neu, um Änderungen anzuwenden.", + "fr": "Les hooks sont chargés depuis votre espace de travail. Cette vue se rafraîchit à la demande depuis l’espace de travail et peut différer des hooks actifs au démarrage de la conversation. Arrêtez puis redémarrez la conversation pour appliquer les modifications.", + "it": "Gli hook vengono caricati dal tuo workspace. Questa vista si aggiorna su richiesta dal workspace e può differire dagli hook attivi all’avvio della conversazione. Interrompi e riavvia la conversazione per applicare le modifiche.", + "pt": "Os hooks são carregados do seu workspace. Esta visualização é atualizada sob demanda a partir do workspace e pode ser diferente dos hooks que estavam ativos quando a conversa foi iniciada. Pare e reinicie a conversa para aplicar as alterações.", + "es": "Los hooks se cargan desde tu espacio de trabajo. Esta vista se actualiza bajo demanda desde el workspace y puede diferir de los hooks que estaban activos cuando comenzó la conversación. Detén y reinicia la conversación para aplicar los cambios.", + "tr": "Kancalar çalışma alanınızdan yüklenir. Bu görünüm istek üzerine çalışma alanından yenilenir ve sohbet başlatıldığında etkin olan kancalardan farklı olabilir. Değişiklikleri uygulamak için sohbeti durdurup yeniden başlatın.", + "uk": "Хуки завантажуються з вашого робочого простору. Це подання оновлюється з робочого простору на вимогу й може відрізнятися від хуків, які були активні під час запуску розмови. Щоб застосувати зміни, зупиніть і перезапустіть розмову." + }, + "HOOKS_MODAL$MATCHER": { + "en": "Matcher", + "ja": "マッチャー", + "zh-CN": "匹配器", + "zh-TW": "匹配器", + "ko-KR": "매처", + "no": "Matcher", + "ar": "المطابق", + "de": "Matcher", + "fr": "Matcher", + "it": "Matcher", + "pt": "Matcher", + "es": "Matcher", + "tr": "Eşleştirici", + "uk": "Матчер" + }, + "HOOKS_MODAL$COMMANDS": { + "en": "Commands", + "ja": "コマンド", + "zh-CN": "命令", + "zh-TW": "命令", + "ko-KR": "명령", + "no": "Kommandoer", + "ar": "الأوامر", + "de": "Befehle", + "fr": "Commandes", + "it": "Comandi", + "pt": "Comandos", + "es": "Comandos", + "tr": "Komutlar", + "uk": "Команди" + }, + "HOOKS_MODAL$HOOK_COUNT": { + "en": "{{count}} hook(s)", + "ja": "{{count}}個のフック", + "zh-CN": "{{count}}个钩子", + "zh-TW": "{{count}}個鉤子", + "ko-KR": "{{count}}개 훅", + "no": "{{count}} krok", + "ar": "{{count}} خطاف", + "de": "{{count}} Hook", + "fr": "{{count}} hook", + "it": "{{count}} hook", + "pt": "{{count}} hook", + "es": "{{count}} hook", + "tr": "{{count}} kanca", + "uk": "{{count}} хук" + }, + "HOOKS_MODAL$TYPE": { + "en": "Type: {{type}}", + "ja": "タイプ: {{type}}", + "zh-CN": "类型: {{type}}", + "zh-TW": "類型: {{type}}", + "ko-KR": "유형: {{type}}", + "no": "Type: {{type}}", + "ar": "النوع: {{type}}", + "de": "Typ: {{type}}", + "fr": "Type: {{type}}", + "it": "Tipo: {{type}}", + "pt": "Tipo: {{type}}", + "es": "Tipo: {{type}}", + "tr": "Tür: {{type}}", + "uk": "Тип: {{type}}" + }, + "HOOKS_MODAL$TIMEOUT": { + "en": "Timeout: {{timeout}}s", + "ja": "タイムアウト: {{timeout}}秒", + "zh-CN": "超时: {{timeout}}秒", + "zh-TW": "超時: {{timeout}}秒", + "ko-KR": "타임아웃: {{timeout}}초", + "no": "Tidsavbrudd: {{timeout}}s", + "ar": "المهلة: {{timeout}} ثانية", + "de": "Timeout: {{timeout}}s", + "fr": "Délai: {{timeout}}s", + "it": "Timeout: {{timeout}}s", + "pt": "Tempo limite: {{timeout}}s", + "es": "Tiempo de espera: {{timeout}}s", + "tr": "Zaman aşımı: {{timeout}}s", + "uk": "Таймаут: {{timeout}}с" + }, + "HOOKS_MODAL$ASYNC": { + "en": "Async", + "ja": "非同期", + "zh-CN": "异步", + "zh-TW": "非同步", + "ko-KR": "비동기", + "no": "Asynkron", + "ar": "غير متزامن", + "de": "Asynchron", + "fr": "Asynchrone", + "it": "Asincrono", + "pt": "Assíncrono", + "es": "Asíncrono", + "tr": "Asenkron", + "uk": "Асинхронний" + }, + "HOOKS_MODAL$EVENT_PRE_TOOL_USE": { + "en": "Pre Tool Use", + "ja": "ツール使用前", + "zh-CN": "工具使用前", + "zh-TW": "工具使用前", + "ko-KR": "도구 사용 전", + "no": "Før verktøybruk", + "ar": "قبل استخدام الأداة", + "de": "Vor Werkzeugnutzung", + "fr": "Avant utilisation de l'outil", + "it": "Prima dell'uso dello strumento", + "pt": "Antes do uso da ferramenta", + "es": "Antes del uso de la herramienta", + "tr": "Araç kullanımı öncesi", + "uk": "Перед використанням інструменту" + }, + "HOOKS_MODAL$EVENT_POST_TOOL_USE": { + "en": "Post Tool Use", + "ja": "ツール使用後", + "zh-CN": "工具使用后", + "zh-TW": "工具使用後", + "ko-KR": "도구 사용 후", + "no": "Etter verktøybruk", + "ar": "بعد استخدام الأداة", + "de": "Nach Werkzeugnutzung", + "fr": "Après utilisation de l'outil", + "it": "Dopo l'uso dello strumento", + "pt": "Após o uso da ferramenta", + "es": "Después del uso de la herramienta", + "tr": "Araç kullanımı sonrası", + "uk": "Після використання інструменту" + }, + "HOOKS_MODAL$EVENT_USER_PROMPT_SUBMIT": { + "en": "User Prompt Submit", + "ja": "ユーザープロンプト送信", + "zh-CN": "用户提示提交", + "zh-TW": "使用者提示提交", + "ko-KR": "사용자 프롬프트 제출", + "no": "Brukerforespørsel sendt", + "ar": "إرسال طلب المستخدم", + "de": "Benutzeranfrage gesendet", + "fr": "Soumission de l'invite utilisateur", + "it": "Invio prompt utente", + "pt": "Envio de prompt do usuário", + "es": "Envío de solicitud del usuario", + "tr": "Kullanıcı istemi gönderimi", + "uk": "Надсилання запиту користувача" + }, + "HOOKS_MODAL$EVENT_SESSION_START": { + "en": "Session Start", + "ja": "セッション開始", + "zh-CN": "会话开始", + "zh-TW": "會話開始", + "ko-KR": "세션 시작", + "no": "Øktstart", + "ar": "بدء الجلسة", + "de": "Sitzungsstart", + "fr": "Début de session", + "it": "Inizio sessione", + "pt": "Início da sessão", + "es": "Inicio de sesión", + "tr": "Oturum başlangıcı", + "uk": "Початок сесії" + }, + "HOOKS_MODAL$EVENT_SESSION_END": { + "en": "Session End", + "ja": "セッション終了", + "zh-CN": "会话结束", + "zh-TW": "會話結束", + "ko-KR": "세션 종료", + "no": "Øktslutt", + "ar": "نهاية الجلسة", + "de": "Sitzungsende", + "fr": "Fin de session", + "it": "Fine sessione", + "pt": "Fim da sessão", + "es": "Fin de sesión", + "tr": "Oturum sonu", + "uk": "Кінець сесії" + }, + "HOOKS_MODAL$EVENT_STOP": { + "en": "Stop", + "ja": "停止", + "zh-CN": "停止", + "zh-TW": "停止", + "ko-KR": "중지", + "no": "Stopp", + "ar": "إيقاف", + "de": "Stopp", + "fr": "Arrêt", + "it": "Stop", + "pt": "Parar", + "es": "Detener", + "tr": "Durdur", + "uk": "Зупинка" + }, + "HOOK$HOOK_LABEL": { + "en": "Hook", + "ja": "フック", + "zh-CN": "钩子", + "zh-TW": "鈎子", + "ko-KR": "훅", + "no": "Krok", + "ar": "خطاف", + "de": "Hook", + "fr": "Crochet", + "it": "Hook", + "pt": "Hook", + "es": "Gancho", + "tr": "Kanca", + "uk": "Хук" + }, + "HOOK$COMMAND": { + "en": "Command", + "ja": "コマンド", + "zh-CN": "命令", + "zh-TW": "命令", + "ko-KR": "명령", + "no": "Kommando", + "ar": "أمر", + "de": "Befehl", + "fr": "Commande", + "it": "Comando", + "pt": "Comando", + "es": "Comando", + "tr": "Komut", + "uk": "Команда" + }, + "HOOK$EXIT_CODE": { + "en": "Exit code", + "ja": "終了コード", + "zh-CN": "退出码", + "zh-TW": "退出碼", + "ko-KR": "종료 코드", + "no": "Avslutningskode", + "ar": "رمز الخروج", + "de": "Exit-Code", + "fr": "Code de sortie", + "it": "Codice di uscita", + "pt": "Código de saída", + "es": "Código de salida", + "tr": "Çıkış kodu", + "uk": "Код виходу" + }, + "HOOK$BLOCKED_REASON": { + "en": "Blocked reason", + "ja": "ブロック理由", + "zh-CN": "阻止原因", + "zh-TW": "阻止原因", + "ko-KR": "차단 이유", + "no": "Blokkert grunn", + "ar": "سبب الحظر", + "de": "Blockierungsgrund", + "fr": "Raison du blocage", + "it": "Motivo del blocco", + "pt": "Motivo do bloqueio", + "es": "Motivo del bloqueo", + "tr": "Engelleme nedeni", + "uk": "Причина блокування" + }, + "HOOK$CONTEXT": { + "en": "Context", + "ja": "コンテキスト", + "zh-CN": "上下文", + "zh-TW": "上下文", + "ko-KR": "컨텍스트", + "no": "Kontekst", + "ar": "سياق", + "de": "Kontext", + "fr": "Contexte", + "it": "Contesto", + "pt": "Contexto", + "es": "Contexto", + "tr": "Bağlam", + "uk": "Контекст" + }, + "HOOK$ERROR": { + "en": "Error", + "ja": "エラー", + "zh-CN": "错误", + "zh-TW": "錯誤", + "ko-KR": "오류", + "no": "Feil", + "ar": "خطأ", + "de": "Fehler", + "fr": "Erreur", + "it": "Errore", + "pt": "Erro", + "es": "Error", + "tr": "Hata", + "uk": "Помилка" + }, + "HOOK$OUTPUT": { + "en": "Output", + "ja": "出力", + "zh-CN": "输出", + "zh-TW": "輸出", + "ko-KR": "출력", + "no": "Utdata", + "ar": "الإخراج", + "de": "Ausgabe", + "fr": "Sortie", + "it": "Output", + "pt": "Saída", + "es": "Salida", + "tr": "Çıktı", + "uk": "Вивід" + }, + "HOOK$STDERR": { + "en": "Stderr", + "ja": "標準エラー", + "zh-CN": "标准错误", + "zh-TW": "標準錯誤", + "ko-KR": "표준 오류", + "no": "Standardfeil", + "ar": "خطأ قياسي", + "de": "Standardfehler", + "fr": "Erreur standard", + "it": "Errore standard", + "pt": "Erro padrão", + "es": "Error estándar", + "tr": "Standart hata", + "uk": "Стандартна помилка" + }, "COMMON$TYPE_EMAIL_AND_PRESS_SPACE": { "en": "Type email and press Space", "ja": "メールアドレスを入力してスペースキーを押してください", diff --git a/frontend/src/types/core/base.ts b/frontend/src/types/core/base.ts index e305bf7d4d..97f56b245d 100644 --- a/frontend/src/types/core/base.ts +++ b/frontend/src/types/core/base.ts @@ -21,7 +21,7 @@ export type OpenHandsEventType = | "task_tracking" | "user_rejected"; -export type OpenHandsSourceType = "agent" | "user" | "environment"; +export type OpenHandsSourceType = "agent" | "user" | "environment" | "hook"; interface OpenHandsBaseEvent { id: number; diff --git a/frontend/src/types/v1/core/base/common.ts b/frontend/src/types/v1/core/base/common.ts index 3e03cc1484..56777d527d 100644 --- a/frontend/src/types/v1/core/base/common.ts +++ b/frontend/src/types/v1/core/base/common.ts @@ -53,7 +53,7 @@ export type EventID = string; export type ToolCallID = string; // Source type for events -export type SourceType = "agent" | "user" | "environment"; +export type SourceType = "agent" | "user" | "environment" | "hook"; // Security risk levels export enum SecurityRisk { diff --git a/frontend/src/types/v1/core/events/hook-execution-event.ts b/frontend/src/types/v1/core/events/hook-execution-event.ts new file mode 100644 index 0000000000..7495a754ee --- /dev/null +++ b/frontend/src/types/v1/core/events/hook-execution-event.ts @@ -0,0 +1,100 @@ +import { BaseEvent } from "../base/event"; + +/** + * Hook event types supported by the system + */ +export type HookEventType = + | "PreToolUse" + | "PostToolUse" + | "UserPromptSubmit" + | "SessionStart" + | "SessionEnd" + | "Stop"; + +/** + * HookExecutionEvent - emitted when a hook script executes + * + * Provides observability into hook execution for PreToolUse, PostToolUse, + * UserPromptSubmit, SessionStart, SessionEnd, and Stop hooks. + */ +export interface HookExecutionEvent extends BaseEvent { + /** + * Discriminator field for type guards + */ + kind: "HookExecutionEvent"; + + /** + * The source is always "hook" for hook execution events + */ + source: "hook"; + + /** + * Type of hook that was executed + */ + hook_event_type: HookEventType; + + /** + * The command that was executed + */ + hook_command: string; + + /** + * Whether the hook executed successfully + */ + success: boolean; + + /** + * Whether the hook blocked the action + */ + blocked: boolean; + + /** + * Exit code from the hook script (null if not applicable) + */ + exit_code: number | null; + + /** + * Reason provided by the hook for blocking (if blocked) + */ + reason: string | null; + + /** + * Name of the tool (for PreToolUse/PostToolUse hooks) + */ + tool_name: string | null; + + /** + * ID of the related action event (for tool hooks) + */ + action_id: string | null; + + /** + * ID of the related message event (for UserPromptSubmit hooks) + */ + message_id: string | null; + + /** + * Standard output from the hook script + */ + stdout: string | null; + + /** + * Standard error from the hook script + */ + stderr: string | null; + + /** + * Error message if the hook failed + */ + error: string | null; + + /** + * Additional context provided by the hook + */ + additional_context: string | null; + + /** + * Input data that was passed to the hook + */ + hook_input: Record | null; +} diff --git a/frontend/src/types/v1/core/events/index.ts b/frontend/src/types/v1/core/events/index.ts index e3d3ee6cb4..388002a52f 100644 --- a/frontend/src/types/v1/core/events/index.ts +++ b/frontend/src/types/v1/core/events/index.ts @@ -2,6 +2,7 @@ export * from "./action-event"; export * from "./condensation-event"; export * from "./conversation-state-event"; +export * from "./hook-execution-event"; export * from "./message-event"; export * from "./observation-event"; export * from "./pause-event"; diff --git a/frontend/src/types/v1/core/openhands-event.ts b/frontend/src/types/v1/core/openhands-event.ts index 4793c5a0ae..fc3a46f714 100644 --- a/frontend/src/types/v1/core/openhands-event.ts +++ b/frontend/src/types/v1/core/openhands-event.ts @@ -11,6 +11,7 @@ import { CondensationSummaryEvent, ConversationStateUpdateEvent, ConversationErrorEvent, + HookExecutionEvent, PauseEvent, } from "./events/index"; @@ -26,6 +27,8 @@ export type OpenHandsEvent = | UserRejectObservation | AgentErrorEvent | SystemPromptEvent + // Hook events + | HookExecutionEvent // Conversation management events | CondensationEvent | CondensationRequestEvent diff --git a/frontend/src/types/v1/type-guards.ts b/frontend/src/types/v1/type-guards.ts index dec1816209..b4fa1c9f5f 100644 --- a/frontend/src/types/v1/type-guards.ts +++ b/frontend/src/types/v1/type-guards.ts @@ -20,6 +20,7 @@ import { ConversationStateUpdateEventStats, ConversationErrorEvent, } from "./core/events/conversation-state-event"; +import { HookExecutionEvent } from "./core/events/hook-execution-event"; import { SystemPromptEvent } from "./core/events/system-event"; import type { OpenHandsParsedEvent } from "../core/index"; @@ -42,7 +43,8 @@ export function isBaseEvent(value: unknown): value is BaseEvent { typeof value.source === "string" && (value.source === "agent" || value.source === "user" || - value.source === "environment") + value.source === "environment" || + value.source === "hook") ); } @@ -191,6 +193,14 @@ export const isConversationErrorEvent = ( ): event is ConversationErrorEvent => "kind" in event && event.kind === "ConversationErrorEvent"; +/** + * Type guard function to check if an event is a hook execution event + */ +export const isHookExecutionEvent = ( + event: OpenHandsEvent, +): event is HookExecutionEvent => + "kind" in event && event.kind === "HookExecutionEvent"; + // ============================================================================= // TEMPORARY COMPATIBILITY TYPE GUARDS // These will be removed once we fully migrate to V1 events diff --git a/openhands/app_server/app_conversation/app_conversation_models.py b/openhands/app_server/app_conversation/app_conversation_models.py index b7a4cc4dce..f4a4467e8c 100644 --- a/openhands/app_server/app_conversation/app_conversation_models.py +++ b/openhands/app_server/app_conversation/app_conversation_models.py @@ -242,3 +242,32 @@ class SkillResponse(BaseModel): type: Literal['repo', 'knowledge', 'agentskills'] content: str triggers: list[str] = [] + + +class HookDefinitionResponse(BaseModel): + """Response model for a single hook definition.""" + + type: str # 'command' or 'prompt' + command: str + timeout: int = 60 + async_: bool = Field(default=False, serialization_alias='async') + + +class HookMatcherResponse(BaseModel): + """Response model for a hook matcher.""" + + matcher: str # Pattern: '*', exact match, or regex + hooks: list[HookDefinitionResponse] = [] + + +class HookEventResponse(BaseModel): + """Response model for hooks of a specific event type.""" + + event_type: str # e.g., 'stop', 'pre_tool_use', 'post_tool_use' + matchers: list[HookMatcherResponse] = [] + + +class GetHooksResponse(BaseModel): + """Response model for hooks endpoint.""" + + hooks: list[HookEventResponse] = [] diff --git a/openhands/app_server/app_conversation/app_conversation_router.py b/openhands/app_server/app_conversation/app_conversation_router.py index 50a8497a85..6babd41dc0 100644 --- a/openhands/app_server/app_conversation/app_conversation_router.py +++ b/openhands/app_server/app_conversation/app_conversation_router.py @@ -5,43 +5,29 @@ import logging import os import sys import tempfile +from dataclasses import dataclass from datetime import datetime from typing import Annotated, AsyncGenerator, Literal from uuid import UUID import httpx - -from openhands.app_server.services.db_session_injector import set_db_session_keep_open -from openhands.app_server.services.httpx_client_injector import ( - set_httpx_client_keep_open, -) -from openhands.app_server.services.injector import InjectorState -from openhands.app_server.user.specifiy_user_context import USER_CONTEXT_ATTR -from openhands.app_server.user.user_context import UserContext -from openhands.server.dependencies import get_dependencies - -# Handle anext compatibility for Python < 3.10 -if sys.version_info >= (3, 10): - from builtins import anext -else: - - async def anext(async_iterator): - """Compatibility function for anext in Python < 3.10""" - return await async_iterator.__anext__() - - from fastapi import APIRouter, HTTPException, Query, Request, Response, status from fastapi.responses import JSONResponse, StreamingResponse from sqlalchemy.ext.asyncio import AsyncSession from openhands.app_server.app_conversation.app_conversation_models import ( AppConversation, + AppConversationInfo, AppConversationPage, AppConversationStartRequest, AppConversationStartTask, AppConversationStartTaskPage, AppConversationStartTaskSortOrder, AppConversationUpdateRequest, + GetHooksResponse, + HookDefinitionResponse, + HookEventResponse, + HookMatcherResponse, SkillResponse, ) from openhands.app_server.app_conversation.app_conversation_service import ( @@ -66,15 +52,35 @@ from openhands.app_server.config import ( ) from openhands.app_server.sandbox.sandbox_models import ( AGENT_SERVER, + SandboxInfo, SandboxStatus, ) from openhands.app_server.sandbox.sandbox_service import SandboxService +from openhands.app_server.sandbox.sandbox_spec_models import SandboxSpecInfo from openhands.app_server.sandbox.sandbox_spec_service import SandboxSpecService +from openhands.app_server.services.db_session_injector import set_db_session_keep_open +from openhands.app_server.services.httpx_client_injector import ( + set_httpx_client_keep_open, +) +from openhands.app_server.services.injector import InjectorState +from openhands.app_server.user.specifiy_user_context import USER_CONTEXT_ATTR +from openhands.app_server.user.user_context import UserContext from openhands.app_server.utils.docker_utils import ( replace_localhost_hostname_for_docker, ) from openhands.sdk.context.skills import KeywordTrigger, TaskTrigger from openhands.sdk.workspace.remote.async_remote_workspace import AsyncRemoteWorkspace +from openhands.server.dependencies import get_dependencies + +# Handle anext compatibility for Python < 3.10 +if sys.version_info >= (3, 10): + from builtins import anext +else: + + async def anext(async_iterator): + """Compatibility function for anext in Python < 3.10""" + return await async_iterator.__anext__() + # We use the get_dependencies method here to signal to the OpenAPI docs that this endpoint # is protected. The actual protection is provided by SetAuthCookieMiddleware @@ -92,6 +98,96 @@ httpx_client_dependency = depends_httpx_client() sandbox_service_dependency = depends_sandbox_service() sandbox_spec_service_dependency = depends_sandbox_spec_service() + +@dataclass +class AgentServerContext: + """Context for accessing the agent server for a conversation.""" + + conversation: AppConversationInfo + sandbox: SandboxInfo + sandbox_spec: SandboxSpecInfo + agent_server_url: str + session_api_key: str | None + + +async def _get_agent_server_context( + conversation_id: UUID, + app_conversation_service: AppConversationService, + sandbox_service: SandboxService, + sandbox_spec_service: SandboxSpecService, +) -> AgentServerContext | JSONResponse: + """Get the agent server context for a conversation. + + This helper retrieves all necessary information to communicate with the + agent server for a given conversation, including the sandbox info, + sandbox spec, and agent server URL. + + Args: + conversation_id: The conversation ID + app_conversation_service: Service for conversation operations + sandbox_service: Service for sandbox operations + sandbox_spec_service: Service for sandbox spec operations + + Returns: + AgentServerContext if successful, or JSONResponse with error details. + """ + # Get the conversation info + conversation = await app_conversation_service.get_app_conversation(conversation_id) + if not conversation: + return JSONResponse( + status_code=status.HTTP_404_NOT_FOUND, + content={'error': f'Conversation {conversation_id} not found'}, + ) + + # Get the sandbox info + sandbox = await sandbox_service.get_sandbox(conversation.sandbox_id) + if not sandbox or sandbox.status != SandboxStatus.RUNNING: + return JSONResponse( + status_code=status.HTTP_404_NOT_FOUND, + content={ + 'error': f'Sandbox not found or not running for conversation {conversation_id}' + }, + ) + + # Get the sandbox spec to find the working directory + sandbox_spec = await sandbox_spec_service.get_sandbox_spec(sandbox.sandbox_spec_id) + if not sandbox_spec: + # TODO: This is a temporary work around for the fact that we don't store previous + # sandbox spec versions when updating OpenHands. When the SandboxSpecServices + # transition to truly multi sandbox spec model this should raise a 404 error + logger.warning('Sandbox spec not found - using default.') + sandbox_spec = await sandbox_spec_service.get_default_sandbox_spec() + + # Get the agent server URL + if not sandbox.exposed_urls: + return JSONResponse( + status_code=status.HTTP_404_NOT_FOUND, + content={'error': 'No agent server URL found for sandbox'}, + ) + + agent_server_url = None + for exposed_url in sandbox.exposed_urls: + if exposed_url.name == AGENT_SERVER: + agent_server_url = exposed_url.url + break + + if not agent_server_url: + return JSONResponse( + status_code=status.HTTP_404_NOT_FOUND, + content={'error': 'Agent server URL not found in sandbox'}, + ) + + agent_server_url = replace_localhost_hostname_for_docker(agent_server_url) + + return AgentServerContext( + conversation=conversation, + sandbox=sandbox, + sandbox_spec=sandbox_spec, + agent_server_url=agent_server_url, + session_api_key=sandbox.session_api_key, + ) + + # Read methods @@ -493,57 +589,15 @@ async def get_conversation_skills( JSONResponse: A JSON response containing the list of skills. """ try: - # Get the conversation info - conversation = await app_conversation_service.get_app_conversation( - conversation_id + # Get agent server context (conversation, sandbox, sandbox_spec, agent_server_url) + ctx = await _get_agent_server_context( + conversation_id, + app_conversation_service, + sandbox_service, + sandbox_spec_service, ) - if not conversation: - return JSONResponse( - status_code=status.HTTP_404_NOT_FOUND, - content={'error': f'Conversation {conversation_id} not found'}, - ) - - # Get the sandbox info - sandbox = await sandbox_service.get_sandbox(conversation.sandbox_id) - if not sandbox or sandbox.status != SandboxStatus.RUNNING: - return JSONResponse( - status_code=status.HTTP_404_NOT_FOUND, - content={ - 'error': f'Sandbox not found or not running for conversation {conversation_id}' - }, - ) - - # Get the sandbox spec to find the working directory - sandbox_spec = await sandbox_spec_service.get_sandbox_spec( - sandbox.sandbox_spec_id - ) - if not sandbox_spec: - # TODO: This is a temporary work around for the fact that we don't store previous - # sandbox spec versions when updating OpenHands. When the SandboxSpecServices - # transition to truly multi sandbox spec model this should raise a 404 error - logger.warning('Sandbox spec not found - using default.') - sandbox_spec = await sandbox_spec_service.get_default_sandbox_spec() - - # Get the agent server URL - if not sandbox.exposed_urls: - return JSONResponse( - status_code=status.HTTP_404_NOT_FOUND, - content={'error': 'No agent server URL found for sandbox'}, - ) - - agent_server_url = None - for exposed_url in sandbox.exposed_urls: - if exposed_url.name == AGENT_SERVER: - agent_server_url = exposed_url.url - break - - if not agent_server_url: - return JSONResponse( - status_code=status.HTTP_404_NOT_FOUND, - content={'error': 'Agent server URL not found in sandbox'}, - ) - - agent_server_url = replace_localhost_hostname_for_docker(agent_server_url) + if isinstance(ctx, JSONResponse): + return ctx # Load skills from all sources logger.info(f'Loading skills for conversation {conversation_id}') @@ -552,13 +606,13 @@ async def get_conversation_skills( all_skills: list = [] if isinstance(app_conversation_service, AppConversationServiceBase): project_dir = get_project_dir( - sandbox_spec.working_dir, conversation.selected_repository + ctx.sandbox_spec.working_dir, ctx.conversation.selected_repository ) all_skills = await app_conversation_service.load_and_merge_all_skills( - sandbox, - conversation.selected_repository, + ctx.sandbox, + ctx.conversation.selected_repository, project_dir, - agent_server_url, + ctx.agent_server_url, ) logger.info( @@ -608,6 +662,147 @@ async def get_conversation_skills( ) +@router.get('/{conversation_id}/hooks') +async def get_conversation_hooks( + conversation_id: UUID, + app_conversation_service: AppConversationService = ( + app_conversation_service_dependency + ), + sandbox_service: SandboxService = sandbox_service_dependency, + sandbox_spec_service: SandboxSpecService = sandbox_spec_service_dependency, + httpx_client: httpx.AsyncClient = httpx_client_dependency, +) -> JSONResponse: + """Get hooks currently configured in the workspace for this conversation. + + This endpoint loads hooks from the conversation's project directory in the + workspace (i.e. `{project_dir}/.openhands/hooks.json`) at request time. + + Note: + This is intentionally a "live" view of the workspace configuration. + If `.openhands/hooks.json` changes over time, this endpoint reflects the + latest file content and may not match the hooks that were used when the + conversation originally started. + + Returns: + JSONResponse: A JSON response containing the list of hook event types. + """ + try: + # Get agent server context (conversation, sandbox, sandbox_spec, agent_server_url) + ctx = await _get_agent_server_context( + conversation_id, + app_conversation_service, + sandbox_service, + sandbox_spec_service, + ) + if isinstance(ctx, JSONResponse): + return ctx + + from openhands.app_server.app_conversation.hook_loader import ( + fetch_hooks_from_agent_server, + get_project_dir_for_hooks, + ) + + project_dir = get_project_dir_for_hooks( + ctx.sandbox_spec.working_dir, + ctx.conversation.selected_repository, + ) + + # Load hooks from agent-server (using the error-raising variant so + # HTTP/connection failures are surfaced to the user, not hidden). + logger.debug( + f'Loading hooks for conversation {conversation_id}, ' + f'agent_server_url={ctx.agent_server_url}, ' + f'project_dir={project_dir}' + ) + + try: + hook_config = await fetch_hooks_from_agent_server( + agent_server_url=ctx.agent_server_url, + session_api_key=ctx.session_api_key, + project_dir=project_dir, + httpx_client=httpx_client, + ) + except httpx.HTTPStatusError as e: + logger.warning( + f'Agent-server returned {e.response.status_code} when loading hooks ' + f'for conversation {conversation_id}: {e.response.text}' + ) + return JSONResponse( + status_code=status.HTTP_502_BAD_GATEWAY, + content={ + 'error': f'Agent-server returned status {e.response.status_code} when loading hooks' + }, + ) + except httpx.RequestError as e: + logger.warning( + f'Failed to reach agent-server when loading hooks ' + f'for conversation {conversation_id}: {e}' + ) + return JSONResponse( + status_code=status.HTTP_502_BAD_GATEWAY, + content={'error': 'Failed to reach agent-server when loading hooks'}, + ) + + # Transform hook_config to response format + hooks_response: list[HookEventResponse] = [] + + if hook_config: + # Define the event types to check + event_types = [ + 'pre_tool_use', + 'post_tool_use', + 'user_prompt_submit', + 'session_start', + 'session_end', + 'stop', + ] + + for field_name in event_types: + matchers = getattr(hook_config, field_name, []) + if matchers: + matcher_responses = [] + for matcher in matchers: + hook_defs = [ + HookDefinitionResponse( + type=hook.type.value + if hasattr(hook.type, 'value') + else str(hook.type), + command=hook.command, + timeout=hook.timeout, + async_=hook.async_, + ) + for hook in matcher.hooks + ] + matcher_responses.append( + HookMatcherResponse( + matcher=matcher.matcher, + hooks=hook_defs, + ) + ) + hooks_response.append( + HookEventResponse( + event_type=field_name, + matchers=matcher_responses, + ) + ) + + logger.debug( + f'Loaded {len(hooks_response)} hook event types for conversation {conversation_id}' + ) + + return JSONResponse( + status_code=status.HTTP_200_OK, + content=GetHooksResponse(hooks=hooks_response).model_dump(by_alias=True), + ) + + except Exception as e: + logger.error(f'Error getting hooks for conversation {conversation_id}: {e}') + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={'error': f'Error getting hooks: {str(e)}'}, + ) + + @router.get('/{conversation_id}/download') async def export_conversation( conversation_id: UUID, diff --git a/openhands/app_server/app_conversation/hook_loader.py b/openhands/app_server/app_conversation/hook_loader.py new file mode 100644 index 0000000000..619ce68d43 --- /dev/null +++ b/openhands/app_server/app_conversation/hook_loader.py @@ -0,0 +1,148 @@ +"""Utilities for loading hooks for V1 conversations. + +This module provides functions to load hooks from the agent-server, +which centralizes all hook loading logic. The app-server acts as a +thin proxy that calls the agent-server's /api/hooks endpoint. + +All hook loading is handled by the agent-server. +""" + +import logging + +import httpx + +from openhands.sdk.hooks import HookConfig + +_logger = logging.getLogger(__name__) + + +def get_project_dir_for_hooks( + working_dir: str, + selected_repository: str | None = None, +) -> str: + """Get the project directory path for loading hooks. + + When a repository is selected, hooks are loaded from + {working_dir}/{repo_name}/.openhands/hooks.json. + Otherwise, hooks are loaded from {working_dir}/.openhands/hooks.json. + + Args: + working_dir: Base working directory path in the sandbox + selected_repository: Repository name (e.g., 'OpenHands/software-agent-sdk') + If provided, the repo name is appended to working_dir. + + Returns: + The project directory path where hooks.json should be located. + """ + if selected_repository: + repo_name = selected_repository.split('/')[-1] + return f'{working_dir}/{repo_name}' + return working_dir + + +async def fetch_hooks_from_agent_server( + agent_server_url: str, + session_api_key: str | None, + project_dir: str, + httpx_client: httpx.AsyncClient, +) -> HookConfig | None: + """Fetch hooks from the agent-server, raising on HTTP/connection errors. + + This is the low-level function that makes a single API call to the + agent-server's /api/hooks endpoint. It raises on HTTP and connection + errors so callers can decide how to handle failures. + + Args: + agent_server_url: URL of the agent server (e.g., 'http://localhost:8000') + session_api_key: Session API key for authentication (optional) + project_dir: Workspace directory path for project hooks + httpx_client: Shared HTTP client for making the request + + Returns: + HookConfig if hooks.json exists and is valid, None if no hooks found. + + Raises: + httpx.HTTPStatusError: If the agent-server returns a non-2xx status. + httpx.RequestError: If the agent-server is unreachable. + """ + _logger.debug( + f'fetch_hooks_from_agent_server called: ' + f'agent_server_url={agent_server_url}, project_dir={project_dir}' + ) + payload = {'project_dir': project_dir} + + headers = {'Content-Type': 'application/json'} + if session_api_key: + headers['X-Session-API-Key'] = session_api_key + + response = await httpx_client.post( + f'{agent_server_url}/api/hooks', + json=payload, + headers=headers, + timeout=30.0, + ) + response.raise_for_status() + + data = response.json() + + hook_config_data = data.get('hook_config') + if hook_config_data is None: + _logger.debug('No hooks found in workspace') + return None + + hook_config = HookConfig.from_dict(hook_config_data) + + if hook_config.is_empty(): + _logger.debug('Hooks config is empty') + return None + + _logger.debug(f'Loaded hooks from agent-server for {project_dir}') + return hook_config + + +async def load_hooks_from_agent_server( + agent_server_url: str, + session_api_key: str | None, + project_dir: str, + httpx_client: httpx.AsyncClient, +) -> HookConfig | None: + """Load hooks from the agent-server, swallowing errors gracefully. + + Wrapper around fetch_hooks_from_agent_server that catches all errors + and returns None. Use this for the conversation-start path where hooks + are optional and failures should not block startup. + + For the hooks viewer endpoint, use fetch_hooks_from_agent_server directly + so errors can be surfaced to the user. + + Args: + agent_server_url: URL of the agent server (e.g., 'http://localhost:8000') + session_api_key: Session API key for authentication (optional) + project_dir: Workspace directory path for project hooks + httpx_client: Shared HTTP client for making the request + + Returns: + HookConfig if hooks.json exists and is valid, None otherwise. + """ + try: + return await fetch_hooks_from_agent_server( + agent_server_url, session_api_key, project_dir, httpx_client + ) + except httpx.HTTPStatusError as e: + _logger.warning( + f'Agent-server at {agent_server_url} returned error status {e.response.status_code} ' + f'when loading hooks from {project_dir}: {e.response.text}' + ) + return None + except httpx.RequestError as e: + _logger.warning( + f'Failed to connect to agent-server at {agent_server_url} ' + f'when loading hooks from {project_dir}: {e}' + ) + return None + except Exception as e: + _logger.warning( + f'Failed to load hooks from agent-server at {agent_server_url} ' + f'for project {project_dir}: {e}' + ) + return None diff --git a/openhands/app_server/app_conversation/live_status_app_conversation_service.py b/openhands/app_server/app_conversation/live_status_app_conversation_service.py index fe07f205c1..902cde7771 100644 --- a/openhands/app_server/app_conversation/live_status_app_conversation_service.py +++ b/openhands/app_server/app_conversation/live_status_app_conversation_service.py @@ -46,6 +46,9 @@ from openhands.app_server.app_conversation.app_conversation_service_base import from openhands.app_server.app_conversation.app_conversation_start_task_service import ( AppConversationStartTaskService, ) +from openhands.app_server.app_conversation.hook_loader import ( + load_hooks_from_agent_server, +) from openhands.app_server.app_conversation.sql_app_conversation_info_service import ( SQLAppConversationInfoService, ) @@ -84,6 +87,7 @@ from openhands.app_server.utils.llm_metadata import ( from openhands.integrations.provider import ProviderType from openhands.integrations.service_types import SuggestedTask from openhands.sdk import Agent, AgentContext, LocalWorkspace +from openhands.sdk.hooks import HookConfig from openhands.sdk.llm import LLM from openhands.sdk.plugin import PluginSource from openhands.sdk.secret import LookupSecret, SecretValue, StaticSecret @@ -312,6 +316,12 @@ class LiveStatusAppConversationService(AppConversationServiceBase): body_json = start_conversation_request.model_dump( mode='json', context={'expose_secrets': True} ) + # Log hook_config to verify it's being passed + hook_config_in_request = body_json.get('hook_config') + _logger.debug( + f'Sending StartConversationRequest with hook_config: ' + f'{hook_config_in_request}' + ) response = await self.httpx_client.post( f'{agent_server_url}/api/conversations', json=body_json, @@ -1295,6 +1305,46 @@ class LiveStatusAppConversationService(AppConversationServiceBase): run=initial_message.run, ) + async def _load_hooks_from_workspace( + self, + remote_workspace: AsyncRemoteWorkspace, + project_dir: str, + ) -> HookConfig | None: + """Load hooks from .openhands/hooks.json in the remote workspace. + + This enables project-level hooks to be automatically loaded when starting + a conversation, similar to how OpenHands-CLI loads hooks from the workspace. + + Uses the agent-server's /api/hooks endpoint, consistent with how skills + are loaded via /api/skills. + + Args: + remote_workspace: AsyncRemoteWorkspace for accessing the agent server + project_dir: Project root directory path in the sandbox. This should + already be the resolved project directory (e.g., + {working_dir}/{repo_name} when a repo is selected). + + Returns: + HookConfig if hooks.json exists and is valid, None otherwise. + Returns None in the following cases: + - hooks.json file does not exist + - hooks.json contains invalid JSON + - hooks.json contains an empty hooks configuration + - Agent server is unreachable or returns an error + + Note: + This method implements graceful degradation - if hooks cannot be loaded + for any reason, it returns None rather than raising an exception. This + ensures that conversation startup is not blocked by hook loading failures. + Errors are logged as warnings for debugging purposes. + """ + return await load_hooks_from_agent_server( + agent_server_url=remote_workspace.host, + session_api_key=remote_workspace._headers.get('X-Session-API-Key'), + project_dir=project_dir, + httpx_client=self.httpx_client, + ) + async def _finalize_conversation_request( self, agent: Agent, @@ -1334,6 +1384,7 @@ class LiveStatusAppConversationService(AppConversationServiceBase): agent = self._update_agent_with_llm_metadata(agent, conversation_id, user.id) # Load and merge skills if remote workspace is available + hook_config: HookConfig | None = None if remote_workspace: try: agent = await self._load_skills_and_update_agent( @@ -1343,6 +1394,28 @@ class LiveStatusAppConversationService(AppConversationServiceBase): _logger.warning(f'Failed to load skills: {e}', exc_info=True) # Continue without skills - don't fail conversation startup + # Load hooks from workspace (.openhands/hooks.json) + # Note: working_dir is already the resolved project_dir + # (includes repo name when a repo is selected), so we pass + # it directly without appending the repo name again. + try: + _logger.debug( + f'Attempting to load hooks from workspace: ' + f'project_dir={working_dir}' + ) + hook_config = await self._load_hooks_from_workspace( + remote_workspace, working_dir + ) + if hook_config: + _logger.debug( + f'Successfully loaded hooks: {hook_config.model_dump()}' + ) + else: + _logger.debug('No hooks found in workspace') + except Exception as e: + _logger.warning(f'Failed to load hooks: {e}', exc_info=True) + # Continue without hooks - don't fail conversation startup + # Incorporate plugin parameters into initial message if specified final_initial_message = self._construct_initial_message_with_plugin_params( initial_message, plugins @@ -1371,6 +1444,7 @@ class LiveStatusAppConversationService(AppConversationServiceBase): initial_message=final_initial_message, secrets=secrets, plugins=sdk_plugins, + hook_config=hook_config, ) async def _build_start_conversation_request_for_user( diff --git a/tests/unit/app_server/test_app_conversation_hooks_endpoint.py b/tests/unit/app_server/test_app_conversation_hooks_endpoint.py new file mode 100644 index 0000000000..ba67c4b488 --- /dev/null +++ b/tests/unit/app_server/test_app_conversation_hooks_endpoint.py @@ -0,0 +1,293 @@ +"""Unit tests for the V1 hooks endpoint in app_conversation_router. + +This module tests the GET /{conversation_id}/hooks endpoint functionality. +""" + +from unittest.mock import AsyncMock, MagicMock, Mock +from uuid import uuid4 + +import httpx +import pytest +from fastapi import status + +from openhands.app_server.app_conversation.app_conversation_models import ( + AppConversation, +) +from openhands.app_server.app_conversation.app_conversation_router import ( + get_conversation_hooks, +) +from openhands.app_server.sandbox.sandbox_models import ( + AGENT_SERVER, + ExposedUrl, + SandboxInfo, + SandboxStatus, +) +from openhands.app_server.sandbox.sandbox_spec_models import SandboxSpecInfo + + +@pytest.mark.asyncio +class TestGetConversationHooks: + async def test_get_hooks_returns_hook_events(self): + conversation_id = uuid4() + sandbox_id = str(uuid4()) + working_dir = '/workspace' + + mock_conversation = AppConversation( + id=conversation_id, + created_by_user_id='test-user', + sandbox_id=sandbox_id, + selected_repository='owner/repo', + sandbox_status=SandboxStatus.RUNNING, + ) + + mock_sandbox = SandboxInfo( + id=sandbox_id, + created_by_user_id='test-user', + status=SandboxStatus.RUNNING, + sandbox_spec_id=str(uuid4()), + session_api_key='test-api-key', + exposed_urls=[ + ExposedUrl(name=AGENT_SERVER, url='http://agent-server:8000', port=8000) + ], + ) + + mock_sandbox_spec = SandboxSpecInfo( + id=str(uuid4()), command=None, working_dir=working_dir + ) + + mock_app_conversation_service = MagicMock() + mock_app_conversation_service.get_app_conversation = AsyncMock( + return_value=mock_conversation + ) + + mock_sandbox_service = MagicMock() + mock_sandbox_service.get_sandbox = AsyncMock(return_value=mock_sandbox) + + mock_sandbox_spec_service = MagicMock() + mock_sandbox_spec_service.get_sandbox_spec = AsyncMock( + return_value=mock_sandbox_spec + ) + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.raise_for_status = Mock() + mock_response.json.return_value = { + 'hook_config': { + 'stop': [ + { + 'matcher': '*', + 'hooks': [ + { + 'type': 'command', + 'command': '.openhands/hooks/on_stop.sh', + 'timeout': 60, + 'async': True, + } + ], + } + ] + } + } + + mock_httpx_client = AsyncMock(spec=httpx.AsyncClient) + mock_httpx_client.post = AsyncMock(return_value=mock_response) + + response = await get_conversation_hooks( + conversation_id=conversation_id, + app_conversation_service=mock_app_conversation_service, + sandbox_service=mock_sandbox_service, + sandbox_spec_service=mock_sandbox_spec_service, + httpx_client=mock_httpx_client, + ) + + assert response.status_code == status.HTTP_200_OK + + data = __import__('json').loads(response.body.decode('utf-8')) + assert 'hooks' in data + assert data['hooks'] + assert data['hooks'][0]['event_type'] == 'stop' + assert data['hooks'][0]['matchers'][0]['matcher'] == '*' + assert data['hooks'][0]['matchers'][0]['hooks'][0]['type'] == 'command' + assert ( + data['hooks'][0]['matchers'][0]['hooks'][0]['command'] + == '.openhands/hooks/on_stop.sh' + ) + assert data['hooks'][0]['matchers'][0]['hooks'][0]['async'] is True + assert 'async_' not in data['hooks'][0]['matchers'][0]['hooks'][0] + + mock_httpx_client.post.assert_called_once() + called_url = mock_httpx_client.post.call_args[0][0] + assert called_url == 'http://agent-server:8000/api/hooks' + + async def test_get_hooks_returns_502_when_agent_server_unreachable(self): + conversation_id = uuid4() + sandbox_id = str(uuid4()) + + mock_conversation = AppConversation( + id=conversation_id, + created_by_user_id='test-user', + sandbox_id=sandbox_id, + selected_repository=None, + sandbox_status=SandboxStatus.RUNNING, + ) + + mock_sandbox = SandboxInfo( + id=sandbox_id, + created_by_user_id='test-user', + status=SandboxStatus.RUNNING, + sandbox_spec_id=str(uuid4()), + session_api_key='test-api-key', + exposed_urls=[ + ExposedUrl(name=AGENT_SERVER, url='http://agent-server:8000', port=8000) + ], + ) + + mock_sandbox_spec = SandboxSpecInfo( + id=str(uuid4()), command=None, working_dir='/workspace' + ) + + mock_app_conversation_service = MagicMock() + mock_app_conversation_service.get_app_conversation = AsyncMock( + return_value=mock_conversation + ) + + mock_sandbox_service = MagicMock() + mock_sandbox_service.get_sandbox = AsyncMock(return_value=mock_sandbox) + + mock_sandbox_spec_service = MagicMock() + mock_sandbox_spec_service.get_sandbox_spec = AsyncMock( + return_value=mock_sandbox_spec + ) + + mock_httpx_client = AsyncMock(spec=httpx.AsyncClient) + + def _raise_request_error(*args, **_kwargs): + request = httpx.Request('POST', args[0]) + raise httpx.RequestError('Connection error', request=request) + + mock_httpx_client.post = AsyncMock(side_effect=_raise_request_error) + + response = await get_conversation_hooks( + conversation_id=conversation_id, + app_conversation_service=mock_app_conversation_service, + sandbox_service=mock_sandbox_service, + sandbox_spec_service=mock_sandbox_spec_service, + httpx_client=mock_httpx_client, + ) + + assert response.status_code == status.HTTP_502_BAD_GATEWAY + data = __import__('json').loads(response.body.decode('utf-8')) + assert 'error' in data + + async def test_get_hooks_returns_502_when_agent_server_returns_error(self): + conversation_id = uuid4() + sandbox_id = str(uuid4()) + + mock_conversation = AppConversation( + id=conversation_id, + created_by_user_id='test-user', + sandbox_id=sandbox_id, + selected_repository=None, + sandbox_status=SandboxStatus.RUNNING, + ) + + mock_sandbox = SandboxInfo( + id=sandbox_id, + created_by_user_id='test-user', + status=SandboxStatus.RUNNING, + sandbox_spec_id=str(uuid4()), + session_api_key='test-api-key', + exposed_urls=[ + ExposedUrl(name=AGENT_SERVER, url='http://agent-server:8000', port=8000) + ], + ) + + mock_sandbox_spec = SandboxSpecInfo( + id=str(uuid4()), command=None, working_dir='/workspace' + ) + + mock_app_conversation_service = MagicMock() + mock_app_conversation_service.get_app_conversation = AsyncMock( + return_value=mock_conversation + ) + + mock_sandbox_service = MagicMock() + mock_sandbox_service.get_sandbox = AsyncMock(return_value=mock_sandbox) + + mock_sandbox_spec_service = MagicMock() + mock_sandbox_spec_service.get_sandbox_spec = AsyncMock( + return_value=mock_sandbox_spec + ) + + mock_httpx_client = AsyncMock(spec=httpx.AsyncClient) + + mock_response = Mock() + mock_response.status_code = 500 + + def _raise_http_status_error(*args, **_kwargs): + request = httpx.Request('POST', args[0]) + response = httpx.Response(status_code=500, text='Internal Server Error') + raise httpx.HTTPStatusError( + 'Server error', request=request, response=response + ) + + mock_httpx_client.post = AsyncMock(side_effect=_raise_http_status_error) + + response = await get_conversation_hooks( + conversation_id=conversation_id, + app_conversation_service=mock_app_conversation_service, + sandbox_service=mock_sandbox_service, + sandbox_spec_service=mock_sandbox_spec_service, + httpx_client=mock_httpx_client, + ) + + assert response.status_code == status.HTTP_502_BAD_GATEWAY + data = __import__('json').loads(response.body.decode('utf-8')) + assert 'error' in data + + async def test_get_hooks_returns_404_when_conversation_not_found(self): + conversation_id = uuid4() + + mock_app_conversation_service = MagicMock() + mock_app_conversation_service.get_app_conversation = AsyncMock( + return_value=None + ) + + response = await get_conversation_hooks( + conversation_id=conversation_id, + app_conversation_service=mock_app_conversation_service, + sandbox_service=MagicMock(), + sandbox_spec_service=MagicMock(), + httpx_client=AsyncMock(spec=httpx.AsyncClient), + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + + async def test_get_hooks_returns_404_when_sandbox_not_running(self): + conversation_id = uuid4() + sandbox_id = str(uuid4()) + + mock_conversation = AppConversation( + id=conversation_id, + created_by_user_id='test-user', + sandbox_id=sandbox_id, + sandbox_status=SandboxStatus.RUNNING, + ) + + mock_app_conversation_service = MagicMock() + mock_app_conversation_service.get_app_conversation = AsyncMock( + return_value=mock_conversation + ) + + mock_sandbox_service = MagicMock() + mock_sandbox_service.get_sandbox = AsyncMock(return_value=None) + + response = await get_conversation_hooks( + conversation_id=conversation_id, + app_conversation_service=mock_app_conversation_service, + sandbox_service=mock_sandbox_service, + sandbox_spec_service=MagicMock(), + httpx_client=AsyncMock(spec=httpx.AsyncClient), + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND diff --git a/tests/unit/app_server/test_live_status_app_conversation_service.py b/tests/unit/app_server/test_live_status_app_conversation_service.py index fcb251797f..27b0c704c9 100644 --- a/tests/unit/app_server/test_live_status_app_conversation_service.py +++ b/tests/unit/app_server/test_live_status_app_conversation_service.py @@ -123,6 +123,10 @@ class TestLiveStatusAppConversationService: self.mock_sandbox.id = uuid4() self.mock_sandbox.status = SandboxStatus.RUNNING + # Default mock for hooks loading - returns None (no hooks found) + # Tests that specifically test hooks loading can override this mock + self.service._load_hooks_from_workspace = AsyncMock(return_value=None) + def test_apply_suggested_task_sets_prompt_and_trigger(self): """Test suggested task prompts populate initial message and trigger.""" suggested_task = SuggestedTask( @@ -179,6 +183,7 @@ class TestLiveStatusAppConversationService: with pytest.raises(ValueError, match='empty prompt'): self.service._apply_suggested_task(request) + @pytest.mark.asyncio async def test_setup_secrets_for_git_providers_no_provider_tokens(self): """Test _setup_secrets_for_git_providers with no provider tokens.""" # Arrange @@ -1139,6 +1144,8 @@ class TestLiveStatusAppConversationService: side_effect=Exception('Skills loading failed') ) + # Note: hooks loading is already mocked in setup_method() to return None + # Act with patch( 'openhands.app_server.app_conversation.live_status_app_conversation_service._logger' @@ -3144,3 +3151,275 @@ class TestAppConversationStartRequestWithPlugins: assert request.plugins[0].source == 'github:owner/plugin1' assert request.plugins[1].repo_path == 'plugins/sub' assert request.plugins[2].source == '/local/path' + + +class TestLoadHooksFromWorkspace: + """Test cases for _load_hooks_from_workspace method.""" + + def setup_method(self): + """Set up test fixtures.""" + # Create mock dependencies + self.mock_user_context = Mock(spec=UserContext) + self.mock_jwt_service = Mock() + self.mock_sandbox_service = Mock() + self.mock_sandbox_spec_service = Mock() + self.mock_app_conversation_info_service = Mock() + self.mock_app_conversation_start_task_service = Mock() + self.mock_event_callback_service = Mock() + self.mock_event_service = Mock() + self.mock_httpx_client = AsyncMock() + + # Create service instance + self.service = LiveStatusAppConversationService( + init_git_in_empty_workspace=True, + user_context=self.mock_user_context, + app_conversation_info_service=self.mock_app_conversation_info_service, + app_conversation_start_task_service=self.mock_app_conversation_start_task_service, + event_callback_service=self.mock_event_callback_service, + event_service=self.mock_event_service, + sandbox_service=self.mock_sandbox_service, + sandbox_spec_service=self.mock_sandbox_spec_service, + jwt_service=self.mock_jwt_service, + sandbox_startup_timeout=30, + sandbox_startup_poll_frequency=1, + httpx_client=self.mock_httpx_client, + web_url='https://test.example.com', + openhands_provider_base_url='https://provider.example.com', + access_token_hard_timeout=None, + app_mode='test', + ) + + @pytest.mark.asyncio + async def test_load_hooks_from_workspace_success(self): + """Test loading hooks from workspace when hooks.json exists.""" + # Arrange + mock_remote_workspace = Mock(spec=AsyncRemoteWorkspace) + mock_remote_workspace.host = 'http://agent-server:8000' + mock_remote_workspace._headers = {'X-Session-API-Key': 'test-key'} + + hooks_response = { + 'hook_config': { + 'stop': [ + { + 'matcher': '*', + 'hooks': [{'type': 'command', 'command': 'echo "stop hook"'}], + } + ] + } + } + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = hooks_response + mock_response.raise_for_status = Mock() + + self.mock_httpx_client.post = AsyncMock(return_value=mock_response) + + # Act + result = await self.service._load_hooks_from_workspace( + mock_remote_workspace, '/workspace' + ) + + # Assert + assert result is not None + assert not result.is_empty() + self.mock_httpx_client.post.assert_called_once_with( + 'http://agent-server:8000/api/hooks', + json={'project_dir': '/workspace'}, + headers={ + 'Content-Type': 'application/json', + 'X-Session-API-Key': 'test-key', + }, + timeout=30.0, + ) + + @pytest.mark.asyncio + async def test_load_hooks_from_workspace_file_not_found(self): + """Test loading hooks when hooks.json does not exist.""" + # Arrange + mock_remote_workspace = Mock(spec=AsyncRemoteWorkspace) + mock_remote_workspace.host = 'http://agent-server:8000' + mock_remote_workspace._headers = {} + + # Agent server returns hook_config: None when file not found + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {'hook_config': None} + mock_response.raise_for_status = Mock() + + self.mock_httpx_client.post = AsyncMock(return_value=mock_response) + + # Act + result = await self.service._load_hooks_from_workspace( + mock_remote_workspace, '/workspace' + ) + + # Assert + assert result is None + + @pytest.mark.asyncio + async def test_load_hooks_from_workspace_empty_hooks(self): + """Test loading hooks when hooks.json is empty or has no hooks.""" + # Arrange + mock_remote_workspace = Mock(spec=AsyncRemoteWorkspace) + mock_remote_workspace.host = 'http://agent-server:8000' + mock_remote_workspace._headers = {} + + # Agent server returns empty hook_config + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {'hook_config': {}} + mock_response.raise_for_status = Mock() + + self.mock_httpx_client.post = AsyncMock(return_value=mock_response) + + # Act + result = await self.service._load_hooks_from_workspace( + mock_remote_workspace, '/workspace' + ) + + # Assert + assert result is None + + @pytest.mark.asyncio + async def test_load_hooks_from_workspace_http_error(self): + """Test loading hooks when HTTP request fails.""" + # Arrange + mock_remote_workspace = Mock(spec=AsyncRemoteWorkspace) + mock_remote_workspace.host = 'http://agent-server:8000' + mock_remote_workspace._headers = {} + + self.mock_httpx_client.post = AsyncMock( + side_effect=Exception('Connection error') + ) + + # Act + result = await self.service._load_hooks_from_workspace( + mock_remote_workspace, '/workspace' + ) + + # Assert + assert result is None + + def test_get_project_dir_for_hooks_with_selected_repository(self): + """Test get_project_dir_for_hooks with a selected repository.""" + from openhands.app_server.app_conversation.hook_loader import ( + get_project_dir_for_hooks, + ) + + result = get_project_dir_for_hooks( + '/workspace/project', + 'OpenHands/software-agent-sdk', + ) + assert result == '/workspace/project/software-agent-sdk' + + def test_get_project_dir_for_hooks_without_selected_repository(self): + """Test get_project_dir_for_hooks without a selected repository.""" + from openhands.app_server.app_conversation.hook_loader import ( + get_project_dir_for_hooks, + ) + + result = get_project_dir_for_hooks('/workspace/project', None) + assert result == '/workspace/project' + + def test_get_project_dir_for_hooks_with_empty_string(self): + """Test get_project_dir_for_hooks with empty string repository.""" + from openhands.app_server.app_conversation.hook_loader import ( + get_project_dir_for_hooks, + ) + + # Empty string should be treated as no repository + result = get_project_dir_for_hooks('/workspace/project', '') + assert result == '/workspace/project' + + @pytest.mark.asyncio + async def test_load_hooks_from_workspace_with_project_dir(self): + """Test loading hooks with a pre-resolved project_dir. + + The caller is responsible for computing the project_dir (which + already includes the repo name when a repo is selected). + _load_hooks_from_workspace should use the project_dir as-is. + """ + # Arrange + mock_remote_workspace = Mock(spec=AsyncRemoteWorkspace) + mock_remote_workspace.host = 'http://agent-server:8000' + mock_remote_workspace._headers = {'X-Session-API-Key': 'test-key'} + + hooks_response = { + 'hook_config': { + 'stop': [ + { + 'matcher': '*', + 'hooks': [{'type': 'command', 'command': 'echo "stop hook"'}], + } + ] + } + } + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = hooks_response + mock_response.raise_for_status = Mock() + + self.mock_httpx_client.post = AsyncMock(return_value=mock_response) + + # Act - project_dir already includes repo name + result = await self.service._load_hooks_from_workspace( + mock_remote_workspace, + '/workspace/project/software-agent-sdk', + ) + + # Assert + assert result is not None + assert not result.is_empty() + # The project_dir should be passed as-is without doubling + self.mock_httpx_client.post.assert_called_once_with( + 'http://agent-server:8000/api/hooks', + json={'project_dir': '/workspace/project/software-agent-sdk'}, + headers={ + 'Content-Type': 'application/json', + 'X-Session-API-Key': 'test-key', + }, + timeout=30.0, + ) + + @pytest.mark.asyncio + async def test_load_hooks_from_workspace_base_dir(self): + """Test loading hooks with a base workspace directory (no repo selected).""" + # Arrange + mock_remote_workspace = Mock(spec=AsyncRemoteWorkspace) + mock_remote_workspace.host = 'http://agent-server:8000' + mock_remote_workspace._headers = {'X-Session-API-Key': 'test-key'} + + hooks_response = { + 'hook_config': { + 'stop': [ + { + 'matcher': '*', + 'hooks': [{'type': 'command', 'command': 'echo "stop hook"'}], + } + ] + } + } + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = hooks_response + mock_response.raise_for_status = Mock() + + self.mock_httpx_client.post = AsyncMock(return_value=mock_response) + + # Act - no repo selected, project_dir is base working_dir + result = await self.service._load_hooks_from_workspace( + mock_remote_workspace, + '/workspace/project', + ) + + # Assert + assert result is not None + self.mock_httpx_client.post.assert_called_once_with( + 'http://agent-server:8000/api/hooks', + json={'project_dir': '/workspace/project'}, + headers={ + 'Content-Type': 'application/json', + 'X-Session-API-Key': 'test-key', + }, + timeout=30.0, + )