From 6d86803f4154894db19230d66c83e2d18fd3f9dd Mon Sep 17 00:00:00 2001 From: Varun Chawla <34209028+veeceey@users.noreply.github.com> Date: Wed, 18 Mar 2026 11:26:27 -0700 Subject: [PATCH 01/28] Add loading feedback to git changes refresh button (#12792) Co-authored-by: hieptl --- frontend/__tests__/routes/changes-tab.test.tsx | 2 ++ .../conversation-tabs/conversation-tab-title.tsx | 12 +++++++++--- .../src/hooks/query/use-unified-get-git-changes.ts | 1 + 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/frontend/__tests__/routes/changes-tab.test.tsx b/frontend/__tests__/routes/changes-tab.test.tsx index 178bb28c40..1cf2513d18 100644 --- a/frontend/__tests__/routes/changes-tab.test.tsx +++ b/frontend/__tests__/routes/changes-tab.test.tsx @@ -32,6 +32,7 @@ describe("Changes Tab", () => { vi.mocked(useUnifiedGetGitChanges).mockReturnValue({ data: [], isLoading: false, + isFetching: false, isSuccess: true, isError: false, error: null, @@ -50,6 +51,7 @@ describe("Changes Tab", () => { vi.mocked(useUnifiedGetGitChanges).mockReturnValue({ data: [{ path: "src/file.ts", status: "M" }], isLoading: false, + isFetching: false, isSuccess: true, isError: false, error: null, diff --git a/frontend/src/components/features/conversation/conversation-tabs/conversation-tab-title.tsx b/frontend/src/components/features/conversation/conversation-tabs/conversation-tab-title.tsx index 75dbb23f8e..ad3bc98c41 100644 --- a/frontend/src/components/features/conversation/conversation-tabs/conversation-tab-title.tsx +++ b/frontend/src/components/features/conversation/conversation-tabs/conversation-tab-title.tsx @@ -20,7 +20,7 @@ export function ConversationTabTitle({ conversationKey, }: ConversationTabTitleProps) { const { t } = useTranslation(); - const { refetch } = useUnifiedGetGitChanges(); + const { refetch, isFetching } = useUnifiedGetGitChanges(); const { handleBuildPlanClick } = useHandleBuildPlanClick(); const { curAgentState } = useAgentState(); const { planContent } = useConversationStore(); @@ -41,10 +41,16 @@ export function ConversationTabTitle({ {conversationKey === "editor" && ( )} {conversationKey === "planner" && ( diff --git a/frontend/src/hooks/query/use-unified-get-git-changes.ts b/frontend/src/hooks/query/use-unified-get-git-changes.ts index 70bc5f451f..801b1a067a 100644 --- a/frontend/src/hooks/query/use-unified-get-git-changes.ts +++ b/frontend/src/hooks/query/use-unified-get-git-changes.ts @@ -100,6 +100,7 @@ export const useUnifiedGetGitChanges = () => { return { data: orderedChanges, isLoading: result.isLoading, + isFetching: result.isFetching, isSuccess: result.isSuccess, isError: result.isError, error: result.error, From 39a4ca422f35b276d8b508bf52e9fddd9ae5a0d5 Mon Sep 17 00:00:00 2001 From: Robert Brennan Date: Wed, 18 Mar 2026 11:42:46 -0700 Subject: [PATCH 02/28] fix: use sentence case for 'Waiting for sandbox' text (#12958) Co-authored-by: openhands --- frontend/__tests__/utils/utils.test.ts | 4 ++-- frontend/src/utils/utils.ts | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/frontend/__tests__/utils/utils.test.ts b/frontend/__tests__/utils/utils.test.ts index 91e9ba031b..8fc286d4b1 100644 --- a/frontend/__tests__/utils/utils.test.ts +++ b/frontend/__tests__/utils/utils.test.ts @@ -11,7 +11,7 @@ import { I18nKey } from "#/i18n/declaration"; // Mock translations const t = (key: string) => { const translations: { [key: string]: string } = { - COMMON$WAITING_FOR_SANDBOX: "Waiting For Sandbox", + COMMON$WAITING_FOR_SANDBOX: "Waiting for sandbox", COMMON$STOPPING: "Stopping", COMMON$STARTING: "Starting", COMMON$SERVER_STOPPED: "Server stopped", @@ -69,7 +69,7 @@ describe("getStatusText", () => { t, }); - expect(result).toBe(t(I18nKey.COMMON$WAITING_FOR_SANDBOX)); + expect(result).toBe("Waiting for sandbox"); }); it("returns task detail when task status is ERROR and detail exists", () => { diff --git a/frontend/src/utils/utils.ts b/frontend/src/utils/utils.ts index 849d65fbdf..80f40158f6 100644 --- a/frontend/src/utils/utils.ts +++ b/frontend/src/utils/utils.ts @@ -838,7 +838,7 @@ interface GetStatusTextArgs { * isStartingStatus: false, * isStopStatus: false, * curAgentState: AgentState.RUNNING - * }) // Returns "Waiting For Sandbox" + * }) // Returns "Waiting for sandbox" */ export function getStatusText({ isPausing = false, @@ -866,13 +866,13 @@ export function getStatusText({ return t(I18nKey.CONVERSATION$READY); } - // Format status text: "WAITING_FOR_SANDBOX" -> "Waiting for sandbox" + // Format status text with sentence case: "WAITING_FOR_SANDBOX" -> "Waiting for sandbox" return ( taskDetail || taskStatus .toLowerCase() .replace(/_/g, " ") - .replace(/\b\w/g, (c) => c.toUpperCase()) + .replace(/^\w/, (c) => c.toUpperCase()) ); } From db41148396188a7566f412f1053211a655aaea7d Mon Sep 17 00:00:00 2001 From: Hiep Le <69354317+hieptl@users.noreply.github.com> Date: Thu, 19 Mar 2026 01:46:23 +0700 Subject: [PATCH 03/28] feat(backend): expose API key org_id via new GET /api/keys/current endpoint (org project) (#13469) --- enterprise/server/auth/saas_user_auth.py | 18 +++- enterprise/server/routes/api_keys.py | 57 ++++++++++++- enterprise/storage/api_key_store.py | 22 ++++- .../storage/saas_conversation_validator.py | 6 +- .../tests/unit/server/routes/test_api_keys.py | 85 +++++++++++++++++++ enterprise/tests/unit/test_api_key_store.py | 14 ++- enterprise/tests/unit/test_saas_user_auth.py | 17 +++- 7 files changed, 203 insertions(+), 16 deletions(-) diff --git a/enterprise/server/auth/saas_user_auth.py b/enterprise/server/auth/saas_user_auth.py index c2b3e1fbe9..6c8aefea7a 100644 --- a/enterprise/server/auth/saas_user_auth.py +++ b/enterprise/server/auth/saas_user_auth.py @@ -1,6 +1,7 @@ import time from dataclasses import dataclass from types import MappingProxyType +from uuid import UUID import jwt from fastapi import Request @@ -59,6 +60,10 @@ class SaasUserAuth(UserAuth): _secrets: Secrets | None = None accepted_tos: bool | None = None auth_type: AuthType = AuthType.COOKIE + # API key context fields - populated when authenticated via API key + api_key_org_id: UUID | None = None + api_key_id: int | None = None + api_key_name: str | None = None async def get_user_id(self) -> str | None: return self.user_id @@ -283,14 +288,19 @@ async def saas_user_auth_from_bearer(request: Request) -> SaasUserAuth | None: return None api_key_store = ApiKeyStore.get_instance() - user_id = await api_key_store.validate_api_key(api_key) - if not user_id: + validation_result = await api_key_store.validate_api_key(api_key) + if not validation_result: return None - offline_token = await token_manager.load_offline_token(user_id) + offline_token = await token_manager.load_offline_token( + validation_result.user_id + ) saas_user_auth = SaasUserAuth( - user_id=user_id, + user_id=validation_result.user_id, refresh_token=SecretStr(offline_token), auth_type=AuthType.BEARER, + api_key_org_id=validation_result.org_id, + api_key_id=validation_result.key_id, + api_key_name=validation_result.key_name, ) await saas_user_auth.refresh() return saas_user_auth diff --git a/enterprise/server/routes/api_keys.py b/enterprise/server/routes/api_keys.py index d5f30f87cf..31320966da 100644 --- a/enterprise/server/routes/api_keys.py +++ b/enterprise/server/routes/api_keys.py @@ -1,7 +1,9 @@ from datetime import UTC, datetime +from typing import cast -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, Depends, HTTPException, Request, status from pydantic import BaseModel, field_validator +from server.auth.saas_user_auth import SaasUserAuth from storage.api_key import ApiKey from storage.api_key_store import ApiKeyStore from storage.lite_llm_manager import LiteLlmManager @@ -11,7 +13,8 @@ from storage.org_service import OrgService from storage.user_store import UserStore from openhands.core.logger import openhands_logger as logger -from openhands.server.user_auth import get_user_id +from openhands.server.user_auth import get_user_auth, get_user_id +from openhands.server.user_auth.user_auth import AuthType # Helper functions for BYOR API key management @@ -150,6 +153,16 @@ class MessageResponse(BaseModel): message: str +class CurrentApiKeyResponse(BaseModel): + """Response model for the current API key endpoint.""" + + id: int + name: str | None + org_id: str + user_id: str + auth_type: str + + def api_key_to_response(key: ApiKey) -> ApiKeyResponse: """Convert an ApiKey model to an ApiKeyResponse.""" return ApiKeyResponse( @@ -262,6 +275,46 @@ async def delete_api_key( ) +@api_router.get('/current', tags=['Keys']) +async def get_current_api_key( + request: Request, + user_id: str = Depends(get_user_id), +) -> CurrentApiKeyResponse: + """Get information about the currently authenticated API key. + + This endpoint returns metadata about the API key used for the current request, + including the org_id associated with the key. This is useful for API key + callers who need to know which organization context their key operates in. + + Returns 400 if not authenticated via API key (e.g., using cookie auth). + """ + user_auth = await get_user_auth(request) + + # Check if authenticated via API key + if user_auth.get_auth_type() != AuthType.BEARER: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail='This endpoint requires API key authentication. Not available for cookie-based auth.', + ) + + # In SaaS context, bearer auth always produces SaasUserAuth + saas_user_auth = cast(SaasUserAuth, user_auth) + + if saas_user_auth.api_key_org_id is None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail='This API key was created before organization support. Please regenerate your API key to use this endpoint.', + ) + + return CurrentApiKeyResponse( + id=saas_user_auth.api_key_id, + name=saas_user_auth.api_key_name, + org_id=str(saas_user_auth.api_key_org_id), + user_id=user_id, + auth_type=saas_user_auth.auth_type.value, + ) + + @api_router.get('/llm/byor', tags=['Keys']) async def get_llm_api_key_for_byor( user_id: str = Depends(get_user_id), diff --git a/enterprise/storage/api_key_store.py b/enterprise/storage/api_key_store.py index 74a2d3d73e..3090b8da07 100644 --- a/enterprise/storage/api_key_store.py +++ b/enterprise/storage/api_key_store.py @@ -4,6 +4,7 @@ import secrets import string from dataclasses import dataclass from datetime import UTC, datetime +from uuid import UUID from sqlalchemy import select, update from storage.api_key import ApiKey @@ -13,6 +14,16 @@ from storage.user_store import UserStore from openhands.core.logger import openhands_logger as logger +@dataclass +class ApiKeyValidationResult: + """Result of API key validation containing user and org context.""" + + user_id: str + org_id: UUID | None + key_id: int + key_name: str | None + + @dataclass class ApiKeyStore: API_KEY_PREFIX = 'sk-oh-' @@ -60,8 +71,8 @@ class ApiKeyStore: return api_key - async def validate_api_key(self, api_key: str) -> str | None: - """Validate an API key and return the associated user_id if valid.""" + async def validate_api_key(self, api_key: str) -> ApiKeyValidationResult | None: + """Validate an API key and return the associated user_id and org_id if valid.""" now = datetime.now(UTC) async with a_session_maker() as session: @@ -89,7 +100,12 @@ class ApiKeyStore: ) await session.commit() - return key_record.user_id + return ApiKeyValidationResult( + user_id=key_record.user_id, + org_id=key_record.org_id, + key_id=key_record.id, + key_name=key_record.name, + ) async def delete_api_key(self, api_key: str) -> bool: """Delete an API key by the key value.""" diff --git a/enterprise/storage/saas_conversation_validator.py b/enterprise/storage/saas_conversation_validator.py index bff4468011..51a5302dfc 100644 --- a/enterprise/storage/saas_conversation_validator.py +++ b/enterprise/storage/saas_conversation_validator.py @@ -28,12 +28,14 @@ class SaasConversationValidator(ConversationValidator): # Validate the API key and get the user_id api_key_store = ApiKeyStore.get_instance() - user_id = await api_key_store.validate_api_key(api_key) + validation_result = await api_key_store.validate_api_key(api_key) - if not user_id: + if not validation_result: logger.warning('Invalid API key') return None + user_id = validation_result.user_id + # Get the offline token for the user offline_token = await token_manager.load_offline_token(user_id) if not offline_token: diff --git a/enterprise/tests/unit/server/routes/test_api_keys.py b/enterprise/tests/unit/server/routes/test_api_keys.py index 57a9cb465d..4c35e9d5be 100644 --- a/enterprise/tests/unit/server/routes/test_api_keys.py +++ b/enterprise/tests/unit/server/routes/test_api_keys.py @@ -1,19 +1,26 @@ """Unit tests for API keys routes, focusing on BYOR key validation and retrieval.""" +import uuid from unittest.mock import AsyncMock, MagicMock, patch import httpx import pytest from fastapi import HTTPException +from pydantic import SecretStr +from server.auth.saas_user_auth import SaasUserAuth from server.routes.api_keys import ( ByorPermittedResponse, + CurrentApiKeyResponse, LlmApiKeyResponse, check_byor_permitted, delete_byor_key_from_litellm, + get_current_api_key, get_llm_api_key_for_byor, ) from storage.lite_llm_manager import LiteLlmManager +from openhands.server.user_auth.user_auth import AuthType + class TestVerifyByorKeyInLitellm: """Test the verify_byor_key_in_litellm function.""" @@ -512,3 +519,81 @@ class TestCheckByorPermitted: assert exc_info.value.status_code == 500 assert 'Failed to check BYOR export permission' in exc_info.value.detail + + +class TestGetCurrentApiKey: + """Test the get_current_api_key endpoint.""" + + @pytest.mark.asyncio + @patch('server.routes.api_keys.get_user_auth') + async def test_returns_api_key_info_for_bearer_auth(self, mock_get_user_auth): + """Test that API key metadata including org_id is returned for bearer token auth.""" + # Arrange + user_id = 'user-123' + org_id = uuid.uuid4() + mock_request = MagicMock() + + user_auth = SaasUserAuth( + refresh_token=SecretStr('mock-token'), + user_id=user_id, + auth_type=AuthType.BEARER, + api_key_org_id=org_id, + api_key_id=42, + api_key_name='My Production Key', + ) + mock_get_user_auth.return_value = user_auth + + # Act + result = await get_current_api_key(request=mock_request, user_id=user_id) + + # Assert + assert isinstance(result, CurrentApiKeyResponse) + assert result.org_id == str(org_id) + assert result.id == 42 + assert result.name == 'My Production Key' + assert result.user_id == user_id + assert result.auth_type == 'bearer' + + @pytest.mark.asyncio + @patch('server.routes.api_keys.get_user_auth') + async def test_returns_400_for_cookie_auth(self, mock_get_user_auth): + """Test that 400 Bad Request is returned when using cookie authentication.""" + # Arrange + user_id = 'user-123' + mock_request = MagicMock() + + mock_user_auth = MagicMock() + mock_user_auth.get_auth_type.return_value = AuthType.COOKIE + mock_get_user_auth.return_value = mock_user_auth + + # Act & Assert + with pytest.raises(HTTPException) as exc_info: + await get_current_api_key(request=mock_request, user_id=user_id) + + assert exc_info.value.status_code == 400 + assert 'API key authentication' in exc_info.value.detail + + @pytest.mark.asyncio + @patch('server.routes.api_keys.get_user_auth') + async def test_returns_400_when_api_key_org_id_is_none(self, mock_get_user_auth): + """Test that 400 is returned when API key has no org_id (legacy key).""" + # Arrange + user_id = 'user-123' + mock_request = MagicMock() + + user_auth = SaasUserAuth( + refresh_token=SecretStr('mock-token'), + user_id=user_id, + auth_type=AuthType.BEARER, + api_key_org_id=None, # No org_id - legacy key + api_key_id=42, + api_key_name='Legacy Key', + ) + mock_get_user_auth.return_value = user_auth + + # Act & Assert + with pytest.raises(HTTPException) as exc_info: + await get_current_api_key(request=mock_request, user_id=user_id) + + assert exc_info.value.status_code == 400 + assert 'created before organization support' in exc_info.value.detail diff --git a/enterprise/tests/unit/test_api_key_store.py b/enterprise/tests/unit/test_api_key_store.py index d3a2d13d1e..baffe5893c 100644 --- a/enterprise/tests/unit/test_api_key_store.py +++ b/enterprise/tests/unit/test_api_key_store.py @@ -126,13 +126,18 @@ async def test_validate_api_key_valid(api_key_store, async_session_maker): ) session.add(key_record) await session.commit() + key_id = key_record.id # Execute - patch a_session_maker to use test's async session maker with patch('storage.api_key_store.a_session_maker', async_session_maker): result = await api_key_store.validate_api_key(api_key_value) - # Verify - assert result == user_id + # Verify - result is now ApiKeyValidationResult + assert result is not None + assert result.user_id == user_id + assert result.org_id == org_id + assert result.key_id == key_id + assert result.key_name == 'Test Key' @pytest.mark.asyncio @@ -218,8 +223,9 @@ async def test_validate_api_key_valid_timezone_naive( with patch('storage.api_key_store.a_session_maker', async_session_maker): result = await api_key_store.validate_api_key(api_key_value) - # Verify - assert result == user_id + # Verify - result is now ApiKeyValidationResult + assert result is not None + assert result.user_id == user_id @pytest.mark.asyncio diff --git a/enterprise/tests/unit/test_saas_user_auth.py b/enterprise/tests/unit/test_saas_user_auth.py index 92552de3ad..2fb1b68445 100644 --- a/enterprise/tests/unit/test_saas_user_auth.py +++ b/enterprise/tests/unit/test_saas_user_auth.py @@ -1,4 +1,5 @@ import time +import uuid from unittest.mock import AsyncMock, MagicMock, patch import jwt @@ -18,6 +19,7 @@ from server.auth.saas_user_auth import ( saas_user_auth_from_cookie, saas_user_auth_from_signed_token, ) +from storage.api_key_store import ApiKeyValidationResult from storage.user_authorization import UserAuthorizationType from openhands.integrations.provider import ProviderToken, ProviderType @@ -468,12 +470,22 @@ async def test_saas_user_auth_from_bearer_success(): algorithm='HS256', ) + mock_org_id = uuid.uuid4() + mock_validation_result = ApiKeyValidationResult( + user_id='test_user_id', + org_id=mock_org_id, + key_id=42, + key_name='Test Key', + ) + with ( patch('server.auth.saas_user_auth.ApiKeyStore') as mock_api_key_store_cls, patch('server.auth.saas_user_auth.token_manager') as mock_token_manager, ): mock_api_key_store = MagicMock() - mock_api_key_store.validate_api_key = AsyncMock(return_value='test_user_id') + mock_api_key_store.validate_api_key = AsyncMock( + return_value=mock_validation_result + ) mock_api_key_store_cls.get_instance.return_value = mock_api_key_store mock_token_manager.load_offline_token = AsyncMock(return_value=offline_token) @@ -485,6 +497,9 @@ async def test_saas_user_auth_from_bearer_success(): assert isinstance(result, SaasUserAuth) assert result.user_id == 'test_user_id' + assert result.api_key_org_id == mock_org_id + assert result.api_key_id == 42 + assert result.api_key_name == 'Test Key' mock_api_key_store.validate_api_key.assert_called_once_with('test_api_key') mock_token_manager.load_offline_token.assert_called_once_with('test_user_id') mock_token_manager.refresh.assert_called_once_with(offline_token) From 1d1ffc2be0454db99cfc77dc05cba5da5e74688c Mon Sep 17 00:00:00 2001 From: Rohit Malhotra Date: Wed, 18 Mar 2026 15:07:36 -0400 Subject: [PATCH 04/28] feat(enterprise): Add service API for automation API key creation (#13467) Co-authored-by: openhands --- enterprise/saas_server.py | 2 + enterprise/server/middleware.py | 4 + enterprise/server/routes/service.py | 270 ++++++++++++++ enterprise/storage/api_key_store.py | 199 ++++++++++- enterprise/tests/unit/routes/__init__.py | 0 enterprise/tests/unit/routes/test_service.py | 331 ++++++++++++++++++ .../tests/unit/storage/test_api_key_store.py | 314 +++++++++++++++++ 7 files changed, 1110 insertions(+), 10 deletions(-) create mode 100644 enterprise/server/routes/service.py create mode 100644 enterprise/tests/unit/routes/__init__.py create mode 100644 enterprise/tests/unit/routes/test_service.py create mode 100644 enterprise/tests/unit/storage/test_api_key_store.py diff --git a/enterprise/saas_server.py b/enterprise/saas_server.py index 8bb576a55b..434652befd 100644 --- a/enterprise/saas_server.py +++ b/enterprise/saas_server.py @@ -46,6 +46,7 @@ from server.routes.org_invitations import ( # noqa: E402 ) from server.routes.orgs import org_router # noqa: E402 from server.routes.readiness import readiness_router # noqa: E402 +from server.routes.service import service_router # noqa: E402 from server.routes.user import saas_user_router # noqa: E402 from server.routes.user_app_settings import user_app_settings_router # noqa: E402 from server.sharing.shared_conversation_router import ( # noqa: E402 @@ -112,6 +113,7 @@ if GITLAB_APP_CLIENT_ID: base_app.include_router(gitlab_integration_router) base_app.include_router(api_keys_router) # Add routes for API key management +base_app.include_router(service_router) # Add routes for internal service API base_app.include_router(org_router) # Add routes for organization management base_app.include_router( verified_models_router diff --git a/enterprise/server/middleware.py b/enterprise/server/middleware.py index 659a66046a..c014864b0b 100644 --- a/enterprise/server/middleware.py +++ b/enterprise/server/middleware.py @@ -182,6 +182,10 @@ class SetAuthCookieMiddleware: if path.startswith('/api/v1/webhooks/'): return False + # Service API uses its own authentication (X-Service-API-Key header) + if path.startswith('/api/service/'): + return False + is_mcp = path.startswith('/mcp') is_api_route = path.startswith('/api') return is_api_route or is_mcp diff --git a/enterprise/server/routes/service.py b/enterprise/server/routes/service.py new file mode 100644 index 0000000000..87e470dd7c --- /dev/null +++ b/enterprise/server/routes/service.py @@ -0,0 +1,270 @@ +""" +Service API routes for internal service-to-service communication. + +This module provides endpoints for trusted internal services (e.g., automations service) +to perform privileged operations like creating API keys on behalf of users. + +Authentication is via a shared secret (X-Service-API-Key header) configured +through the AUTOMATIONS_SERVICE_API_KEY environment variable. +""" + +import os +from uuid import UUID + +from fastapi import APIRouter, Header, HTTPException, status +from pydantic import BaseModel, field_validator +from storage.api_key_store import ApiKeyStore +from storage.org_member_store import OrgMemberStore +from storage.user_store import UserStore + +from openhands.core.logger import openhands_logger as logger + +# Environment variable for the service API key +AUTOMATIONS_SERVICE_API_KEY = os.getenv('AUTOMATIONS_SERVICE_API_KEY', '').strip() + +service_router = APIRouter(prefix='/api/service', tags=['Service']) + + +class CreateUserApiKeyRequest(BaseModel): + """Request model for creating an API key on behalf of a user.""" + + name: str # Required - used to identify the key + + @field_validator('name') + @classmethod + def validate_name(cls, v: str) -> str: + if not v or not v.strip(): + raise ValueError('name is required and cannot be empty') + return v.strip() + + +class CreateUserApiKeyResponse(BaseModel): + """Response model for created API key.""" + + key: str + user_id: str + org_id: str + name: str + + +class ServiceInfoResponse(BaseModel): + """Response model for service info endpoint.""" + + service: str + authenticated: bool + + +async def validate_service_api_key( + x_service_api_key: str | None = Header(default=None, alias='X-Service-API-Key'), +) -> str: + """ + Validate the service API key from the request header. + + Args: + x_service_api_key: The service API key from the X-Service-API-Key header + + Returns: + str: Service identifier for audit logging + + Raises: + HTTPException: 401 if key is missing or invalid + HTTPException: 503 if service auth is not configured + """ + if not AUTOMATIONS_SERVICE_API_KEY: + logger.warning( + 'Service authentication not configured (AUTOMATIONS_SERVICE_API_KEY not set)' + ) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail='Service authentication not configured', + ) + + if not x_service_api_key: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail='X-Service-API-Key header is required', + ) + + if x_service_api_key != AUTOMATIONS_SERVICE_API_KEY: + logger.warning('Invalid service API key attempted') + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail='Invalid service API key', + ) + + return 'automations-service' + + +@service_router.get('/health') +async def service_health() -> dict: + """Health check endpoint for the service API. + + This endpoint does not require authentication and can be used + to verify the service routes are accessible. + """ + return { + 'status': 'ok', + 'service_auth_configured': bool(AUTOMATIONS_SERVICE_API_KEY), + } + + +@service_router.post('/users/{user_id}/orgs/{org_id}/api-keys') +async def get_or_create_api_key_for_user( + user_id: str, + org_id: UUID, + request: CreateUserApiKeyRequest, + x_service_api_key: str | None = Header(default=None, alias='X-Service-API-Key'), +) -> CreateUserApiKeyResponse: + """ + Get or create an API key for a user on behalf of the automations service. + + If a key with the given name already exists for the user/org and is not expired, + returns the existing key. Otherwise, creates a new key. + + The created/returned keys are system keys and are: + - Not visible to the user in their API keys list + - Not deletable by the user + - Never expire + + Args: + user_id: The user ID + org_id: The organization ID + request: Request body containing name (required) + x_service_api_key: Service API key header for authentication + + Returns: + CreateUserApiKeyResponse: The API key and metadata + + Raises: + HTTPException: 401 if service key is invalid + HTTPException: 404 if user not found + HTTPException: 403 if user is not a member of the specified org + """ + # Validate service API key + service_id = await validate_service_api_key(x_service_api_key) + + # Verify user exists + user = await UserStore.get_user_by_id(user_id) + if not user: + logger.warning( + 'Service attempted to create key for non-existent user', + extra={'user_id': user_id}, + ) + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f'User {user_id} not found', + ) + + # Verify user is a member of the specified org + org_member = await OrgMemberStore.get_org_member(org_id, UUID(user_id)) + if not org_member: + logger.warning( + 'Service attempted to create key for user not in org', + extra={ + 'user_id': user_id, + 'org_id': str(org_id), + }, + ) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f'User {user_id} is not a member of org {org_id}', + ) + + # Get or create the system API key + api_key_store = ApiKeyStore.get_instance() + + try: + api_key = await api_key_store.get_or_create_system_api_key( + user_id=user_id, + org_id=org_id, + name=request.name, + ) + except Exception as e: + logger.exception( + 'Failed to get or create system API key', + extra={ + 'user_id': user_id, + 'org_id': str(org_id), + 'error': str(e), + }, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail='Failed to get or create API key', + ) + + logger.info( + 'Service created API key for user', + extra={ + 'service_id': service_id, + 'user_id': user_id, + 'org_id': str(org_id), + 'key_name': request.name, + }, + ) + + return CreateUserApiKeyResponse( + key=api_key, + user_id=user_id, + org_id=str(org_id), + name=request.name, + ) + + +@service_router.delete('/users/{user_id}/orgs/{org_id}/api-keys/{key_name}') +async def delete_user_api_key( + user_id: str, + org_id: UUID, + key_name: str, + x_service_api_key: str | None = Header(default=None, alias='X-Service-API-Key'), +) -> dict: + """ + Delete a system API key created by the service. + + This endpoint allows the automations service to clean up API keys + it previously created for users. + + Args: + user_id: The user ID + org_id: The organization ID + key_name: The name of the key to delete (without __SYSTEM__: prefix) + x_service_api_key: Service API key header for authentication + + Returns: + dict: Success message + + Raises: + HTTPException: 401 if service key is invalid + HTTPException: 404 if key not found + """ + # Validate service API key + service_id = await validate_service_api_key(x_service_api_key) + + api_key_store = ApiKeyStore.get_instance() + + # Delete the key by name (wrap with system key prefix since service creates system keys) + system_key_name = api_key_store.make_system_key_name(key_name) + success = await api_key_store.delete_api_key_by_name( + user_id=user_id, + org_id=org_id, + name=system_key_name, + allow_system=True, + ) + + if not success: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f'API key with name "{key_name}" not found for user {user_id} in org {org_id}', + ) + + logger.info( + 'Service deleted API key for user', + extra={ + 'service_id': service_id, + 'user_id': user_id, + 'org_id': str(org_id), + 'key_name': key_name, + }, + ) + + return {'message': 'API key deleted successfully'} diff --git a/enterprise/storage/api_key_store.py b/enterprise/storage/api_key_store.py index 3090b8da07..ecbb375592 100644 --- a/enterprise/storage/api_key_store.py +++ b/enterprise/storage/api_key_store.py @@ -27,6 +27,9 @@ class ApiKeyValidationResult: @dataclass class ApiKeyStore: API_KEY_PREFIX = 'sk-oh-' + # Prefix for system keys created by internal services (e.g., automations) + # Keys with this prefix are hidden from users and cannot be deleted by users + SYSTEM_KEY_NAME_PREFIX = '__SYSTEM__:' def generate_api_key(self, length: int = 32) -> str: """Generate a random API key with the sk-oh- prefix.""" @@ -34,6 +37,19 @@ class ApiKeyStore: random_part = ''.join(secrets.choice(alphabet) for _ in range(length)) return f'{self.API_KEY_PREFIX}{random_part}' + @classmethod + def is_system_key_name(cls, name: str | None) -> bool: + """Check if a key name indicates a system key.""" + return name is not None and name.startswith(cls.SYSTEM_KEY_NAME_PREFIX) + + @classmethod + def make_system_key_name(cls, name: str) -> str: + """Create a system key name with the appropriate prefix. + + Format: __SYSTEM__: + """ + return f'{cls.SYSTEM_KEY_NAME_PREFIX}{name}' + async def create_api_key( self, user_id: str, name: str | None = None, expires_at: datetime | None = None ) -> str: @@ -71,6 +87,113 @@ class ApiKeyStore: return api_key + async def get_or_create_system_api_key( + self, + user_id: str, + org_id: UUID, + name: str, + ) -> str: + """Get or create a system API key for a user on behalf of an internal service. + + If a key with the given name already exists for this user/org and is not expired, + returns the existing key. Otherwise, creates a new key (and deletes any expired one). + + System keys are: + - Not visible to users in their API keys list (filtered by name prefix) + - Not deletable by users (protected by name prefix check) + - Associated with a specific org (not the user's current org) + - Never expire (no expiration date) + + Args: + user_id: The ID of the user to create the key for + org_id: The organization ID to associate the key with + name: Required name for the key (will be prefixed with __SYSTEM__:) + + Returns: + The API key (existing or newly created) + """ + # Create system key name with prefix + system_key_name = self.make_system_key_name(name) + + async with a_session_maker() as session: + # Check if key already exists for this user/org/name + result = await session.execute( + select(ApiKey).filter( + ApiKey.user_id == user_id, + ApiKey.org_id == org_id, + ApiKey.name == system_key_name, + ) + ) + existing_key = result.scalars().first() + + if existing_key: + # Check if expired + if existing_key.expires_at: + now = datetime.now(UTC) + expires_at = existing_key.expires_at + if expires_at.tzinfo is None: + expires_at = expires_at.replace(tzinfo=UTC) + + if expires_at < now: + # Key is expired, delete it and create new one + logger.info( + 'System API key expired, re-issuing', + extra={ + 'user_id': user_id, + 'org_id': str(org_id), + 'key_name': system_key_name, + }, + ) + await session.delete(existing_key) + await session.commit() + else: + # Key exists and is not expired, return it + logger.debug( + 'Returning existing system API key', + extra={ + 'user_id': user_id, + 'org_id': str(org_id), + 'key_name': system_key_name, + }, + ) + return existing_key.key + else: + # Key exists and has no expiration, return it + logger.debug( + 'Returning existing system API key', + extra={ + 'user_id': user_id, + 'org_id': str(org_id), + 'key_name': system_key_name, + }, + ) + return existing_key.key + + # Create new key (no expiration) + api_key = self.generate_api_key() + + async with a_session_maker() as session: + key_record = ApiKey( + key=api_key, + user_id=user_id, + org_id=org_id, + name=system_key_name, + expires_at=None, # System keys never expire + ) + session.add(key_record) + await session.commit() + + logger.info( + 'Created system API key', + extra={ + 'user_id': user_id, + 'org_id': str(org_id), + 'key_name': system_key_name, + }, + ) + + return api_key + async def validate_api_key(self, api_key: str) -> ApiKeyValidationResult | None: """Validate an API key and return the associated user_id and org_id if valid.""" now = datetime.now(UTC) @@ -121,8 +244,18 @@ class ApiKeyStore: return True - async def delete_api_key_by_id(self, key_id: int) -> bool: - """Delete an API key by its ID.""" + async def delete_api_key_by_id( + self, key_id: int, allow_system: bool = False + ) -> bool: + """Delete an API key by its ID. + + Args: + key_id: The ID of the key to delete + allow_system: If False (default), system keys cannot be deleted + + Returns: + True if the key was deleted, False if not found or is a protected system key + """ async with a_session_maker() as session: result = await session.execute(select(ApiKey).filter(ApiKey.id == key_id)) key_record = result.scalars().first() @@ -130,13 +263,26 @@ class ApiKeyStore: if not key_record: return False + # Protect system keys from deletion unless explicitly allowed + if self.is_system_key_name(key_record.name) and not allow_system: + logger.warning( + 'Attempted to delete system API key', + extra={'key_id': key_id, 'user_id': key_record.user_id}, + ) + return False + await session.delete(key_record) await session.commit() return True async def list_api_keys(self, user_id: str) -> list[ApiKey]: - """List all API keys for a user.""" + """List all user-visible API keys for a user. + + This excludes: + - System keys (name starts with __SYSTEM__:) - created by internal services + - MCP_API_KEY - internal MCP key + """ user = await UserStore.get_user_by_id(user_id) if user is None: raise ValueError(f'User not found: {user_id}') @@ -145,11 +291,17 @@ class ApiKeyStore: async with a_session_maker() as session: result = await session.execute( select(ApiKey).filter( - ApiKey.user_id == user_id, ApiKey.org_id == org_id + ApiKey.user_id == user_id, + ApiKey.org_id == org_id, ) ) keys = result.scalars().all() - return [key for key in keys if key.name != 'MCP_API_KEY'] + # Filter out system keys and MCP_API_KEY + return [ + key + for key in keys + if key.name != 'MCP_API_KEY' and not self.is_system_key_name(key.name) + ] async def retrieve_mcp_api_key(self, user_id: str) -> str | None: user = await UserStore.get_user_by_id(user_id) @@ -179,17 +331,44 @@ class ApiKeyStore: key_record = result.scalars().first() return key_record.key if key_record else None - async def delete_api_key_by_name(self, user_id: str, name: str) -> bool: - """Delete an API key by name for a specific user.""" + async def delete_api_key_by_name( + self, + user_id: str, + name: str, + org_id: UUID | None = None, + allow_system: bool = False, + ) -> bool: + """Delete an API key by name for a specific user. + + Args: + user_id: The ID of the user whose key to delete + name: The name of the key to delete + org_id: Optional organization ID to filter by (required for system keys) + allow_system: If False (default), system keys cannot be deleted + + Returns: + True if the key was deleted, False if not found or is a protected system key + """ async with a_session_maker() as session: - result = await session.execute( - select(ApiKey).filter(ApiKey.user_id == user_id, ApiKey.name == name) - ) + # Build the query filters + filters = [ApiKey.user_id == user_id, ApiKey.name == name] + if org_id is not None: + filters.append(ApiKey.org_id == org_id) + + result = await session.execute(select(ApiKey).filter(*filters)) key_record = result.scalars().first() if not key_record: return False + # Protect system keys from deletion unless explicitly allowed + if self.is_system_key_name(key_record.name) and not allow_system: + logger.warning( + 'Attempted to delete system API key', + extra={'user_id': user_id, 'key_name': name}, + ) + return False + await session.delete(key_record) await session.commit() diff --git a/enterprise/tests/unit/routes/__init__.py b/enterprise/tests/unit/routes/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/enterprise/tests/unit/routes/test_service.py b/enterprise/tests/unit/routes/test_service.py new file mode 100644 index 0000000000..a7156ec117 --- /dev/null +++ b/enterprise/tests/unit/routes/test_service.py @@ -0,0 +1,331 @@ +"""Unit tests for service API routes.""" + +import uuid +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import HTTPException +from server.routes.service import ( + CreateUserApiKeyRequest, + delete_user_api_key, + get_or_create_api_key_for_user, + validate_service_api_key, +) + + +class TestValidateServiceApiKey: + """Test cases for validate_service_api_key.""" + + @pytest.mark.asyncio + async def test_valid_service_key(self): + """Test validation with valid service API key.""" + with patch( + 'server.routes.service.AUTOMATIONS_SERVICE_API_KEY', 'test-service-key' + ): + result = await validate_service_api_key('test-service-key') + assert result == 'automations-service' + + @pytest.mark.asyncio + async def test_missing_service_key(self): + """Test validation with missing service API key header.""" + with patch( + 'server.routes.service.AUTOMATIONS_SERVICE_API_KEY', 'test-service-key' + ): + with pytest.raises(HTTPException) as exc_info: + await validate_service_api_key(None) + assert exc_info.value.status_code == 401 + assert 'X-Service-API-Key header is required' in exc_info.value.detail + + @pytest.mark.asyncio + async def test_invalid_service_key(self): + """Test validation with invalid service API key.""" + with patch( + 'server.routes.service.AUTOMATIONS_SERVICE_API_KEY', 'test-service-key' + ): + with pytest.raises(HTTPException) as exc_info: + await validate_service_api_key('wrong-key') + assert exc_info.value.status_code == 401 + assert 'Invalid service API key' in exc_info.value.detail + + @pytest.mark.asyncio + async def test_service_auth_not_configured(self): + """Test validation when service auth is not configured.""" + with patch('server.routes.service.AUTOMATIONS_SERVICE_API_KEY', ''): + with pytest.raises(HTTPException) as exc_info: + await validate_service_api_key('any-key') + assert exc_info.value.status_code == 503 + assert 'Service authentication not configured' in exc_info.value.detail + + +class TestCreateUserApiKeyRequest: + """Test cases for CreateUserApiKeyRequest validation.""" + + def test_valid_request(self): + """Test valid request with all fields.""" + request = CreateUserApiKeyRequest( + name='automation', + ) + assert request.name == 'automation' + + def test_name_is_required(self): + """Test that name field is required.""" + with pytest.raises(ValueError): + CreateUserApiKeyRequest( + name='', # Empty name should fail + ) + + def test_name_is_stripped(self): + """Test that name field is stripped of whitespace.""" + request = CreateUserApiKeyRequest( + name=' automation ', + ) + assert request.name == 'automation' + + def test_whitespace_only_name_fails(self): + """Test that whitespace-only name fails validation.""" + with pytest.raises(ValueError): + CreateUserApiKeyRequest( + name=' ', + ) + + +class TestGetOrCreateApiKeyForUser: + """Test cases for get_or_create_api_key_for_user endpoint.""" + + @pytest.fixture + def valid_user_id(self): + """Return a valid user ID.""" + return '5594c7b6-f959-4b81-92e9-b09c206f5081' + + @pytest.fixture + def valid_org_id(self): + """Return a valid org ID.""" + return uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081') + + @pytest.fixture + def valid_request(self): + """Create a valid request object.""" + return CreateUserApiKeyRequest( + name='automation', + ) + + @pytest.mark.asyncio + async def test_user_not_found(self, valid_user_id, valid_org_id, valid_request): + """Test error when user doesn't exist.""" + with patch('server.routes.service.AUTOMATIONS_SERVICE_API_KEY', 'test-key'): + with patch( + 'server.routes.service.UserStore.get_user_by_id', new_callable=AsyncMock + ) as mock_get_user: + mock_get_user.return_value = None + with pytest.raises(HTTPException) as exc_info: + await get_or_create_api_key_for_user( + user_id=valid_user_id, + org_id=valid_org_id, + request=valid_request, + x_service_api_key='test-key', + ) + assert exc_info.value.status_code == 404 + assert 'not found' in exc_info.value.detail + + @pytest.mark.asyncio + async def test_user_not_in_org(self, valid_user_id, valid_org_id, valid_request): + """Test error when user is not a member of the org.""" + mock_user = MagicMock() + + with patch('server.routes.service.AUTOMATIONS_SERVICE_API_KEY', 'test-key'): + with patch( + 'server.routes.service.UserStore.get_user_by_id', new_callable=AsyncMock + ) as mock_get_user: + with patch( + 'server.routes.service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, + ) as mock_get_member: + mock_get_user.return_value = mock_user + mock_get_member.return_value = None + with pytest.raises(HTTPException) as exc_info: + await get_or_create_api_key_for_user( + user_id=valid_user_id, + org_id=valid_org_id, + request=valid_request, + x_service_api_key='test-key', + ) + assert exc_info.value.status_code == 403 + assert 'not a member of org' in exc_info.value.detail + + @pytest.mark.asyncio + async def test_successful_key_creation( + self, valid_user_id, valid_org_id, valid_request + ): + """Test successful API key creation.""" + mock_user = MagicMock() + mock_org_member = MagicMock() + mock_api_key_store = MagicMock() + mock_api_key_store.get_or_create_system_api_key = AsyncMock( + return_value='sk-oh-test-key-12345678901234567890' + ) + + with patch('server.routes.service.AUTOMATIONS_SERVICE_API_KEY', 'test-key'): + with patch( + 'server.routes.service.UserStore.get_user_by_id', new_callable=AsyncMock + ) as mock_get_user: + with patch( + 'server.routes.service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, + ) as mock_get_member: + with patch( + 'server.routes.service.ApiKeyStore.get_instance' + ) as mock_get_store: + mock_get_user.return_value = mock_user + mock_get_member.return_value = mock_org_member + mock_get_store.return_value = mock_api_key_store + + response = await get_or_create_api_key_for_user( + user_id=valid_user_id, + org_id=valid_org_id, + request=valid_request, + x_service_api_key='test-key', + ) + + assert response.key == 'sk-oh-test-key-12345678901234567890' + assert response.user_id == valid_user_id + assert response.org_id == str(valid_org_id) + assert response.name == 'automation' + + # Verify the store was called with correct arguments + mock_api_key_store.get_or_create_system_api_key.assert_called_once_with( + user_id=valid_user_id, + org_id=valid_org_id, + name='automation', + ) + + @pytest.mark.asyncio + async def test_store_exception_handling( + self, valid_user_id, valid_org_id, valid_request + ): + """Test error handling when store raises exception.""" + mock_user = MagicMock() + mock_org_member = MagicMock() + mock_api_key_store = MagicMock() + mock_api_key_store.get_or_create_system_api_key = AsyncMock( + side_effect=Exception('Database error') + ) + + with patch('server.routes.service.AUTOMATIONS_SERVICE_API_KEY', 'test-key'): + with patch( + 'server.routes.service.UserStore.get_user_by_id', new_callable=AsyncMock + ) as mock_get_user: + with patch( + 'server.routes.service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, + ) as mock_get_member: + with patch( + 'server.routes.service.ApiKeyStore.get_instance' + ) as mock_get_store: + mock_get_user.return_value = mock_user + mock_get_member.return_value = mock_org_member + mock_get_store.return_value = mock_api_key_store + + with pytest.raises(HTTPException) as exc_info: + await get_or_create_api_key_for_user( + user_id=valid_user_id, + org_id=valid_org_id, + request=valid_request, + x_service_api_key='test-key', + ) + + assert exc_info.value.status_code == 500 + assert 'Failed to get or create API key' in exc_info.value.detail + + +class TestDeleteUserApiKey: + """Test cases for delete_user_api_key endpoint.""" + + @pytest.fixture + def valid_org_id(self): + """Return a valid org ID.""" + return uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081') + + @pytest.mark.asyncio + async def test_successful_delete(self, valid_org_id): + """Test successful deletion of a system API key.""" + mock_api_key_store = MagicMock() + mock_api_key_store.make_system_key_name.return_value = '__SYSTEM__:automation' + mock_api_key_store.delete_api_key_by_name = AsyncMock(return_value=True) + + with patch('server.routes.service.AUTOMATIONS_SERVICE_API_KEY', 'test-key'): + with patch( + 'server.routes.service.ApiKeyStore.get_instance' + ) as mock_get_store: + mock_get_store.return_value = mock_api_key_store + + response = await delete_user_api_key( + user_id='user-123', + org_id=valid_org_id, + key_name='automation', + x_service_api_key='test-key', + ) + + assert response == {'message': 'API key deleted successfully'} + + # Verify the store was called with correct arguments + mock_api_key_store.make_system_key_name.assert_called_once_with('automation') + mock_api_key_store.delete_api_key_by_name.assert_called_once_with( + user_id='user-123', + org_id=valid_org_id, + name='__SYSTEM__:automation', + allow_system=True, + ) + + @pytest.mark.asyncio + async def test_delete_key_not_found(self, valid_org_id): + """Test error when key to delete is not found.""" + mock_api_key_store = MagicMock() + mock_api_key_store.make_system_key_name.return_value = '__SYSTEM__:nonexistent' + mock_api_key_store.delete_api_key_by_name = AsyncMock(return_value=False) + + with patch('server.routes.service.AUTOMATIONS_SERVICE_API_KEY', 'test-key'): + with patch( + 'server.routes.service.ApiKeyStore.get_instance' + ) as mock_get_store: + mock_get_store.return_value = mock_api_key_store + + with pytest.raises(HTTPException) as exc_info: + await delete_user_api_key( + user_id='user-123', + org_id=valid_org_id, + key_name='nonexistent', + x_service_api_key='test-key', + ) + + assert exc_info.value.status_code == 404 + assert 'not found' in exc_info.value.detail + + @pytest.mark.asyncio + async def test_delete_invalid_service_key(self, valid_org_id): + """Test error when service API key is invalid.""" + with patch('server.routes.service.AUTOMATIONS_SERVICE_API_KEY', 'test-key'): + with pytest.raises(HTTPException) as exc_info: + await delete_user_api_key( + user_id='user-123', + org_id=valid_org_id, + key_name='automation', + x_service_api_key='wrong-key', + ) + + assert exc_info.value.status_code == 401 + assert 'Invalid service API key' in exc_info.value.detail + + @pytest.mark.asyncio + async def test_delete_missing_service_key(self, valid_org_id): + """Test error when service API key header is missing.""" + with patch('server.routes.service.AUTOMATIONS_SERVICE_API_KEY', 'test-key'): + with pytest.raises(HTTPException) as exc_info: + await delete_user_api_key( + user_id='user-123', + org_id=valid_org_id, + key_name='automation', + x_service_api_key=None, + ) + + assert exc_info.value.status_code == 401 + assert 'X-Service-API-Key header is required' in exc_info.value.detail diff --git a/enterprise/tests/unit/storage/test_api_key_store.py b/enterprise/tests/unit/storage/test_api_key_store.py new file mode 100644 index 0000000000..0db2d8bb96 --- /dev/null +++ b/enterprise/tests/unit/storage/test_api_key_store.py @@ -0,0 +1,314 @@ +"""Unit tests for ApiKeyStore system key functionality.""" + +import uuid +from datetime import UTC, datetime, timedelta +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from sqlalchemy import select +from storage.api_key import ApiKey +from storage.api_key_store import ApiKeyStore + + +@pytest.fixture +def api_key_store(): + """Create ApiKeyStore instance.""" + return ApiKeyStore() + + +class TestApiKeyStoreSystemKeys: + """Test cases for system API key functionality.""" + + def test_is_system_key_name_with_prefix(self, api_key_store): + """Test that names with __SYSTEM__: prefix are identified as system keys.""" + assert api_key_store.is_system_key_name('__SYSTEM__:automation') is True + assert api_key_store.is_system_key_name('__SYSTEM__:test-key') is True + assert api_key_store.is_system_key_name('__SYSTEM__:') is True + + def test_is_system_key_name_without_prefix(self, api_key_store): + """Test that names without __SYSTEM__: prefix are not system keys.""" + assert api_key_store.is_system_key_name('my-key') is False + assert api_key_store.is_system_key_name('automation') is False + assert api_key_store.is_system_key_name('MCP_API_KEY') is False + assert api_key_store.is_system_key_name('') is False + + def test_is_system_key_name_none(self, api_key_store): + """Test that None is not a system key.""" + assert api_key_store.is_system_key_name(None) is False + + def test_make_system_key_name(self, api_key_store): + """Test system key name generation.""" + assert ( + api_key_store.make_system_key_name('automation') == '__SYSTEM__:automation' + ) + assert api_key_store.make_system_key_name('test-key') == '__SYSTEM__:test-key' + + @pytest.mark.asyncio + async def test_get_or_create_system_api_key_creates_new( + self, api_key_store, async_session_maker + ): + """Test creating a new system API key when none exists.""" + user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081' + org_id = uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081') + key_name = 'automation' + + with patch('storage.api_key_store.a_session_maker', async_session_maker): + api_key = await api_key_store.get_or_create_system_api_key( + user_id=user_id, + org_id=org_id, + name=key_name, + ) + + assert api_key.startswith('sk-oh-') + assert len(api_key) == len('sk-oh-') + 32 + + # Verify the key was created in the database + async with async_session_maker() as session: + result = await session.execute(select(ApiKey).filter(ApiKey.key == api_key)) + key_record = result.scalars().first() + assert key_record is not None + assert key_record.user_id == user_id + assert key_record.org_id == org_id + assert key_record.name == '__SYSTEM__:automation' + assert key_record.expires_at is None # System keys never expire + + @pytest.mark.asyncio + async def test_get_or_create_system_api_key_returns_existing( + self, api_key_store, async_session_maker + ): + """Test that existing valid system key is returned.""" + user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081' + org_id = uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081') + key_name = 'automation' + + with patch('storage.api_key_store.a_session_maker', async_session_maker): + # Create the first key + first_key = await api_key_store.get_or_create_system_api_key( + user_id=user_id, + org_id=org_id, + name=key_name, + ) + + # Request again - should return the same key + second_key = await api_key_store.get_or_create_system_api_key( + user_id=user_id, + org_id=org_id, + name=key_name, + ) + + assert first_key == second_key + + @pytest.mark.asyncio + async def test_get_or_create_system_api_key_different_names( + self, api_key_store, async_session_maker + ): + """Test that different names create different keys.""" + user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081' + org_id = uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081') + + with patch('storage.api_key_store.a_session_maker', async_session_maker): + key1 = await api_key_store.get_or_create_system_api_key( + user_id=user_id, + org_id=org_id, + name='automation-1', + ) + + key2 = await api_key_store.get_or_create_system_api_key( + user_id=user_id, + org_id=org_id, + name='automation-2', + ) + + assert key1 != key2 + + @pytest.mark.asyncio + async def test_get_or_create_system_api_key_reissues_expired( + self, api_key_store, async_session_maker + ): + """Test that expired system key is replaced with a new one.""" + user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081' + org_id = uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081') + key_name = 'automation' + system_key_name = '__SYSTEM__:automation' + + # First, manually create an expired key + expired_time = datetime.now(UTC) - timedelta(hours=1) + async with async_session_maker() as session: + expired_key = ApiKey( + key='sk-oh-expired-key-12345678901234567890', + user_id=user_id, + org_id=org_id, + name=system_key_name, + expires_at=expired_time.replace(tzinfo=None), + ) + session.add(expired_key) + await session.commit() + + with patch('storage.api_key_store.a_session_maker', async_session_maker): + # Request the key - should create a new one + new_key = await api_key_store.get_or_create_system_api_key( + user_id=user_id, + org_id=org_id, + name=key_name, + ) + + assert new_key != 'sk-oh-expired-key-12345678901234567890' + assert new_key.startswith('sk-oh-') + + # Verify old key was deleted and new key exists + async with async_session_maker() as session: + result = await session.execute( + select(ApiKey).filter(ApiKey.name == system_key_name) + ) + keys = result.scalars().all() + assert len(keys) == 1 + assert keys[0].key == new_key + assert keys[0].expires_at is None + + @pytest.mark.asyncio + async def test_list_api_keys_excludes_system_keys( + self, api_key_store, async_session_maker + ): + """Test that list_api_keys excludes system keys.""" + user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081' + org_id = uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081') + + # Create a user key and a system key + async with async_session_maker() as session: + user_key = ApiKey( + key='sk-oh-user-key-123456789012345678901', + user_id=user_id, + org_id=org_id, + name='my-user-key', + ) + system_key = ApiKey( + key='sk-oh-system-key-12345678901234567890', + user_id=user_id, + org_id=org_id, + name='__SYSTEM__:automation', + ) + mcp_key = ApiKey( + key='sk-oh-mcp-key-1234567890123456789012', + user_id=user_id, + org_id=org_id, + name='MCP_API_KEY', + ) + session.add(user_key) + session.add(system_key) + session.add(mcp_key) + await session.commit() + + # Mock UserStore.get_user_by_id to return a user with the correct org + mock_user = MagicMock() + mock_user.current_org_id = org_id + + with patch('storage.api_key_store.a_session_maker', async_session_maker): + with patch( + 'storage.api_key_store.UserStore.get_user_by_id', new_callable=AsyncMock + ) as mock_get_user: + mock_get_user.return_value = mock_user + keys = await api_key_store.list_api_keys(user_id) + + # Should only return the user key + assert len(keys) == 1 + assert keys[0].name == 'my-user-key' + + @pytest.mark.asyncio + async def test_delete_api_key_by_id_protects_system_keys( + self, api_key_store, async_session_maker + ): + """Test that system keys cannot be deleted by users.""" + user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081' + org_id = uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081') + + # Create a system key + async with async_session_maker() as session: + system_key = ApiKey( + key='sk-oh-system-key-12345678901234567890', + user_id=user_id, + org_id=org_id, + name='__SYSTEM__:automation', + ) + session.add(system_key) + await session.commit() + key_id = system_key.id + + with patch('storage.api_key_store.a_session_maker', async_session_maker): + # Attempt to delete without allow_system flag + result = await api_key_store.delete_api_key_by_id( + key_id, allow_system=False + ) + + assert result is False + + # Verify the key still exists + async with async_session_maker() as session: + result = await session.execute(select(ApiKey).filter(ApiKey.id == key_id)) + key_record = result.scalars().first() + assert key_record is not None + + @pytest.mark.asyncio + async def test_delete_api_key_by_id_allows_system_with_flag( + self, api_key_store, async_session_maker + ): + """Test that system keys can be deleted with allow_system=True.""" + user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081' + org_id = uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081') + + # Create a system key + async with async_session_maker() as session: + system_key = ApiKey( + key='sk-oh-system-key-12345678901234567890', + user_id=user_id, + org_id=org_id, + name='__SYSTEM__:automation', + ) + session.add(system_key) + await session.commit() + key_id = system_key.id + + with patch('storage.api_key_store.a_session_maker', async_session_maker): + # Delete with allow_system=True + result = await api_key_store.delete_api_key_by_id(key_id, allow_system=True) + + assert result is True + + # Verify the key was deleted + async with async_session_maker() as session: + result = await session.execute(select(ApiKey).filter(ApiKey.id == key_id)) + key_record = result.scalars().first() + assert key_record is None + + @pytest.mark.asyncio + async def test_delete_api_key_by_id_allows_regular_keys( + self, api_key_store, async_session_maker + ): + """Test that regular keys can be deleted normally.""" + user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081' + org_id = uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081') + + # Create a regular key + async with async_session_maker() as session: + regular_key = ApiKey( + key='sk-oh-regular-key-1234567890123456789', + user_id=user_id, + org_id=org_id, + name='my-regular-key', + ) + session.add(regular_key) + await session.commit() + key_id = regular_key.id + + with patch('storage.api_key_store.a_session_maker', async_session_maker): + # Delete without allow_system flag - should work for regular keys + result = await api_key_store.delete_api_key_by_id( + key_id, allow_system=False + ) + + assert result is True + + # Verify the key was deleted + async with async_session_maker() as session: + result = await session.execute(select(ApiKey).filter(ApiKey.id == key_id)) + key_record = result.scalars().first() + assert key_record is None From 2879e587813b8a128c155d9e2bbcece862b6153c Mon Sep 17 00:00:00 2001 From: aivong-openhands Date: Wed, 18 Mar 2026 15:00:06 -0500 Subject: [PATCH 05/28] Fix CVE-2026-30922: Update pyasn1 to 0.6.3 (#13452) Co-authored-by: OpenHands CVE Fix Bot --- enterprise/poetry.lock | 6 +++--- poetry.lock | 6 +++--- uv.lock | 6 +++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/enterprise/poetry.lock b/enterprise/poetry.lock index 589be34bb0..1bb48f24c6 100644 --- a/enterprise/poetry.lock +++ b/enterprise/poetry.lock @@ -7597,14 +7597,14 @@ wrappers-encryption = ["cryptography (>=45.0.0)"] [[package]] name = "pyasn1" -version = "0.6.2" +version = "0.6.3" description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)" optional = false python-versions = ">=3.8" groups = ["main"] files = [ - {file = "pyasn1-0.6.2-py3-none-any.whl", hash = "sha256:1eb26d860996a18e9b6ed05e7aae0e9fc21619fcee6af91cca9bad4fbea224bf"}, - {file = "pyasn1-0.6.2.tar.gz", hash = "sha256:9b59a2b25ba7e4f8197db7686c09fb33e658b98339fadb826e9512629017833b"}, + {file = "pyasn1-0.6.3-py3-none-any.whl", hash = "sha256:a80184d120f0864a52a073acc6fc642847d0be408e7c7252f31390c0f4eadcde"}, + {file = "pyasn1-0.6.3.tar.gz", hash = "sha256:697a8ecd6d98891189184ca1fa05d1bb00e2f84b5977c481452050549c8a72cf"}, ] [[package]] diff --git a/poetry.lock b/poetry.lock index 5b0b30f61d..bccd0eea80 100644 --- a/poetry.lock +++ b/poetry.lock @@ -7589,14 +7589,14 @@ wrappers-encryption = ["cryptography (>=45.0.0)"] [[package]] name = "pyasn1" -version = "0.6.2" +version = "0.6.3" description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)" optional = false python-versions = ">=3.8" groups = ["main"] files = [ - {file = "pyasn1-0.6.2-py3-none-any.whl", hash = "sha256:1eb26d860996a18e9b6ed05e7aae0e9fc21619fcee6af91cca9bad4fbea224bf"}, - {file = "pyasn1-0.6.2.tar.gz", hash = "sha256:9b59a2b25ba7e4f8197db7686c09fb33e658b98339fadb826e9512629017833b"}, + {file = "pyasn1-0.6.3-py3-none-any.whl", hash = "sha256:a80184d120f0864a52a073acc6fc642847d0be408e7c7252f31390c0f4eadcde"}, + {file = "pyasn1-0.6.3.tar.gz", hash = "sha256:697a8ecd6d98891189184ca1fa05d1bb00e2f84b5977c481452050549c8a72cf"}, ] [[package]] diff --git a/uv.lock b/uv.lock index aec35e87db..67c7965698 100644 --- a/uv.lock +++ b/uv.lock @@ -4643,11 +4643,11 @@ memory = [ [[package]] name = "pyasn1" -version = "0.6.2" +version = "0.6.3" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/fe/b6/6e630dff89739fcd427e3f72b3d905ce0acb85a45d4ec3e2678718a3487f/pyasn1-0.6.2.tar.gz", hash = "sha256:9b59a2b25ba7e4f8197db7686c09fb33e658b98339fadb826e9512629017833b", size = 146586, upload-time = "2026-01-16T18:04:18.534Z" } +sdist = { url = "https://files.pythonhosted.org/packages/5c/5f/6583902b6f79b399c9c40674ac384fd9cd77805f9e6205075f828ef11fb2/pyasn1-0.6.3.tar.gz", hash = "sha256:697a8ecd6d98891189184ca1fa05d1bb00e2f84b5977c481452050549c8a72cf", size = 148685, upload-time = "2026-03-17T01:06:53.382Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/44/b5/a96872e5184f354da9c84ae119971a0a4c221fe9b27a4d94bd43f2596727/pyasn1-0.6.2-py3-none-any.whl", hash = "sha256:1eb26d860996a18e9b6ed05e7aae0e9fc21619fcee6af91cca9bad4fbea224bf", size = 83371, upload-time = "2026-01-16T18:04:17.174Z" }, + { url = "https://files.pythonhosted.org/packages/5d/a0/7d793dce3fa811fe047d6ae2431c672364b462850c6235ae306c0efd025f/pyasn1-0.6.3-py3-none-any.whl", hash = "sha256:a80184d120f0864a52a073acc6fc642847d0be408e7c7252f31390c0f4eadcde", size = 83997, upload-time = "2026-03-17T01:06:52.036Z" }, ] [[package]] From abd1f9948f888c73440b94334e57771aabee5556 Mon Sep 17 00:00:00 2001 From: HeyItsChloe <54480367+HeyItsChloe@users.noreply.github.com> Date: Wed, 18 Mar 2026 13:46:00 -0700 Subject: [PATCH 06/28] fix: return empty skills list instead of 404 for stopped sandboxes (#13429) Co-authored-by: openhands --- .../app_conversation_router.py | 26 ++++++++--- .../test_app_conversation_hooks_endpoint.py | 43 ++++++++++++++++++- .../test_app_conversation_skills_endpoint.py | 21 +++------ 3 files changed, 69 insertions(+), 21 deletions(-) diff --git a/openhands/app_server/app_conversation/app_conversation_router.py b/openhands/app_server/app_conversation/app_conversation_router.py index d3ad901db7..582de93761 100644 --- a/openhands/app_server/app_conversation/app_conversation_router.py +++ b/openhands/app_server/app_conversation/app_conversation_router.py @@ -115,7 +115,7 @@ async def _get_agent_server_context( app_conversation_service: AppConversationService, sandbox_service: SandboxService, sandbox_spec_service: SandboxSpecService, -) -> AgentServerContext | JSONResponse: +) -> AgentServerContext | JSONResponse | None: """Get the agent server context for a conversation. This helper retrieves all necessary information to communicate with the @@ -129,7 +129,8 @@ async def _get_agent_server_context( sandbox_spec_service: Service for sandbox spec operations Returns: - AgentServerContext if successful, or JSONResponse with error details. + AgentServerContext if successful, JSONResponse(404) if conversation + not found, or None if sandbox is not running (e.g. closed conversation). """ # Get the conversation info conversation = await app_conversation_service.get_app_conversation(conversation_id) @@ -141,12 +142,19 @@ async def _get_agent_server_context( # Get the sandbox info sandbox = await sandbox_service.get_sandbox(conversation.sandbox_id) - if not sandbox or sandbox.status != SandboxStatus.RUNNING: + if not sandbox: return JSONResponse( status_code=status.HTTP_404_NOT_FOUND, - content={ - 'error': f'Sandbox not found or not running for conversation {conversation_id}' - }, + content={'error': f'Sandbox not found for conversation {conversation_id}'}, + ) + # Return None for paused sandboxes (closed conversation) + if sandbox.status == SandboxStatus.PAUSED: + return None + # Return 404 for other non-running states (STARTING, ERROR, MISSING) + if sandbox.status != SandboxStatus.RUNNING: + return JSONResponse( + status_code=status.HTTP_404_NOT_FOUND, + content={'error': f'Sandbox not ready for conversation {conversation_id}'}, ) # Get the sandbox spec to find the working directory @@ -587,6 +595,7 @@ async def get_conversation_skills( Returns: JSONResponse: A JSON response containing the list of skills. + Returns an empty list if the sandbox is not running. """ try: # Get agent server context (conversation, sandbox, sandbox_spec, agent_server_url) @@ -598,6 +607,8 @@ async def get_conversation_skills( ) if isinstance(ctx, JSONResponse): return ctx + if ctx is None: + return JSONResponse(status_code=status.HTTP_200_OK, content={'skills': []}) # Load skills from all sources logger.info(f'Loading skills for conversation {conversation_id}') @@ -685,6 +696,7 @@ async def get_conversation_hooks( Returns: JSONResponse: A JSON response containing the list of hook event types. + Returns an empty list if the sandbox is not running. """ try: # Get agent server context (conversation, sandbox, sandbox_spec, agent_server_url) @@ -696,6 +708,8 @@ async def get_conversation_hooks( ) if isinstance(ctx, JSONResponse): return ctx + if ctx is None: + return JSONResponse(status_code=status.HTTP_200_OK, content={'hooks': []}) from openhands.app_server.app_conversation.hook_loader import ( fetch_hooks_from_agent_server, diff --git a/tests/unit/app_server/test_app_conversation_hooks_endpoint.py b/tests/unit/app_server/test_app_conversation_hooks_endpoint.py index ba67c4b488..ffc8c54d37 100644 --- a/tests/unit/app_server/test_app_conversation_hooks_endpoint.py +++ b/tests/unit/app_server/test_app_conversation_hooks_endpoint.py @@ -263,7 +263,7 @@ class TestGetConversationHooks: assert response.status_code == status.HTTP_404_NOT_FOUND - async def test_get_hooks_returns_404_when_sandbox_not_running(self): + async def test_get_hooks_returns_404_when_sandbox_not_found(self): conversation_id = uuid4() sandbox_id = str(uuid4()) @@ -291,3 +291,44 @@ class TestGetConversationHooks: ) assert response.status_code == status.HTTP_404_NOT_FOUND + + async def test_get_hooks_returns_empty_list_when_sandbox_paused(self): + conversation_id = uuid4() + sandbox_id = str(uuid4()) + + mock_conversation = AppConversation( + id=conversation_id, + created_by_user_id='test-user', + sandbox_id=sandbox_id, + sandbox_status=SandboxStatus.PAUSED, + ) + + mock_sandbox = SandboxInfo( + id=sandbox_id, + created_by_user_id='test-user', + status=SandboxStatus.PAUSED, + sandbox_spec_id=str(uuid4()), + session_api_key='test-api-key', + ) + + mock_app_conversation_service = MagicMock() + mock_app_conversation_service.get_app_conversation = AsyncMock( + return_value=mock_conversation + ) + + mock_sandbox_service = MagicMock() + mock_sandbox_service.get_sandbox = AsyncMock(return_value=mock_sandbox) + + response = await get_conversation_hooks( + conversation_id=conversation_id, + app_conversation_service=mock_app_conversation_service, + sandbox_service=mock_sandbox_service, + sandbox_spec_service=MagicMock(), + httpx_client=AsyncMock(spec=httpx.AsyncClient), + ) + + assert response.status_code == status.HTTP_200_OK + import json + + data = json.loads(response.body.decode('utf-8')) + assert data == {'hooks': []} diff --git a/tests/unit/app_server/test_app_conversation_skills_endpoint.py b/tests/unit/app_server/test_app_conversation_skills_endpoint.py index ed7fedd43d..6b601cf9db 100644 --- a/tests/unit/app_server/test_app_conversation_skills_endpoint.py +++ b/tests/unit/app_server/test_app_conversation_skills_endpoint.py @@ -203,7 +203,7 @@ class TestGetConversationSkills: Arrange: Setup conversation but no sandbox Act: Call get_conversation_skills endpoint - Assert: Response is 404 with sandbox error message + Assert: Response is 404 """ # Arrange conversation_id = uuid4() @@ -237,19 +237,13 @@ class TestGetConversationSkills: # Assert assert response.status_code == status.HTTP_404_NOT_FOUND - content = response.body.decode('utf-8') - import json - data = json.loads(content) - assert 'error' in data - assert 'Sandbox not found' in data['error'] + async def test_get_skills_returns_empty_list_when_sandbox_paused(self): + """Test endpoint returns empty skills when sandbox is PAUSED (closed conversation). - async def test_get_skills_returns_404_when_sandbox_not_running(self): - """Test endpoint returns 404 when sandbox is not in RUNNING state. - - Arrange: Setup conversation with stopped sandbox + Arrange: Setup conversation with paused sandbox Act: Call get_conversation_skills endpoint - Assert: Response is 404 with sandbox not running message + Assert: Response is 200 with empty skills list """ # Arrange conversation_id = uuid4() @@ -290,13 +284,12 @@ class TestGetConversationSkills: ) # Assert - assert response.status_code == status.HTTP_404_NOT_FOUND + assert response.status_code == status.HTTP_200_OK content = response.body.decode('utf-8') import json data = json.loads(content) - assert 'error' in data - assert 'not running' in data['error'] + assert data == {'skills': []} async def test_get_skills_handles_task_trigger_skills(self): """Test endpoint correctly handles skills with TaskTrigger. From 7edebcbc0c09fd0edc6c666b327919f267e12197 Mon Sep 17 00:00:00 2001 From: Tim O'Farrell Date: Wed, 18 Mar 2026 16:49:32 -0600 Subject: [PATCH 07/28] fix: use atomic write in LocalFileStore to prevent race conditions (#13480) Co-authored-by: openhands Co-authored-by: OpenHands Bot --- openhands/storage/local.py | 17 ++++++++-- tests/unit/storage/test_storage.py | 52 ++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 2 deletions(-) diff --git a/openhands/storage/local.py b/openhands/storage/local.py index fcb766c0ef..e646f3137e 100644 --- a/openhands/storage/local.py +++ b/openhands/storage/local.py @@ -1,5 +1,6 @@ import os import shutil +import threading from openhands.core.logger import openhands_logger as logger from openhands.storage.files import FileStore @@ -23,8 +24,20 @@ class LocalFileStore(FileStore): full_path = self.get_full_path(path) os.makedirs(os.path.dirname(full_path), exist_ok=True) mode = 'w' if isinstance(contents, str) else 'wb' - with open(full_path, mode) as f: - f.write(contents) + + # Use atomic write: write to temp file, then rename + # This prevents race conditions where concurrent writes could corrupt the file + temp_path = f'{full_path}.tmp.{os.getpid()}.{threading.get_ident()}' + try: + with open(temp_path, mode) as f: + f.write(contents) + f.flush() + os.fsync(f.fileno()) + os.replace(temp_path, full_path) + except Exception: + if os.path.exists(temp_path): + os.remove(temp_path) + raise def read(self, path: str) -> str: full_path = self.get_full_path(path) diff --git a/tests/unit/storage/test_storage.py b/tests/unit/storage/test_storage.py index a78c12df98..5d2508705f 100644 --- a/tests/unit/storage/test_storage.py +++ b/tests/unit/storage/test_storage.py @@ -3,6 +3,7 @@ from __future__ import annotations import logging import shutil import tempfile +import threading from abc import ABC from dataclasses import dataclass, field from io import BytesIO, StringIO @@ -122,6 +123,57 @@ class TestLocalFileStore(TestCase, _StorageTest): f'Failed to remove temporary directory {self.temp_dir}: {e}' ) + def test_concurrent_writes_no_corruption(self): + """Test that concurrent writes don't corrupt file content. + + This test verifies the atomic write fix by having 9 threads write + progressively shorter strings to the same file simultaneously. + Without atomic writes, a shorter write following a longer write + could result in corrupted content (e.g., "123" followed by garbage + from the previous longer write). + + The final content must be exactly one of the valid strings written, + with no trailing garbage from other writes. + """ + filename = 'concurrent_test.txt' + # Strings from longest to shortest: "123456789", "12345678", ..., "1" + valid_contents = ['123456789'[:i] for i in range(9, 0, -1)] + errors: list[Exception] = [] + barrier = threading.Barrier(len(valid_contents)) + + def write_content(content: str): + try: + # Wait for all threads to be ready before writing + barrier.wait() + self.store.write(filename, content) + except Exception as e: + errors.append(e) + + # Start all threads + threads = [ + threading.Thread(target=write_content, args=(content,)) + for content in valid_contents + ] + for t in threads: + t.start() + for t in threads: + t.join() + + # Check for errors during writes + self.assertEqual( + errors, [], f'Errors occurred during concurrent writes: {errors}' + ) + + # Read final content and verify it's one of the valid strings + final_content = self.store.read(filename) + self.assertIn( + final_content, + valid_contents, + f"File content '{final_content}' is not one of the valid strings. " + f'Length: {len(final_content)}. This indicates file corruption from ' + f'concurrent writes (e.g., shorter write did not fully replace longer write).', + ) + class TestInMemoryFileStore(TestCase, _StorageTest): def setUp(self): From dcb2e21b87b87d9a52c23c8a44f3d907f0648f60 Mon Sep 17 00:00:00 2001 From: Saurya Velagapudi Date: Wed, 18 Mar 2026 17:07:19 -0700 Subject: [PATCH 08/28] feat: Auto-forward LLM_* env vars to agent-server and fix host network config (#13192) Co-authored-by: openhands --- .../sandbox/docker_sandbox_service.py | 21 +- .../sandbox/sandbox_spec_service.py | 53 ++- .../test_agent_server_env_override.py | 422 +++++++++++++++++- .../app_server/test_docker_sandbox_service.py | 53 +++ 4 files changed, 528 insertions(+), 21 deletions(-) diff --git a/openhands/app_server/sandbox/docker_sandbox_service.py b/openhands/app_server/sandbox/docker_sandbox_service.py index 6c692a680a..f5a302fa73 100644 --- a/openhands/app_server/sandbox/docker_sandbox_service.py +++ b/openhands/app_server/sandbox/docker_sandbox_service.py @@ -43,6 +43,16 @@ _logger = logging.getLogger(__name__) STARTUP_GRACE_SECONDS = 15 +def _get_use_host_network_default() -> bool: + """Get the default value for use_host_network from environment variables. + + This function is called at runtime (not at class definition time) to ensure + that environment variable changes are picked up correctly. + """ + value = os.getenv('AGENT_SERVER_USE_HOST_NETWORK', '') + return value.lower() in ('true', '1', 'yes') + + class VolumeMount(BaseModel): """Mounted volume within the container.""" @@ -591,18 +601,13 @@ class DockerSandboxServiceInjector(SandboxServiceInjector): ), ) use_host_network: bool = Field( - default=os.getenv('SANDBOX_USE_HOST_NETWORK', '').lower() - in ( - 'true', - '1', - 'yes', - ), + default_factory=_get_use_host_network_default, description=( - 'Whether to use host networking mode for sandbox containers. ' + 'Whether to use host networking mode for agent-server containers. ' 'When enabled, containers share the host network namespace, ' 'making all container ports directly accessible on the host. ' 'This is useful for reverse proxy setups where dynamic port mapping ' - 'is problematic. Configure via OH_SANDBOX_USE_HOST_NETWORK environment variable.' + 'is problematic. Configure via AGENT_SERVER_USE_HOST_NETWORK environment variable.' ), ) diff --git a/openhands/app_server/sandbox/sandbox_spec_service.py b/openhands/app_server/sandbox/sandbox_spec_service.py index 4034af1f5b..5025bdee6b 100644 --- a/openhands/app_server/sandbox/sandbox_spec_service.py +++ b/openhands/app_server/sandbox/sandbox_spec_service.py @@ -69,27 +69,58 @@ def get_agent_server_image() -> str: return AGENT_SERVER_IMAGE +# Prefixes for environment variables that should be auto-forwarded to agent-server +# These are typically configuration variables that affect the agent's behavior +AUTO_FORWARD_PREFIXES = ('LLM_',) + + def get_agent_server_env() -> dict[str, str]: """Get environment variables to be injected into agent server sandbox environments. - This function reads environment variable overrides from the OH_AGENT_SERVER_ENV - environment variable, which should contain a JSON string mapping variable names - to their values. + This function combines two sources of environment variables: + + 1. **Auto-forwarded variables**: Environment variables with certain prefixes + (e.g., LLM_*) are automatically forwarded to the agent-server container. + This ensures that LLM configuration like timeouts and retry settings + work correctly in the two-container V1 architecture. + + 2. **Explicit overrides via OH_AGENT_SERVER_ENV**: A JSON string that allows + setting arbitrary environment variables in the agent-server container. + Values set here take precedence over auto-forwarded variables. + + Auto-forwarded prefixes: + - LLM_* : LLM configuration (timeout, retries, model settings, etc.) Usage: - Set OH_AGENT_SERVER_ENV to a JSON string: - OH_AGENT_SERVER_ENV='{"DEBUG": "true", "LOG_LEVEL": "info", "CUSTOM_VAR": "value"}' + # Auto-forwarding (no action needed): + export LLM_TIMEOUT=3600 + export LLM_NUM_RETRIES=10 + # These will automatically be available in the agent-server - This will inject the following environment variables into all sandbox environments: - - DEBUG=true - - LOG_LEVEL=info - - CUSTOM_VAR=value + # Explicit override via JSON: + OH_AGENT_SERVER_ENV='{"DEBUG": "true", "CUSTOM_VAR": "value"}' + + # Override an auto-forwarded variable: + export LLM_TIMEOUT=3600 # Would be auto-forwarded as 3600 + OH_AGENT_SERVER_ENV='{"LLM_TIMEOUT": "7200"}' # Overrides to 7200 Returns: dict[str, str]: Dictionary of environment variable names to values. - Returns empty dict if OH_AGENT_SERVER_ENV is not set or invalid. + Returns empty dict if no variables are found. Raises: JSONDecodeError: If OH_AGENT_SERVER_ENV contains invalid JSON. """ - return env_parser.from_env(dict[str, str], 'OH_AGENT_SERVER_ENV') + result: dict[str, str] = {} + + # Step 1: Auto-forward environment variables with recognized prefixes + for key, value in os.environ.items(): + if any(key.startswith(prefix) for prefix in AUTO_FORWARD_PREFIXES): + result[key] = value + + # Step 2: Apply explicit overrides from OH_AGENT_SERVER_ENV + # These take precedence over auto-forwarded variables + explicit_env = env_parser.from_env(dict[str, str], 'OH_AGENT_SERVER_ENV') + result.update(explicit_env) + + return result diff --git a/tests/unit/app_server/test_agent_server_env_override.py b/tests/unit/app_server/test_agent_server_env_override.py index 61d851590e..5c1c1ea208 100644 --- a/tests/unit/app_server/test_agent_server_env_override.py +++ b/tests/unit/app_server/test_agent_server_env_override.py @@ -2,10 +2,11 @@ This module tests the environment variable override functionality that allows users to inject custom environment variables into sandbox environments via -OH_AGENT_SERVER_ENV_* environment variables. +OH_AGENT_SERVER_ENV environment variable and auto-forwarding of LLM_* variables. The functionality includes: -- Parsing OH_AGENT_SERVER_ENV_* environment variables +- Auto-forwarding of LLM_* environment variables to agent-server containers +- Explicit overrides via OH_AGENT_SERVER_ENV JSON - Merging them into sandbox specifications - Integration across different sandbox types (Docker, Process, Remote) """ @@ -25,6 +26,7 @@ from openhands.app_server.sandbox.remote_sandbox_spec_service import ( get_default_sandbox_specs as get_default_remote_sandbox_specs, ) from openhands.app_server.sandbox.sandbox_spec_service import ( + AUTO_FORWARD_PREFIXES, get_agent_server_env, ) @@ -185,6 +187,114 @@ class TestGetAgentServerEnv: assert result == expected +class TestLLMAutoForwarding: + """Test cases for automatic forwarding of LLM_* environment variables.""" + + def test_auto_forward_prefixes_contains_llm(self): + """Test that LLM_ is in the auto-forward prefixes.""" + assert 'LLM_' in AUTO_FORWARD_PREFIXES + + def test_llm_timeout_auto_forwarded(self): + """Test that LLM_TIMEOUT is automatically forwarded.""" + env_vars = { + 'LLM_TIMEOUT': '3600', + 'OTHER_VAR': 'should_not_be_included', + } + + with patch.dict(os.environ, env_vars, clear=True): + result = get_agent_server_env() + assert 'LLM_TIMEOUT' in result + assert result['LLM_TIMEOUT'] == '3600' + assert 'OTHER_VAR' not in result + + def test_llm_num_retries_auto_forwarded(self): + """Test that LLM_NUM_RETRIES is automatically forwarded.""" + env_vars = { + 'LLM_NUM_RETRIES': '10', + } + + with patch.dict(os.environ, env_vars, clear=True): + result = get_agent_server_env() + assert 'LLM_NUM_RETRIES' in result + assert result['LLM_NUM_RETRIES'] == '10' + + def test_multiple_llm_vars_auto_forwarded(self): + """Test that multiple LLM_* variables are automatically forwarded.""" + env_vars = { + 'LLM_TIMEOUT': '3600', + 'LLM_NUM_RETRIES': '10', + 'LLM_MODEL': 'gpt-4', + 'LLM_BASE_URL': 'https://api.example.com', + 'LLM_API_KEY': 'secret-key', + 'NON_LLM_VAR': 'should_not_be_included', + } + + with patch.dict(os.environ, env_vars, clear=True): + result = get_agent_server_env() + assert result['LLM_TIMEOUT'] == '3600' + assert result['LLM_NUM_RETRIES'] == '10' + assert result['LLM_MODEL'] == 'gpt-4' + assert result['LLM_BASE_URL'] == 'https://api.example.com' + assert result['LLM_API_KEY'] == 'secret-key' + assert 'NON_LLM_VAR' not in result + + def test_explicit_override_takes_precedence(self): + """Test that OH_AGENT_SERVER_ENV overrides auto-forwarded variables.""" + env_vars = { + 'LLM_TIMEOUT': '3600', # Auto-forwarded value + 'OH_AGENT_SERVER_ENV': '{"LLM_TIMEOUT": "7200"}', # Explicit override + } + + with patch.dict(os.environ, env_vars, clear=True): + result = get_agent_server_env() + # Explicit override should win + assert result['LLM_TIMEOUT'] == '7200' + + def test_combined_auto_forward_and_explicit(self): + """Test combining auto-forwarded and explicit variables.""" + env_vars = { + 'LLM_TIMEOUT': '3600', # Auto-forwarded + 'LLM_NUM_RETRIES': '10', # Auto-forwarded + 'OH_AGENT_SERVER_ENV': '{"DEBUG": "true", "CUSTOM_VAR": "value"}', # Explicit + } + + with patch.dict(os.environ, env_vars, clear=True): + result = get_agent_server_env() + # Auto-forwarded + assert result['LLM_TIMEOUT'] == '3600' + assert result['LLM_NUM_RETRIES'] == '10' + # Explicit + assert result['DEBUG'] == 'true' + assert result['CUSTOM_VAR'] == 'value' + + def test_no_llm_vars_returns_empty_without_explicit(self): + """Test that no LLM_* vars and no explicit env returns empty dict.""" + env_vars = { + 'SOME_OTHER_VAR': 'value', + 'ANOTHER_VAR': 'another_value', + } + + with patch.dict(os.environ, env_vars, clear=True): + result = get_agent_server_env() + assert result == {} + + def test_llm_prefix_is_case_sensitive(self): + """Test that LLM_ prefix matching is case-sensitive.""" + env_vars = { + 'LLM_TIMEOUT': '3600', # Should be included + 'llm_timeout': 'lowercase', # Should NOT be included (wrong case) + 'Llm_Timeout': 'mixed', # Should NOT be included (wrong case) + } + + with patch.dict(os.environ, env_vars, clear=True): + result = get_agent_server_env() + assert 'LLM_TIMEOUT' in result + assert result['LLM_TIMEOUT'] == '3600' + # Lowercase variants should not be included + assert 'llm_timeout' not in result + assert 'Llm_Timeout' not in result + + class TestDockerSandboxSpecEnvironmentOverride: """Test environment variable override integration in Docker sandbox specs.""" @@ -476,3 +586,311 @@ class TestEnvironmentOverrideIntegration: # Should not have the old variables assert 'VAR1' not in spec_2.initial_env assert 'VAR2' not in spec_2.initial_env + + +class TestDockerSandboxServiceEnvIntegration: + """Integration tests for environment variable propagation to Docker sandbox containers. + + These tests verify that environment variables are correctly propagated through + the entire flow from the app-server environment to the agent-server container. + """ + + def test_llm_env_vars_propagated_to_container_run(self): + """Test that LLM_* env vars are included in docker container.run() environment argument.""" + from unittest.mock import patch + + # Set up environment with LLM_* variables + env_vars = { + 'LLM_TIMEOUT': '3600', + 'LLM_NUM_RETRIES': '10', + 'LLM_MODEL': 'gpt-4', + 'OTHER_VAR': 'should_not_be_forwarded', + } + + with patch.dict(os.environ, env_vars, clear=True): + # Create a sandbox spec using the actual factory to get LLM_* vars + specs = get_default_docker_sandbox_specs() + sandbox_spec = specs[0] + + # Verify the sandbox spec has the LLM_* variables + assert 'LLM_TIMEOUT' in sandbox_spec.initial_env + assert sandbox_spec.initial_env['LLM_TIMEOUT'] == '3600' + assert 'LLM_NUM_RETRIES' in sandbox_spec.initial_env + assert sandbox_spec.initial_env['LLM_NUM_RETRIES'] == '10' + assert 'LLM_MODEL' in sandbox_spec.initial_env + assert sandbox_spec.initial_env['LLM_MODEL'] == 'gpt-4' + # Non-LLM_* variables should not be included + assert 'OTHER_VAR' not in sandbox_spec.initial_env + + def test_explicit_oh_agent_server_env_overrides_llm_vars(self): + """Test that OH_AGENT_SERVER_ENV can override auto-forwarded LLM_* variables.""" + env_vars = { + 'LLM_TIMEOUT': '3600', # Auto-forwarded value + 'OH_AGENT_SERVER_ENV': '{"LLM_TIMEOUT": "7200"}', # Override value + } + + with patch.dict(os.environ, env_vars, clear=True): + specs = get_default_docker_sandbox_specs() + sandbox_spec = specs[0] + + # OH_AGENT_SERVER_ENV should take precedence + assert sandbox_spec.initial_env['LLM_TIMEOUT'] == '7200' + + def test_multiple_llm_vars_combined_with_explicit_overrides(self): + """Test complex scenario with multiple LLM_* vars and explicit overrides.""" + env_vars = { + 'LLM_TIMEOUT': '3600', + 'LLM_NUM_RETRIES': '10', + 'LLM_MODEL': 'gpt-4', + 'LLM_TEMPERATURE': '0.7', + 'OH_AGENT_SERVER_ENV': '{"LLM_MODEL": "gpt-3.5-turbo", "CUSTOM_VAR": "custom_value"}', + } + + with patch.dict(os.environ, env_vars, clear=True): + specs = get_default_docker_sandbox_specs() + sandbox_spec = specs[0] + + # Auto-forwarded LLM_* vars that weren't overridden + assert sandbox_spec.initial_env['LLM_TIMEOUT'] == '3600' + assert sandbox_spec.initial_env['LLM_NUM_RETRIES'] == '10' + assert sandbox_spec.initial_env['LLM_TEMPERATURE'] == '0.7' + + # LLM_MODEL should be overridden by OH_AGENT_SERVER_ENV + assert sandbox_spec.initial_env['LLM_MODEL'] == 'gpt-3.5-turbo' + + # Custom variable from OH_AGENT_SERVER_ENV + assert sandbox_spec.initial_env['CUSTOM_VAR'] == 'custom_value' + + def test_sandbox_spec_env_passed_to_docker_container_run(self): + """Test that sandbox spec's initial_env is passed to docker container run.""" + from unittest.mock import AsyncMock, MagicMock, patch + + import httpx + + from openhands.app_server.sandbox.docker_sandbox_service import ( + DockerSandboxService, + ExposedPort, + ) + + # Create mock docker client + mock_docker_client = MagicMock() + mock_container = MagicMock() + mock_container.name = 'oh-test-abc123' + mock_container.image.tags = ['test-image:latest'] + mock_container.attrs = { + 'Created': '2024-01-01T00:00:00Z', + 'Config': { + 'Env': ['SESSION_API_KEY=test-key'], + 'WorkingDir': '/workspace', + }, + 'NetworkSettings': {'Ports': {'8000/tcp': [{'HostPort': '32768'}]}}, + 'HostConfig': {'NetworkMode': 'bridge'}, + } + mock_container.status = 'running' + mock_docker_client.containers.run.return_value = mock_container + mock_docker_client.containers.list.return_value = [] + + # Create mock sandbox spec service + mock_spec_service = MagicMock() + + # Create sandbox spec with LLM_* environment variables + env_vars = { + 'LLM_TIMEOUT': '3600', + 'LLM_NUM_RETRIES': '10', + } + + with patch.dict(os.environ, env_vars, clear=True): + specs = get_default_docker_sandbox_specs() + sandbox_spec = specs[0] + + mock_spec_service.get_default_sandbox_spec = AsyncMock( + return_value=sandbox_spec + ) + + # Create service + service = DockerSandboxService( + sandbox_spec_service=mock_spec_service, + container_name_prefix='oh-test-', + host_port=3000, + container_url_pattern='http://localhost:{port}', + mounts=[], + exposed_ports=[ + ExposedPort( + name='AGENT_SERVER', + description='Agent server', + container_port=8000, + ) + ], + health_check_path='/health', + httpx_client=MagicMock(spec=httpx.AsyncClient), + max_num_sandboxes=5, + docker_client=mock_docker_client, + ) + + # Start sandbox + import asyncio + + asyncio.get_event_loop().run_until_complete(service.start_sandbox()) + + # Verify docker was called with environment variables including LLM_* + call_kwargs = mock_docker_client.containers.run.call_args[1] + container_env = call_kwargs['environment'] + + # LLM_* variables should be in the container environment + assert 'LLM_TIMEOUT' in container_env + assert container_env['LLM_TIMEOUT'] == '3600' + assert 'LLM_NUM_RETRIES' in container_env + assert container_env['LLM_NUM_RETRIES'] == '10' + + # Default variables should also be present + assert 'OPENVSCODE_SERVER_ROOT' in container_env + assert 'LOG_JSON' in container_env + + def test_host_network_mode_with_env_var(self): + """Test that AGENT_SERVER_USE_HOST_NETWORK affects container network mode.""" + from unittest.mock import AsyncMock, MagicMock, patch + + import httpx + + from openhands.app_server.sandbox.docker_sandbox_service import ( + DockerSandboxService, + ExposedPort, + _get_use_host_network_default, + ) + + # Test with environment variable set + with patch.dict( + os.environ, {'AGENT_SERVER_USE_HOST_NETWORK': 'true'}, clear=True + ): + assert _get_use_host_network_default() is True + + # Create mock docker client + mock_docker_client = MagicMock() + mock_container = MagicMock() + mock_container.name = 'oh-test-abc123' + mock_container.image.tags = ['test-image:latest'] + mock_container.attrs = { + 'Created': '2024-01-01T00:00:00Z', + 'Config': { + 'Env': ['SESSION_API_KEY=test-key'], + 'WorkingDir': '/workspace', + }, + 'NetworkSettings': {'Ports': {}}, + 'HostConfig': {'NetworkMode': 'host'}, + } + mock_container.status = 'running' + mock_docker_client.containers.run.return_value = mock_container + mock_docker_client.containers.list.return_value = [] + + # Create mock sandbox spec service + mock_spec_service = MagicMock() + specs = get_default_docker_sandbox_specs() + mock_spec_service.get_default_sandbox_spec = AsyncMock( + return_value=specs[0] + ) + + # Create service with host network enabled + service = DockerSandboxService( + sandbox_spec_service=mock_spec_service, + container_name_prefix='oh-test-', + host_port=3000, + container_url_pattern='http://localhost:{port}', + mounts=[], + exposed_ports=[ + ExposedPort( + name='AGENT_SERVER', + description='Agent server', + container_port=8000, + ) + ], + health_check_path='/health', + httpx_client=MagicMock(spec=httpx.AsyncClient), + max_num_sandboxes=5, + docker_client=mock_docker_client, + use_host_network=True, + ) + + # Start sandbox + import asyncio + + asyncio.get_event_loop().run_until_complete(service.start_sandbox()) + + # Verify docker was called with host network mode + call_kwargs = mock_docker_client.containers.run.call_args[1] + assert call_kwargs['network_mode'] == 'host' + # Port mappings should be None in host network mode + assert call_kwargs['ports'] is None + + def test_bridge_network_mode_without_env_var(self): + """Test that default (bridge) network mode is used when env var is not set.""" + from unittest.mock import AsyncMock, MagicMock, patch + + import httpx + + from openhands.app_server.sandbox.docker_sandbox_service import ( + DockerSandboxService, + ExposedPort, + _get_use_host_network_default, + ) + + # Test without environment variable + with patch.dict(os.environ, {}, clear=True): + assert _get_use_host_network_default() is False + + # Create mock docker client + mock_docker_client = MagicMock() + mock_container = MagicMock() + mock_container.name = 'oh-test-abc123' + mock_container.image.tags = ['test-image:latest'] + mock_container.attrs = { + 'Created': '2024-01-01T00:00:00Z', + 'Config': { + 'Env': ['SESSION_API_KEY=test-key'], + 'WorkingDir': '/workspace', + }, + 'NetworkSettings': {'Ports': {'8000/tcp': [{'HostPort': '32768'}]}}, + 'HostConfig': {'NetworkMode': 'bridge'}, + } + mock_container.status = 'running' + mock_docker_client.containers.run.return_value = mock_container + mock_docker_client.containers.list.return_value = [] + + # Create mock sandbox spec service + mock_spec_service = MagicMock() + specs = get_default_docker_sandbox_specs() + mock_spec_service.get_default_sandbox_spec = AsyncMock( + return_value=specs[0] + ) + + # Create service with bridge network (default) + service = DockerSandboxService( + sandbox_spec_service=mock_spec_service, + container_name_prefix='oh-test-', + host_port=3000, + container_url_pattern='http://localhost:{port}', + mounts=[], + exposed_ports=[ + ExposedPort( + name='AGENT_SERVER', + description='Agent server', + container_port=8000, + ) + ], + health_check_path='/health', + httpx_client=MagicMock(spec=httpx.AsyncClient), + max_num_sandboxes=5, + docker_client=mock_docker_client, + use_host_network=False, + ) + + # Start sandbox + import asyncio + + asyncio.get_event_loop().run_until_complete(service.start_sandbox()) + + # Verify docker was called with bridge network mode (network_mode=None) + call_kwargs = mock_docker_client.containers.run.call_args[1] + assert call_kwargs['network_mode'] is None + # Port mappings should be present in bridge mode + assert call_kwargs['ports'] is not None + assert 8000 in call_kwargs['ports'] diff --git a/tests/unit/app_server/test_docker_sandbox_service.py b/tests/unit/app_server/test_docker_sandbox_service.py index 23a6d51b04..f6ae716eef 100644 --- a/tests/unit/app_server/test_docker_sandbox_service.py +++ b/tests/unit/app_server/test_docker_sandbox_service.py @@ -1254,6 +1254,59 @@ class TestDockerSandboxServiceInjector: injector = DockerSandboxServiceInjector(use_host_network=True) assert injector.use_host_network is True + def test_use_host_network_from_agent_server_env_var(self): + """Test that AGENT_SERVER_USE_HOST_NETWORK env var enables host network mode.""" + import os + from unittest.mock import patch + + from openhands.app_server.sandbox.docker_sandbox_service import ( + DockerSandboxServiceInjector, + ) + + env_vars = { + 'AGENT_SERVER_USE_HOST_NETWORK': 'true', + } + + with patch.dict(os.environ, env_vars, clear=True): + injector = DockerSandboxServiceInjector() + assert injector.use_host_network is True + + def test_use_host_network_env_var_accepts_various_true_values(self): + """Test that use_host_network accepts various truthy values.""" + import os + from unittest.mock import patch + + from openhands.app_server.sandbox.docker_sandbox_service import ( + DockerSandboxServiceInjector, + ) + + for true_value in ['true', 'TRUE', 'True', '1', 'yes', 'YES', 'Yes']: + env_vars = {'AGENT_SERVER_USE_HOST_NETWORK': true_value} + with patch.dict(os.environ, env_vars, clear=True): + injector = DockerSandboxServiceInjector() + assert injector.use_host_network is True, ( + f'Failed for value: {true_value}' + ) + + def test_use_host_network_env_var_defaults_to_false(self): + """Test that unset or empty env var defaults to False.""" + import os + from unittest.mock import patch + + from openhands.app_server.sandbox.docker_sandbox_service import ( + DockerSandboxServiceInjector, + ) + + # Empty environment + with patch.dict(os.environ, {}, clear=True): + injector = DockerSandboxServiceInjector() + assert injector.use_host_network is False + + # Empty string + with patch.dict(os.environ, {'AGENT_SERVER_USE_HOST_NETWORK': ''}, clear=True): + injector = DockerSandboxServiceInjector() + assert injector.use_host_network is False + class TestDockerSandboxServiceInjectorFromEnv: """Test cases for DockerSandboxServiceInjector environment variable configuration.""" From a96760eea70fa99b1bcb73bd6a35171c596366d1 Mon Sep 17 00:00:00 2001 From: Saurya Velagapudi Date: Wed, 18 Mar 2026 17:16:43 -0700 Subject: [PATCH 09/28] fix: ensure LiteLLM user exists before generating API keys (#12667) Co-authored-by: openhands --- enterprise/storage/lite_llm_manager.py | 84 +++++- .../tests/unit/test_lite_llm_manager.py | 284 ++++++++++++++++-- 2 files changed, 331 insertions(+), 37 deletions(-) diff --git a/enterprise/storage/lite_llm_manager.py b/enterprise/storage/lite_llm_manager.py index 725b8147a3..b515b7a7d9 100644 --- a/enterprise/storage/lite_llm_manager.py +++ b/enterprise/storage/lite_llm_manager.py @@ -164,9 +164,33 @@ class LiteLlmManager: ) if create_user: - await LiteLlmManager._create_user( + user_created = await LiteLlmManager._create_user( client, keycloak_user_info.get('email'), keycloak_user_id ) + if not user_created: + logger.error( + 'create_entries_failed_user_creation', + extra={ + 'org_id': org_id, + 'user_id': keycloak_user_id, + }, + ) + return None + + # Verify user exists before proceeding with key generation + user_exists = await LiteLlmManager._user_exists( + client, keycloak_user_id + ) + if not user_exists: + logger.error( + 'create_entries_user_not_found_before_key_generation', + extra={ + 'org_id': org_id, + 'user_id': keycloak_user_id, + 'create_user_flag': create_user, + }, + ) + return None await LiteLlmManager._add_user_to_team( client, keycloak_user_id, org_id, team_budget @@ -655,15 +679,48 @@ class LiteLlmManager: ) response.raise_for_status() + @staticmethod + async def _user_exists( + client: httpx.AsyncClient, + user_id: str, + ) -> bool: + """Check if a user exists in LiteLLM. + + Returns True if the user exists, False otherwise. + """ + if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None: + return False + try: + response = await client.get( + f'{LITE_LLM_API_URL}/user/info?user_id={user_id}', + ) + if response.is_success: + user_data = response.json() + # Check that user_info exists and has the user_id + user_info = user_data.get('user_info', {}) + return user_info.get('user_id') == user_id + return False + except Exception as e: + logger.warning( + 'litellm_user_exists_check_failed', + extra={'user_id': user_id, 'error': str(e)}, + ) + return False + @staticmethod async def _create_user( client: httpx.AsyncClient, email: str | None, keycloak_user_id: str, - ): + ) -> bool: + """Create a user in LiteLLM. + + Returns True if the user was created or already exists and is verified, + False if creation failed and user does not exist. + """ if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None: logger.warning('LiteLLM API configuration not found') - return + return False response = await client.post( f'{LITE_LLM_API_URL}/user/new', json={ @@ -716,17 +773,33 @@ class LiteLlmManager: 'user_id': keycloak_user_id, }, ) - return + # Verify the user actually exists before returning success + user_exists = await LiteLlmManager._user_exists( + client, keycloak_user_id + ) + if not user_exists: + logger.error( + 'litellm_user_claimed_exists_but_not_found', + extra={ + 'user_id': keycloak_user_id, + 'status_code': response.status_code, + 'text': response.text, + }, + ) + return False + return True logger.error( 'error_creating_litellm_user', extra={ 'status_code': response.status_code, 'text': response.text, - 'user_id': [keycloak_user_id], + 'user_id': keycloak_user_id, 'email': None, }, ) + return False response.raise_for_status() + return True @staticmethod async def _get_user(client: httpx.AsyncClient, user_id: str) -> dict | None: @@ -1450,6 +1523,7 @@ class LiteLlmManager: create_team = staticmethod(with_http_client(_create_team)) get_team = staticmethod(with_http_client(_get_team)) update_team = staticmethod(with_http_client(_update_team)) + user_exists = staticmethod(with_http_client(_user_exists)) create_user = staticmethod(with_http_client(_create_user)) get_user = staticmethod(with_http_client(_get_user)) update_user = staticmethod(with_http_client(_update_user)) diff --git a/enterprise/tests/unit/test_lite_llm_manager.py b/enterprise/tests/unit/test_lite_llm_manager.py index 0cfc9fe58b..3da159421d 100644 --- a/enterprise/tests/unit/test_lite_llm_manager.py +++ b/enterprise/tests/unit/test_lite_llm_manager.py @@ -239,6 +239,16 @@ class TestLiteLlmManager: mock_404_response = MagicMock() mock_404_response.status_code = 404 mock_404_response.is_success = False + mock_404_response.raise_for_status.side_effect = httpx.HTTPStatusError( + message='Not Found', request=MagicMock(), response=mock_404_response + ) + + # Mock user exists check response + mock_user_exists_response = MagicMock() + mock_user_exists_response.is_success = True + mock_user_exists_response.json.return_value = { + 'user_info': {'user_id': 'test-user-id'} + } mock_token_manager = MagicMock() mock_token_manager.return_value.get_user_info_from_user_id = AsyncMock( @@ -246,12 +256,8 @@ class TestLiteLlmManager: ) mock_client = AsyncMock() - mock_client.get.return_value = mock_404_response - mock_client.get.return_value.raise_for_status.side_effect = ( - httpx.HTTPStatusError( - message='Not Found', request=MagicMock(), response=mock_404_response - ) - ) + # First GET is for _get_team (404), second GET is for _user_exists (success) + mock_client.get.side_effect = [mock_404_response, mock_user_exists_response] mock_client.post.return_value = mock_response mock_client_class = MagicMock() @@ -274,8 +280,8 @@ class TestLiteLlmManager: assert result.llm_api_key.get_secret_value() == 'test-api-key' assert result.llm_base_url == 'http://test.com' - # Verify API calls were made (get_team + 4 posts) - assert mock_client.get.call_count == 1 # get_team + # Verify API calls were made (get_team + user_exists + 4 posts) + assert mock_client.get.call_count == 2 # get_team + user_exists assert ( mock_client.post.call_count == 4 ) # create_team, add_user_to_team, delete_key_by_alias, generate_key @@ -294,13 +300,21 @@ class TestLiteLlmManager: } mock_team_response.raise_for_status = MagicMock() + # Mock user exists check response + mock_user_exists_response = MagicMock() + mock_user_exists_response.is_success = True + mock_user_exists_response.json.return_value = { + 'user_info': {'user_id': 'test-user-id'} + } + mock_token_manager = MagicMock() mock_token_manager.return_value.get_user_info_from_user_id = AsyncMock( return_value={'email': 'test@example.com'} ) mock_client = AsyncMock() - mock_client.get.return_value = mock_team_response + # First GET is for _get_team (success), second GET is for _user_exists (success) + mock_client.get.side_effect = [mock_team_response, mock_user_exists_response] mock_client.post.return_value = mock_response mock_client_class = MagicMock() @@ -320,8 +334,8 @@ class TestLiteLlmManager: assert result is not None # Verify _get_team was called first - mock_client.get.assert_called_once() - get_call_url = mock_client.get.call_args[0][0] + assert mock_client.get.call_count == 2 # get_team + user_exists + get_call_url = mock_client.get.call_args_list[0][0][0] assert 'team/info' in get_call_url assert 'test-org-id' in get_call_url @@ -343,19 +357,25 @@ class TestLiteLlmManager: mock_404_response = MagicMock() mock_404_response.status_code = 404 mock_404_response.is_success = False + mock_404_response.raise_for_status.side_effect = httpx.HTTPStatusError( + message='Not Found', request=MagicMock(), response=mock_404_response + ) mock_token_manager = MagicMock() mock_token_manager.return_value.get_user_info_from_user_id = AsyncMock( return_value={'email': 'test@example.com'} ) + # Mock user exists check response + mock_user_exists_response = MagicMock() + mock_user_exists_response.is_success = True + mock_user_exists_response.json.return_value = { + 'user_info': {'user_id': 'test-user-id'} + } + mock_client = AsyncMock() - mock_client.get.return_value = mock_404_response - mock_client.get.return_value.raise_for_status.side_effect = ( - httpx.HTTPStatusError( - message='Not Found', request=MagicMock(), response=mock_404_response - ) - ) + # First GET is for _get_team (404), second GET is for _user_exists (success) + mock_client.get.side_effect = [mock_404_response, mock_user_exists_response] mock_client.post.return_value = mock_response mock_client_class = MagicMock() @@ -393,6 +413,16 @@ class TestLiteLlmManager: mock_404_response = MagicMock() mock_404_response.status_code = 404 mock_404_response.is_success = False + mock_404_response.raise_for_status.side_effect = httpx.HTTPStatusError( + message='Not Found', request=MagicMock(), response=mock_404_response + ) + + # Mock user exists check response + mock_user_exists_response = MagicMock() + mock_user_exists_response.is_success = True + mock_user_exists_response.json.return_value = { + 'user_info': {'user_id': 'test-user-id'} + } mock_token_manager = MagicMock() mock_token_manager.return_value.get_user_info_from_user_id = AsyncMock( @@ -400,12 +430,8 @@ class TestLiteLlmManager: ) mock_client = AsyncMock() - mock_client.get.return_value = mock_404_response - mock_client.get.return_value.raise_for_status.side_effect = ( - httpx.HTTPStatusError( - message='Not Found', request=MagicMock(), response=mock_404_response - ) - ) + # First GET is for _get_team (404), second GET is for _user_exists (success) + mock_client.get.side_effect = [mock_404_response, mock_user_exists_response] mock_client.post.return_value = mock_response mock_client_class = MagicMock() @@ -833,15 +859,16 @@ class TestLiteLlmManager: @pytest.mark.asyncio async def test_create_user_success(self, mock_http_client, mock_response): - """Test successful _create_user operation.""" + """Test successful _create_user operation returns True.""" mock_http_client.post.return_value = mock_response with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'): with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'): - await LiteLlmManager._create_user( + result = await LiteLlmManager._create_user( mock_http_client, 'test@example.com', 'test-user-id' ) + assert result is True mock_http_client.post.assert_called_once() call_args = mock_http_client.post.call_args assert 'http://test.com/user/new' in call_args[0] @@ -850,7 +877,7 @@ class TestLiteLlmManager: @pytest.mark.asyncio async def test_create_user_duplicate_email(self, mock_http_client, mock_response): - """Test _create_user with duplicate email handling.""" + """Test _create_user with duplicate email handling returns True.""" # First call fails with duplicate email error_response = MagicMock() error_response.is_success = False @@ -862,23 +889,81 @@ class TestLiteLlmManager: with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'): with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'): - await LiteLlmManager._create_user( + result = await LiteLlmManager._create_user( mock_http_client, 'test@example.com', 'test-user-id' ) + assert result is True assert mock_http_client.post.call_count == 2 # Second call should have None email second_call_args = mock_http_client.post.call_args_list[1] assert second_call_args[1]['json']['user_email'] is None + @pytest.mark.asyncio + @patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com') + @patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key') + async def test_user_exists_returns_true(self, mock_http_client): + """Test _user_exists returns True when user exists in LiteLLM.""" + # Arrange + user_response = MagicMock() + user_response.is_success = True + user_response.json.return_value = { + 'user_info': {'user_id': 'test-user-id', 'email': 'test@example.com'} + } + mock_http_client.get.return_value = user_response + + # Act + result = await LiteLlmManager._user_exists(mock_http_client, 'test-user-id') + + # Assert + assert result is True + mock_http_client.get.assert_called_once() + + @pytest.mark.asyncio + @patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com') + @patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key') + async def test_user_exists_returns_false_when_not_found(self, mock_http_client): + """Test _user_exists returns False when user not found.""" + # Arrange + user_response = MagicMock() + user_response.is_success = False + mock_http_client.get.return_value = user_response + + # Act + result = await LiteLlmManager._user_exists(mock_http_client, 'test-user-id') + + # Assert + assert result is False + + @pytest.mark.asyncio + @patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com') + @patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key') + async def test_user_exists_returns_false_on_mismatched_user_id( + self, mock_http_client + ): + """Test _user_exists returns False when returned user_id doesn't match.""" + # Arrange + user_response = MagicMock() + user_response.is_success = True + user_response.json.return_value = { + 'user_info': {'user_id': 'different-user-id'} + } + mock_http_client.get.return_value = user_response + + # Act + result = await LiteLlmManager._user_exists(mock_http_client, 'test-user-id') + + # Assert + assert result is False + @pytest.mark.asyncio @patch('storage.lite_llm_manager.logger') @patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com') @patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key') - async def test_create_user_already_exists_with_409_status_code( + async def test_create_user_already_exists_and_verified( self, mock_logger, mock_http_client ): - """Test _create_user handles 409 Conflict when user already exists.""" + """Test _create_user returns True when user already exists and is verified.""" # Arrange first_response = MagicMock() first_response.is_success = False @@ -890,14 +975,141 @@ class TestLiteLlmManager: second_response.status_code = 409 second_response.text = 'User with id test-user-id already exists' + user_exists_response = MagicMock() + user_exists_response.is_success = True + user_exists_response.json.return_value = { + 'user_info': {'user_id': 'test-user-id'} + } + mock_http_client.post.side_effect = [first_response, second_response] + mock_http_client.get.return_value = user_exists_response # Act - await LiteLlmManager._create_user( + result = await LiteLlmManager._create_user( mock_http_client, 'test@example.com', 'test-user-id' ) # Assert + assert result is True + mock_logger.warning.assert_any_call( + 'litellm_user_already_exists', + extra={'user_id': 'test-user-id'}, + ) + + @pytest.mark.asyncio + @patch('storage.lite_llm_manager.logger') + @patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com') + @patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key') + async def test_create_user_already_exists_but_not_found_returns_false( + self, mock_logger, mock_http_client + ): + """Test _create_user returns False when LiteLLM claims user exists but verification fails.""" + # Arrange + first_response = MagicMock() + first_response.is_success = False + first_response.status_code = 400 + first_response.text = 'duplicate email' + + second_response = MagicMock() + second_response.is_success = False + second_response.status_code = 409 + second_response.text = 'User with id test-user-id already exists' + + user_not_exists_response = MagicMock() + user_not_exists_response.is_success = False + + mock_http_client.post.side_effect = [first_response, second_response] + mock_http_client.get.return_value = user_not_exists_response + + # Act + result = await LiteLlmManager._create_user( + mock_http_client, 'test@example.com', 'test-user-id' + ) + + # Assert + assert result is False + mock_logger.error.assert_any_call( + 'litellm_user_claimed_exists_but_not_found', + extra={ + 'user_id': 'test-user-id', + 'status_code': 409, + 'text': 'User with id test-user-id already exists', + }, + ) + + @pytest.mark.asyncio + @patch('storage.lite_llm_manager.logger') + @patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com') + @patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key') + async def test_create_user_failure_returns_false( + self, mock_logger, mock_http_client + ): + """Test _create_user returns False when creation fails with non-'already exists' error.""" + # Arrange + first_response = MagicMock() + first_response.is_success = False + first_response.status_code = 400 + first_response.text = 'duplicate email' + + second_response = MagicMock() + second_response.is_success = False + second_response.status_code = 500 + second_response.text = 'Internal server error' + + mock_http_client.post.side_effect = [first_response, second_response] + + # Act + result = await LiteLlmManager._create_user( + mock_http_client, 'test@example.com', 'test-user-id' + ) + + # Assert + assert result is False + mock_logger.error.assert_any_call( + 'error_creating_litellm_user', + extra={ + 'status_code': 500, + 'text': 'Internal server error', + 'user_id': 'test-user-id', + 'email': None, + }, + ) + + @pytest.mark.asyncio + @patch('storage.lite_llm_manager.logger') + @patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com') + @patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key') + async def test_create_user_already_exists_with_409_status_code( + self, mock_logger, mock_http_client + ): + """Test _create_user handles 409 Conflict when user already exists and verifies.""" + # Arrange + first_response = MagicMock() + first_response.is_success = False + first_response.status_code = 400 + first_response.text = 'duplicate email' + + second_response = MagicMock() + second_response.is_success = False + second_response.status_code = 409 + second_response.text = 'User with id test-user-id already exists' + + user_exists_response = MagicMock() + user_exists_response.is_success = True + user_exists_response.json.return_value = { + 'user_info': {'user_id': 'test-user-id'} + } + + mock_http_client.post.side_effect = [first_response, second_response] + mock_http_client.get.return_value = user_exists_response + + # Act + result = await LiteLlmManager._create_user( + mock_http_client, 'test@example.com', 'test-user-id' + ) + + # Assert + assert result is True mock_logger.warning.assert_any_call( 'litellm_user_already_exists', extra={'user_id': 'test-user-id'}, @@ -910,7 +1122,7 @@ class TestLiteLlmManager: async def test_create_user_already_exists_with_400_status_code( self, mock_logger, mock_http_client ): - """Test _create_user handles 400 Bad Request when user already exists.""" + """Test _create_user handles 400 Bad Request when user already exists and verifies.""" # Arrange first_response = MagicMock() first_response.is_success = False @@ -922,14 +1134,22 @@ class TestLiteLlmManager: second_response.status_code = 400 second_response.text = 'User already exists' + user_exists_response = MagicMock() + user_exists_response.is_success = True + user_exists_response.json.return_value = { + 'user_info': {'user_id': 'test-user-id'} + } + mock_http_client.post.side_effect = [first_response, second_response] + mock_http_client.get.return_value = user_exists_response # Act - await LiteLlmManager._create_user( + result = await LiteLlmManager._create_user( mock_http_client, 'test@example.com', 'test-user-id' ) # Assert + assert result is True mock_logger.warning.assert_any_call( 'litellm_user_already_exists', extra={'user_id': 'test-user-id'}, From 8039807c3f4c8bd8eeea96c2fe6bd87399285317 Mon Sep 17 00:00:00 2001 From: Hiep Le <69354317+hieptl@users.noreply.github.com> Date: Thu, 19 Mar 2026 14:18:29 +0700 Subject: [PATCH 10/28] fix(frontend): scope organization data queries by organization ID (org project) (#13459) --- .../analytics-consent-form-modal.test.tsx | 7 +- .../features/chat/messages.test.tsx | 6 +- .../features/home/repo-connector.test.tsx | 2 + .../settings/api-keys-manager.test.tsx | 7 +- .../features/sidebar/sidebar.test.tsx | 7 +- .../components/interactive-chat-box.test.tsx | 21 +- .../context/ws-client-provider.test.tsx | 2 + .../conversation-websocket-handler.test.tsx | 5 + .../hooks/mutation/use-save-settings.test.tsx | 7 +- .../organization-scoped-queries.test.tsx | 225 ++++++++++++++++++ frontend/__tests__/routes/accept-tos.test.tsx | 6 +- .../__tests__/routes/app-settings.test.tsx | 7 +- .../utils/check-hardcoded-strings.test.tsx | 21 +- .../settings/secrets-settings/secret-form.tsx | 10 +- .../hooks/mutation/use-add-git-providers.ts | 6 +- .../src/hooks/mutation/use-add-mcp-server.ts | 6 +- .../src/hooks/mutation/use-create-api-key.ts | 6 +- .../src/hooks/mutation/use-delete-api-key.ts | 6 +- .../hooks/mutation/use-delete-mcp-server.ts | 6 +- .../src/hooks/mutation/use-save-settings.ts | 6 +- .../hooks/mutation/use-switch-organization.ts | 5 +- .../hooks/mutation/use-update-mcp-server.ts | 6 +- frontend/src/hooks/query/use-api-keys.ts | 6 +- frontend/src/hooks/query/use-get-secrets.ts | 6 +- frontend/src/hooks/query/use-settings.ts | 13 +- frontend/src/i18n/declaration.ts | 8 + frontend/src/routes/secrets-settings.tsx | 6 +- frontend/src/routes/settings.tsx | 2 - frontend/src/routes/user-settings.tsx | 12 +- frontend/vitest.setup.ts | 9 + 30 files changed, 391 insertions(+), 51 deletions(-) create mode 100644 frontend/__tests__/hooks/query/organization-scoped-queries.test.tsx diff --git a/frontend/__tests__/components/features/analytics/analytics-consent-form-modal.test.tsx b/frontend/__tests__/components/features/analytics/analytics-consent-form-modal.test.tsx index eb7c39397c..d1e0fbf5dc 100644 --- a/frontend/__tests__/components/features/analytics/analytics-consent-form-modal.test.tsx +++ b/frontend/__tests__/components/features/analytics/analytics-consent-form-modal.test.tsx @@ -1,11 +1,16 @@ import userEvent from "@testing-library/user-event"; -import { describe, expect, it, vi } from "vitest"; +import { describe, expect, it, vi, beforeEach } from "vitest"; import { render, screen, waitFor } from "@testing-library/react"; import { QueryClient, QueryClientProvider } from "@tanstack/react-query"; import { AnalyticsConsentFormModal } from "#/components/features/analytics/analytics-consent-form-modal"; import SettingsService from "#/api/settings-service/settings-service.api"; +import { useSelectedOrganizationStore } from "#/stores/selected-organization-store"; describe("AnalyticsConsentFormModal", () => { + beforeEach(() => { + useSelectedOrganizationStore.setState({ organizationId: "test-org-id" }); + }); + it("should call saveUserSettings with consent", async () => { const user = userEvent.setup(); const onCloseMock = vi.fn(); diff --git a/frontend/__tests__/components/features/chat/messages.test.tsx b/frontend/__tests__/components/features/chat/messages.test.tsx index 577f6db5a1..194ad1ce46 100644 --- a/frontend/__tests__/components/features/chat/messages.test.tsx +++ b/frontend/__tests__/components/features/chat/messages.test.tsx @@ -10,9 +10,12 @@ import { import { OpenHandsObservation } from "#/types/core/observations"; import ConversationService from "#/api/conversation-service/conversation-service.api"; import { Conversation } from "#/api/open-hands.types"; +import { useSelectedOrganizationStore } from "#/stores/selected-organization-store"; -vi.mock("react-router", () => ({ +vi.mock("react-router", async (importOriginal) => ({ + ...(await importOriginal()), useParams: () => ({ conversationId: "123" }), + useRevalidator: () => ({ revalidate: vi.fn() }), })); let queryClient: QueryClient; @@ -47,6 +50,7 @@ const renderMessages = ({ describe("Messages", () => { beforeEach(() => { queryClient = new QueryClient(); + useSelectedOrganizationStore.setState({ organizationId: "test-org-id" }); }); const assistantMessage: AssistantMessageAction = { diff --git a/frontend/__tests__/components/features/home/repo-connector.test.tsx b/frontend/__tests__/components/features/home/repo-connector.test.tsx index 17a43c75ed..4cba3850b4 100644 --- a/frontend/__tests__/components/features/home/repo-connector.test.tsx +++ b/frontend/__tests__/components/features/home/repo-connector.test.tsx @@ -10,6 +10,7 @@ import OptionService from "#/api/option-service/option-service.api"; import { GitRepository } from "#/types/git"; import { RepoConnector } from "#/components/features/home/repo-connector"; import { MOCK_DEFAULT_USER_SETTINGS } from "#/mocks/handlers"; +import { useSelectedOrganizationStore } from "#/stores/selected-organization-store"; const renderRepoConnector = () => { const mockRepoSelection = vi.fn(); @@ -65,6 +66,7 @@ const MOCK_RESPOSITORIES: GitRepository[] = [ ]; beforeEach(() => { + useSelectedOrganizationStore.setState({ organizationId: "test-org-id" }); const getSettingsSpy = vi.spyOn(SettingsService, "getSettings"); getSettingsSpy.mockResolvedValue({ ...MOCK_DEFAULT_USER_SETTINGS, diff --git a/frontend/__tests__/components/features/settings/api-keys-manager.test.tsx b/frontend/__tests__/components/features/settings/api-keys-manager.test.tsx index 6c3f9884a9..b2783c6363 100644 --- a/frontend/__tests__/components/features/settings/api-keys-manager.test.tsx +++ b/frontend/__tests__/components/features/settings/api-keys-manager.test.tsx @@ -1,7 +1,8 @@ import { render, screen } from "@testing-library/react"; -import { describe, expect, it, vi } from "vitest"; +import { describe, expect, it, vi, beforeEach } from "vitest"; import { QueryClient, QueryClientProvider } from "@tanstack/react-query"; import { ApiKeysManager } from "#/components/features/settings/api-keys-manager"; +import { useSelectedOrganizationStore } from "#/stores/selected-organization-store"; // Mock the react-i18next vi.mock("react-i18next", async () => { @@ -37,6 +38,10 @@ vi.mock("#/hooks/query/use-api-keys", () => ({ })); describe("ApiKeysManager", () => { + beforeEach(() => { + useSelectedOrganizationStore.setState({ organizationId: "test-org-id" }); + }); + const renderComponent = () => { const queryClient = new QueryClient(); return render( diff --git a/frontend/__tests__/components/features/sidebar/sidebar.test.tsx b/frontend/__tests__/components/features/sidebar/sidebar.test.tsx index b83abbfeae..cf6ce1ff9b 100644 --- a/frontend/__tests__/components/features/sidebar/sidebar.test.tsx +++ b/frontend/__tests__/components/features/sidebar/sidebar.test.tsx @@ -1,4 +1,4 @@ -import { afterEach, describe, expect, it, vi } from "vitest"; +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; import { renderWithProviders, createAxiosNotFoundErrorObject, @@ -10,6 +10,7 @@ import SettingsService from "#/api/settings-service/settings-service.api"; import OptionService from "#/api/option-service/option-service.api"; import { MOCK_DEFAULT_USER_SETTINGS } from "#/mocks/handlers"; import { WebClientConfig } from "#/api/option-service/option.types"; +import { useSelectedOrganizationStore } from "#/stores/selected-organization-store"; // Helper to create mock config with sensible defaults const createMockConfig = ( @@ -76,6 +77,10 @@ describe("Sidebar", () => { const getSettingsSpy = vi.spyOn(SettingsService, "getSettings"); const getConfigSpy = vi.spyOn(OptionService, "getConfig"); + beforeEach(() => { + useSelectedOrganizationStore.setState({ organizationId: "test-org-id" }); + }); + afterEach(() => { vi.clearAllMocks(); }); diff --git a/frontend/__tests__/components/interactive-chat-box.test.tsx b/frontend/__tests__/components/interactive-chat-box.test.tsx index bafa673731..ecb6623806 100644 --- a/frontend/__tests__/components/interactive-chat-box.test.tsx +++ b/frontend/__tests__/components/interactive-chat-box.test.tsx @@ -1,26 +1,25 @@ import { screen } from "@testing-library/react"; import userEvent from "@testing-library/user-event"; -import { afterEach, beforeAll, describe, expect, it, vi } from "vitest"; +import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest"; import { MemoryRouter } from "react-router"; import { InteractiveChatBox } from "#/components/features/chat/interactive-chat-box"; import { renderWithProviders } from "../../test-utils"; import { AgentState } from "#/types/agent-state"; import { useAgentState } from "#/hooks/use-agent-state"; import { useConversationStore } from "#/stores/conversation-store"; +import { useSelectedOrganizationStore } from "#/stores/selected-organization-store"; vi.mock("#/hooks/use-agent-state", () => ({ useAgentState: vi.fn(), })); // Mock React Router hooks -vi.mock("react-router", async () => { - const actual = await vi.importActual("react-router"); - return { - ...actual, - useNavigate: () => vi.fn(), - useParams: () => ({ conversationId: "test-conversation-id" }), - }; -}); +vi.mock("react-router", async (importOriginal) => ({ + ...(await importOriginal()), + useNavigate: () => vi.fn(), + useParams: () => ({ conversationId: "test-conversation-id" }), + useRevalidator: () => ({ revalidate: vi.fn() }), +})); // Mock the useActiveConversation hook vi.mock("#/hooks/query/use-active-conversation", () => ({ @@ -52,6 +51,10 @@ vi.mock("#/hooks/use-conversation-name-context-menu", () => ({ describe("InteractiveChatBox", () => { const onSubmitMock = vi.fn(); + beforeEach(() => { + useSelectedOrganizationStore.setState({ organizationId: "test-org-id" }); + }); + const mockStores = (agentState: AgentState = AgentState.INIT) => { vi.mocked(useAgentState).mockReturnValue({ curAgentState: agentState, diff --git a/frontend/__tests__/context/ws-client-provider.test.tsx b/frontend/__tests__/context/ws-client-provider.test.tsx index 55a27732fc..3e2ac11f23 100644 --- a/frontend/__tests__/context/ws-client-provider.test.tsx +++ b/frontend/__tests__/context/ws-client-provider.test.tsx @@ -7,6 +7,7 @@ import { WsClientProvider, useWsClient, } from "#/context/ws-client-provider"; +import { useSelectedOrganizationStore } from "#/stores/selected-organization-store"; describe("Propagate error message", () => { it("should do nothing when no message was passed from server", () => { @@ -56,6 +57,7 @@ function TestComponent() { describe("WsClientProvider", () => { beforeEach(() => { vi.clearAllMocks(); + useSelectedOrganizationStore.setState({ organizationId: "test-org-id" }); vi.mock("#/hooks/query/use-active-conversation", () => ({ useActiveConversation: () => { return { data: { diff --git a/frontend/__tests__/conversation-websocket-handler.test.tsx b/frontend/__tests__/conversation-websocket-handler.test.tsx index 393d6f68f0..e3de4572db 100644 --- a/frontend/__tests__/conversation-websocket-handler.test.tsx +++ b/frontend/__tests__/conversation-websocket-handler.test.tsx @@ -40,6 +40,7 @@ import { import { conversationWebSocketTestSetup } from "./helpers/msw-websocket-setup"; import { useEventStore } from "#/stores/use-event-store"; import { isV1Event } from "#/types/v1/type-guards"; +import { useSelectedOrganizationStore } from "#/stores/selected-organization-store"; // Mock useUserConversation to return V1 conversation data vi.mock("#/hooks/query/use-user-conversation", () => ({ @@ -62,6 +63,10 @@ beforeAll(() => { mswServer.listen({ onUnhandledRequest: "bypass" }); }); +beforeEach(() => { + useSelectedOrganizationStore.setState({ organizationId: "test-org-id" }); +}); + afterEach(() => { mswServer.resetHandlers(); // Clean up any React components diff --git a/frontend/__tests__/hooks/mutation/use-save-settings.test.tsx b/frontend/__tests__/hooks/mutation/use-save-settings.test.tsx index d2a7c798c4..e3216beb3c 100644 --- a/frontend/__tests__/hooks/mutation/use-save-settings.test.tsx +++ b/frontend/__tests__/hooks/mutation/use-save-settings.test.tsx @@ -1,10 +1,15 @@ import { renderHook, waitFor } from "@testing-library/react"; -import { describe, expect, it, vi } from "vitest"; +import { describe, expect, it, vi, beforeEach } from "vitest"; import { QueryClient, QueryClientProvider } from "@tanstack/react-query"; import SettingsService from "#/api/settings-service/settings-service.api"; import { useSaveSettings } from "#/hooks/mutation/use-save-settings"; +import { useSelectedOrganizationStore } from "#/stores/selected-organization-store"; describe("useSaveSettings", () => { + beforeEach(() => { + useSelectedOrganizationStore.setState({ organizationId: "test-org-id" }); + }); + it("should send an empty string for llm_api_key if an empty string is passed, otherwise undefined", async () => { const saveSettingsSpy = vi.spyOn(SettingsService, "saveSettings"); const { result } = renderHook(() => useSaveSettings(), { diff --git a/frontend/__tests__/hooks/query/organization-scoped-queries.test.tsx b/frontend/__tests__/hooks/query/organization-scoped-queries.test.tsx new file mode 100644 index 0000000000..a32ea3500a --- /dev/null +++ b/frontend/__tests__/hooks/query/organization-scoped-queries.test.tsx @@ -0,0 +1,225 @@ +import { renderHook, waitFor } from "@testing-library/react"; +import { describe, expect, it, vi, beforeEach } from "vitest"; +import { QueryClient, QueryClientProvider } from "@tanstack/react-query"; +import React from "react"; +import { useSettings } from "#/hooks/query/use-settings"; +import { useGetSecrets } from "#/hooks/query/use-get-secrets"; +import { useApiKeys } from "#/hooks/query/use-api-keys"; +import SettingsService from "#/api/settings-service/settings-service.api"; +import { SecretsService } from "#/api/secrets-service"; +import ApiKeysClient from "#/api/api-keys"; +import { MOCK_DEFAULT_USER_SETTINGS } from "#/mocks/handlers"; +import { useSelectedOrganizationStore } from "#/stores/selected-organization-store"; + +vi.mock("#/hooks/query/use-config", () => ({ + useConfig: () => ({ + data: { app_mode: "saas" }, + }), +})); + +vi.mock("#/hooks/query/use-is-authed", () => ({ + useIsAuthed: () => ({ + data: true, + }), +})); + +vi.mock("#/hooks/use-is-on-intermediate-page", () => ({ + useIsOnIntermediatePage: () => false, +})); + +describe("Organization-scoped query hooks", () => { + let queryClient: QueryClient; + + const createWrapper = () => { + return ({ children }: { children: React.ReactNode }) => ( + {children} + ); + }; + + beforeEach(() => { + queryClient = new QueryClient({ + defaultOptions: { + queries: { + retry: false, + }, + }, + }); + useSelectedOrganizationStore.setState({ organizationId: "org-1" }); + vi.clearAllMocks(); + }); + + describe("useSettings", () => { + it("should include organizationId in query key for proper cache isolation", async () => { + const getSettingsSpy = vi.spyOn(SettingsService, "getSettings"); + getSettingsSpy.mockResolvedValue(MOCK_DEFAULT_USER_SETTINGS); + + const { result } = renderHook(() => useSettings(), { + wrapper: createWrapper(), + }); + + await waitFor(() => expect(result.current.isFetched).toBe(true)); + + // Verify the query was cached with the org-specific key + const cachedData = queryClient.getQueryData(["settings", "org-1"]); + expect(cachedData).toBeDefined(); + + // Verify no data is cached under the old key without org ID + const oldKeyData = queryClient.getQueryData(["settings"]); + expect(oldKeyData).toBeUndefined(); + }); + + it("should refetch when organization changes", async () => { + const getSettingsSpy = vi.spyOn(SettingsService, "getSettings"); + getSettingsSpy.mockResolvedValue({ + ...MOCK_DEFAULT_USER_SETTINGS, + language: "en", + }); + + // First render with org-1 + const { result, rerender } = renderHook(() => useSettings(), { + wrapper: createWrapper(), + }); + + await waitFor(() => expect(result.current.isFetched).toBe(true)); + expect(getSettingsSpy).toHaveBeenCalledTimes(1); + + // Change organization + useSelectedOrganizationStore.setState({ organizationId: "org-2" }); + getSettingsSpy.mockResolvedValue({ + ...MOCK_DEFAULT_USER_SETTINGS, + language: "es", + }); + + // Rerender to pick up the new org ID + rerender(); + + await waitFor(() => { + // Should have fetched again for the new org + expect(getSettingsSpy).toHaveBeenCalledTimes(2); + }); + + // Verify both org caches exist independently + const org1Data = queryClient.getQueryData(["settings", "org-1"]); + const org2Data = queryClient.getQueryData(["settings", "org-2"]); + expect(org1Data).toBeDefined(); + expect(org2Data).toBeDefined(); + }); + }); + + describe("useGetSecrets", () => { + it("should include organizationId in query key for proper cache isolation", async () => { + const getSecretsSpy = vi.spyOn(SecretsService, "getSecrets"); + getSecretsSpy.mockResolvedValue([]); + + const { result } = renderHook(() => useGetSecrets(), { + wrapper: createWrapper(), + }); + + await waitFor(() => expect(result.current.isFetched).toBe(true)); + + // Verify the query was cached with the org-specific key + const cachedData = queryClient.getQueryData(["secrets", "org-1"]); + expect(cachedData).toBeDefined(); + + // Verify no data is cached under the old key without org ID + const oldKeyData = queryClient.getQueryData(["secrets"]); + expect(oldKeyData).toBeUndefined(); + }); + + it("should fetch different data when organization changes", async () => { + const getSecretsSpy = vi.spyOn(SecretsService, "getSecrets"); + + // Mock different secrets for different orgs + getSecretsSpy.mockResolvedValueOnce([ + { name: "SECRET_ORG_1", description: "Org 1 secret" }, + ]); + + const { result, rerender } = renderHook(() => useGetSecrets(), { + wrapper: createWrapper(), + }); + + await waitFor(() => expect(result.current.isFetched).toBe(true)); + expect(result.current.data).toHaveLength(1); + expect(result.current.data?.[0].name).toBe("SECRET_ORG_1"); + + // Change organization + useSelectedOrganizationStore.setState({ organizationId: "org-2" }); + getSecretsSpy.mockResolvedValueOnce([ + { name: "SECRET_ORG_2", description: "Org 2 secret" }, + ]); + + rerender(); + + await waitFor(() => { + expect(result.current.data?.[0]?.name).toBe("SECRET_ORG_2"); + }); + }); + }); + + describe("useApiKeys", () => { + it("should include organizationId in query key for proper cache isolation", async () => { + const getApiKeysSpy = vi.spyOn(ApiKeysClient, "getApiKeys"); + getApiKeysSpy.mockResolvedValue([]); + + const { result } = renderHook(() => useApiKeys(), { + wrapper: createWrapper(), + }); + + await waitFor(() => expect(result.current.isFetched).toBe(true)); + + // Verify the query was cached with the org-specific key + const cachedData = queryClient.getQueryData(["api-keys", "org-1"]); + expect(cachedData).toBeDefined(); + + // Verify no data is cached under the old key without org ID + const oldKeyData = queryClient.getQueryData(["api-keys"]); + expect(oldKeyData).toBeUndefined(); + }); + }); + + describe("Cache isolation between organizations", () => { + it("should maintain separate caches for each organization", async () => { + const getSettingsSpy = vi.spyOn(SettingsService, "getSettings"); + + // Simulate fetching for org-1 + getSettingsSpy.mockResolvedValueOnce({ + ...MOCK_DEFAULT_USER_SETTINGS, + language: "en", + }); + + useSelectedOrganizationStore.setState({ organizationId: "org-1" }); + const { rerender } = renderHook(() => useSettings(), { + wrapper: createWrapper(), + }); + + await waitFor(() => { + expect(queryClient.getQueryData(["settings", "org-1"])).toBeDefined(); + }); + + // Switch to org-2 + getSettingsSpy.mockResolvedValueOnce({ + ...MOCK_DEFAULT_USER_SETTINGS, + language: "fr", + }); + + useSelectedOrganizationStore.setState({ organizationId: "org-2" }); + rerender(); + + await waitFor(() => { + expect(queryClient.getQueryData(["settings", "org-2"])).toBeDefined(); + }); + + // Switch back to org-1 - should use cached data, not refetch + useSelectedOrganizationStore.setState({ organizationId: "org-1" }); + rerender(); + + // org-1 data should still be in cache + const org1Cache = queryClient.getQueryData(["settings", "org-1"]) as any; + expect(org1Cache?.language).toBe("en"); + + // org-2 data should also still be in cache + const org2Cache = queryClient.getQueryData(["settings", "org-2"]) as any; + expect(org2Cache?.language).toBe("fr"); + }); + }); +}); diff --git a/frontend/__tests__/routes/accept-tos.test.tsx b/frontend/__tests__/routes/accept-tos.test.tsx index 7b15081485..2e1e48a1c7 100644 --- a/frontend/__tests__/routes/accept-tos.test.tsx +++ b/frontend/__tests__/routes/accept-tos.test.tsx @@ -5,9 +5,11 @@ import { QueryClient, QueryClientProvider } from "@tanstack/react-query"; import AcceptTOS from "#/routes/accept-tos"; import * as CaptureConsent from "#/utils/handle-capture-consent"; import { openHands } from "#/api/open-hands-axios"; +import { useSelectedOrganizationStore } from "#/stores/selected-organization-store"; // Mock the react-router hooks -vi.mock("react-router", () => ({ +vi.mock("react-router", async (importOriginal) => ({ + ...(await importOriginal()), useNavigate: () => vi.fn(), useSearchParams: () => [ { @@ -19,6 +21,7 @@ vi.mock("react-router", () => ({ }, }, ], + useRevalidator: () => ({ revalidate: vi.fn() }), })); // Mock the axios instance @@ -54,6 +57,7 @@ const createWrapper = () => { describe("AcceptTOS", () => { beforeEach(() => { + useSelectedOrganizationStore.setState({ organizationId: "test-org-id" }); vi.stubGlobal("location", { href: "" }); }); diff --git a/frontend/__tests__/routes/app-settings.test.tsx b/frontend/__tests__/routes/app-settings.test.tsx index 7b42844246..a40d21d8e6 100644 --- a/frontend/__tests__/routes/app-settings.test.tsx +++ b/frontend/__tests__/routes/app-settings.test.tsx @@ -1,5 +1,5 @@ import { render, screen, waitFor } from "@testing-library/react"; -import { afterEach, describe, expect, it, vi } from "vitest"; +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; import { QueryClient, QueryClientProvider } from "@tanstack/react-query"; import userEvent from "@testing-library/user-event"; import AppSettingsScreen, { clientLoader } from "#/routes/app-settings"; @@ -8,6 +8,11 @@ import { MOCK_DEFAULT_USER_SETTINGS } from "#/mocks/handlers"; import { AvailableLanguages } from "#/i18n"; import * as CaptureConsent from "#/utils/handle-capture-consent"; import * as ToastHandlers from "#/utils/custom-toast-handlers"; +import { useSelectedOrganizationStore } from "#/stores/selected-organization-store"; + +beforeEach(() => { + useSelectedOrganizationStore.setState({ organizationId: "test-org-id" }); +}); const renderAppSettingsScreen = () => render(, { diff --git a/frontend/__tests__/utils/check-hardcoded-strings.test.tsx b/frontend/__tests__/utils/check-hardcoded-strings.test.tsx index ff0de34962..7c0a4e592d 100644 --- a/frontend/__tests__/utils/check-hardcoded-strings.test.tsx +++ b/frontend/__tests__/utils/check-hardcoded-strings.test.tsx @@ -1,8 +1,13 @@ import { render, screen } from "@testing-library/react"; -import { test, expect, describe, vi } from "vitest"; +import { test, expect, describe, vi, beforeEach } from "vitest"; import { MemoryRouter } from "react-router"; import { InteractiveChatBox } from "#/components/features/chat/interactive-chat-box"; import { renderWithProviders } from "../../test-utils"; +import { useSelectedOrganizationStore } from "#/stores/selected-organization-store"; + +beforeEach(() => { + useSelectedOrganizationStore.setState({ organizationId: "test-org-id" }); +}); // Mock the translation function vi.mock("react-i18next", async () => { @@ -29,14 +34,12 @@ vi.mock("#/hooks/query/use-active-conversation", () => ({ })); // Mock React Router hooks -vi.mock("react-router", async () => { - const actual = await vi.importActual("react-router"); - return { - ...actual, - useNavigate: () => vi.fn(), - useParams: () => ({ conversationId: "test-conversation-id" }), - }; -}); +vi.mock("react-router", async (importOriginal) => ({ + ...(await importOriginal()), + useNavigate: () => vi.fn(), + useParams: () => ({ conversationId: "test-conversation-id" }), + useRevalidator: () => ({ revalidate: vi.fn() }), +})); // Mock other hooks that might be used by the component vi.mock("#/hooks/use-user-providers", () => ({ diff --git a/frontend/src/components/features/settings/secrets-settings/secret-form.tsx b/frontend/src/components/features/settings/secrets-settings/secret-form.tsx index b67e105f41..9987faced2 100644 --- a/frontend/src/components/features/settings/secrets-settings/secret-form.tsx +++ b/frontend/src/components/features/settings/secrets-settings/secret-form.tsx @@ -10,6 +10,7 @@ import { BrandButton } from "../brand-button"; import { useGetSecrets } from "#/hooks/query/use-get-secrets"; import { GetSecretsResponse } from "#/api/secrets-service.types"; import { OptionalTag } from "../optional-tag"; +import { useSelectedOrganizationId } from "#/context/use-selected-organization"; interface SecretFormProps { mode: "add" | "edit"; @@ -24,6 +25,7 @@ export function SecretForm({ }: SecretFormProps) { const queryClient = useQueryClient(); const { t } = useTranslation(); + const { organizationId } = useSelectedOrganizationId(); const { data: secrets } = useGetSecrets(); const { mutate: createSecret } = useCreateSecret(); @@ -49,7 +51,9 @@ export function SecretForm({ { onSettled: onCancel, onSuccess: async () => { - await queryClient.invalidateQueries({ queryKey: ["secrets"] }); + await queryClient.invalidateQueries({ + queryKey: ["secrets", organizationId], + }); }, }, ); @@ -61,7 +65,7 @@ export function SecretForm({ description?: string, ) => { queryClient.setQueryData( - ["secrets"], + ["secrets", organizationId], (oldSecrets) => { if (!oldSecrets) return []; return oldSecrets.map((secret) => { @@ -79,7 +83,7 @@ export function SecretForm({ }; const revertOptimisticUpdate = () => { - queryClient.invalidateQueries({ queryKey: ["secrets"] }); + queryClient.invalidateQueries({ queryKey: ["secrets", organizationId] }); }; const handleEditSecret = ( diff --git a/frontend/src/hooks/mutation/use-add-git-providers.ts b/frontend/src/hooks/mutation/use-add-git-providers.ts index b7788b88c4..a6b7d85f8d 100644 --- a/frontend/src/hooks/mutation/use-add-git-providers.ts +++ b/frontend/src/hooks/mutation/use-add-git-providers.ts @@ -2,10 +2,12 @@ import { useMutation, useQueryClient } from "@tanstack/react-query"; import { SecretsService } from "#/api/secrets-service"; import { Provider, ProviderToken } from "#/types/settings"; import { useTracking } from "#/hooks/use-tracking"; +import { useSelectedOrganizationId } from "#/context/use-selected-organization"; export const useAddGitProviders = () => { const queryClient = useQueryClient(); const { trackGitProviderConnected } = useTracking(); + const { organizationId } = useSelectedOrganizationId(); return useMutation({ mutationFn: ({ @@ -25,7 +27,9 @@ export const useAddGitProviders = () => { }); } - await queryClient.invalidateQueries({ queryKey: ["settings"] }); + await queryClient.invalidateQueries({ + queryKey: ["settings", organizationId], + }); }, meta: { disableToast: true, diff --git a/frontend/src/hooks/mutation/use-add-mcp-server.ts b/frontend/src/hooks/mutation/use-add-mcp-server.ts index c9aaf4e446..bb90890f0d 100644 --- a/frontend/src/hooks/mutation/use-add-mcp-server.ts +++ b/frontend/src/hooks/mutation/use-add-mcp-server.ts @@ -2,6 +2,7 @@ import { useMutation, useQueryClient } from "@tanstack/react-query"; import { useSettings } from "#/hooks/query/use-settings"; import SettingsService from "#/api/settings-service/settings-service.api"; import { MCPSSEServer, MCPStdioServer, MCPSHTTPServer } from "#/types/settings"; +import { useSelectedOrganizationId } from "#/context/use-selected-organization"; type MCPServerType = "sse" | "stdio" | "shttp"; @@ -19,6 +20,7 @@ interface MCPServerConfig { export function useAddMcpServer() { const queryClient = useQueryClient(); const { data: settings } = useSettings(); + const { organizationId } = useSelectedOrganizationId(); return useMutation({ mutationFn: async (server: MCPServerConfig): Promise => { @@ -64,7 +66,9 @@ export function useAddMcpServer() { }, onSuccess: () => { // Invalidate the settings query to trigger a refetch - queryClient.invalidateQueries({ queryKey: ["settings"] }); + queryClient.invalidateQueries({ + queryKey: ["settings", organizationId], + }); }, }); } diff --git a/frontend/src/hooks/mutation/use-create-api-key.ts b/frontend/src/hooks/mutation/use-create-api-key.ts index fd3c05c975..4ab31b53df 100644 --- a/frontend/src/hooks/mutation/use-create-api-key.ts +++ b/frontend/src/hooks/mutation/use-create-api-key.ts @@ -1,16 +1,20 @@ import { useMutation, useQueryClient } from "@tanstack/react-query"; import ApiKeysClient, { CreateApiKeyResponse } from "#/api/api-keys"; import { API_KEYS_QUERY_KEY } from "#/hooks/query/use-api-keys"; +import { useSelectedOrganizationId } from "#/context/use-selected-organization"; export function useCreateApiKey() { const queryClient = useQueryClient(); + const { organizationId } = useSelectedOrganizationId(); return useMutation({ mutationFn: async (name: string): Promise => ApiKeysClient.createApiKey(name), onSuccess: () => { // Invalidate the API keys query to trigger a refetch - queryClient.invalidateQueries({ queryKey: [API_KEYS_QUERY_KEY] }); + queryClient.invalidateQueries({ + queryKey: [API_KEYS_QUERY_KEY, organizationId], + }); }, }); } diff --git a/frontend/src/hooks/mutation/use-delete-api-key.ts b/frontend/src/hooks/mutation/use-delete-api-key.ts index 4f4b566fab..9932343ce6 100644 --- a/frontend/src/hooks/mutation/use-delete-api-key.ts +++ b/frontend/src/hooks/mutation/use-delete-api-key.ts @@ -1,9 +1,11 @@ import { useMutation, useQueryClient } from "@tanstack/react-query"; import ApiKeysClient from "#/api/api-keys"; import { API_KEYS_QUERY_KEY } from "#/hooks/query/use-api-keys"; +import { useSelectedOrganizationId } from "#/context/use-selected-organization"; export function useDeleteApiKey() { const queryClient = useQueryClient(); + const { organizationId } = useSelectedOrganizationId(); return useMutation({ mutationFn: async (id: string): Promise => { @@ -11,7 +13,9 @@ export function useDeleteApiKey() { }, onSuccess: () => { // Invalidate the API keys query to trigger a refetch - queryClient.invalidateQueries({ queryKey: [API_KEYS_QUERY_KEY] }); + queryClient.invalidateQueries({ + queryKey: [API_KEYS_QUERY_KEY, organizationId], + }); }, }); } diff --git a/frontend/src/hooks/mutation/use-delete-mcp-server.ts b/frontend/src/hooks/mutation/use-delete-mcp-server.ts index 43d1b2a7cc..03cdc7759d 100644 --- a/frontend/src/hooks/mutation/use-delete-mcp-server.ts +++ b/frontend/src/hooks/mutation/use-delete-mcp-server.ts @@ -2,10 +2,12 @@ import { useMutation, useQueryClient } from "@tanstack/react-query"; import { useSettings } from "#/hooks/query/use-settings"; import SettingsService from "#/api/settings-service/settings-service.api"; import { MCPConfig } from "#/types/settings"; +import { useSelectedOrganizationId } from "#/context/use-selected-organization"; export function useDeleteMcpServer() { const queryClient = useQueryClient(); const { data: settings } = useSettings(); + const { organizationId } = useSelectedOrganizationId(); return useMutation({ mutationFn: async (serverId: string): Promise => { @@ -32,7 +34,9 @@ export function useDeleteMcpServer() { }, onSuccess: () => { // Invalidate the settings query to trigger a refetch - queryClient.invalidateQueries({ queryKey: ["settings"] }); + queryClient.invalidateQueries({ + queryKey: ["settings", organizationId], + }); }, }); } diff --git a/frontend/src/hooks/mutation/use-save-settings.ts b/frontend/src/hooks/mutation/use-save-settings.ts index f335fd83ec..9ccfc04ca3 100644 --- a/frontend/src/hooks/mutation/use-save-settings.ts +++ b/frontend/src/hooks/mutation/use-save-settings.ts @@ -4,6 +4,7 @@ import { DEFAULT_SETTINGS } from "#/services/settings"; import SettingsService from "#/api/settings-service/settings-service.api"; import { Settings } from "#/types/settings"; import { useSettings } from "../query/use-settings"; +import { useSelectedOrganizationId } from "#/context/use-selected-organization"; const saveSettingsMutationFn = async (settings: Partial) => { const settingsToSave: Partial = { @@ -30,6 +31,7 @@ export const useSaveSettings = () => { const posthog = usePostHog(); const queryClient = useQueryClient(); const { data: currentSettings } = useSettings(); + const { organizationId } = useSelectedOrganizationId(); return useMutation({ mutationFn: async (settings: Partial) => { @@ -56,7 +58,9 @@ export const useSaveSettings = () => { await saveSettingsMutationFn(newSettings); }, onSuccess: async () => { - await queryClient.invalidateQueries({ queryKey: ["settings"] }); + await queryClient.invalidateQueries({ + queryKey: ["settings", organizationId], + }); }, meta: { disableToast: true, diff --git a/frontend/src/hooks/mutation/use-switch-organization.ts b/frontend/src/hooks/mutation/use-switch-organization.ts index 45fadedaf4..32e0f7b189 100644 --- a/frontend/src/hooks/mutation/use-switch-organization.ts +++ b/frontend/src/hooks/mutation/use-switch-organization.ts @@ -17,10 +17,9 @@ export const useSwitchOrganization = () => { queryClient.invalidateQueries({ queryKey: ["organizations", orgId, "me"], }); - // Update local state + // Update local state - this triggers automatic refetch for all org-scoped queries + // since their query keys include organizationId (e.g., ["settings", orgId], ["secrets", orgId]) setOrganizationId(orgId); - // Invalidate settings for the new org context - queryClient.invalidateQueries({ queryKey: ["settings"] }); // Invalidate conversations to fetch data for the new org context queryClient.invalidateQueries({ queryKey: ["user", "conversations"] }); // Remove all individual conversation queries to clear any stale/null data diff --git a/frontend/src/hooks/mutation/use-update-mcp-server.ts b/frontend/src/hooks/mutation/use-update-mcp-server.ts index 558997b500..af2a9d173f 100644 --- a/frontend/src/hooks/mutation/use-update-mcp-server.ts +++ b/frontend/src/hooks/mutation/use-update-mcp-server.ts @@ -2,6 +2,7 @@ import { useMutation, useQueryClient } from "@tanstack/react-query"; import { useSettings } from "#/hooks/query/use-settings"; import SettingsService from "#/api/settings-service/settings-service.api"; import { MCPSSEServer, MCPStdioServer, MCPSHTTPServer } from "#/types/settings"; +import { useSelectedOrganizationId } from "#/context/use-selected-organization"; type MCPServerType = "sse" | "stdio" | "shttp"; @@ -19,6 +20,7 @@ interface MCPServerConfig { export function useUpdateMcpServer() { const queryClient = useQueryClient(); const { data: settings } = useSettings(); + const { organizationId } = useSelectedOrganizationId(); return useMutation({ mutationFn: async ({ @@ -66,7 +68,9 @@ export function useUpdateMcpServer() { }, onSuccess: () => { // Invalidate the settings query to trigger a refetch - queryClient.invalidateQueries({ queryKey: ["settings"] }); + queryClient.invalidateQueries({ + queryKey: ["settings", organizationId], + }); }, }); } diff --git a/frontend/src/hooks/query/use-api-keys.ts b/frontend/src/hooks/query/use-api-keys.ts index 954e22ad26..2ff496253f 100644 --- a/frontend/src/hooks/query/use-api-keys.ts +++ b/frontend/src/hooks/query/use-api-keys.ts @@ -1,15 +1,17 @@ import { useQuery } from "@tanstack/react-query"; import ApiKeysClient from "#/api/api-keys"; import { useConfig } from "./use-config"; +import { useSelectedOrganizationId } from "#/context/use-selected-organization"; export const API_KEYS_QUERY_KEY = "api-keys"; export function useApiKeys() { const { data: config } = useConfig(); + const { organizationId } = useSelectedOrganizationId(); return useQuery({ - queryKey: [API_KEYS_QUERY_KEY], - enabled: config?.app_mode === "saas", + queryKey: [API_KEYS_QUERY_KEY, organizationId], + enabled: config?.app_mode === "saas" && !!organizationId, queryFn: async () => { const keys = await ApiKeysClient.getApiKeys(); return Array.isArray(keys) ? keys : []; diff --git a/frontend/src/hooks/query/use-get-secrets.ts b/frontend/src/hooks/query/use-get-secrets.ts index e89df3d149..9c402e1c39 100644 --- a/frontend/src/hooks/query/use-get-secrets.ts +++ b/frontend/src/hooks/query/use-get-secrets.ts @@ -2,16 +2,18 @@ import { useQuery } from "@tanstack/react-query"; import { SecretsService } from "#/api/secrets-service"; import { useConfig } from "./use-config"; import { useIsAuthed } from "#/hooks/query/use-is-authed"; +import { useSelectedOrganizationId } from "#/context/use-selected-organization"; export const useGetSecrets = () => { const { data: config } = useConfig(); const { data: isAuthed } = useIsAuthed(); + const { organizationId } = useSelectedOrganizationId(); const isOss = config?.app_mode === "oss"; return useQuery({ - queryKey: ["secrets"], + queryKey: ["secrets", organizationId], queryFn: SecretsService.getSecrets, - enabled: isOss || isAuthed, // Enable regardless of providers + enabled: isOss || (isAuthed && !!organizationId), }); }; diff --git a/frontend/src/hooks/query/use-settings.ts b/frontend/src/hooks/query/use-settings.ts index 6c6d766b69..2c18569081 100644 --- a/frontend/src/hooks/query/use-settings.ts +++ b/frontend/src/hooks/query/use-settings.ts @@ -4,6 +4,8 @@ import { DEFAULT_SETTINGS } from "#/services/settings"; import { useIsOnIntermediatePage } from "#/hooks/use-is-on-intermediate-page"; import { Settings } from "#/types/settings"; import { useIsAuthed } from "./use-is-authed"; +import { useSelectedOrganizationId } from "#/context/use-selected-organization"; +import { useConfig } from "./use-config"; const getSettingsQueryFn = async (): Promise => { const settings = await SettingsService.getSettings(); @@ -27,9 +29,13 @@ const getSettingsQueryFn = async (): Promise => { export const useSettings = () => { const isOnIntermediatePage = useIsOnIntermediatePage(); const { data: userIsAuthenticated } = useIsAuthed(); + const { organizationId } = useSelectedOrganizationId(); + const { data: config } = useConfig(); + + const isOss = config?.app_mode === "oss"; const query = useQuery({ - queryKey: ["settings"], + queryKey: ["settings", organizationId], queryFn: getSettingsQueryFn, // Only retry if the error is not a 404 because we // would want to show the modal immediately if the @@ -38,7 +44,10 @@ export const useSettings = () => { refetchOnWindowFocus: false, staleTime: 1000 * 60 * 5, // 5 minutes gcTime: 1000 * 60 * 15, // 15 minutes - enabled: !isOnIntermediatePage && !!userIsAuthenticated, + enabled: + !isOnIntermediatePage && + !!userIsAuthenticated && + (isOss || !!organizationId), meta: { disableToast: true, }, diff --git a/frontend/src/i18n/declaration.ts b/frontend/src/i18n/declaration.ts index a66552ff3c..fe6f248cfa 100644 --- a/frontend/src/i18n/declaration.ts +++ b/frontend/src/i18n/declaration.ts @@ -1084,6 +1084,14 @@ export enum I18nKey { CONVERSATION$NO_HISTORY_AVAILABLE = "CONVERSATION$NO_HISTORY_AVAILABLE", CONVERSATION$SHARED_CONVERSATION = "CONVERSATION$SHARED_CONVERSATION", CONVERSATION$LINK_COPIED = "CONVERSATION$LINK_COPIED", + ONBOARDING$STEP1_TITLE = "ONBOARDING$STEP1_TITLE", + ONBOARDING$STEP1_SUBTITLE = "ONBOARDING$STEP1_SUBTITLE", + ONBOARDING$SOFTWARE_ENGINEER = "ONBOARDING$SOFTWARE_ENGINEER", + ONBOARDING$ENGINEERING_MANAGER = "ONBOARDING$ENGINEERING_MANAGER", + ONBOARDING$CTO_FOUNDER = "ONBOARDING$CTO_FOUNDER", + ONBOARDING$PRODUCT_OPERATIONS = "ONBOARDING$PRODUCT_OPERATIONS", + ONBOARDING$STUDENT_HOBBYIST = "ONBOARDING$STUDENT_HOBBYIST", + ONBOARDING$OTHER = "ONBOARDING$OTHER", HOOKS_MODAL$TITLE = "HOOKS_MODAL$TITLE", HOOKS_MODAL$WARNING = "HOOKS_MODAL$WARNING", HOOKS_MODAL$MATCHER = "HOOKS_MODAL$MATCHER", diff --git a/frontend/src/routes/secrets-settings.tsx b/frontend/src/routes/secrets-settings.tsx index ec6a9c3a28..48cda5ecbb 100644 --- a/frontend/src/routes/secrets-settings.tsx +++ b/frontend/src/routes/secrets-settings.tsx @@ -13,12 +13,14 @@ import { ConfirmationModal } from "#/components/shared/modals/confirmation-modal import { GetSecretsResponse } from "#/api/secrets-service.types"; import { I18nKey } from "#/i18n/declaration"; import { createPermissionGuard } from "#/utils/org/permission-guard"; +import { useSelectedOrganizationId } from "#/context/use-selected-organization"; export const clientLoader = createPermissionGuard("manage_secrets"); function SecretsSettingsScreen() { const queryClient = useQueryClient(); const { t } = useTranslation(); + const { organizationId } = useSelectedOrganizationId(); const { data: secrets, isLoading: isLoadingSecrets } = useGetSecrets(); const { mutate: deleteSecret } = useDeleteSecret(); @@ -34,7 +36,7 @@ function SecretsSettingsScreen() { const deleteSecretOptimistically = (secret: string) => { queryClient.setQueryData( - ["secrets"], + ["secrets", organizationId], (oldSecrets) => { if (!oldSecrets) return []; return oldSecrets.filter((s) => s.name !== secret); @@ -43,7 +45,7 @@ function SecretsSettingsScreen() { }; const revertOptimisticUpdate = () => { - queryClient.invalidateQueries({ queryKey: ["secrets"] }); + queryClient.invalidateQueries({ queryKey: ["secrets", organizationId] }); }; const handleDeleteSecret = (secret: string) => { diff --git a/frontend/src/routes/settings.tsx b/frontend/src/routes/settings.tsx index cc1c3563c6..9082bea730 100644 --- a/frontend/src/routes/settings.tsx +++ b/frontend/src/routes/settings.tsx @@ -30,7 +30,6 @@ const SAAS_ONLY_PATHS = [ export const clientLoader = async ({ request }: Route.ClientLoaderArgs) => { const url = new URL(request.url); const { pathname } = url; - console.log("clientLoader", { pathname }); // Step 1: Get config first (needed for all checks, no user data required) let config = queryClient.getQueryData(["web-client-config"]); @@ -51,7 +50,6 @@ export const clientLoader = async ({ request }: Route.ClientLoaderArgs) => { // This handles hide_llm_settings, hide_users_page, hide_billing_page, hide_integrations_page if (isSettingsPageHidden(pathname, featureFlags)) { const fallbackPath = getFirstAvailablePath(isSaas, featureFlags); - console.log("fallbackPath", fallbackPath); if (fallbackPath && fallbackPath !== pathname) { return redirect(fallbackPath); } diff --git a/frontend/src/routes/user-settings.tsx b/frontend/src/routes/user-settings.tsx index 3e40d104a8..6fd4372ecf 100644 --- a/frontend/src/routes/user-settings.tsx +++ b/frontend/src/routes/user-settings.tsx @@ -5,6 +5,7 @@ import { useSettings } from "#/hooks/query/use-settings"; import { openHands } from "#/api/open-hands-axios"; import { displaySuccessToast } from "#/utils/custom-toast-handlers"; import { useEmailVerification } from "#/hooks/use-email-verification"; +import { useSelectedOrganizationId } from "#/context/use-selected-organization"; // Email validation regex pattern const EMAIL_REGEX = /^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$/; @@ -113,6 +114,7 @@ function VerificationAlert() { function UserSettingsScreen() { const { t } = useTranslation(); const { data: settings, isLoading, refetch } = useSettings(); + const { organizationId } = useSelectedOrganizationId(); const [email, setEmail] = useState(""); const [originalEmail, setOriginalEmail] = useState(""); const [isSaving, setIsSaving] = useState(false); @@ -144,7 +146,9 @@ function UserSettingsScreen() { // Display toast notification instead of setting state displaySuccessToast(t("SETTINGS$EMAIL_VERIFIED_SUCCESSFULLY")); setTimeout(() => { - queryClient.invalidateQueries({ queryKey: ["settings"] }); + queryClient.invalidateQueries({ + queryKey: ["settings", organizationId], + }); }, 2000); } @@ -162,7 +166,7 @@ function UserSettingsScreen() { pollingIntervalRef.current = null; } }; - }, [settings?.email_verified, refetch, queryClient, t]); + }, [settings?.email_verified, refetch, queryClient, t, organizationId]); const handleEmailChange = (e: React.ChangeEvent) => { const newEmail = e.target.value; @@ -178,7 +182,9 @@ function UserSettingsScreen() { setOriginalEmail(email); // Display toast notification instead of setting state displaySuccessToast(t("SETTINGS$EMAIL_SAVED_SUCCESSFULLY")); - queryClient.invalidateQueries({ queryKey: ["settings"] }); + queryClient.invalidateQueries({ + queryKey: ["settings", organizationId], + }); } catch (error) { // eslint-disable-next-line no-console console.error(t("SETTINGS$FAILED_TO_SAVE_EMAIL"), error); diff --git a/frontend/vitest.setup.ts b/frontend/vitest.setup.ts index c43fa03553..b96506baba 100644 --- a/frontend/vitest.setup.ts +++ b/frontend/vitest.setup.ts @@ -36,6 +36,15 @@ vi.mock("#/hooks/use-is-on-intermediate-page", () => ({ useIsOnIntermediatePage: () => false, })); +// Mock useRevalidator from react-router to allow direct store manipulation +// in tests instead of mocking useSelectedOrganizationId hook +vi.mock("react-router", async (importOriginal) => ({ + ...(await importOriginal()), + useRevalidator: () => ({ + revalidate: vi.fn(), + }), +})); + // Import the Zustand mock to enable automatic store resets vi.mock("zustand"); From e02dbb89749c6a4874ea6ca5cdb78037699f7ecd Mon Sep 17 00:00:00 2001 From: Hiep Le <69354317+hieptl@users.noreply.github.com> Date: Thu, 19 Mar 2026 16:09:37 +0700 Subject: [PATCH 11/28] fix(backend): validate API key org_id during authorization to prevent cross-org access (org project) (#13468) --- enterprise/server/auth/authorization.py | 38 ++- enterprise/server/auth/saas_user_auth.py | 11 +- enterprise/storage/api_key_store.py | 11 +- .../storage/saas_conversation_validator.py | 4 +- enterprise/tests/unit/test_api_key_store.py | 48 +++- enterprise/tests/unit/test_authorization.py | 271 +++++++++++++++++- enterprise/tests/unit/test_saas_user_auth.py | 3 +- 7 files changed, 354 insertions(+), 32 deletions(-) diff --git a/enterprise/server/auth/authorization.py b/enterprise/server/auth/authorization.py index c8d72021c6..203f74f112 100644 --- a/enterprise/server/auth/authorization.py +++ b/enterprise/server/auth/authorization.py @@ -35,7 +35,7 @@ Usage: from enum import Enum from uuid import UUID -from fastapi import Depends, HTTPException, status +from fastapi import Depends, HTTPException, Request, status from storage.org_member_store import OrgMemberStore from storage.role import Role from storage.role_store import RoleStore @@ -214,6 +214,19 @@ def has_permission(user_role: Role, permission: Permission) -> bool: return permission in permissions +async def get_api_key_org_id_from_request(request: Request) -> UUID | None: + """Get the org_id bound to the API key used for authentication. + + Returns None if: + - Not authenticated via API key (cookie auth) + - API key is a legacy key without org binding + """ + user_auth = getattr(request.state, 'user_auth', None) + if user_auth and hasattr(user_auth, 'get_api_key_org_id'): + return user_auth.get_api_key_org_id() + return None + + def require_permission(permission: Permission): """ Factory function that creates a dependency to require a specific permission. @@ -221,8 +234,9 @@ def require_permission(permission: Permission): This creates a FastAPI dependency that: 1. Extracts org_id from the path parameter 2. Gets the authenticated user_id - 3. Checks if the user has the required permission in the organization - 4. Returns the user_id if authorized, raises HTTPException otherwise + 3. Validates API key org binding (if using API key auth) + 4. Checks if the user has the required permission in the organization + 5. Returns the user_id if authorized, raises HTTPException otherwise Usage: @router.get('/{org_id}/settings') @@ -240,6 +254,7 @@ def require_permission(permission: Permission): """ async def permission_checker( + request: Request, org_id: UUID | None = None, user_id: str | None = Depends(get_user_id), ) -> str: @@ -249,6 +264,23 @@ def require_permission(permission: Permission): detail='User not authenticated', ) + # Validate API key organization binding + api_key_org_id = await get_api_key_org_id_from_request(request) + if api_key_org_id is not None and org_id is not None: + if api_key_org_id != org_id: + logger.warning( + 'API key organization mismatch', + extra={ + 'user_id': user_id, + 'api_key_org_id': str(api_key_org_id), + 'target_org_id': str(org_id), + }, + ) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail='API key is not authorized for this organization', + ) + user_role = await get_user_org_role(user_id, org_id) if not user_role: diff --git a/enterprise/server/auth/saas_user_auth.py b/enterprise/server/auth/saas_user_auth.py index 6c8aefea7a..15f50cc40b 100644 --- a/enterprise/server/auth/saas_user_auth.py +++ b/enterprise/server/auth/saas_user_auth.py @@ -61,10 +61,19 @@ class SaasUserAuth(UserAuth): accepted_tos: bool | None = None auth_type: AuthType = AuthType.COOKIE # API key context fields - populated when authenticated via API key - api_key_org_id: UUID | None = None + api_key_org_id: UUID | None = None # Org bound to the API key used for auth api_key_id: int | None = None api_key_name: str | None = None + def get_api_key_org_id(self) -> UUID | None: + """Get the organization ID bound to the API key used for authentication. + + Returns: + The org_id if authenticated via API key with org binding, None otherwise + (cookie auth or legacy API keys without org binding). + """ + return self.api_key_org_id + async def get_user_id(self) -> str | None: return self.user_id diff --git a/enterprise/storage/api_key_store.py b/enterprise/storage/api_key_store.py index ecbb375592..30e5b242e8 100644 --- a/enterprise/storage/api_key_store.py +++ b/enterprise/storage/api_key_store.py @@ -16,10 +16,10 @@ from openhands.core.logger import openhands_logger as logger @dataclass class ApiKeyValidationResult: - """Result of API key validation containing user and org context.""" + """Result of API key validation containing user and organization info.""" user_id: str - org_id: UUID | None + org_id: UUID | None # None for legacy API keys without org binding key_id: int key_name: str | None @@ -195,7 +195,12 @@ class ApiKeyStore: return api_key async def validate_api_key(self, api_key: str) -> ApiKeyValidationResult | None: - """Validate an API key and return the associated user_id and org_id if valid.""" + """Validate an API key and return the associated user_id and org_id if valid. + + Returns: + ApiKeyValidationResult if the key is valid, None otherwise. + The org_id may be None for legacy API keys that weren't bound to an organization. + """ now = datetime.now(UTC) async with a_session_maker() as session: diff --git a/enterprise/storage/saas_conversation_validator.py b/enterprise/storage/saas_conversation_validator.py index 51a5302dfc..6493fdd602 100644 --- a/enterprise/storage/saas_conversation_validator.py +++ b/enterprise/storage/saas_conversation_validator.py @@ -15,13 +15,13 @@ class SaasConversationValidator(ConversationValidator): async def _validate_api_key(self, api_key: str) -> str | None: """ - Validate an API key and return the user_id and github_user_id if valid. + Validate an API key and return the user_id if valid. Args: api_key: The API key to validate Returns: - A tuple of (user_id, github_user_id) if the API key is valid, None otherwise + The user_id if the API key is valid, None otherwise """ try: token_manager = TokenManager() diff --git a/enterprise/tests/unit/test_api_key_store.py b/enterprise/tests/unit/test_api_key_store.py index baffe5893c..c57f63f2ae 100644 --- a/enterprise/tests/unit/test_api_key_store.py +++ b/enterprise/tests/unit/test_api_key_store.py @@ -5,7 +5,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest from sqlalchemy import select from storage.api_key import ApiKey -from storage.api_key_store import ApiKeyStore +from storage.api_key_store import ApiKeyStore, ApiKeyValidationResult @pytest.fixture @@ -110,8 +110,8 @@ async def test_create_api_key( @pytest.mark.asyncio async def test_validate_api_key_valid(api_key_store, async_session_maker): - """Test validating a valid API key.""" - # Setup - create an API key in the database + """Test validating a valid API key returns user_id and org_id.""" + # Arrange user_id = str(uuid.uuid4()) org_id = uuid.uuid4() api_key_value = 'test-api-key' @@ -128,11 +128,12 @@ async def test_validate_api_key_valid(api_key_store, async_session_maker): await session.commit() key_id = key_record.id - # Execute - patch a_session_maker to use test's async session maker + # Act with patch('storage.api_key_store.a_session_maker', async_session_maker): result = await api_key_store.validate_api_key(api_key_value) - # Verify - result is now ApiKeyValidationResult + # Assert + assert isinstance(result, ApiKeyValidationResult) assert result is not None assert result.user_id == user_id assert result.org_id == org_id @@ -202,7 +203,7 @@ async def test_validate_api_key_valid_timezone_naive( api_key_store, async_session_maker ): """Test validating a valid API key with timezone-naive datetime from database.""" - # Setup - create a valid API key with timezone-naive datetime (future date) + # Arrange user_id = str(uuid.uuid4()) org_id = uuid.uuid4() api_key_value = 'test-valid-naive-key' @@ -219,13 +220,44 @@ async def test_validate_api_key_valid_timezone_naive( session.add(key_record) await session.commit() - # Execute - patch a_session_maker to use test's async session maker + # Act with patch('storage.api_key_store.a_session_maker', async_session_maker): result = await api_key_store.validate_api_key(api_key_value) - # Verify - result is now ApiKeyValidationResult + # Assert + assert isinstance(result, ApiKeyValidationResult) + assert result.user_id == user_id + assert result.org_id == org_id + + +@pytest.mark.asyncio +async def test_validate_api_key_legacy_without_org_id( + api_key_store, async_session_maker +): + """Test validating a legacy API key without org_id returns None for org_id.""" + # Arrange + user_id = str(uuid.uuid4()) + api_key_value = 'test-legacy-key-no-org' + + async with async_session_maker() as session: + key_record = ApiKey( + key=api_key_value, + user_id=user_id, + org_id=None, # Legacy key without org binding + name='Legacy Key', + ) + session.add(key_record) + await session.commit() + + # Act + with patch('storage.api_key_store.a_session_maker', async_session_maker): + result = await api_key_store.validate_api_key(api_key_value) + + # Assert + assert isinstance(result, ApiKeyValidationResult) assert result is not None assert result.user_id == user_id + assert result.org_id is None @pytest.mark.asyncio diff --git a/enterprise/tests/unit/test_authorization.py b/enterprise/tests/unit/test_authorization.py index c751e6454a..a4051b4824 100644 --- a/enterprise/tests/unit/test_authorization.py +++ b/enterprise/tests/unit/test_authorization.py @@ -13,6 +13,7 @@ from server.auth.authorization import ( ROLE_PERMISSIONS, Permission, RoleName, + get_api_key_org_id_from_request, get_role_permissions, get_user_org_role, has_permission, @@ -444,6 +445,15 @@ class TestGetUserOrgRole: # ============================================================================= +def _create_mock_request(api_key_org_id=None): + """Helper to create a mock request with optional api_key_org_id.""" + mock_request = MagicMock() + mock_user_auth = MagicMock() + mock_user_auth.get_api_key_org_id.return_value = api_key_org_id + mock_request.state.user_auth = mock_user_auth + return mock_request + + class TestRequirePermission: """Tests for require_permission dependency factory.""" @@ -456,6 +466,7 @@ class TestRequirePermission: """ user_id = str(uuid4()) org_id = uuid4() + mock_request = _create_mock_request() mock_role = MagicMock() mock_role.name = 'admin' @@ -465,7 +476,9 @@ class TestRequirePermission: AsyncMock(return_value=mock_role), ): permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS) - result = await permission_checker(org_id=org_id, user_id=user_id) + result = await permission_checker( + request=mock_request, org_id=org_id, user_id=user_id + ) assert result == user_id @pytest.mark.asyncio @@ -476,10 +489,11 @@ class TestRequirePermission: THEN: 401 Unauthorized is raised """ org_id = uuid4() + mock_request = _create_mock_request() permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS) with pytest.raises(HTTPException) as exc_info: - await permission_checker(org_id=org_id, user_id=None) + await permission_checker(request=mock_request, org_id=org_id, user_id=None) assert exc_info.value.status_code == 401 assert 'not authenticated' in exc_info.value.detail.lower() @@ -493,6 +507,7 @@ class TestRequirePermission: """ user_id = str(uuid4()) org_id = uuid4() + mock_request = _create_mock_request() with patch( 'server.auth.authorization.get_user_org_role', @@ -500,7 +515,9 @@ class TestRequirePermission: ): permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS) with pytest.raises(HTTPException) as exc_info: - await permission_checker(org_id=org_id, user_id=user_id) + await permission_checker( + request=mock_request, org_id=org_id, user_id=user_id + ) assert exc_info.value.status_code == 403 assert 'not a member' in exc_info.value.detail.lower() @@ -514,6 +531,7 @@ class TestRequirePermission: """ user_id = str(uuid4()) org_id = uuid4() + mock_request = _create_mock_request() mock_role = MagicMock() mock_role.name = 'member' @@ -524,7 +542,9 @@ class TestRequirePermission: ): permission_checker = require_permission(Permission.DELETE_ORGANIZATION) with pytest.raises(HTTPException) as exc_info: - await permission_checker(org_id=org_id, user_id=user_id) + await permission_checker( + request=mock_request, org_id=org_id, user_id=user_id + ) assert exc_info.value.status_code == 403 assert 'delete_organization' in exc_info.value.detail.lower() @@ -538,6 +558,7 @@ class TestRequirePermission: """ user_id = str(uuid4()) org_id = uuid4() + mock_request = _create_mock_request() mock_role = MagicMock() mock_role.name = 'owner' @@ -547,7 +568,9 @@ class TestRequirePermission: AsyncMock(return_value=mock_role), ): permission_checker = require_permission(Permission.DELETE_ORGANIZATION) - result = await permission_checker(org_id=org_id, user_id=user_id) + result = await permission_checker( + request=mock_request, org_id=org_id, user_id=user_id + ) assert result == user_id @pytest.mark.asyncio @@ -559,6 +582,7 @@ class TestRequirePermission: """ user_id = str(uuid4()) org_id = uuid4() + mock_request = _create_mock_request() mock_role = MagicMock() mock_role.name = 'admin' @@ -569,7 +593,9 @@ class TestRequirePermission: ): permission_checker = require_permission(Permission.DELETE_ORGANIZATION) with pytest.raises(HTTPException) as exc_info: - await permission_checker(org_id=org_id, user_id=user_id) + await permission_checker( + request=mock_request, org_id=org_id, user_id=user_id + ) assert exc_info.value.status_code == 403 @@ -582,6 +608,7 @@ class TestRequirePermission: """ user_id = str(uuid4()) org_id = uuid4() + mock_request = _create_mock_request() mock_role = MagicMock() mock_role.name = 'member' @@ -595,7 +622,9 @@ class TestRequirePermission: ): permission_checker = require_permission(Permission.DELETE_ORGANIZATION) with pytest.raises(HTTPException): - await permission_checker(org_id=org_id, user_id=user_id) + await permission_checker( + request=mock_request, org_id=org_id, user_id=user_id + ) mock_logger.warning.assert_called() call_args = mock_logger.warning.call_args @@ -611,6 +640,7 @@ class TestRequirePermission: THEN: User ID is returned """ user_id = str(uuid4()) + mock_request = _create_mock_request() mock_role = MagicMock() mock_role.name = 'admin' @@ -620,7 +650,9 @@ class TestRequirePermission: AsyncMock(return_value=mock_role), ) as mock_get_role: permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS) - result = await permission_checker(org_id=None, user_id=user_id) + result = await permission_checker( + request=mock_request, org_id=None, user_id=user_id + ) assert result == user_id mock_get_role.assert_called_once_with(user_id, None) @@ -632,6 +664,7 @@ class TestRequirePermission: THEN: HTTPException with 403 status is raised """ user_id = str(uuid4()) + mock_request = _create_mock_request() with patch( 'server.auth.authorization.get_user_org_role', @@ -639,7 +672,9 @@ class TestRequirePermission: ): permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS) with pytest.raises(HTTPException) as exc_info: - await permission_checker(org_id=None, user_id=user_id) + await permission_checker( + request=mock_request, org_id=None, user_id=user_id + ) assert exc_info.value.status_code == 403 assert 'not a member' in exc_info.value.detail @@ -662,6 +697,7 @@ class TestPermissionScenarios: """ user_id = str(uuid4()) org_id = uuid4() + mock_request = _create_mock_request() mock_role = MagicMock() mock_role.name = 'member' @@ -671,7 +707,9 @@ class TestPermissionScenarios: AsyncMock(return_value=mock_role), ): permission_checker = require_permission(Permission.MANAGE_SECRETS) - result = await permission_checker(org_id=org_id, user_id=user_id) + result = await permission_checker( + request=mock_request, org_id=org_id, user_id=user_id + ) assert result == user_id @pytest.mark.asyncio @@ -683,6 +721,7 @@ class TestPermissionScenarios: """ user_id = str(uuid4()) org_id = uuid4() + mock_request = _create_mock_request() mock_role = MagicMock() mock_role.name = 'member' @@ -695,7 +734,9 @@ class TestPermissionScenarios: Permission.INVITE_USER_TO_ORGANIZATION ) with pytest.raises(HTTPException) as exc_info: - await permission_checker(org_id=org_id, user_id=user_id) + await permission_checker( + request=mock_request, org_id=org_id, user_id=user_id + ) assert exc_info.value.status_code == 403 @@ -708,6 +749,7 @@ class TestPermissionScenarios: """ user_id = str(uuid4()) org_id = uuid4() + mock_request = _create_mock_request() mock_role = MagicMock() mock_role.name = 'admin' @@ -719,7 +761,9 @@ class TestPermissionScenarios: permission_checker = require_permission( Permission.INVITE_USER_TO_ORGANIZATION ) - result = await permission_checker(org_id=org_id, user_id=user_id) + result = await permission_checker( + request=mock_request, org_id=org_id, user_id=user_id + ) assert result == user_id @pytest.mark.asyncio @@ -731,6 +775,7 @@ class TestPermissionScenarios: """ user_id = str(uuid4()) org_id = uuid4() + mock_request = _create_mock_request() mock_role = MagicMock() mock_role.name = 'admin' @@ -741,7 +786,9 @@ class TestPermissionScenarios: ): permission_checker = require_permission(Permission.CHANGE_USER_ROLE_OWNER) with pytest.raises(HTTPException) as exc_info: - await permission_checker(org_id=org_id, user_id=user_id) + await permission_checker( + request=mock_request, org_id=org_id, user_id=user_id + ) assert exc_info.value.status_code == 403 @@ -754,6 +801,7 @@ class TestPermissionScenarios: """ user_id = str(uuid4()) org_id = uuid4() + mock_request = _create_mock_request() mock_role = MagicMock() mock_role.name = 'owner' @@ -763,5 +811,200 @@ class TestPermissionScenarios: AsyncMock(return_value=mock_role), ): permission_checker = require_permission(Permission.CHANGE_USER_ROLE_OWNER) - result = await permission_checker(org_id=org_id, user_id=user_id) + result = await permission_checker( + request=mock_request, org_id=org_id, user_id=user_id + ) assert result == user_id + + +# ============================================================================= +# Tests for API key organization validation +# ============================================================================= + + +class TestApiKeyOrgValidation: + """Tests for API key organization binding validation in require_permission.""" + + @pytest.mark.asyncio + async def test_allows_access_when_api_key_org_matches_target_org(self): + """ + GIVEN: API key with org_id that matches the target org_id in the request + WHEN: Permission checker is called + THEN: User ID is returned (access allowed) + """ + # Arrange + user_id = str(uuid4()) + org_id = uuid4() + mock_request = _create_mock_request(api_key_org_id=org_id) + + mock_role = MagicMock() + mock_role.name = 'admin' + + # Act & Assert + with patch( + 'server.auth.authorization.get_user_org_role', + AsyncMock(return_value=mock_role), + ): + permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS) + result = await permission_checker( + request=mock_request, org_id=org_id, user_id=user_id + ) + assert result == user_id + + @pytest.mark.asyncio + async def test_denies_access_when_api_key_org_mismatches_target_org(self): + """ + GIVEN: API key created for Org A, but user tries to access Org B + WHEN: Permission checker is called + THEN: 403 Forbidden is raised with org mismatch message + """ + # Arrange + user_id = str(uuid4()) + api_key_org_id = uuid4() # Org A - where API key was created + target_org_id = uuid4() # Org B - where user is trying to access + mock_request = _create_mock_request(api_key_org_id=api_key_org_id) + + # Act & Assert + permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS) + with pytest.raises(HTTPException) as exc_info: + await permission_checker( + request=mock_request, org_id=target_org_id, user_id=user_id + ) + + assert exc_info.value.status_code == 403 + assert ( + 'API key is not authorized for this organization' in exc_info.value.detail + ) + + @pytest.mark.asyncio + async def test_allows_access_for_legacy_api_key_without_org_binding(self): + """ + GIVEN: Legacy API key without org_id binding (org_id is None) + WHEN: Permission checker is called + THEN: Falls through to normal permission check (backward compatible) + """ + # Arrange + user_id = str(uuid4()) + org_id = uuid4() + mock_request = _create_mock_request(api_key_org_id=None) + + mock_role = MagicMock() + mock_role.name = 'admin' + + # Act & Assert + with patch( + 'server.auth.authorization.get_user_org_role', + AsyncMock(return_value=mock_role), + ): + permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS) + result = await permission_checker( + request=mock_request, org_id=org_id, user_id=user_id + ) + assert result == user_id + + @pytest.mark.asyncio + async def test_allows_access_for_cookie_auth_without_api_key_org_id(self): + """ + GIVEN: Cookie-based authentication (no api_key_org_id in user_auth) + WHEN: Permission checker is called + THEN: Falls through to normal permission check + """ + # Arrange + user_id = str(uuid4()) + org_id = uuid4() + mock_request = _create_mock_request(api_key_org_id=None) + + mock_role = MagicMock() + mock_role.name = 'admin' + + # Act & Assert + with patch( + 'server.auth.authorization.get_user_org_role', + AsyncMock(return_value=mock_role), + ): + permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS) + result = await permission_checker( + request=mock_request, org_id=org_id, user_id=user_id + ) + assert result == user_id + + @pytest.mark.asyncio + async def test_logs_warning_on_api_key_org_mismatch(self): + """ + GIVEN: API key org_id doesn't match target org_id + WHEN: Permission checker is called + THEN: Warning is logged with org mismatch details + """ + # Arrange + user_id = str(uuid4()) + api_key_org_id = uuid4() + target_org_id = uuid4() + mock_request = _create_mock_request(api_key_org_id=api_key_org_id) + + # Act & Assert + with patch('server.auth.authorization.logger') as mock_logger: + permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS) + with pytest.raises(HTTPException): + await permission_checker( + request=mock_request, org_id=target_org_id, user_id=user_id + ) + + mock_logger.warning.assert_called() + call_args = mock_logger.warning.call_args + assert call_args[1]['extra']['user_id'] == user_id + assert call_args[1]['extra']['api_key_org_id'] == str(api_key_org_id) + assert call_args[1]['extra']['target_org_id'] == str(target_org_id) + + +class TestGetApiKeyOrgIdFromRequest: + """Tests for get_api_key_org_id_from_request helper function.""" + + @pytest.mark.asyncio + async def test_returns_org_id_when_user_auth_has_api_key_org_id(self): + """ + GIVEN: Request with user_auth that has api_key_org_id + WHEN: get_api_key_org_id_from_request is called + THEN: Returns the api_key_org_id + """ + # Arrange + org_id = uuid4() + mock_request = _create_mock_request(api_key_org_id=org_id) + + # Act + result = await get_api_key_org_id_from_request(mock_request) + + # Assert + assert result == org_id + + @pytest.mark.asyncio + async def test_returns_none_when_user_auth_has_no_api_key_org_id(self): + """ + GIVEN: Request with user_auth that has no api_key_org_id (cookie auth) + WHEN: get_api_key_org_id_from_request is called + THEN: Returns None + """ + # Arrange + mock_request = _create_mock_request(api_key_org_id=None) + + # Act + result = await get_api_key_org_id_from_request(mock_request) + + # Assert + assert result is None + + @pytest.mark.asyncio + async def test_returns_none_when_no_user_auth_in_request(self): + """ + GIVEN: Request without user_auth in state + WHEN: get_api_key_org_id_from_request is called + THEN: Returns None + """ + # Arrange + mock_request = MagicMock() + mock_request.state.user_auth = None + + # Act + result = await get_api_key_org_id_from_request(mock_request) + + # Assert + assert result is None diff --git a/enterprise/tests/unit/test_saas_user_auth.py b/enterprise/tests/unit/test_saas_user_auth.py index 2fb1b68445..726702f310 100644 --- a/enterprise/tests/unit/test_saas_user_auth.py +++ b/enterprise/tests/unit/test_saas_user_auth.py @@ -459,7 +459,8 @@ async def test_get_instance_no_auth(mock_request): @pytest.mark.asyncio async def test_saas_user_auth_from_bearer_success(): - """Test successful authentication from bearer token.""" + """Test successful authentication from bearer token sets user_id and api_key_org_id.""" + # Arrange mock_request = MagicMock() mock_request.headers = {'Authorization': 'Bearer test_api_key'} From 3a9f00aa3714cd67398aa8ac7ccfe1b966134073 Mon Sep 17 00:00:00 2001 From: Engel Nyst Date: Thu, 19 Mar 2026 14:46:56 +0100 Subject: [PATCH 12/28] Keep VSCode accessible when agent errors (#13492) Co-authored-by: openhands --- .../hooks/use-runtime-is-ready.test.tsx | 64 ++++++++++++++++++ frontend/__tests__/routes/vscode-tab.test.tsx | 65 +++++++++++++++++++ .../vscode-tooltip-content.tsx | 5 +- .../src/hooks/query/use-unified-vscode-url.ts | 2 +- frontend/src/hooks/use-runtime-is-ready.ts | 20 ++++-- frontend/src/routes/vscode-tab.tsx | 8 +-- frontend/src/types/agent-state.tsx | 5 +- 7 files changed, 156 insertions(+), 13 deletions(-) create mode 100644 frontend/__tests__/hooks/use-runtime-is-ready.test.tsx create mode 100644 frontend/__tests__/routes/vscode-tab.test.tsx diff --git a/frontend/__tests__/hooks/use-runtime-is-ready.test.tsx b/frontend/__tests__/hooks/use-runtime-is-ready.test.tsx new file mode 100644 index 0000000000..86bc8d82f7 --- /dev/null +++ b/frontend/__tests__/hooks/use-runtime-is-ready.test.tsx @@ -0,0 +1,64 @@ +import { renderHook } from "@testing-library/react"; +import { beforeEach, describe, expect, it, vi } from "vitest"; +import type { Conversation } from "#/api/open-hands.types"; +import { useRuntimeIsReady } from "#/hooks/use-runtime-is-ready"; +import { useAgentState } from "#/hooks/use-agent-state"; +import { useActiveConversation } from "#/hooks/query/use-active-conversation"; +import { AgentState } from "#/types/agent-state"; + +vi.mock("#/hooks/use-agent-state"); +vi.mock("#/hooks/query/use-active-conversation"); + +function asMockReturnValue(value: Partial): T { + return value as T; +} + +function makeConversation(): Conversation { + return { + conversation_id: "conv-123", + title: "Test Conversation", + selected_repository: null, + selected_branch: null, + git_provider: null, + last_updated_at: new Date().toISOString(), + created_at: new Date().toISOString(), + status: "RUNNING", + runtime_status: null, + url: null, + session_api_key: null, + }; +} + +describe("useRuntimeIsReady", () => { + beforeEach(() => { + vi.clearAllMocks(); + + vi.mocked(useActiveConversation).mockReturnValue( + asMockReturnValue>({ + data: makeConversation(), + }), + ); + }); + + it("treats agent errors as not ready by default", () => { + vi.mocked(useAgentState).mockReturnValue({ + curAgentState: AgentState.ERROR, + }); + + const { result } = renderHook(() => useRuntimeIsReady()); + + expect(result.current).toBe(false); + }); + + it("allows runtime-backed tabs to stay ready when the agent errors", () => { + vi.mocked(useAgentState).mockReturnValue({ + curAgentState: AgentState.ERROR, + }); + + const { result } = renderHook(() => + useRuntimeIsReady({ allowAgentError: true }), + ); + + expect(result.current).toBe(true); + }); +}); diff --git a/frontend/__tests__/routes/vscode-tab.test.tsx b/frontend/__tests__/routes/vscode-tab.test.tsx new file mode 100644 index 0000000000..8c84678603 --- /dev/null +++ b/frontend/__tests__/routes/vscode-tab.test.tsx @@ -0,0 +1,65 @@ +import { screen } from "@testing-library/react"; +import { beforeEach, describe, expect, it, vi } from "vitest"; +import { renderWithProviders } from "test-utils"; +import VSCodeTab from "#/routes/vscode-tab"; +import { useUnifiedVSCodeUrl } from "#/hooks/query/use-unified-vscode-url"; +import { useAgentState } from "#/hooks/use-agent-state"; +import { AgentState } from "#/types/agent-state"; + +vi.mock("#/hooks/query/use-unified-vscode-url"); +vi.mock("#/hooks/use-agent-state"); +vi.mock("#/utils/feature-flags", () => ({ + VSCODE_IN_NEW_TAB: () => false, +})); + +function mockVSCodeUrlHook( + value: Partial>, +) { + vi.mocked(useUnifiedVSCodeUrl).mockReturnValue({ + data: { url: "http://localhost:3000/vscode", error: null }, + error: null, + isLoading: false, + isError: false, + isSuccess: true, + status: "success", + refetch: vi.fn(), + ...value, + } as ReturnType); +} + +describe("VSCodeTab", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("keeps VSCode accessible when the agent is in an error state", () => { + vi.mocked(useAgentState).mockReturnValue({ + curAgentState: AgentState.ERROR, + }); + mockVSCodeUrlHook({}); + + renderWithProviders(); + + expect( + screen.queryByText("DIFF_VIEWER$WAITING_FOR_RUNTIME"), + ).not.toBeInTheDocument(); + expect(screen.getByTitle("VSCODE$TITLE")).toHaveAttribute( + "src", + "http://localhost:3000/vscode", + ); + }); + + it("still waits while the runtime is starting", () => { + vi.mocked(useAgentState).mockReturnValue({ + curAgentState: AgentState.LOADING, + }); + mockVSCodeUrlHook({}); + + renderWithProviders(); + + expect( + screen.getByText("DIFF_VIEWER$WAITING_FOR_RUNTIME"), + ).toBeInTheDocument(); + expect(screen.queryByTitle("VSCODE$TITLE")).not.toBeInTheDocument(); + }); +}); diff --git a/frontend/src/components/features/conversation/conversation-tabs/vscode-tooltip-content.tsx b/frontend/src/components/features/conversation/conversation-tabs/vscode-tooltip-content.tsx index 07509ab19d..08e7879fad 100644 --- a/frontend/src/components/features/conversation/conversation-tabs/vscode-tooltip-content.tsx +++ b/frontend/src/components/features/conversation/conversation-tabs/vscode-tooltip-content.tsx @@ -1,14 +1,15 @@ import { FaExternalLinkAlt } from "react-icons/fa"; import { useTranslation } from "react-i18next"; import { I18nKey } from "#/i18n/declaration"; -import { RUNTIME_INACTIVE_STATES } from "#/types/agent-state"; import { useAgentState } from "#/hooks/use-agent-state"; import { useUnifiedVSCodeUrl } from "#/hooks/query/use-unified-vscode-url"; +import { RUNTIME_STARTING_STATES } from "#/types/agent-state"; export function VSCodeTooltipContent() { const { curAgentState } = useAgentState(); const { t } = useTranslation(); const { data, refetch } = useUnifiedVSCodeUrl(); + const isRuntimeStarting = RUNTIME_STARTING_STATES.includes(curAgentState); const handleVSCodeClick = async (e: React.MouseEvent) => { e.preventDefault(); @@ -29,7 +30,7 @@ export function VSCodeTooltipContent() { return (
{t(I18nKey.COMMON$CODE)} - {!RUNTIME_INACTIVE_STATES.includes(curAgentState) ? ( + {!isRuntimeStarting ? ( { const { t } = useTranslation(); const { conversationId } = useConversationId(); const { data: conversation } = useActiveConversation(); - const runtimeIsReady = useRuntimeIsReady(); + const runtimeIsReady = useRuntimeIsReady({ allowAgentError: true }); const isV1Conversation = conversation?.conversation_version === "V1"; diff --git a/frontend/src/hooks/use-runtime-is-ready.ts b/frontend/src/hooks/use-runtime-is-ready.ts index 914b3624c4..e09af98872 100644 --- a/frontend/src/hooks/use-runtime-is-ready.ts +++ b/frontend/src/hooks/use-runtime-is-ready.ts @@ -1,18 +1,30 @@ -import { RUNTIME_INACTIVE_STATES } from "#/types/agent-state"; -import { useActiveConversation } from "./query/use-active-conversation"; import { useAgentState } from "#/hooks/use-agent-state"; +import { + RUNTIME_INACTIVE_STATES, + RUNTIME_STARTING_STATES, +} from "#/types/agent-state"; +import { useActiveConversation } from "./query/use-active-conversation"; + +interface UseRuntimeIsReadyOptions { + allowAgentError?: boolean; +} /** * Hook to determine if the runtime is ready for operations * * @returns boolean indicating if the runtime is ready */ -export const useRuntimeIsReady = (): boolean => { +export const useRuntimeIsReady = ({ + allowAgentError = false, +}: UseRuntimeIsReadyOptions = {}): boolean => { const { data: conversation } = useActiveConversation(); const { curAgentState } = useAgentState(); + const inactiveStates = allowAgentError + ? RUNTIME_STARTING_STATES + : RUNTIME_INACTIVE_STATES; return ( conversation?.status === "RUNNING" && - !RUNTIME_INACTIVE_STATES.includes(curAgentState) + !inactiveStates.includes(curAgentState) ); }; diff --git a/frontend/src/routes/vscode-tab.tsx b/frontend/src/routes/vscode-tab.tsx index e1bb2e8fe4..fe60a52dac 100644 --- a/frontend/src/routes/vscode-tab.tsx +++ b/frontend/src/routes/vscode-tab.tsx @@ -1,17 +1,17 @@ import React, { useState, useEffect } from "react"; import { useTranslation } from "react-i18next"; import { I18nKey } from "#/i18n/declaration"; -import { RUNTIME_INACTIVE_STATES } from "#/types/agent-state"; import { useUnifiedVSCodeUrl } from "#/hooks/query/use-unified-vscode-url"; +import { useAgentState } from "#/hooks/use-agent-state"; +import { RUNTIME_STARTING_STATES } from "#/types/agent-state"; import { VSCODE_IN_NEW_TAB } from "#/utils/feature-flags"; import { WaitingForRuntimeMessage } from "#/components/features/chat/waiting-for-runtime-message"; -import { useAgentState } from "#/hooks/use-agent-state"; function VSCodeTab() { const { t } = useTranslation(); const { data, isLoading, error } = useUnifiedVSCodeUrl(); const { curAgentState } = useAgentState(); - const isRuntimeInactive = RUNTIME_INACTIVE_STATES.includes(curAgentState); + const isRuntimeStarting = RUNTIME_STARTING_STATES.includes(curAgentState); const iframeRef = React.useRef(null); const [isCrossProtocol, setIsCrossProtocol] = useState(false); const [iframeError, setIframeError] = useState(null); @@ -39,7 +39,7 @@ function VSCodeTab() { } }; - if (isRuntimeInactive) { + if (isRuntimeStarting) { return ; } diff --git a/frontend/src/types/agent-state.tsx b/frontend/src/types/agent-state.tsx index 9309ef5e41..ab05ea89df 100644 --- a/frontend/src/types/agent-state.tsx +++ b/frontend/src/types/agent-state.tsx @@ -14,9 +14,10 @@ export enum AgentState { USER_REJECTED = "user_rejected", } +export const RUNTIME_STARTING_STATES = [AgentState.INIT, AgentState.LOADING]; + export const RUNTIME_INACTIVE_STATES = [ - AgentState.INIT, - AgentState.LOADING, + ...RUNTIME_STARTING_STATES, // Removed AgentState.STOPPED to allow tabs to remain visible when agent is stopped AgentState.ERROR, ]; From 0ec962e96be6281a4603e851536bcd7953613f67 Mon Sep 17 00:00:00 2001 From: MkDev11 <94194147+MkDev11@users.noreply.github.com> Date: Thu, 19 Mar 2026 07:13:58 -0700 Subject: [PATCH 13/28] feat: add /clear endpoint for V1 conversations (#12786) Co-authored-by: mkdev11 Co-authored-by: openhands Co-authored-by: tofarr Co-authored-by: hieptl --- .../components/interactive-chat-box.test.tsx | 30 ++ .../use-new-conversation-command.test.tsx | 299 ++++++++++++++++++ .../__tests__/hooks/use-websocket.test.ts | 4 +- .../v1-conversation-service.api.ts | 4 + .../features/chat/chat-interface.tsx | 32 +- .../chat/components/chat-input-container.tsx | 3 + .../chat/components/chat-input-field.tsx | 10 +- .../chat/components/chat-input-row.tsx | 3 + .../features/chat/custom-chat-input.tsx | 3 + .../features/chat/interactive-chat-box.tsx | 8 +- .../mutation/use-new-conversation-command.ts | 115 +++++++ .../query/use-unified-get-git-changes.ts | 1 + frontend/src/i18n/declaration.ts | 8 + frontend/src/i18n/translation.json | 136 ++++++++ frontend/src/utils/websocket-url.ts | 13 + .../app_conversation_info_service.py | 8 + .../app_conversation_service.py | 26 +- .../live_status_app_conversation_service.py | 13 +- .../sql_app_conversation_info_service.py | 8 + openhands/app_server/config.py | 21 ++ .../sandbox/docker_sandbox_service.py | 17 +- openhands/server/middleware.py | 21 +- .../server/routes/manage_conversations.py | 18 +- .../test_sql_app_conversation_info_service.py | 48 +++ .../server/data_models/test_conversation.py | 12 +- tests/unit/server/test_middleware.py | 79 +++-- 26 files changed, 884 insertions(+), 56 deletions(-) create mode 100644 frontend/__tests__/hooks/mutation/use-new-conversation-command.test.tsx create mode 100644 frontend/src/hooks/mutation/use-new-conversation-command.ts diff --git a/frontend/__tests__/components/interactive-chat-box.test.tsx b/frontend/__tests__/components/interactive-chat-box.test.tsx index ecb6623806..884217facb 100644 --- a/frontend/__tests__/components/interactive-chat-box.test.tsx +++ b/frontend/__tests__/components/interactive-chat-box.test.tsx @@ -216,6 +216,36 @@ describe("InteractiveChatBox", () => { expect(onSubmitMock).not.toHaveBeenCalled(); }); + it("should lock the text input field when disabled prop is true (isNewConversationPending)", () => { + mockStores(AgentState.INIT); + + renderInteractiveChatBox({ + onSubmit: onSubmitMock, + disabled: true, + }); + + const chatInput = screen.getByTestId("chat-input"); + // When disabled=true, the text field should not be editable + expect(chatInput).toHaveAttribute("contenteditable", "false"); + // Should show visual disabled state + expect(chatInput.className).toContain("cursor-not-allowed"); + expect(chatInput.className).toContain("opacity-50"); + }); + + it("should keep the text input field editable when disabled prop is false", () => { + mockStores(AgentState.INIT); + + renderInteractiveChatBox({ + onSubmit: onSubmitMock, + disabled: false, + }); + + const chatInput = screen.getByTestId("chat-input"); + expect(chatInput).toHaveAttribute("contenteditable", "true"); + expect(chatInput.className).not.toContain("cursor-not-allowed"); + expect(chatInput.className).not.toContain("opacity-50"); + }); + it("should handle image upload and message submission correctly", async () => { const user = userEvent.setup(); const onSubmit = vi.fn(); diff --git a/frontend/__tests__/hooks/mutation/use-new-conversation-command.test.tsx b/frontend/__tests__/hooks/mutation/use-new-conversation-command.test.tsx new file mode 100644 index 0000000000..07f110ff17 --- /dev/null +++ b/frontend/__tests__/hooks/mutation/use-new-conversation-command.test.tsx @@ -0,0 +1,299 @@ +import { renderHook, waitFor } from "@testing-library/react"; +import { QueryClient, QueryClientProvider } from "@tanstack/react-query"; +import { describe, expect, it, vi, beforeEach } from "vitest"; +import V1ConversationService from "#/api/conversation-service/v1-conversation-service.api"; +import { useNewConversationCommand } from "#/hooks/mutation/use-new-conversation-command"; + +const mockNavigate = vi.fn(); + +vi.mock("react-router", () => ({ + useNavigate: () => mockNavigate, + useParams: () => ({ conversationId: "conv-123" }), +})); + +vi.mock("react-i18next", () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})); + +const { mockToast } = vi.hoisted(() => { + const mockToast = Object.assign(vi.fn(), { + loading: vi.fn(), + dismiss: vi.fn(), + }); + return { mockToast }; +}); + +vi.mock("react-hot-toast", () => ({ + default: mockToast, +})); + +vi.mock("#/utils/custom-toast-handlers", () => ({ + displaySuccessToast: vi.fn(), + displayErrorToast: vi.fn(), + TOAST_OPTIONS: { position: "top-right" }, +})); + +const mockConversation = { + conversation_id: "conv-123", + sandbox_id: "sandbox-456", + title: "Test Conversation", + selected_repository: null, + selected_branch: null, + git_provider: null, + last_updated_at: new Date().toISOString(), + created_at: new Date().toISOString(), + status: "RUNNING" as const, + runtime_status: null, + url: null, + session_api_key: null, + conversation_version: "V1" as const, +}; + +vi.mock("#/hooks/query/use-active-conversation", () => ({ + useActiveConversation: () => ({ + data: mockConversation, + }), +})); + +function makeStartTask(overrides: Record = {}) { + return { + id: "task-789", + created_by_user_id: null, + status: "READY", + detail: null, + app_conversation_id: "new-conv-999", + sandbox_id: "sandbox-456", + agent_server_url: "http://agent-server.local", + request: { + sandbox_id: null, + initial_message: null, + processors: [], + llm_model: null, + selected_repository: null, + selected_branch: null, + git_provider: null, + suggested_task: null, + title: null, + trigger: null, + pr_number: [], + parent_conversation_id: null, + agent_type: "default", + }, + created_at: new Date().toISOString(), + updated_at: new Date().toISOString(), + ...overrides, + }; +} + +describe("useNewConversationCommand", () => { + let queryClient: QueryClient; + + beforeEach(() => { + vi.clearAllMocks(); + queryClient = new QueryClient({ + defaultOptions: { mutations: { retry: false } }, + }); + // Mock batchGetAppConversations to return V1 data with llm_model + vi.spyOn( + V1ConversationService, + "batchGetAppConversations", + ).mockResolvedValue([ + { + id: "conv-123", + title: "Test Conversation", + sandbox_id: "sandbox-456", + sandbox_status: "RUNNING", + execution_status: "IDLE", + conversation_url: null, + session_api_key: null, + selected_repository: null, + selected_branch: null, + git_provider: null, + trigger: null, + pr_number: [], + llm_model: "gpt-4o", + metrics: null, + created_at: new Date().toISOString(), + updated_at: new Date().toISOString(), + sub_conversation_ids: [], + public: false, + } as never, + ]); + }); + + const wrapper = ({ children }: { children: React.ReactNode }) => ( + {children} + ); + + it("calls createConversation with sandbox_id and navigates on success", async () => { + const readyTask = makeStartTask(); + const createSpy = vi + .spyOn(V1ConversationService, "createConversation") + .mockResolvedValue(readyTask as never); + const getStartTaskSpy = vi + .spyOn(V1ConversationService, "getStartTask") + .mockResolvedValue(readyTask as never); + + const { result } = renderHook(() => useNewConversationCommand(), { wrapper }); + + await result.current.mutateAsync(); + + await waitFor(() => { + expect(createSpy).toHaveBeenCalledWith( + undefined, + undefined, + undefined, + undefined, + undefined, + undefined, + undefined, + undefined, + undefined, + "sandbox-456", + "gpt-4o", + ); + expect(getStartTaskSpy).toHaveBeenCalledWith("task-789"); + expect(mockNavigate).toHaveBeenCalledWith( + "/conversations/new-conv-999", + ); + }); + }); + + it("polls getStartTask until status is READY", async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }); + + const workingTask = makeStartTask({ + status: "WORKING", + app_conversation_id: null, + }); + const readyTask = makeStartTask({ status: "READY" }); + + vi.spyOn(V1ConversationService, "createConversation").mockResolvedValue( + workingTask as never, + ); + const getStartTaskSpy = vi + .spyOn(V1ConversationService, "getStartTask") + .mockResolvedValueOnce(workingTask as never) + .mockResolvedValueOnce(readyTask as never); + + const { result } = renderHook(() => useNewConversationCommand(), { wrapper }); + + const mutatePromise = result.current.mutateAsync(); + + await vi.advanceTimersByTimeAsync(2000); + await mutatePromise; + + await waitFor(() => { + expect(getStartTaskSpy).toHaveBeenCalledTimes(2); + expect(mockNavigate).toHaveBeenCalledWith( + "/conversations/new-conv-999", + ); + }); + + vi.useRealTimers(); + }); + + it("throws when task status is ERROR", async () => { + const errorTask = makeStartTask({ + status: "ERROR", + detail: "Sandbox crashed", + app_conversation_id: null, + }); + + vi.spyOn(V1ConversationService, "createConversation").mockResolvedValue( + errorTask as never, + ); + vi.spyOn(V1ConversationService, "getStartTask").mockResolvedValue( + errorTask as never, + ); + + const { result } = renderHook(() => useNewConversationCommand(), { wrapper }); + + await expect(result.current.mutateAsync()).rejects.toThrow( + "Sandbox crashed", + ); + }); + + it("invalidates conversation list queries on success", async () => { + const readyTask = makeStartTask(); + + vi.spyOn(V1ConversationService, "createConversation").mockResolvedValue( + readyTask as never, + ); + vi.spyOn(V1ConversationService, "getStartTask").mockResolvedValue( + readyTask as never, + ); + + const invalidateSpy = vi.spyOn(queryClient, "invalidateQueries"); + + const { result } = renderHook(() => useNewConversationCommand(), { wrapper }); + + await result.current.mutateAsync(); + + await waitFor(() => { + expect(invalidateSpy).toHaveBeenCalledWith({ + queryKey: ["user", "conversations"], + }); + expect(invalidateSpy).toHaveBeenCalledWith({ + queryKey: ["v1-batch-get-app-conversations"], + }); + }); + }); + + it("creates a standalone conversation (not a sub-conversation) so it appears in the list", async () => { + const readyTask = makeStartTask(); + const createSpy = vi + .spyOn(V1ConversationService, "createConversation") + .mockResolvedValue(readyTask as never); + vi.spyOn(V1ConversationService, "getStartTask").mockResolvedValue( + readyTask as never, + ); + + const { result } = renderHook(() => useNewConversationCommand(), { wrapper }); + + await result.current.mutateAsync(); + + await waitFor(() => { + // parent_conversation_id should be undefined so the new conversation + // is NOT a sub-conversation and will appear in the conversation list. + expect(createSpy).toHaveBeenCalledWith( + undefined, // selectedRepository (null from mock) + undefined, // git_provider (null from mock) + undefined, // initialUserMsg + undefined, // selected_branch (null from mock) + undefined, // conversationInstructions + undefined, // suggestedTask + undefined, // trigger + undefined, // parent_conversation_id is NOT set + undefined, // agent_type + "sandbox-456", // sandbox_id IS set to reuse the sandbox + "gpt-4o", // llm_model IS inherited from the original conversation + ); + }); + }); + + it("shows a loading toast immediately and dismisses it on success", async () => { + const readyTask = makeStartTask(); + + vi.spyOn(V1ConversationService, "createConversation").mockResolvedValue( + readyTask as never, + ); + vi.spyOn(V1ConversationService, "getStartTask").mockResolvedValue( + readyTask as never, + ); + + const { result } = renderHook(() => useNewConversationCommand(), { wrapper }); + + await result.current.mutateAsync(); + + await waitFor(() => { + expect(mockToast.loading).toHaveBeenCalledWith( + "CONVERSATION$CLEARING", + expect.objectContaining({ id: "clear-conversation" }), + ); + expect(mockToast.dismiss).toHaveBeenCalledWith("clear-conversation"); + }); + }); +}); diff --git a/frontend/__tests__/hooks/use-websocket.test.ts b/frontend/__tests__/hooks/use-websocket.test.ts index 7d42507a87..d00db6f856 100644 --- a/frontend/__tests__/hooks/use-websocket.test.ts +++ b/frontend/__tests__/hooks/use-websocket.test.ts @@ -205,7 +205,9 @@ describe("useWebSocket", () => { expect(result.current.isConnected).toBe(true); }); - expect(onCloseSpy).not.toHaveBeenCalled(); + // Reset spy after connection is established to ignore any spurious + // close events fired by the MSW mock during the handshake. + onCloseSpy.mockClear(); // Unmount to trigger close unmount(); diff --git a/frontend/src/api/conversation-service/v1-conversation-service.api.ts b/frontend/src/api/conversation-service/v1-conversation-service.api.ts index a0e99abe0f..bcdad50077 100644 --- a/frontend/src/api/conversation-service/v1-conversation-service.api.ts +++ b/frontend/src/api/conversation-service/v1-conversation-service.api.ts @@ -68,6 +68,8 @@ class V1ConversationService { trigger?: ConversationTrigger, parent_conversation_id?: string, agent_type?: "default" | "plan", + sandbox_id?: string, + llm_model?: string, ): Promise { const body: V1AppConversationStartRequest = { selected_repository: selectedRepository, @@ -78,6 +80,8 @@ class V1ConversationService { trigger, parent_conversation_id: parent_conversation_id || null, agent_type, + sandbox_id: sandbox_id || null, + llm_model: llm_model || null, }; // suggested_task implies the backend will construct the initial_message diff --git a/frontend/src/components/features/chat/chat-interface.tsx b/frontend/src/components/features/chat/chat-interface.tsx index 43218149ae..3b11a2fbc4 100644 --- a/frontend/src/components/features/chat/chat-interface.tsx +++ b/frontend/src/components/features/chat/chat-interface.tsx @@ -38,6 +38,8 @@ import { useTaskPolling } from "#/hooks/query/use-task-polling"; import { useConversationWebSocket } from "#/contexts/conversation-websocket-context"; import ChatStatusIndicator from "./chat-status-indicator"; import { getStatusColor, getStatusText } from "#/utils/utils"; +import { useNewConversationCommand } from "#/hooks/mutation/use-new-conversation-command"; +import { I18nKey } from "#/i18n/declaration"; function getEntryPoint( hasRepository: boolean | null, @@ -80,6 +82,10 @@ export function ChatInterface() { setHitBottom, } = useScrollToBottom(scrollRef); const { data: config } = useConfig(); + const { + mutate: newConversationCommand, + isPending: isNewConversationPending, + } = useNewConversationCommand(); const { curAgentState } = useAgentState(); const { handleBuildPlanClick } = useHandleBuildPlanClick(); @@ -146,6 +152,27 @@ export function ChatInterface() { originalImages: File[], originalFiles: File[], ) => { + // Handle /new command for V1 conversations + if (content.trim() === "/new") { + if (!isV1Conversation) { + displayErrorToast(t(I18nKey.CONVERSATION$CLEAR_V1_ONLY)); + return; + } + if (!params.conversationId) { + displayErrorToast(t(I18nKey.CONVERSATION$CLEAR_NO_ID)); + return; + } + if (totalEvents === 0) { + displayErrorToast(t(I18nKey.CONVERSATION$CLEAR_EMPTY)); + return; + } + if (isNewConversationPending) { + return; + } + newConversationCommand(); + return; + } + // Create mutable copies of the arrays const images = [...originalImages]; const files = [...originalFiles]; @@ -338,7 +365,10 @@ export function ChatInterface() { /> )} - +
{config?.app_mode !== "saas" && !isV1Conversation && ( diff --git a/frontend/src/components/features/chat/components/chat-input-container.tsx b/frontend/src/components/features/chat/components/chat-input-container.tsx index ef67069de5..ebb3924458 100644 --- a/frontend/src/components/features/chat/components/chat-input-container.tsx +++ b/frontend/src/components/features/chat/components/chat-input-container.tsx @@ -12,6 +12,7 @@ interface ChatInputContainerProps { chatContainerRef: React.RefObject; isDragOver: boolean; disabled: boolean; + isNewConversationPending?: boolean; showButton: boolean; buttonClassName: string; chatInputRef: React.RefObject; @@ -36,6 +37,7 @@ export function ChatInputContainer({ chatContainerRef, isDragOver, disabled, + isNewConversationPending = false, showButton, buttonClassName, chatInputRef, @@ -89,6 +91,7 @@ export function ChatInputContainer({ ; + disabled?: boolean; onInput: () => void; onPaste: (e: React.ClipboardEvent) => void; onKeyDown: (e: React.KeyboardEvent) => void; @@ -14,6 +16,7 @@ interface ChatInputFieldProps { export function ChatInputField({ chatInputRef, + disabled = false, onInput, onPaste, onKeyDown, @@ -36,8 +39,11 @@ export function ChatInputField({
; disabled: boolean; + isNewConversationPending?: boolean; showButton: boolean; buttonClassName: string; handleFileIconClick: (isDisabled: boolean) => void; @@ -21,6 +22,7 @@ interface ChatInputRowProps { export function ChatInputRow({ chatInputRef, disabled, + isNewConversationPending = false, showButton, buttonClassName, handleFileIconClick, @@ -41,6 +43,7 @@ export function ChatInputRow({ void; @@ -25,6 +26,7 @@ export interface CustomChatInputProps { export function CustomChatInput({ disabled = false, + isNewConversationPending = false, showButton = true, conversationStatus = null, onSubmit, @@ -147,6 +149,7 @@ export function CustomChatInput({ chatContainerRef={chatContainerRef} isDragOver={isDragOver} disabled={isDisabled} + isNewConversationPending={isNewConversationPending} showButton={showButton} buttonClassName={buttonClassName} chatInputRef={chatInputRef} diff --git a/frontend/src/components/features/chat/interactive-chat-box.tsx b/frontend/src/components/features/chat/interactive-chat-box.tsx index 74818d1d6c..cf46336887 100644 --- a/frontend/src/components/features/chat/interactive-chat-box.tsx +++ b/frontend/src/components/features/chat/interactive-chat-box.tsx @@ -13,9 +13,13 @@ import { isTaskPolling } from "#/utils/utils"; interface InteractiveChatBoxProps { onSubmit: (message: string, images: File[], files: File[]) => void; + disabled?: boolean; } -export function InteractiveChatBox({ onSubmit }: InteractiveChatBoxProps) { +export function InteractiveChatBox({ + onSubmit, + disabled = false, +}: InteractiveChatBoxProps) { const { images, files, @@ -145,6 +149,7 @@ export function InteractiveChatBox({ onSubmit }: InteractiveChatBoxProps) { // Allow users to submit messages during LOADING state - they will be // queued server-side and delivered when the conversation becomes ready const isDisabled = + disabled || curAgentState === AgentState.AWAITING_USER_CONFIRMATION || isTaskPolling(subConversationTaskStatus); @@ -152,6 +157,7 @@ export function InteractiveChatBox({ onSubmit }: InteractiveChatBoxProps) {
{ + const queryClient = useQueryClient(); + const navigate = useNavigate(); + const { t } = useTranslation(); + const { data: conversation } = useActiveConversation(); + + const mutation = useMutation({ + mutationFn: async () => { + if (!conversation?.conversation_id || !conversation.sandbox_id) { + throw new Error("No active conversation or sandbox"); + } + + // Fetch V1 conversation data to get llm_model (not available in legacy type) + const v1Conversations = + await V1ConversationService.batchGetAppConversations([ + conversation.conversation_id, + ]); + const llmModel = v1Conversations?.[0]?.llm_model; + + // Start a new conversation reusing the existing sandbox directly. + // We pass sandbox_id instead of parent_conversation_id so that the + // new conversation is NOT marked as a sub-conversation and will + // appear in the conversation list. + const startTask = await V1ConversationService.createConversation( + conversation.selected_repository ?? undefined, // selectedRepository + conversation.git_provider ?? undefined, // git_provider + undefined, // initialUserMsg + conversation.selected_branch ?? undefined, // selected_branch + undefined, // conversationInstructions + undefined, // suggestedTask + undefined, // trigger + undefined, // parent_conversation_id + undefined, // agent_type + conversation.sandbox_id ?? undefined, // sandbox_id - reuse the same sandbox + llmModel ?? undefined, // llm_model - preserve the LLM model + ); + + // Poll for the task to complete and get the new conversation ID + let task = await V1ConversationService.getStartTask(startTask.id); + const maxAttempts = 60; // 60 seconds timeout + let attempts = 0; + + /* eslint-disable no-await-in-loop */ + while ( + task && + !["READY", "ERROR"].includes(task.status) && + attempts < maxAttempts + ) { + // eslint-disable-next-line no-await-in-loop + await new Promise((resolve) => { + setTimeout(resolve, 1000); + }); + task = await V1ConversationService.getStartTask(startTask.id); + attempts += 1; + } + + if (!task || task.status !== "READY" || !task.app_conversation_id) { + throw new Error( + task?.detail || "Failed to create new conversation in sandbox", + ); + } + + return { + newConversationId: task.app_conversation_id, + oldConversationId: conversation.conversation_id, + }; + }, + onMutate: () => { + toast.loading(t(I18nKey.CONVERSATION$CLEARING), { + ...TOAST_OPTIONS, + id: "clear-conversation", + }); + }, + onSuccess: (data) => { + toast.dismiss("clear-conversation"); + displaySuccessToast(t(I18nKey.CONVERSATION$CLEAR_SUCCESS)); + navigate(`/conversations/${data.newConversationId}`); + + // Refresh the sidebar to show the new conversation. + queryClient.invalidateQueries({ + queryKey: ["user", "conversations"], + }); + queryClient.invalidateQueries({ + queryKey: ["v1-batch-get-app-conversations"], + }); + }, + onError: (error) => { + toast.dismiss("clear-conversation"); + let clearError = t(I18nKey.CONVERSATION$CLEAR_UNKNOWN_ERROR); + if (error instanceof Error) { + clearError = error.message; + } else if (typeof error === "string") { + clearError = error; + } + displayErrorToast( + t(I18nKey.CONVERSATION$CLEAR_FAILED, { error: clearError }), + ); + }, + }); + + return mutation; +}; diff --git a/frontend/src/hooks/query/use-unified-get-git-changes.ts b/frontend/src/hooks/query/use-unified-get-git-changes.ts index 801b1a067a..a1de3852f9 100644 --- a/frontend/src/hooks/query/use-unified-get-git-changes.ts +++ b/frontend/src/hooks/query/use-unified-get-git-changes.ts @@ -57,6 +57,7 @@ export const useUnifiedGetGitChanges = () => { retry: false, staleTime: 1000 * 60 * 5, // 5 minutes gcTime: 1000 * 60 * 15, // 15 minutes + refetchOnMount: "always", // Always refetch when mounting (e.g. navigating between conversations that share a sandbox) enabled: runtimeIsReady && !!conversationId, meta: { disableToast: true, diff --git a/frontend/src/i18n/declaration.ts b/frontend/src/i18n/declaration.ts index fe6f248cfa..9b355ae432 100644 --- a/frontend/src/i18n/declaration.ts +++ b/frontend/src/i18n/declaration.ts @@ -1151,6 +1151,14 @@ export enum I18nKey { ONBOARDING$NEXT_BUTTON = "ONBOARDING$NEXT_BUTTON", ONBOARDING$BACK_BUTTON = "ONBOARDING$BACK_BUTTON", ONBOARDING$FINISH_BUTTON = "ONBOARDING$FINISH_BUTTON", + CONVERSATION$CLEAR_V1_ONLY = "CONVERSATION$CLEAR_V1_ONLY", + CONVERSATION$CLEAR_EMPTY = "CONVERSATION$CLEAR_EMPTY", + CONVERSATION$CLEAR_NO_ID = "CONVERSATION$CLEAR_NO_ID", + CONVERSATION$CLEAR_NO_NEW_ID = "CONVERSATION$CLEAR_NO_NEW_ID", + CONVERSATION$CLEAR_UNKNOWN_ERROR = "CONVERSATION$CLEAR_UNKNOWN_ERROR", + CONVERSATION$CLEAR_FAILED = "CONVERSATION$CLEAR_FAILED", + CONVERSATION$CLEAR_SUCCESS = "CONVERSATION$CLEAR_SUCCESS", + CONVERSATION$CLEARING = "CONVERSATION$CLEARING", CTA$ENTERPRISE = "CTA$ENTERPRISE", CTA$ENTERPRISE_DEPLOY = "CTA$ENTERPRISE_DEPLOY", CTA$FEATURE_ON_PREMISES = "CTA$FEATURE_ON_PREMISES", diff --git a/frontend/src/i18n/translation.json b/frontend/src/i18n/translation.json index f43c33b0d2..57b89cd193 100644 --- a/frontend/src/i18n/translation.json +++ b/frontend/src/i18n/translation.json @@ -19569,6 +19569,142 @@ "uk": "Завершити", "ca": "Finalitza" }, + "CONVERSATION$CLEAR_V1_ONLY": { + "en": "The /new command is only available for V1 conversations", + "ja": "/newコマンドはV1会話でのみ使用できます", + "zh-CN": "/new 命令仅适用于 V1 对话", + "zh-TW": "/new 指令僅適用於 V1 對話", + "ko-KR": "/new 명령은 V1 대화에서만 사용할 수 있습니다", + "no": "/new-kommandoen er kun tilgjengelig for V1-samtaler", + "it": "Il comando /new è disponibile solo per le conversazioni V1", + "pt": "O comando /new está disponível apenas para conversas V1", + "es": "El comando /new solo está disponible para conversaciones V1", + "ar": "أمر /new متاح فقط لمحادثات V1", + "fr": "La commande /new n'est disponible que pour les conversations V1", + "tr": "/new komutu yalnızca V1 konuşmalarında kullanılabilir", + "de": "Der /new-Befehl ist nur für V1-Konversationen verfügbar", + "uk": "Команда /new доступна лише для розмов V1", + "ca": "L'ordre /new només està disponible per a converses V1" + }, + "CONVERSATION$CLEAR_EMPTY": { + "en": "Nothing to clear. This conversation has no messages yet.", + "ja": "クリアするものがありません。この会話にはまだメッセージがありません。", + "zh-CN": "没有可清除的内容。此对话尚无消息。", + "zh-TW": "沒有可清除的內容。此對話尚無訊息。", + "ko-KR": "지울 내용이 없습니다. 이 대화에는 아직 메시지가 없습니다.", + "no": "Ingenting å tømme. Denne samtalen har ingen meldinger ennå.", + "it": "Niente da cancellare. Questa conversazione non ha ancora messaggi.", + "pt": "Nada para limpar. Esta conversa ainda não tem mensagens.", + "es": "Nada que borrar. Esta conversación aún no tiene mensajes.", + "ar": "لا يوجد شيء لمسحه. لا تحتوي هذه المحادثة على رسائل بعد.", + "fr": "Rien à effacer. Cette conversation n'a pas encore de messages.", + "tr": "Temizlenecek bir şey yok. Bu konuşmada henüz mesaj yok.", + "de": "Nichts zu löschen. Diese Konversation hat noch keine Nachrichten.", + "uk": "Нічого очищувати. Ця розмова ще не має повідомлень.", + "ca": "No hi ha res a esborrar. Aquesta conversa encara no té missatges." + }, + "CONVERSATION$CLEAR_NO_ID": { + "en": "No conversation ID found", + "ja": "会話IDが見つかりません", + "zh-CN": "未找到对话 ID", + "zh-TW": "找不到對話 ID", + "ko-KR": "대화 ID를 찾을 수 없습니다", + "no": "Ingen samtale-ID funnet", + "it": "Nessun ID conversazione trovato", + "pt": "Nenhum ID de conversa encontrado", + "es": "No se encontró el ID de conversación", + "ar": "لم يتم العثور على معرف المحادثة", + "fr": "Aucun identifiant de conversation trouvé", + "tr": "Konuşma kimliği bulunamadı", + "de": "Keine Konversations-ID gefunden", + "uk": "Ідентифікатор розмови не знайдено", + "ca": "No s'ha trobat l'identificador de la conversa" + }, + "CONVERSATION$CLEAR_NO_NEW_ID": { + "en": "Server did not return a new conversation ID", + "ja": "サーバーが新しい会話IDを返しませんでした", + "zh-CN": "服务器未返回新的对话 ID", + "zh-TW": "伺服器未返回新的對話 ID", + "ko-KR": "서버가 새 대화 ID를 반환하지 않았습니다", + "no": "Serveren returnerte ikke en ny samtale-ID", + "it": "Il server non ha restituito un nuovo ID conversazione", + "pt": "O servidor não retornou um novo ID de conversa", + "es": "El servidor no devolvió un nuevo ID de conversación", + "ar": "لم يقم الخادم بإرجاع معرف محادثة جديد", + "fr": "Le serveur n'a pas renvoyé un nouvel identifiant de conversation", + "tr": "Sunucu yeni bir konuşma kimliği döndürmedi", + "de": "Der Server hat keine neue Konversations-ID zurückgegeben", + "uk": "Сервер не повернув новий ідентифікатор розмови", + "ca": "El servidor no ha retornat un nou identificador de conversa" + }, + "CONVERSATION$CLEAR_UNKNOWN_ERROR": { + "en": "Unknown error", + "ja": "不明なエラー", + "zh-CN": "未知错误", + "zh-TW": "未知錯誤", + "ko-KR": "알 수 없는 오류", + "no": "Ukjent feil", + "it": "Errore sconosciuto", + "pt": "Erro desconhecido", + "es": "Error desconocido", + "ar": "خطأ غير معروف", + "fr": "Erreur inconnue", + "tr": "Bilinmeyen hata", + "de": "Unbekannter Fehler", + "uk": "Невідома помилка", + "ca": "Error desconegut" + }, + "CONVERSATION$CLEAR_FAILED": { + "en": "Failed to start new conversation: {{error}}", + "ja": "新しい会話の開始に失敗しました: {{error}}", + "zh-CN": "启动新对话失败: {{error}}", + "zh-TW": "啟動新對話失敗: {{error}}", + "ko-KR": "새 대화 시작 실패: {{error}}", + "no": "Kunne ikke starte ny samtale: {{error}}", + "it": "Impossibile avviare una nuova conversazione: {{error}}", + "pt": "Falha ao iniciar nova conversa: {{error}}", + "es": "Error al iniciar nueva conversación: {{error}}", + "ar": "فشل في بدء محادثة جديدة: {{error}}", + "fr": "Échec du démarrage d'une nouvelle conversation : {{error}}", + "tr": "Yeni konuşma başlatılamadı: {{error}}", + "de": "Neue Konversation konnte nicht gestartet werden: {{error}}", + "uk": "Не вдалося розпочати нову розмову: {{error}}", + "ca": "No s'ha pogut iniciar una nova conversa: {{error}}" + }, + "CONVERSATION$CLEAR_SUCCESS": { + "en": "Starting a new conversation in the same sandbox. These conversations share the same runtime.", + "ja": "同じサンドボックスで新しい会話を開始します。これらの会話は同じランタイムを共有します。", + "zh-CN": "正在同一沙箱中开始新对话。这些对话共享同一运行时。", + "zh-TW": "正在同一沙盒中開始新對話。這些對話共享同一執行環境。", + "ko-KR": "같은 샌드박스에서 새 대화를 시작합니다. 이 대화들은 같은 런타임을 공유합니다.", + "no": "Starter ny samtale i samme sandbox. Disse samtalene deler samme kjøretid.", + "it": "Avvio nuova conversazione nello stesso sandbox. Queste conversazioni condividono lo stesso runtime.", + "pt": "Iniciando nova conversa no mesmo sandbox. Essas conversas compartilham o mesmo runtime.", + "es": "Iniciando nueva conversación en el mismo sandbox. Estas conversaciones comparten el mismo runtime.", + "ar": "بدء محادثة جديدة في نفس صندوق الحماية. هذه المحادثات تشارك نفس وقت التشغيل.", + "fr": "Démarrage d'une nouvelle conversation dans le même bac à sable. Ces conversations partagent le même environnement d'exécution.", + "tr": "Aynı korumalı alanda yeni konuşma başlatılıyor. Bu konuşmalar aynı çalışma ortamını paylaşır.", + "de": "Starte neue Konversation in derselben Sandbox. Diese Konversationen teilen dieselbe Laufzeitumgebung.", + "uk": "Починаю нову розмову в тому самому захищеному середовищі. Ці розмови використовують одне середовище виконання.", + "ca": "S'està iniciant una nova conversa al mateix entorn aïllat. Aquestes converses comparteixen el mateix entorn d'execució." + }, + "CONVERSATION$CLEARING": { + "en": "Creating new conversation...", + "ja": "新しい会話を作成中...", + "zh-CN": "正在创建新对话...", + "zh-TW": "正在建立新對話...", + "ko-KR": "새 대화를 만드는 중...", + "no": "Oppretter ny samtale...", + "it": "Creazione nuova conversazione...", + "pt": "Criando nova conversa...", + "es": "Creando nueva conversación...", + "ar": "جارٍ إنشاء محادثة جديدة...", + "fr": "Création d'une nouvelle conversation...", + "tr": "Yeni konuşma oluşturuluyor...", + "de": "Neue Konversation wird erstellt...", + "uk": "Створення нової розмови...", + "ca": "S'està creant una nova conversa..." + }, "CTA$ENTERPRISE": { "en": "Enterprise", "ja": "エンタープライズ", diff --git a/frontend/src/utils/websocket-url.ts b/frontend/src/utils/websocket-url.ts index 0e72c24dc8..787032b2c9 100644 --- a/frontend/src/utils/websocket-url.ts +++ b/frontend/src/utils/websocket-url.ts @@ -9,6 +9,19 @@ export function extractBaseHost( if (conversationUrl && !conversationUrl.startsWith("/")) { try { const url = new URL(conversationUrl); + // If the conversation URL points to localhost but we're accessing from external, + // use the browser's hostname with the conversation URL's port + const urlHostname = url.hostname; + const browserHostname = + window.location.hostname ?? window.location.host?.split(":")[0]; + if ( + browserHostname && + (urlHostname === "localhost" || urlHostname === "127.0.0.1") && + browserHostname !== "localhost" && + browserHostname !== "127.0.0.1" + ) { + return `${browserHostname}:${url.port}`; + } return url.host; // e.g., "localhost:3000" } catch { return window.location.host; diff --git a/openhands/app_server/app_conversation/app_conversation_info_service.py b/openhands/app_server/app_conversation/app_conversation_info_service.py index bb83ab5801..e14f1dbf6e 100644 --- a/openhands/app_server/app_conversation/app_conversation_info_service.py +++ b/openhands/app_server/app_conversation/app_conversation_info_service.py @@ -84,6 +84,14 @@ class AppConversationInfoService(ABC): List of sub-conversation IDs """ + @abstractmethod + async def count_conversations_by_sandbox_id(self, sandbox_id: str) -> int: + """Count V1 conversations that reference the given sandbox. + + Used to decide whether a sandbox can be safely deleted when a + conversation is removed (only delete if count is 0). + """ + # Mutators @abstractmethod diff --git a/openhands/app_server/app_conversation/app_conversation_service.py b/openhands/app_server/app_conversation/app_conversation_service.py index 1f955cac9c..6be1d32ddf 100644 --- a/openhands/app_server/app_conversation/app_conversation_service.py +++ b/openhands/app_server/app_conversation/app_conversation_service.py @@ -77,8 +77,20 @@ class AppConversationService(ABC): id, starting a conversation, attaching a callback, and then running the conversation. - Yields an instance of AppConversationStartTask as updates occur, which can be used to determine - the progress of the task. + This method returns an async iterator that yields the same + AppConversationStartTask repeatedly as status updates occur. Callers + should iterate until the task reaches a terminal status:: + + async for task in service.start_app_conversation(request): + if task.status in ( + AppConversationStartTaskStatus.READY, + AppConversationStartTaskStatus.ERROR, + ): + break + + Status progression: WORKING → WAITING_FOR_SANDBOX → PREPARING_REPOSITORY + → RUNNING_SETUP_SCRIPT → SETTING_UP_GIT_HOOKS → SETTING_UP_SKILLS + → STARTING_CONVERSATION → READY (or ERROR at any point). """ # This is an abstract method - concrete implementations should provide real values from openhands.app_server.app_conversation.app_conversation_models import ( @@ -111,15 +123,21 @@ class AppConversationService(ABC): """ @abstractmethod - async def delete_app_conversation(self, conversation_id: UUID) -> bool: + async def delete_app_conversation( + self, conversation_id: UUID, skip_agent_server_delete: bool = False + ) -> bool: """Delete a V1 conversation and all its associated data. Args: conversation_id: The UUID of the conversation to delete. + skip_agent_server_delete: If True, skip the agent server DELETE call. + This should be set when the sandbox is shared with other + conversations (e.g. created via /new) to avoid destabilizing + the shared runtime. This method should: 1. Delete the conversation from the database - 2. Call the agent server to delete the conversation + 2. Call the agent server to delete the conversation (unless skipped) 3. Clean up any related data Returns True if the conversation was deleted successfully, False otherwise. diff --git a/openhands/app_server/app_conversation/live_status_app_conversation_service.py b/openhands/app_server/app_conversation/live_status_app_conversation_service.py index 703899ec83..b85e1de48f 100644 --- a/openhands/app_server/app_conversation/live_status_app_conversation_service.py +++ b/openhands/app_server/app_conversation/live_status_app_conversation_service.py @@ -1740,13 +1740,19 @@ class LiveStatusAppConversationService(AppConversationServiceBase): conversations = await self._build_app_conversations([info]) return conversations[0] - async def delete_app_conversation(self, conversation_id: UUID) -> bool: + async def delete_app_conversation( + self, conversation_id: UUID, skip_agent_server_delete: bool = False + ) -> bool: """Delete a V1 conversation and all its associated data. This method will also cascade delete all sub-conversations of the parent. Args: conversation_id: The UUID of the conversation to delete. + skip_agent_server_delete: If True, skip the agent server DELETE call. + This should be set when the sandbox is shared with other + conversations (e.g. created via /new) to avoid destabilizing + the shared runtime. """ # Check if we have the required SQL implementation for transactional deletion if not isinstance( @@ -1772,8 +1778,9 @@ class LiveStatusAppConversationService(AppConversationServiceBase): await self._delete_sub_conversations(conversation_id) # Now delete the parent conversation - # Delete from agent server if sandbox is running - await self._delete_from_agent_server(app_conversation) + # Delete from agent server if sandbox is running (skip if sandbox is shared) + if not skip_agent_server_delete: + await self._delete_from_agent_server(app_conversation) # Delete from database using the conversation info from app_conversation # AppConversation extends AppConversationInfo, so we can use it directly diff --git a/openhands/app_server/app_conversation/sql_app_conversation_info_service.py b/openhands/app_server/app_conversation/sql_app_conversation_info_service.py index c7c9e1935e..80b77957ba 100644 --- a/openhands/app_server/app_conversation/sql_app_conversation_info_service.py +++ b/openhands/app_server/app_conversation/sql_app_conversation_info_service.py @@ -278,6 +278,14 @@ class SQLAppConversationInfoService(AppConversationInfoService): rows = result_set.scalars().all() return [UUID(row.conversation_id) for row in rows] + async def count_conversations_by_sandbox_id(self, sandbox_id: str) -> int: + query = await self._secure_select() + query = query.where(StoredConversationMetadata.sandbox_id == sandbox_id) + count_query = select(func.count()).select_from(query.subquery()) + result = await self.db_session.execute(count_query) + count = result.scalar() + return count or 0 + async def get_app_conversation_info( self, conversation_id: UUID ) -> AppConversationInfo | None: diff --git a/openhands/app_server/config.py b/openhands/app_server/config.py index 4b7f78e389..96168143ed 100644 --- a/openhands/app_server/config.py +++ b/openhands/app_server/config.py @@ -87,6 +87,19 @@ def get_default_web_url() -> str | None: return f'https://{web_host}' +def get_default_permitted_cors_origins() -> list[str]: + """Get permitted CORS origins, falling back to legacy PERMITTED_CORS_ORIGINS env var. + + The preferred configuration is via OH_PERMITTED_CORS_ORIGINS_0, _1, etc. + (handled by the pydantic from_env parser). This fallback supports the legacy + comma-separated PERMITTED_CORS_ORIGINS environment variable. + """ + legacy = os.getenv('PERMITTED_CORS_ORIGINS', '') + if legacy: + return [o.strip() for o in legacy.split(',') if o.strip()] + return [] + + def get_openhands_provider_base_url() -> str | None: """Return the base URL for the OpenHands provider, if configured.""" return os.getenv('OPENHANDS_PROVIDER_BASE_URL') or None @@ -106,6 +119,14 @@ class AppServerConfig(OpenHandsModel): default_factory=get_default_web_url, description='The URL where OpenHands is running (e.g., http://localhost:3000)', ) + permitted_cors_origins: list[str] = Field( + default_factory=get_default_permitted_cors_origins, + description=( + 'Additional permitted CORS origins for both the app server and agent ' + 'server containers. Configure via OH_PERMITTED_CORS_ORIGINS_0, _1, etc. ' + 'Falls back to legacy PERMITTED_CORS_ORIGINS env var.' + ), + ) openhands_provider_base_url: str | None = Field( default_factory=get_openhands_provider_base_url, description='Base URL for the OpenHands provider', diff --git a/openhands/app_server/sandbox/docker_sandbox_service.py b/openhands/app_server/sandbox/docker_sandbox_service.py index f5a302fa73..cccd873cb6 100644 --- a/openhands/app_server/sandbox/docker_sandbox_service.py +++ b/openhands/app_server/sandbox/docker_sandbox_service.py @@ -27,7 +27,6 @@ from openhands.app_server.sandbox.sandbox_models import ( SandboxStatus, ) from openhands.app_server.sandbox.sandbox_service import ( - ALLOW_CORS_ORIGINS_VARIABLE, SESSION_API_KEY_VARIABLE, WEBHOOK_CALLBACK_VARIABLE, SandboxService, @@ -91,6 +90,7 @@ class DockerSandboxService(SandboxService): httpx_client: httpx.AsyncClient max_num_sandboxes: int web_url: str | None = None + permitted_cors_origins: list[str] = field(default_factory=list) extra_hosts: dict[str, str] = field(default_factory=dict) docker_client: docker.DockerClient = field(default_factory=get_docker_client) startup_grace_seconds: int = STARTUP_GRACE_SECONDS @@ -386,8 +386,18 @@ class DockerSandboxService(SandboxService): # Set CORS origins for remote browser access when web_url is configured. # This allows the agent-server container to accept requests from the # frontend when running OpenHands on a remote machine. + # Each origin gets its own indexed env var (OH_ALLOW_CORS_ORIGINS_0, _1, etc.) + cors_origins: list[str] = [] if self.web_url: - env_vars[ALLOW_CORS_ORIGINS_VARIABLE] = self.web_url + cors_origins.append(self.web_url) + cors_origins.extend(self.permitted_cors_origins) + # Deduplicate while preserving order + seen: set[str] = set() + for origin in cors_origins: + if origin not in seen: + seen.add(origin) + idx = len(seen) - 1 + env_vars[f'OH_ALLOW_CORS_ORIGINS_{idx}'] = origin # Prepare port mappings and add port environment variables # When using host network, container ports are directly accessible on the host @@ -621,7 +631,7 @@ class DockerSandboxServiceInjector(SandboxServiceInjector): get_sandbox_spec_service, ) - # Get web_url from global config for CORS support + # Get web_url and permitted_cors_origins from global config config = get_global_config() web_url = config.web_url @@ -640,6 +650,7 @@ class DockerSandboxServiceInjector(SandboxServiceInjector): httpx_client=httpx_client, max_num_sandboxes=self.max_num_sandboxes, web_url=web_url, + permitted_cors_origins=config.permitted_cors_origins, extra_hosts=self.extra_hosts, startup_grace_seconds=self.startup_grace_seconds, use_host_network=self.use_host_network, diff --git a/openhands/server/middleware.py b/openhands/server/middleware.py index 902a881df8..b1e9c5649e 100644 --- a/openhands/server/middleware.py +++ b/openhands/server/middleware.py @@ -7,7 +7,7 @@ # Tag: Legacy-V0 # This module belongs to the old V0 web server. The V1 application server lives under openhands/app_server/. import asyncio -import os +import logging from collections import defaultdict from datetime import datetime, timedelta from urllib.parse import urlparse @@ -20,6 +20,8 @@ from starlette.requests import Request as StarletteRequest from starlette.responses import Response from starlette.types import ASGIApp +from openhands.app_server.config import get_global_config + class LocalhostCORSMiddleware(CORSMiddleware): """Custom CORS middleware that allows any request from localhost/127.0.0.1 domains, @@ -27,13 +29,8 @@ class LocalhostCORSMiddleware(CORSMiddleware): """ def __init__(self, app: ASGIApp) -> None: - allow_origins_str = os.getenv('PERMITTED_CORS_ORIGINS') - if allow_origins_str: - allow_origins = tuple( - origin.strip() for origin in allow_origins_str.split(',') - ) - else: - allow_origins = () + config = get_global_config() + allow_origins = tuple(config.permitted_cors_origins) super().__init__( app, allow_origins=allow_origins, @@ -51,6 +48,14 @@ class LocalhostCORSMiddleware(CORSMiddleware): if hostname in ['localhost', '127.0.0.1']: return True + # Allow any origin when no specific origins are configured (development mode) + # WARNING: This disables CORS protection. Use explicit CORS origins in production. + logging.getLogger(__name__).warning( + f'No CORS origins configured, allowing origin: {origin}. ' + 'Set OH_PERMITTED_CORS_ORIGINS for production environments.' + ) + return True + # For missing origin or other origins, use the parent class's logic result: bool = super().is_allowed_origin(origin) return result diff --git a/openhands/server/routes/manage_conversations.py b/openhands/server/routes/manage_conversations.py index fa73aa4d52..5789e9784c 100644 --- a/openhands/server/routes/manage_conversations.py +++ b/openhands/server/routes/manage_conversations.py @@ -603,16 +603,28 @@ async def _try_delete_v1_conversation( ) ) if app_conversation_info: + # Check if the sandbox is shared with other conversations + # (e.g. multiple conversations can share a sandbox via /new). + # If shared, skip the agent server DELETE call to avoid + # destabilizing the runtime for the remaining conversations. + sandbox_id = app_conversation_info.sandbox_id + sandbox_is_shared = False + if sandbox_id: + conversation_count = await app_conversation_info_service.count_conversations_by_sandbox_id( + sandbox_id + ) + sandbox_is_shared = conversation_count > 1 + # This is a V1 conversation, delete it using the app conversation service - # Pass the conversation ID for secure deletion result = await app_conversation_service.delete_app_conversation( - app_conversation_info.id + app_conversation_info.id, + skip_agent_server_delete=sandbox_is_shared, ) # Manually commit so that the conversation will vanish from the list await db_session.commit() - # Delete the sandbox in the background + # Delete the sandbox in the background (checks remaining conversations first) asyncio.create_task( _finalize_delete_and_close_connections( sandbox_service, diff --git a/tests/unit/app_server/test_sql_app_conversation_info_service.py b/tests/unit/app_server/test_sql_app_conversation_info_service.py index a491fa93af..48e9693641 100644 --- a/tests/unit/app_server/test_sql_app_conversation_info_service.py +++ b/tests/unit/app_server/test_sql_app_conversation_info_service.py @@ -286,6 +286,54 @@ class TestSQLAppConversationInfoService: results = await service.batch_get_app_conversation_info([]) assert results == [] + @pytest.mark.asyncio + async def test_count_conversations_by_sandbox_id( + self, + service: SQLAppConversationInfoService, + ): + """Test count by sandbox_id: only delete sandbox when no conversation uses it.""" + base_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + shared_sandbox = 'shared_sandbox_1' + other_sandbox = 'other_sandbox' + for i in range(3): + info = AppConversationInfo( + id=uuid4(), + created_by_user_id=None, + sandbox_id=shared_sandbox, + selected_repository='https://github.com/test/repo', + selected_branch='main', + git_provider=ProviderType.GITHUB, + title=f'Conversation {i}', + trigger=ConversationTrigger.GUI, + pr_number=[], + llm_model='gpt-4', + metrics=None, + created_at=base_time, + updated_at=base_time, + ) + await service.save_app_conversation_info(info) + for i in range(2): + info = AppConversationInfo( + id=uuid4(), + created_by_user_id=None, + sandbox_id=other_sandbox, + selected_repository='https://github.com/test/repo', + selected_branch='main', + git_provider=ProviderType.GITHUB, + title=f'Other {i}', + trigger=ConversationTrigger.GUI, + pr_number=[], + llm_model='gpt-4', + metrics=None, + created_at=base_time, + updated_at=base_time, + ) + await service.save_app_conversation_info(info) + + assert await service.count_conversations_by_sandbox_id(shared_sandbox) == 3 + assert await service.count_conversations_by_sandbox_id(other_sandbox) == 2 + assert await service.count_conversations_by_sandbox_id('no_such_sandbox') == 0 + @pytest.mark.asyncio async def test_search_conversation_info_no_filters( self, diff --git a/tests/unit/server/data_models/test_conversation.py b/tests/unit/server/data_models/test_conversation.py index 99dbdfaacc..3c84afd0c6 100644 --- a/tests/unit/server/data_models/test_conversation.py +++ b/tests/unit/server/data_models/test_conversation.py @@ -1038,6 +1038,9 @@ async def test_delete_v1_conversation_success(): return_value=mock_app_conversation_info ) mock_service.delete_app_conversation = AsyncMock(return_value=True) + mock_info_service.count_conversations_by_sandbox_id = AsyncMock( + return_value=1 + ) # Call delete_conversation with V1 conversation ID result = await delete_conversation( @@ -1059,7 +1062,8 @@ async def test_delete_v1_conversation_success(): # Verify that delete_app_conversation was called with the conversation ID mock_service.delete_app_conversation.assert_called_once_with( - conversation_uuid + conversation_uuid, + skip_agent_server_delete=False, ) @@ -1357,6 +1361,9 @@ async def test_delete_v1_conversation_with_agent_server(): return_value=mock_app_conversation_info ) mock_service.delete_app_conversation = AsyncMock(return_value=True) + mock_info_service.count_conversations_by_sandbox_id = AsyncMock( + return_value=1 + ) # Call delete_conversation with V1 conversation ID result = await delete_conversation( @@ -1378,7 +1385,8 @@ async def test_delete_v1_conversation_with_agent_server(): # Verify that delete_app_conversation was called with the conversation ID mock_service.delete_app_conversation.assert_called_once_with( - conversation_uuid + conversation_uuid, + skip_agent_server_delete=False, ) diff --git a/tests/unit/server/test_middleware.py b/tests/unit/server/test_middleware.py index 2bdf2275fc..cdc922b5e9 100644 --- a/tests/unit/server/test_middleware.py +++ b/tests/unit/server/test_middleware.py @@ -1,5 +1,4 @@ -import os -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest from fastapi import FastAPI @@ -21,34 +20,46 @@ def app(): return app -def test_localhost_cors_middleware_init_with_env_var(): - """Test that the middleware correctly parses PERMITTED_CORS_ORIGINS environment variable.""" - with patch.dict( - os.environ, {'PERMITTED_CORS_ORIGINS': 'https://example.com,https://test.com'} +def test_localhost_cors_middleware_init_with_config(): + """Test that the middleware correctly reads permitted_cors_origins from global config.""" + mock_config = MagicMock() + mock_config.permitted_cors_origins = [ + 'https://example.com', + 'https://test.com', + ] + with patch( + 'openhands.server.middleware.get_global_config', return_value=mock_config ): app = FastAPI() middleware = LocalhostCORSMiddleware(app) - # Check that the origins were correctly parsed from the environment variable + # Check that the origins were correctly read from the config assert 'https://example.com' in middleware.allow_origins assert 'https://test.com' in middleware.allow_origins assert len(middleware.allow_origins) == 2 -def test_localhost_cors_middleware_init_without_env_var(): - """Test that the middleware works correctly without PERMITTED_CORS_ORIGINS environment variable.""" - with patch.dict(os.environ, {}, clear=True): +def test_localhost_cors_middleware_init_without_config(): + """Test that the middleware works correctly without permitted_cors_origins configured.""" + mock_config = MagicMock() + mock_config.permitted_cors_origins = [] + with patch( + 'openhands.server.middleware.get_global_config', return_value=mock_config + ): app = FastAPI() middleware = LocalhostCORSMiddleware(app) - # Check that allow_origins is empty when no environment variable is set + # Check that allow_origins is empty when no origins are configured assert middleware.allow_origins == () def test_localhost_cors_middleware_is_allowed_origin_localhost(app): """Test that localhost origins are allowed regardless of port when no specific origins are configured.""" - # Test without setting PERMITTED_CORS_ORIGINS to trigger localhost behavior - with patch.dict(os.environ, {}, clear=True): + mock_config = MagicMock() + mock_config.permitted_cors_origins = [] + with patch( + 'openhands.server.middleware.get_global_config', return_value=mock_config + ): app.add_middleware(LocalhostCORSMiddleware) client = TestClient(app) @@ -76,8 +87,11 @@ def test_localhost_cors_middleware_is_allowed_origin_localhost(app): def test_localhost_cors_middleware_is_allowed_origin_non_localhost(app): """Test that non-localhost origins follow the standard CORS rules.""" - # Set up the middleware with specific allowed origins - with patch.dict(os.environ, {'PERMITTED_CORS_ORIGINS': 'https://example.com'}): + mock_config = MagicMock() + mock_config.permitted_cors_origins = ['https://example.com'] + with patch( + 'openhands.server.middleware.get_global_config', return_value=mock_config + ): app.add_middleware(LocalhostCORSMiddleware) client = TestClient(app) @@ -95,7 +109,11 @@ def test_localhost_cors_middleware_is_allowed_origin_non_localhost(app): def test_localhost_cors_middleware_missing_origin(app): """Test behavior when Origin header is missing.""" - with patch.dict(os.environ, {}, clear=True): + mock_config = MagicMock() + mock_config.permitted_cors_origins = [] + with patch( + 'openhands.server.middleware.get_global_config', return_value=mock_config + ): app.add_middleware(LocalhostCORSMiddleware) client = TestClient(app) @@ -113,17 +131,22 @@ def test_localhost_cors_middleware_inheritance(): def test_localhost_cors_middleware_cors_parameters(): """Test that CORS parameters are set correctly in the middleware.""" - # We need to inspect the initialization parameters rather than attributes - # since CORSMiddleware doesn't expose these as attributes - with patch('fastapi.middleware.cors.CORSMiddleware.__init__') as mock_init: - mock_init.return_value = None - app = FastAPI() - LocalhostCORSMiddleware(app) + mock_config = MagicMock() + mock_config.permitted_cors_origins = [] + with patch( + 'openhands.server.middleware.get_global_config', return_value=mock_config + ): + # We need to inspect the initialization parameters rather than attributes + # since CORSMiddleware doesn't expose these as attributes + with patch('fastapi.middleware.cors.CORSMiddleware.__init__') as mock_init: + mock_init.return_value = None + app = FastAPI() + LocalhostCORSMiddleware(app) - # Check that the parent class was initialized with the correct parameters - mock_init.assert_called_once() - _, kwargs = mock_init.call_args + # Check that the parent class was initialized with the correct parameters + mock_init.assert_called_once() + _, kwargs = mock_init.call_args - assert kwargs['allow_credentials'] is True - assert kwargs['allow_methods'] == ['*'] - assert kwargs['allow_headers'] == ['*'] + assert kwargs['allow_credentials'] is True + assert kwargs['allow_methods'] == ['*'] + assert kwargs['allow_headers'] == ['*'] From 2d1e9fa35b58f22bbafc266fe61f045e0d781cbd Mon Sep 17 00:00:00 2001 From: aivong-openhands Date: Thu, 19 Mar 2026 10:05:30 -0500 Subject: [PATCH 14/28] Fix CVE-2026-33123: Update pypdf to 6.9.1 (#13473) Co-authored-by: OpenHands CVE Fix Bot --- enterprise/poetry.lock | 6 +++--- poetry.lock | 8 ++++---- pyproject.toml | 4 ++-- uv.lock | 8 ++++---- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/enterprise/poetry.lock b/enterprise/poetry.lock index 1bb48f24c6..39ef61101d 100644 --- a/enterprise/poetry.lock +++ b/enterprise/poetry.lock @@ -11587,14 +11587,14 @@ diagrams = ["jinja2", "railroad-diagrams"] [[package]] name = "pypdf" -version = "6.8.0" +version = "6.9.1" description = "A pure-python PDF library capable of splitting, merging, cropping, and transforming PDF files" optional = false python-versions = ">=3.9" groups = ["main"] files = [ - {file = "pypdf-6.8.0-py3-none-any.whl", hash = "sha256:2a025080a8dd73f48123c89c57174a5ff3806c71763ee4e49572dc90454943c7"}, - {file = "pypdf-6.8.0.tar.gz", hash = "sha256:cb7eaeaa4133ce76f762184069a854e03f4d9a08568f0e0623f7ea810407833b"}, + {file = "pypdf-6.9.1-py3-none-any.whl", hash = "sha256:f35a6a022348fae47e092a908339a8f3dc993510c026bb39a96718fc7185e89f"}, + {file = "pypdf-6.9.1.tar.gz", hash = "sha256:ae052407d33d34de0c86c5c729be6d51010bf36e03035a8f23ab449bca52377d"}, ] [package.extras] diff --git a/poetry.lock b/poetry.lock index bccd0eea80..9644ef383c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -11564,14 +11564,14 @@ diagrams = ["jinja2", "railroad-diagrams"] [[package]] name = "pypdf" -version = "6.8.0" +version = "6.9.1" description = "A pure-python PDF library capable of splitting, merging, cropping, and transforming PDF files" optional = false python-versions = ">=3.9" groups = ["main"] files = [ - {file = "pypdf-6.8.0-py3-none-any.whl", hash = "sha256:2a025080a8dd73f48123c89c57174a5ff3806c71763ee4e49572dc90454943c7"}, - {file = "pypdf-6.8.0.tar.gz", hash = "sha256:cb7eaeaa4133ce76f762184069a854e03f4d9a08568f0e0623f7ea810407833b"}, + {file = "pypdf-6.9.1-py3-none-any.whl", hash = "sha256:f35a6a022348fae47e092a908339a8f3dc993510c026bb39a96718fc7185e89f"}, + {file = "pypdf-6.9.1.tar.gz", hash = "sha256:ae052407d33d34de0c86c5c729be6d51010bf36e03035a8f23ab449bca52377d"}, ] [package.extras] @@ -14833,4 +14833,4 @@ third-party-runtimes = ["daytona", "e2b-code-interpreter", "modal", "runloop-api [metadata] lock-version = "2.1" python-versions = "^3.12,<3.14" -content-hash = "1a8151b36fb64667d1a2e83f38060841de15bd0284f18e8f58c6ee95095e933e" +content-hash = "1d1661870075ed85d87818cc3f3bd30bf23dcd00d1604be57f616f60b583c758" diff --git a/pyproject.toml b/pyproject.toml index 87609dbf9b..9595af0fab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,7 +76,7 @@ dependencies = [ "pygithub>=2.5", "pyjwt>=2.12", "pylatexenc", - "pypdf>=6.7.2", + "pypdf>=6.9.1", "python-docx", "python-dotenv", "python-frontmatter>=1.1", @@ -224,7 +224,7 @@ python-docx = "*" bashlex = "^0.18" # Explicitly pinned packages for latest versions -pypdf = "^6.7.2" +pypdf = "^6.9.1" pillow = "^12.1.1" starlette = "^0.49.1" urllib3 = "^2.6.3" diff --git a/uv.lock b/uv.lock index 67c7965698..269ff03c0f 100644 --- a/uv.lock +++ b/uv.lock @@ -3846,7 +3846,7 @@ requires-dist = [ { name = "pygithub", specifier = ">=2.5" }, { name = "pyjwt", specifier = ">=2.12" }, { name = "pylatexenc" }, - { name = "pypdf", specifier = ">=6.7.2" }, + { name = "pypdf", specifier = ">=6.9.1" }, { name = "python-docx" }, { name = "python-dotenv" }, { name = "python-frontmatter", specifier = ">=1.1" }, @@ -7385,11 +7385,11 @@ wheels = [ [[package]] name = "pypdf" -version = "6.8.0" +version = "6.9.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/b4/a3/e705b0805212b663a4c27b861c8a603dba0f8b4bb281f96f8e746576a50d/pypdf-6.8.0.tar.gz", hash = "sha256:cb7eaeaa4133ce76f762184069a854e03f4d9a08568f0e0623f7ea810407833b", size = 5307831, upload-time = "2026-03-09T13:37:40.591Z" } +sdist = { url = "https://files.pythonhosted.org/packages/f9/fb/dc2e8cb006e80b0020ed20d8649106fe4274e82d8e756ad3e24ade19c0df/pypdf-6.9.1.tar.gz", hash = "sha256:ae052407d33d34de0c86c5c729be6d51010bf36e03035a8f23ab449bca52377d", size = 5311551, upload-time = "2026-03-17T10:46:07.876Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/8c/ec/4ccf3bb86b1afe5d7176e1c8abcdbf22b53dd682ec2eda50e1caadcf6846/pypdf-6.8.0-py3-none-any.whl", hash = "sha256:2a025080a8dd73f48123c89c57174a5ff3806c71763ee4e49572dc90454943c7", size = 332177, upload-time = "2026-03-09T13:37:38.774Z" }, + { url = "https://files.pythonhosted.org/packages/f9/f4/75543fa802b86e72f87e9395440fe1a89a6d149887e3e55745715c3352ac/pypdf-6.9.1-py3-none-any.whl", hash = "sha256:f35a6a022348fae47e092a908339a8f3dc993510c026bb39a96718fc7185e89f", size = 333661, upload-time = "2026-03-17T10:46:06.286Z" }, ] [[package]] From 2224127ac305f353d2013655e71b591c33291056 Mon Sep 17 00:00:00 2001 From: chuckbutkus Date: Thu, 19 Mar 2026 11:14:48 -0400 Subject: [PATCH 15/28] Fix when budgets are None (#13482) Co-authored-by: openhands --- enterprise/storage/lite_llm_manager.py | 60 ++++-- .../tests/unit/test_lite_llm_manager.py | 192 ++++++++++++++++++ 2 files changed, 231 insertions(+), 21 deletions(-) diff --git a/enterprise/storage/lite_llm_manager.py b/enterprise/storage/lite_llm_manager.py index b515b7a7d9..d4e1aefd2c 100644 --- a/enterprise/storage/lite_llm_manager.py +++ b/enterprise/storage/lite_llm_manager.py @@ -589,20 +589,26 @@ class LiteLlmManager: if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None: logger.warning('LiteLLM API configuration not found') return + + json_data: dict[str, Any] = { + 'team_id': team_id, + 'team_alias': team_alias, + 'models': [], + 'spend': 0, + 'metadata': { + 'version': ORG_SETTINGS_VERSION, + 'model': get_default_litellm_model(), + }, + } + + if max_budget is not None: + json_data['max_budget'] = max_budget + response = await client.post( f'{LITE_LLM_API_URL}/team/new', - json={ - 'team_id': team_id, - 'team_alias': team_alias, - 'models': [], - 'max_budget': max_budget, # None disables budget enforcement - 'spend': 0, - 'metadata': { - 'version': ORG_SETTINGS_VERSION, - 'model': get_default_litellm_model(), - }, - }, + json=json_data, ) + # Team failed to create in litellm - this is an unforseen error state... if not response.is_success: if ( @@ -1040,14 +1046,20 @@ class LiteLlmManager: if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None: logger.warning('LiteLLM API configuration not found') return + + json_data: dict[str, Any] = { + 'team_id': team_id, + 'member': {'user_id': keycloak_user_id, 'role': 'user'}, + } + + if max_budget is not None: + json_data['max_budget_in_team'] = max_budget + response = await client.post( f'{LITE_LLM_API_URL}/team/member_add', - json={ - 'team_id': team_id, - 'member': {'user_id': keycloak_user_id, 'role': 'user'}, - 'max_budget_in_team': max_budget, # None disables budget enforcement - }, + json=json_data, ) + # Failed to add user to team - this is an unforseen error state... if not response.is_success: if ( @@ -1129,14 +1141,20 @@ class LiteLlmManager: if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None: logger.warning('LiteLLM API configuration not found') return + + json_data: dict[str, Any] = { + 'team_id': team_id, + 'user_id': keycloak_user_id, + } + + if max_budget is not None: + json_data['max_budget_in_team'] = max_budget + response = await client.post( f'{LITE_LLM_API_URL}/team/member_update', - json={ - 'team_id': team_id, - 'user_id': keycloak_user_id, - 'max_budget_in_team': max_budget, # None disables budget enforcement - }, + json=json_data, ) + # Failed to update user in team - this is an unforseen error state... if not response.is_success: logger.error( diff --git a/enterprise/tests/unit/test_lite_llm_manager.py b/enterprise/tests/unit/test_lite_llm_manager.py index 3da159421d..ffd964b77f 100644 --- a/enterprise/tests/unit/test_lite_llm_manager.py +++ b/enterprise/tests/unit/test_lite_llm_manager.py @@ -2384,3 +2384,195 @@ class TestVerifyExistingKey: openhands_type=True, ) assert result is False + + +class TestBudgetPayloadHandling: + """Test cases for budget field handling in API payloads. + + These tests verify that when max_budget is None, the budget field is NOT + included in the JSON payload (which tells LiteLLM to disable budget + enforcement), and when max_budget has a value, it IS included. + """ + + @pytest.mark.asyncio + async def test_create_team_excludes_max_budget_when_none(self): + """Test that _create_team does NOT include max_budget when it is None.""" + mock_client = AsyncMock(spec=httpx.AsyncClient) + mock_response = MagicMock() + mock_response.is_success = True + mock_response.status_code = 200 + mock_client.post.return_value = mock_response + + with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-api-key'): + with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'): + await LiteLlmManager._create_team( + mock_client, + team_alias='test-team', + team_id='test-team-id', + max_budget=None, # None = no budget limit + ) + + # Verify the call was made + mock_client.post.assert_called_once() + call_args = mock_client.post.call_args + + # Verify URL + assert call_args[0][0] == 'http://test.com/team/new' + + # Verify that max_budget is NOT in the JSON payload + json_payload = call_args[1]['json'] + assert 'max_budget' not in json_payload, ( + 'max_budget should NOT be in payload when None ' + '(omitting it tells LiteLLM to disable budget enforcement)' + ) + + @pytest.mark.asyncio + async def test_create_team_includes_max_budget_when_set(self): + """Test that _create_team includes max_budget when it has a value.""" + mock_client = AsyncMock(spec=httpx.AsyncClient) + mock_response = MagicMock() + mock_response.is_success = True + mock_response.status_code = 200 + mock_client.post.return_value = mock_response + + with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-api-key'): + with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'): + await LiteLlmManager._create_team( + mock_client, + team_alias='test-team', + team_id='test-team-id', + max_budget=100.0, # Explicit budget limit + ) + + # Verify the call was made + mock_client.post.assert_called_once() + call_args = mock_client.post.call_args + + # Verify that max_budget IS in the JSON payload with the correct value + json_payload = call_args[1]['json'] + assert ( + 'max_budget' in json_payload + ), 'max_budget should be in payload when set to a value' + assert json_payload['max_budget'] == 100.0 + + @pytest.mark.asyncio + async def test_add_user_to_team_excludes_max_budget_when_none(self): + """Test that _add_user_to_team does NOT include max_budget_in_team when None.""" + mock_client = AsyncMock(spec=httpx.AsyncClient) + mock_response = MagicMock() + mock_response.is_success = True + mock_response.status_code = 200 + mock_client.post.return_value = mock_response + + with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-api-key'): + with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'): + await LiteLlmManager._add_user_to_team( + mock_client, + keycloak_user_id='test-user-id', + team_id='test-team-id', + max_budget=None, # None = no budget limit + ) + + # Verify the call was made + mock_client.post.assert_called_once() + call_args = mock_client.post.call_args + + # Verify URL + assert call_args[0][0] == 'http://test.com/team/member_add' + + # Verify that max_budget_in_team is NOT in the JSON payload + json_payload = call_args[1]['json'] + assert 'max_budget_in_team' not in json_payload, ( + 'max_budget_in_team should NOT be in payload when None ' + '(omitting it tells LiteLLM to disable budget enforcement)' + ) + + @pytest.mark.asyncio + async def test_add_user_to_team_includes_max_budget_when_set(self): + """Test that _add_user_to_team includes max_budget_in_team when set.""" + mock_client = AsyncMock(spec=httpx.AsyncClient) + mock_response = MagicMock() + mock_response.is_success = True + mock_response.status_code = 200 + mock_client.post.return_value = mock_response + + with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-api-key'): + with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'): + await LiteLlmManager._add_user_to_team( + mock_client, + keycloak_user_id='test-user-id', + team_id='test-team-id', + max_budget=50.0, # Explicit budget limit + ) + + # Verify the call was made + mock_client.post.assert_called_once() + call_args = mock_client.post.call_args + + # Verify that max_budget_in_team IS in the JSON payload + json_payload = call_args[1]['json'] + assert ( + 'max_budget_in_team' in json_payload + ), 'max_budget_in_team should be in payload when set to a value' + assert json_payload['max_budget_in_team'] == 50.0 + + @pytest.mark.asyncio + async def test_update_user_in_team_excludes_max_budget_when_none(self): + """Test that _update_user_in_team does NOT include max_budget_in_team when None.""" + mock_client = AsyncMock(spec=httpx.AsyncClient) + mock_response = MagicMock() + mock_response.is_success = True + mock_response.status_code = 200 + mock_client.post.return_value = mock_response + + with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-api-key'): + with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'): + await LiteLlmManager._update_user_in_team( + mock_client, + keycloak_user_id='test-user-id', + team_id='test-team-id', + max_budget=None, # None = no budget limit + ) + + # Verify the call was made + mock_client.post.assert_called_once() + call_args = mock_client.post.call_args + + # Verify URL + assert call_args[0][0] == 'http://test.com/team/member_update' + + # Verify that max_budget_in_team is NOT in the JSON payload + json_payload = call_args[1]['json'] + assert 'max_budget_in_team' not in json_payload, ( + 'max_budget_in_team should NOT be in payload when None ' + '(omitting it tells LiteLLM to disable budget enforcement)' + ) + + @pytest.mark.asyncio + async def test_update_user_in_team_includes_max_budget_when_set(self): + """Test that _update_user_in_team includes max_budget_in_team when set.""" + mock_client = AsyncMock(spec=httpx.AsyncClient) + mock_response = MagicMock() + mock_response.is_success = True + mock_response.status_code = 200 + mock_client.post.return_value = mock_response + + with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-api-key'): + with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'): + await LiteLlmManager._update_user_in_team( + mock_client, + keycloak_user_id='test-user-id', + team_id='test-team-id', + max_budget=75.0, # Explicit budget limit + ) + + # Verify the call was made + mock_client.post.assert_called_once() + call_args = mock_client.post.call_args + + # Verify that max_budget_in_team IS in the JSON payload + json_payload = call_args[1]['json'] + assert ( + 'max_budget_in_team' in json_payload + ), 'max_budget_in_team should be in payload when set to a value' + assert json_payload['max_budget_in_team'] == 75.0 From 120fd7516a4e4eb9b4aa808bfaf53eb3072558e3 Mon Sep 17 00:00:00 2001 From: Chris Bagwell Date: Thu, 19 Mar 2026 10:33:01 -0500 Subject: [PATCH 16/28] Fix: Prevent auto-logout on 401 errors in oss mode (#13466) --- .../features/user/user-context-menu.test.tsx | 40 +++++++++++++++++-- .../features/user/user-context-menu.tsx | 17 ++++---- frontend/src/hooks/query/use-git-user.ts | 7 ++-- 3 files changed, 50 insertions(+), 14 deletions(-) diff --git a/frontend/__tests__/components/features/user/user-context-menu.test.tsx b/frontend/__tests__/components/features/user/user-context-menu.test.tsx index 635f66e645..07895f547c 100644 --- a/frontend/__tests__/components/features/user/user-context-menu.test.tsx +++ b/frontend/__tests__/components/features/user/user-context-menu.test.tsx @@ -156,11 +156,19 @@ describe("UserContextMenu", () => { useSelectedOrganizationStore.setState({ organizationId: null }); }); - it("should render the default context items for a user", () => { + it("should render the default context items for a user", async () => { + vi.spyOn(OptionService, "getConfig").mockResolvedValue( + createMockWebClientConfig({ app_mode: "saas" }), + ); + renderUserContextMenu({ type: "member", onClose: vi.fn, onOpenInviteModal: vi.fn }); screen.getByTestId("org-selector"); - screen.getByText("ACCOUNT_SETTINGS$LOGOUT"); + + // Wait for config to load so logout button appears + await waitFor(() => { + expect(screen.getByText("ACCOUNT_SETTINGS$LOGOUT")).toBeInTheDocument(); + }); expect( screen.queryByText("ORG$INVITE_ORG_MEMBERS"), @@ -304,6 +312,20 @@ describe("UserContextMenu", () => { screen.queryByText("Organization Members"), ).not.toBeInTheDocument(); }); + + it("should not display logout button in OSS mode", async () => { + renderUserContextMenu({ type: "member", onClose: vi.fn, onOpenInviteModal: vi.fn }); + + // Wait for the config to load + await waitFor(() => { + expect(screen.getByText("SETTINGS$NAV_LLM")).toBeInTheDocument(); + }); + + // Verify logout button is NOT rendered in OSS mode + expect( + screen.queryByText("ACCOUNT_SETTINGS$LOGOUT"), + ).not.toBeInTheDocument(); + }); }); describe("HIDE_LLM_SETTINGS feature flag", () => { @@ -382,10 +404,15 @@ describe("UserContextMenu", () => { }); it("should call the logout handler when Logout is clicked", async () => { + vi.spyOn(OptionService, "getConfig").mockResolvedValue( + createMockWebClientConfig({ app_mode: "saas" }), + ); + const logoutSpy = vi.spyOn(AuthService, "logout"); renderUserContextMenu({ type: "member", onClose: vi.fn, onOpenInviteModal: vi.fn }); - const logoutButton = screen.getByText("ACCOUNT_SETTINGS$LOGOUT"); + // Wait for config to load so logout button appears + const logoutButton = await screen.findByText("ACCOUNT_SETTINGS$LOGOUT"); await userEvent.click(logoutButton); expect(logoutSpy).toHaveBeenCalledOnce(); @@ -488,6 +515,10 @@ describe("UserContextMenu", () => { }); it("should call the onClose handler after each action", async () => { + vi.spyOn(OptionService, "getConfig").mockResolvedValue( + createMockWebClientConfig({ app_mode: "saas" }), + ); + // Mock a team org so org management buttons are visible vi.spyOn(organizationService, "getOrganizations").mockResolvedValue({ items: [MOCK_TEAM_ORG_ACME], @@ -497,7 +528,8 @@ describe("UserContextMenu", () => { const onCloseMock = vi.fn(); renderUserContextMenu({ type: "owner", onClose: onCloseMock, onOpenInviteModal: vi.fn }); - const logoutButton = screen.getByText("ACCOUNT_SETTINGS$LOGOUT"); + // Wait for config to load so logout button appears + const logoutButton = await screen.findByText("ACCOUNT_SETTINGS$LOGOUT"); await userEvent.click(logoutButton); expect(onCloseMock).toHaveBeenCalledTimes(1); diff --git a/frontend/src/components/features/user/user-context-menu.tsx b/frontend/src/components/features/user/user-context-menu.tsx index b9094cc6d3..424dc7c0ec 100644 --- a/frontend/src/components/features/user/user-context-menu.tsx +++ b/frontend/src/components/features/user/user-context-menu.tsx @@ -156,13 +156,16 @@ export function UserContextMenu({ {t(I18nKey.SIDEBAR$DOCS)} - - - {t(I18nKey.ACCOUNT_SETTINGS$LOGOUT)} - + {/* Only show logout in saas mode - oss mode has no session to invalidate */} + {isSaasMode && ( + + + {t(I18nKey.ACCOUNT_SETTINGS$LOGOUT)} + + )}
diff --git a/frontend/src/hooks/query/use-git-user.ts b/frontend/src/hooks/query/use-git-user.ts index 971999f25c..a239b2d18a 100644 --- a/frontend/src/hooks/query/use-git-user.ts +++ b/frontend/src/hooks/query/use-git-user.ts @@ -35,13 +35,14 @@ export const useGitUser = () => { } }, [user.data]); - // If we get a 401 here, it means that the integration tokens need to be + // In saas mode, a 401 means that the integration tokens need to be // refreshed. Since this happens at login, we log out. + // In oss mode, skip auto-logout since there's no token refresh mechanism React.useEffect(() => { - if (user?.error?.response?.status === 401) { + if (user?.error?.response?.status === 401 && config?.app_mode === "saas") { logout.mutate(); } - }, [user.status]); + }, [user.status, config?.app_mode]); return user; }; From 04330898b6fca7dbe642e8221076783b94b1754d Mon Sep 17 00:00:00 2001 From: Hiep Le <69354317+hieptl@users.noreply.github.com> Date: Fri, 20 Mar 2026 00:12:38 +0700 Subject: [PATCH 17/28] refactor(frontend): add delay before closing user context menu (#13491) --- .../components/user-actions.test.tsx | 139 ++++++++++++++++-- .../features/sidebar/user-actions.tsx | 30 +++- 2 files changed, 151 insertions(+), 18 deletions(-) diff --git a/frontend/__tests__/components/user-actions.test.tsx b/frontend/__tests__/components/user-actions.test.tsx index 936586168d..4a8a42d1be 100644 --- a/frontend/__tests__/components/user-actions.test.tsx +++ b/frontend/__tests__/components/user-actions.test.tsx @@ -1,13 +1,16 @@ -import { render, screen, waitFor } from "@testing-library/react"; +import { render, screen, waitFor, fireEvent, act } from "@testing-library/react"; import { describe, expect, it, vi, afterEach, beforeEach, test } from "vitest"; import userEvent from "@testing-library/user-event"; import { QueryClientProvider, QueryClient } from "@tanstack/react-query"; -import { MemoryRouter } from "react-router"; +import { MemoryRouter, createRoutesStub } from "react-router"; import { ReactElement } from "react"; +import { http, HttpResponse } from "msw"; import { UserActions } from "#/components/features/sidebar/user-actions"; import { organizationService } from "#/api/organization-service/organization-service.api"; import { MOCK_PERSONAL_ORG, MOCK_TEAM_ORG_ACME } from "#/mocks/org-handlers"; import { useSelectedOrganizationStore } from "#/stores/selected-organization-store"; +import { server } from "#/mocks/node"; +import { createMockWebClientConfig } from "#/mocks/settings-handlers"; import { renderWithProviders } from "../../test-utils"; vi.mock("react-router", async (importActual) => ({ @@ -59,6 +62,20 @@ const renderUserActions = (props = { hasAvatar: true }) => { ); }; +// RouterStub and render helper for menu close delay tests +const RouterStubForMenuCloseDelay = createRoutesStub([ + { + path: "/", + Component: () => ( + + ), + }, +]); + +const renderUserActionsForMenuCloseDelay = () => { + return renderWithProviders(); +}; + // Create mocks for all the hooks we need const useIsAuthedMock = vi .fn() @@ -347,7 +364,7 @@ describe("UserActions", () => { expect(contextMenu).toBeVisible(); }); - it("should have pointer-events-none on hover bridge pseudo-element to allow menu item clicks", async () => { + it("should use state-based visibility for hover behavior instead of CSS pseudo-element", async () => { renderUserActions(); const userActions = screen.getByTestId("user-actions"); @@ -356,19 +373,17 @@ describe("UserActions", () => { const contextMenu = screen.getByTestId("user-context-menu"); const hoverBridgeContainer = contextMenu.parentElement; - // The hover bridge uses a ::before pseudo-element for diagonal mouse movement - // This pseudo-element MUST have pointer-events-none to allow clicks through to menu items - // The class should include "before:pointer-events-none" to prevent the hover bridge from blocking clicks - expect(hoverBridgeContainer?.className).toContain( - "before:pointer-events-none", - ); + // The component uses state-based visibility with a 500ms delay for diagonal mouse movement + // When visible, the container should have opacity-100 and pointer-events-auto + expect(hoverBridgeContainer?.className).toContain("opacity-100"); + expect(hoverBridgeContainer?.className).toContain("pointer-events-auto"); }); describe("Org selector dropdown state reset when context menu hides", () => { // These tests verify that the org selector dropdown resets its internal // state (search text, open/closed) when the context menu hides and - // reappears. Without this, stale state persists because the context - // menu is hidden via CSS (opacity/pointer-events) rather than unmounted. + // reappears. The component uses a 500ms delay before hiding (to support + // diagonal mouse movement). beforeEach(() => { vi.spyOn(organizationService, "getOrganizations").mockResolvedValue({ @@ -400,8 +415,22 @@ describe("UserActions", () => { await user.type(input, "search text"); expect(input).toHaveValue("search text"); - // Unhover to hide context menu, then hover again + // Unhover to trigger hide timeout, then wait for the 500ms delay to complete await user.unhover(userActions); + + // Wait for the 500ms hide delay to complete and menu to actually hide + await waitFor( + () => { + // The menu resets when it actually hides (after 500ms delay) + // After hiding, hovering again should show a fresh menu + }, + { timeout: 600 }, + ); + + // Wait a bit more for the timeout to fire + await new Promise((resolve) => setTimeout(resolve, 550)); + + // Now hover again to show the menu await user.hover(userActions); // Org selector should be reset — showing selected org name, not search text @@ -434,8 +463,13 @@ describe("UserActions", () => { await user.type(input, "Acme"); expect(input).toHaveValue("Acme"); - // Unhover to hide context menu, then hover again + // Unhover to trigger hide timeout await user.unhover(userActions); + + // Wait for the 500ms hide delay to complete + await new Promise((resolve) => setTimeout(resolve, 550)); + + // Now hover again to show the menu await user.hover(userActions); // Wait for fresh component with org data @@ -454,4 +488,83 @@ describe("UserActions", () => { expect(screen.queryAllByRole("option")).toHaveLength(0); }); }); + + describe("menu close delay", () => { + beforeEach(() => { + vi.useFakeTimers(); + useSelectedOrganizationStore.setState({ organizationId: "1" }); + + // Mock config to return SaaS mode so useShouldShowUserFeatures returns true + server.use( + http.get("/api/v1/web-client/config", () => + HttpResponse.json(createMockWebClientConfig({ app_mode: "saas" })), + ), + ); + }); + + afterEach(() => { + vi.useRealTimers(); + server.resetHandlers(); + }); + + it("should keep menu visible when mouse leaves and re-enters within 500ms", async () => { + // Arrange - render and wait for queries to settle + renderUserActionsForMenuCloseDelay(); + await act(async () => { + await vi.runAllTimersAsync(); + }); + + const userActions = screen.getByTestId("user-actions"); + + // Act - open menu + await act(async () => { + fireEvent.mouseEnter(userActions); + }); + + // Assert - menu is visible + expect(screen.getByTestId("user-context-menu")).toBeInTheDocument(); + + // Act - leave and re-enter within 500ms + await act(async () => { + fireEvent.mouseLeave(userActions); + await vi.advanceTimersByTimeAsync(200); + fireEvent.mouseEnter(userActions); + }); + + // Assert - menu should still be visible after waiting (pending close was cancelled) + await act(async () => { + await vi.advanceTimersByTimeAsync(500); + }); + expect(screen.getByTestId("user-context-menu")).toBeInTheDocument(); + }); + + it("should not close menu before 500ms delay when mouse leaves", async () => { + // Arrange - render and wait for queries to settle + renderUserActionsForMenuCloseDelay(); + await act(async () => { + await vi.runAllTimersAsync(); + }); + + const userActions = screen.getByTestId("user-actions"); + + // Act - open menu + await act(async () => { + fireEvent.mouseEnter(userActions); + }); + + // Assert - menu is visible + expect(screen.getByTestId("user-context-menu")).toBeInTheDocument(); + + // Act - leave without re-entering, but check before timeout expires + await act(async () => { + fireEvent.mouseLeave(userActions); + await vi.advanceTimersByTimeAsync(400); // Before the 500ms delay + }); + + // Assert - menu should still be visible (delay hasn't expired yet) + // Note: The menu is always in DOM but with opacity-0 when closed. + // This test verifies the state hasn't changed yet (delay is working). + expect(screen.getByTestId("user-context-menu")).toBeInTheDocument(); + }); + }); }); diff --git a/frontend/src/components/features/sidebar/user-actions.tsx b/frontend/src/components/features/sidebar/user-actions.tsx index 3620663789..2c715e4c2c 100644 --- a/frontend/src/components/features/sidebar/user-actions.tsx +++ b/frontend/src/components/features/sidebar/user-actions.tsx @@ -22,20 +22,43 @@ export function UserActions({ user, isLoading }: UserActionsProps) { const [menuResetCount, setMenuResetCount] = React.useState(0); const [inviteMemberModalIsOpen, setInviteMemberModalIsOpen] = React.useState(false); + const hideTimeoutRef = React.useRef(null); // Use the shared hook to determine if user actions should be shown const shouldShowUserActions = useShouldShowUserFeatures(); + // Clean up timeout on unmount + React.useEffect( + () => () => { + if (hideTimeoutRef.current) { + clearTimeout(hideTimeoutRef.current); + } + }, + [], + ); + const showAccountMenu = () => { + // Cancel any pending hide to allow diagonal mouse movement to menu + if (hideTimeoutRef.current) { + clearTimeout(hideTimeoutRef.current); + hideTimeoutRef.current = null; + } setAccountContextMenuIsVisible(true); }; const hideAccountMenu = () => { - setAccountContextMenuIsVisible(false); - setMenuResetCount((c) => c + 1); + // Delay hiding to allow diagonal mouse movement to menu + hideTimeoutRef.current = window.setTimeout(() => { + setAccountContextMenuIsVisible(false); + setMenuResetCount((c) => c + 1); + }, 500); }; const closeAccountMenu = () => { + if (hideTimeoutRef.current) { + clearTimeout(hideTimeoutRef.current); + hideTimeoutRef.current = null; + } if (accountContextMenuIsVisible) { setAccountContextMenuIsVisible(false); setMenuResetCount((c) => c + 1); @@ -61,9 +84,6 @@ export function UserActions({ user, isLoading }: UserActionsProps) { className={cn( "opacity-0 pointer-events-none group-hover:opacity-100 group-hover:pointer-events-auto", accountContextMenuIsVisible && "opacity-100 pointer-events-auto", - // Invisible hover bridge: extends hover zone to create a "safe corridor" - // for diagonal mouse movement to the menu (only active when menu is visible) - "group-hover:before:content-[''] group-hover:before:block group-hover:before:absolute group-hover:before:inset-[-320px] group-hover:before:z-50 before:pointer-events-none", )} > Date: Fri, 20 Mar 2026 00:12:48 +0700 Subject: [PATCH 18/28] refactor(frontend): extract AddCreditsModal into separate component file (#13490) --- .../features/org/add-credits-modal.test.tsx | 351 ++++++++++++++++++ frontend/__tests__/routes/manage-org.test.tsx | 299 --------------- .../features/org/add-credits-modal.tsx | 103 +++++ frontend/src/routes/manage-org.tsx | 101 +---- 4 files changed, 455 insertions(+), 399 deletions(-) create mode 100644 frontend/__tests__/components/features/org/add-credits-modal.test.tsx create mode 100644 frontend/src/components/features/org/add-credits-modal.tsx diff --git a/frontend/__tests__/components/features/org/add-credits-modal.test.tsx b/frontend/__tests__/components/features/org/add-credits-modal.test.tsx new file mode 100644 index 0000000000..1c049aedcd --- /dev/null +++ b/frontend/__tests__/components/features/org/add-credits-modal.test.tsx @@ -0,0 +1,351 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; +import { screen, waitFor } from "@testing-library/react"; +import userEvent from "@testing-library/user-event"; +import { renderWithProviders } from "test-utils"; +import { AddCreditsModal } from "#/components/features/org/add-credits-modal"; +import BillingService from "#/api/billing-service/billing-service.api"; + +vi.mock("react-i18next", async (importOriginal) => ({ + ...(await importOriginal()), + useTranslation: () => ({ + t: (key: string) => key, + i18n: { + changeLanguage: vi.fn(), + }, + }), +})); + +describe("AddCreditsModal", () => { + const onCloseMock = vi.fn(); + + const renderModal = () => { + const user = userEvent.setup(); + renderWithProviders(); + return { user }; + }; + + beforeEach(() => { + vi.clearAllMocks(); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + describe("Rendering", () => { + it("should render the form with correct elements", () => { + renderModal(); + + expect(screen.getByTestId("add-credits-form")).toBeInTheDocument(); + expect(screen.getByTestId("amount-input")).toBeInTheDocument(); + expect(screen.getByRole("button", { name: /ORG\$NEXT/i })).toBeInTheDocument(); + }); + + it("should display the title", () => { + renderModal(); + + expect(screen.getByText("ORG$ADD_CREDITS")).toBeInTheDocument(); + }); + }); + + describe("Button State Management", () => { + it("should enable submit button initially when modal opens", () => { + renderModal(); + + const nextButton = screen.getByRole("button", { name: /ORG\$NEXT/i }); + expect(nextButton).not.toBeDisabled(); + }); + + it("should enable submit button when input contains invalid value", async () => { + const { user } = renderModal(); + const amountInput = screen.getByTestId("amount-input"); + const nextButton = screen.getByRole("button", { name: /ORG\$NEXT/i }); + + await user.type(amountInput, "-50"); + + expect(nextButton).not.toBeDisabled(); + }); + + it("should enable submit button when input contains valid value", async () => { + const { user } = renderModal(); + const amountInput = screen.getByTestId("amount-input"); + const nextButton = screen.getByRole("button", { name: /ORG\$NEXT/i }); + + await user.type(amountInput, "100"); + + expect(nextButton).not.toBeDisabled(); + }); + + it("should enable submit button after validation error is shown", async () => { + const { user } = renderModal(); + const amountInput = screen.getByTestId("amount-input"); + const nextButton = screen.getByRole("button", { name: /ORG\$NEXT/i }); + + await user.type(amountInput, "9"); + await user.click(nextButton); + + await waitFor(() => { + expect(screen.getByTestId("amount-error")).toBeInTheDocument(); + }); + + expect(nextButton).not.toBeDisabled(); + }); + }); + + describe("Input Attributes & Placeholder", () => { + it("should have min attribute set to 10", () => { + renderModal(); + + const amountInput = screen.getByTestId("amount-input"); + expect(amountInput).toHaveAttribute("min", "10"); + }); + + it("should have max attribute set to 25000", () => { + renderModal(); + + const amountInput = screen.getByTestId("amount-input"); + expect(amountInput).toHaveAttribute("max", "25000"); + }); + + it("should have step attribute set to 1", () => { + renderModal(); + + const amountInput = screen.getByTestId("amount-input"); + expect(amountInput).toHaveAttribute("step", "1"); + }); + }); + + describe("Error Message Display", () => { + it("should not display error message initially when modal opens", () => { + renderModal(); + + const errorMessage = screen.queryByTestId("amount-error"); + expect(errorMessage).not.toBeInTheDocument(); + }); + + it("should display error message after submitting amount above maximum", async () => { + const { user } = renderModal(); + const amountInput = screen.getByTestId("amount-input"); + const nextButton = screen.getByRole("button", { name: /ORG\$NEXT/i }); + + await user.type(amountInput, "25001"); + await user.click(nextButton); + + await waitFor(() => { + const errorMessage = screen.getByTestId("amount-error"); + expect(errorMessage).toHaveTextContent("PAYMENT$ERROR_MAXIMUM_AMOUNT"); + }); + }); + + it("should display error message after submitting decimal value", async () => { + const { user } = renderModal(); + const amountInput = screen.getByTestId("amount-input"); + const nextButton = screen.getByRole("button", { name: /ORG\$NEXT/i }); + + await user.type(amountInput, "50.5"); + await user.click(nextButton); + + await waitFor(() => { + const errorMessage = screen.getByTestId("amount-error"); + expect(errorMessage).toHaveTextContent("PAYMENT$ERROR_MUST_BE_WHOLE_NUMBER"); + }); + }); + + it("should display error message after submitting amount below minimum", async () => { + const { user } = renderModal(); + const amountInput = screen.getByTestId("amount-input"); + const nextButton = screen.getByRole("button", { name: /ORG\$NEXT/i }); + + await user.type(amountInput, "9"); + await user.click(nextButton); + + await waitFor(() => { + const errorMessage = screen.getByTestId("amount-error"); + expect(errorMessage).toHaveTextContent("PAYMENT$ERROR_MINIMUM_AMOUNT"); + }); + }); + + it("should display error message after submitting negative amount", async () => { + const { user } = renderModal(); + const amountInput = screen.getByTestId("amount-input"); + const nextButton = screen.getByRole("button", { name: /ORG\$NEXT/i }); + + await user.type(amountInput, "-50"); + await user.click(nextButton); + + await waitFor(() => { + const errorMessage = screen.getByTestId("amount-error"); + expect(errorMessage).toHaveTextContent("PAYMENT$ERROR_NEGATIVE_AMOUNT"); + }); + }); + + it("should replace error message when submitting different invalid value", async () => { + const { user } = renderModal(); + const amountInput = screen.getByTestId("amount-input"); + const nextButton = screen.getByRole("button", { name: /ORG\$NEXT/i }); + + await user.type(amountInput, "9"); + await user.click(nextButton); + + await waitFor(() => { + const errorMessage = screen.getByTestId("amount-error"); + expect(errorMessage).toHaveTextContent("PAYMENT$ERROR_MINIMUM_AMOUNT"); + }); + + await user.clear(amountInput); + await user.type(amountInput, "25001"); + await user.click(nextButton); + + await waitFor(() => { + const errorMessage = screen.getByTestId("amount-error"); + expect(errorMessage).toHaveTextContent("PAYMENT$ERROR_MAXIMUM_AMOUNT"); + }); + }); + }); + + describe("Form Submission Behavior", () => { + it("should prevent submission when amount is invalid", async () => { + const createCheckoutSessionSpy = vi.spyOn( + BillingService, + "createCheckoutSession", + ); + const { user } = renderModal(); + const amountInput = screen.getByTestId("amount-input"); + const nextButton = screen.getByRole("button", { name: /ORG\$NEXT/i }); + + await user.type(amountInput, "9"); + await user.click(nextButton); + + expect(createCheckoutSessionSpy).not.toHaveBeenCalled(); + await waitFor(() => { + const errorMessage = screen.getByTestId("amount-error"); + expect(errorMessage).toHaveTextContent("PAYMENT$ERROR_MINIMUM_AMOUNT"); + }); + }); + + it("should call createCheckoutSession with correct amount when valid", async () => { + const createCheckoutSessionSpy = vi.spyOn( + BillingService, + "createCheckoutSession", + ); + const { user } = renderModal(); + const amountInput = screen.getByTestId("amount-input"); + const nextButton = screen.getByRole("button", { name: /ORG\$NEXT/i }); + + await user.type(amountInput, "1000"); + await user.click(nextButton); + + expect(createCheckoutSessionSpy).toHaveBeenCalledWith(1000); + const errorMessage = screen.queryByTestId("amount-error"); + expect(errorMessage).not.toBeInTheDocument(); + }); + + it("should not call createCheckoutSession when validation fails", async () => { + const createCheckoutSessionSpy = vi.spyOn( + BillingService, + "createCheckoutSession", + ); + const { user } = renderModal(); + const amountInput = screen.getByTestId("amount-input"); + const nextButton = screen.getByRole("button", { name: /ORG\$NEXT/i }); + + await user.type(amountInput, "-50"); + await user.click(nextButton); + + expect(createCheckoutSessionSpy).not.toHaveBeenCalled(); + await waitFor(() => { + const errorMessage = screen.getByTestId("amount-error"); + expect(errorMessage).toHaveTextContent("PAYMENT$ERROR_NEGATIVE_AMOUNT"); + }); + }); + + it("should close modal on successful submission", async () => { + vi.spyOn(BillingService, "createCheckoutSession").mockResolvedValue( + "https://checkout.stripe.com/test-session", + ); + + const { user } = renderModal(); + const amountInput = screen.getByTestId("amount-input"); + const nextButton = screen.getByRole("button", { name: /ORG\$NEXT/i }); + + await user.type(amountInput, "1000"); + await user.click(nextButton); + + await waitFor(() => { + expect(onCloseMock).toHaveBeenCalled(); + }); + }); + + it("should allow API call when validation passes and clear any previous errors", async () => { + const createCheckoutSessionSpy = vi.spyOn( + BillingService, + "createCheckoutSession", + ); + + const { user } = renderModal(); + const amountInput = screen.getByTestId("amount-input"); + const nextButton = screen.getByRole("button", { name: /ORG\$NEXT/i }); + + // First submit invalid value + await user.type(amountInput, "9"); + await user.click(nextButton); + + await waitFor(() => { + expect(screen.getByTestId("amount-error")).toBeInTheDocument(); + }); + + // Then submit valid value + await user.clear(amountInput); + await user.type(amountInput, "100"); + await user.click(nextButton); + + expect(createCheckoutSessionSpy).toHaveBeenCalledWith(100); + const errorMessage = screen.queryByTestId("amount-error"); + expect(errorMessage).not.toBeInTheDocument(); + }); + }); + + describe("Edge Cases", () => { + it("should handle zero value correctly", async () => { + const { user } = renderModal(); + const amountInput = screen.getByTestId("amount-input"); + const nextButton = screen.getByRole("button", { name: /ORG\$NEXT/i }); + + await user.type(amountInput, "0"); + await user.click(nextButton); + + await waitFor(() => { + const errorMessage = screen.getByTestId("amount-error"); + expect(errorMessage).toHaveTextContent("PAYMENT$ERROR_MINIMUM_AMOUNT"); + }); + }); + + it("should handle whitespace-only input correctly", async () => { + const createCheckoutSessionSpy = vi.spyOn( + BillingService, + "createCheckoutSession", + ); + const { user } = renderModal(); + const amountInput = screen.getByTestId("amount-input"); + const nextButton = screen.getByRole("button", { name: /ORG\$NEXT/i }); + + // Number inputs typically don't accept spaces, but test the behavior + await user.type(amountInput, " "); + await user.click(nextButton); + + // Should not call API (empty/invalid input) + expect(createCheckoutSessionSpy).not.toHaveBeenCalled(); + }); + }); + + describe("Modal Interaction", () => { + it("should call onClose when cancel button is clicked", async () => { + const { user } = renderModal(); + + const cancelButton = screen.getByRole("button", { name: /close/i }); + await user.click(cancelButton); + + expect(onCloseMock).toHaveBeenCalledOnce(); + }); + }); +}); diff --git a/frontend/__tests__/routes/manage-org.test.tsx b/frontend/__tests__/routes/manage-org.test.tsx index 390b10fc43..8f5cc137be 100644 --- a/frontend/__tests__/routes/manage-org.test.tsx +++ b/frontend/__tests__/routes/manage-org.test.tsx @@ -283,305 +283,6 @@ describe("Manage Org Route", () => { expect(createCheckoutSessionSpy).not.toHaveBeenCalled(); }); - describe("AddCreditsModal", () => { - const openAddCreditsModal = async () => { - const user = userEvent.setup(); - renderManageOrg(); - await screen.findByTestId("manage-org-screen"); - - await selectOrganization({ orgIndex: 0 }); // user is owner in org 1 - - const addCreditsButton = await waitFor(() => screen.getByText(/add/i)); - await user.click(addCreditsButton); - - const addCreditsForm = screen.getByTestId("add-credits-form"); - expect(addCreditsForm).toBeInTheDocument(); - - return { user, addCreditsForm }; - }; - - describe("Button State Management", () => { - it("should enable submit button initially when modal opens", async () => { - await openAddCreditsModal(); - - const nextButton = screen.getByRole("button", { name: /next/i }); - expect(nextButton).not.toBeDisabled(); - }); - - it("should enable submit button when input contains invalid value", async () => { - const { user } = await openAddCreditsModal(); - const amountInput = screen.getByTestId("amount-input"); - const nextButton = screen.getByRole("button", { name: /next/i }); - - await user.type(amountInput, "-50"); - - expect(nextButton).not.toBeDisabled(); - }); - - it("should enable submit button when input contains valid value", async () => { - const { user } = await openAddCreditsModal(); - const amountInput = screen.getByTestId("amount-input"); - const nextButton = screen.getByRole("button", { name: /next/i }); - - await user.type(amountInput, "100"); - - expect(nextButton).not.toBeDisabled(); - }); - - it("should enable submit button after validation error is shown", async () => { - const { user } = await openAddCreditsModal(); - const amountInput = screen.getByTestId("amount-input"); - const nextButton = screen.getByRole("button", { name: /next/i }); - - await user.type(amountInput, "9"); - await user.click(nextButton); - - await waitFor(() => { - expect(screen.getByTestId("amount-error")).toBeInTheDocument(); - }); - - expect(nextButton).not.toBeDisabled(); - }); - }); - - describe("Input Attributes & Placeholder", () => { - it("should have min attribute set to 10", async () => { - await openAddCreditsModal(); - - const amountInput = screen.getByTestId("amount-input"); - expect(amountInput).toHaveAttribute("min", "10"); - }); - - it("should have max attribute set to 25000", async () => { - await openAddCreditsModal(); - - const amountInput = screen.getByTestId("amount-input"); - expect(amountInput).toHaveAttribute("max", "25000"); - }); - - it("should have step attribute set to 1", async () => { - await openAddCreditsModal(); - - const amountInput = screen.getByTestId("amount-input"); - expect(amountInput).toHaveAttribute("step", "1"); - }); - }); - - describe("Error Message Display", () => { - it("should not display error message initially when modal opens", async () => { - await openAddCreditsModal(); - - const errorMessage = screen.queryByTestId("amount-error"); - expect(errorMessage).not.toBeInTheDocument(); - }); - - it("should display error message after submitting amount above maximum", async () => { - const { user } = await openAddCreditsModal(); - const amountInput = screen.getByTestId("amount-input"); - const nextButton = screen.getByRole("button", { name: /next/i }); - - await user.type(amountInput, "25001"); - await user.click(nextButton); - - await waitFor(() => { - const errorMessage = screen.getByTestId("amount-error"); - expect(errorMessage).toHaveTextContent( - "PAYMENT$ERROR_MAXIMUM_AMOUNT", - ); - }); - }); - - it("should display error message after submitting decimal value", async () => { - const { user } = await openAddCreditsModal(); - const amountInput = screen.getByTestId("amount-input"); - const nextButton = screen.getByRole("button", { name: /next/i }); - - await user.type(amountInput, "50.5"); - await user.click(nextButton); - - await waitFor(() => { - const errorMessage = screen.getByTestId("amount-error"); - expect(errorMessage).toHaveTextContent( - "PAYMENT$ERROR_MUST_BE_WHOLE_NUMBER", - ); - }); - }); - - it("should replace error message when submitting different invalid value", async () => { - const { user } = await openAddCreditsModal(); - const amountInput = screen.getByTestId("amount-input"); - const nextButton = screen.getByRole("button", { name: /next/i }); - - await user.type(amountInput, "9"); - await user.click(nextButton); - - await waitFor(() => { - const errorMessage = screen.getByTestId("amount-error"); - expect(errorMessage).toHaveTextContent( - "PAYMENT$ERROR_MINIMUM_AMOUNT", - ); - }); - - await user.clear(amountInput); - await user.type(amountInput, "25001"); - await user.click(nextButton); - - await waitFor(() => { - const errorMessage = screen.getByTestId("amount-error"); - expect(errorMessage).toHaveTextContent( - "PAYMENT$ERROR_MAXIMUM_AMOUNT", - ); - }); - }); - }); - - describe("Form Submission Behavior", () => { - it("should prevent submission when amount is invalid", async () => { - const createCheckoutSessionSpy = vi.spyOn( - BillingService, - "createCheckoutSession", - ); - const { user } = await openAddCreditsModal(); - const amountInput = screen.getByTestId("amount-input"); - const nextButton = screen.getByRole("button", { name: /next/i }); - - await user.type(amountInput, "9"); - await user.click(nextButton); - - expect(createCheckoutSessionSpy).not.toHaveBeenCalled(); - await waitFor(() => { - const errorMessage = screen.getByTestId("amount-error"); - expect(errorMessage).toHaveTextContent( - "PAYMENT$ERROR_MINIMUM_AMOUNT", - ); - }); - }); - - it("should call createCheckoutSession with correct amount when valid", async () => { - const createCheckoutSessionSpy = vi.spyOn( - BillingService, - "createCheckoutSession", - ); - const { user } = await openAddCreditsModal(); - const amountInput = screen.getByTestId("amount-input"); - const nextButton = screen.getByRole("button", { name: /next/i }); - - await user.type(amountInput, "1000"); - await user.click(nextButton); - - expect(createCheckoutSessionSpy).toHaveBeenCalledWith(1000); - const errorMessage = screen.queryByTestId("amount-error"); - expect(errorMessage).not.toBeInTheDocument(); - }); - - it("should not call createCheckoutSession when validation fails", async () => { - const createCheckoutSessionSpy = vi.spyOn( - BillingService, - "createCheckoutSession", - ); - const { user } = await openAddCreditsModal(); - const amountInput = screen.getByTestId("amount-input"); - const nextButton = screen.getByRole("button", { name: /next/i }); - - await user.type(amountInput, "-50"); - await user.click(nextButton); - - // Verify mutation was not called - expect(createCheckoutSessionSpy).not.toHaveBeenCalled(); - await waitFor(() => { - const errorMessage = screen.getByTestId("amount-error"); - expect(errorMessage).toHaveTextContent( - "PAYMENT$ERROR_NEGATIVE_AMOUNT", - ); - }); - }); - - it("should close modal on successful submission", async () => { - const createCheckoutSessionSpy = vi - .spyOn(BillingService, "createCheckoutSession") - .mockResolvedValue("https://checkout.stripe.com/test-session"); - - const { user } = await openAddCreditsModal(); - const amountInput = screen.getByTestId("amount-input"); - const nextButton = screen.getByRole("button", { name: /next/i }); - - await user.type(amountInput, "1000"); - await user.click(nextButton); - - expect(createCheckoutSessionSpy).toHaveBeenCalledWith(1000); - - await waitFor(() => { - expect( - screen.queryByTestId("add-credits-form"), - ).not.toBeInTheDocument(); - }); - }); - - it("should allow API call when validation passes and clear any previous errors", async () => { - const createCheckoutSessionSpy = vi.spyOn( - BillingService, - "createCheckoutSession", - ); - - const { user } = await openAddCreditsModal(); - const amountInput = screen.getByTestId("amount-input"); - const nextButton = screen.getByRole("button", { name: /next/i }); - - // First submit invalid value - await user.type(amountInput, "9"); - await user.click(nextButton); - - await waitFor(() => { - expect(screen.getByTestId("amount-error")).toBeInTheDocument(); - }); - - // Then submit valid value - await user.clear(amountInput); - await user.type(amountInput, "100"); - await user.click(nextButton); - - expect(createCheckoutSessionSpy).toHaveBeenCalledWith(100); - const errorMessage = screen.queryByTestId("amount-error"); - expect(errorMessage).not.toBeInTheDocument(); - }); - }); - - describe("Edge Cases", () => { - it("should handle zero value correctly", async () => { - const { user } = await openAddCreditsModal(); - const amountInput = screen.getByTestId("amount-input"); - const nextButton = screen.getByRole("button", { name: /next/i }); - - await user.type(amountInput, "0"); - await user.click(nextButton); - - await waitFor(() => { - const errorMessage = screen.getByTestId("amount-error"); - expect(errorMessage).toHaveTextContent( - "PAYMENT$ERROR_MINIMUM_AMOUNT", - ); - }); - }); - - it("should handle whitespace-only input correctly", async () => { - const createCheckoutSessionSpy = vi.spyOn( - BillingService, - "createCheckoutSession", - ); - const { user } = await openAddCreditsModal(); - const amountInput = screen.getByTestId("amount-input"); - const nextButton = screen.getByRole("button", { name: /next/i }); - - // Number inputs typically don't accept spaces, but test the behavior - await user.type(amountInput, " "); - await user.click(nextButton); - - // Should not call API (empty/invalid input) - expect(createCheckoutSessionSpy).not.toHaveBeenCalled(); - }); - }); - }); - it("should show add credits option for ADMIN role", async () => { renderManageOrg(); await screen.findByTestId("manage-org-screen"); diff --git a/frontend/src/components/features/org/add-credits-modal.tsx b/frontend/src/components/features/org/add-credits-modal.tsx new file mode 100644 index 0000000000..78ef6519b2 --- /dev/null +++ b/frontend/src/components/features/org/add-credits-modal.tsx @@ -0,0 +1,103 @@ +import React from "react"; +import { useTranslation } from "react-i18next"; +import { useCreateStripeCheckoutSession } from "#/hooks/mutation/stripe/use-create-stripe-checkout-session"; +import { ModalBackdrop } from "#/components/shared/modals/modal-backdrop"; +import { ModalButtonGroup } from "#/components/shared/modals/modal-button-group"; +import { SettingsInput } from "#/components/features/settings/settings-input"; +import { I18nKey } from "#/i18n/declaration"; +import { amountIsValid } from "#/utils/amount-is-valid"; + +interface AddCreditsModalProps { + onClose: () => void; +} + +export function AddCreditsModal({ onClose }: AddCreditsModalProps) { + const { t } = useTranslation(); + const { mutate: addBalance } = useCreateStripeCheckoutSession(); + + const [inputValue, setInputValue] = React.useState(""); + const [errorMessage, setErrorMessage] = React.useState(null); + + const getErrorMessage = (value: string): string | null => { + if (!value.trim()) return null; + + const numValue = parseInt(value, 10); + if (Number.isNaN(numValue)) { + return t(I18nKey.PAYMENT$ERROR_INVALID_NUMBER); + } + if (numValue < 0) { + return t(I18nKey.PAYMENT$ERROR_NEGATIVE_AMOUNT); + } + if (numValue < 10) { + return t(I18nKey.PAYMENT$ERROR_MINIMUM_AMOUNT); + } + if (numValue > 25000) { + return t(I18nKey.PAYMENT$ERROR_MAXIMUM_AMOUNT); + } + if (numValue !== parseFloat(value)) { + return t(I18nKey.PAYMENT$ERROR_MUST_BE_WHOLE_NUMBER); + } + return null; + }; + + const formAction = (formData: FormData) => { + const amount = formData.get("amount")?.toString(); + + if (amount?.trim()) { + if (!amountIsValid(amount)) { + const error = getErrorMessage(amount); + setErrorMessage(error || "Invalid amount"); + return; + } + + const intValue = parseInt(amount, 10); + + addBalance({ amount: intValue }, { onSuccess: onClose }); + + setErrorMessage(null); + } + }; + + const handleAmountInputChange = (value: string) => { + setInputValue(value); + setErrorMessage(null); + }; + + return ( + +
+

{t(I18nKey.ORG$ADD_CREDITS)}

+
+ handleAmountInputChange(value)} + className="w-full" + /> + {errorMessage && ( +

+ {errorMessage} +

+ )} +
+ + + +
+ ); +} diff --git a/frontend/src/routes/manage-org.tsx b/frontend/src/routes/manage-org.tsx index cff5429344..cc14274923 100644 --- a/frontend/src/routes/manage-org.tsx +++ b/frontend/src/routes/manage-org.tsx @@ -1,14 +1,9 @@ import React from "react"; import { useTranslation } from "react-i18next"; -import { useCreateStripeCheckoutSession } from "#/hooks/mutation/stripe/use-create-stripe-checkout-session"; import { useOrganization } from "#/hooks/query/use-organization"; -import { ModalBackdrop } from "#/components/shared/modals/modal-backdrop"; -import { ModalButtonGroup } from "#/components/shared/modals/modal-button-group"; -import { SettingsInput } from "#/components/features/settings/settings-input"; import { useMe } from "#/hooks/query/use-me"; import { useConfig } from "#/hooks/query/use-config"; import { I18nKey } from "#/i18n/declaration"; -import { amountIsValid } from "#/utils/amount-is-valid"; import { CreditsChip } from "#/ui/credits-chip"; import { InteractiveChip } from "#/ui/interactive-chip"; import { usePermission } from "#/hooks/organizations/use-permissions"; @@ -16,104 +11,10 @@ import { createPermissionGuard } from "#/utils/org/permission-guard"; import { isBillingHidden } from "#/utils/org/billing-visibility"; import { DeleteOrgConfirmationModal } from "#/components/features/org/delete-org-confirmation-modal"; import { ChangeOrgNameModal } from "#/components/features/org/change-org-name-modal"; +import { AddCreditsModal } from "#/components/features/org/add-credits-modal"; import { useBalance } from "#/hooks/query/use-balance"; import { cn } from "#/utils/utils"; -interface AddCreditsModalProps { - onClose: () => void; -} - -function AddCreditsModal({ onClose }: AddCreditsModalProps) { - const { t } = useTranslation(); - const { mutate: addBalance } = useCreateStripeCheckoutSession(); - - const [inputValue, setInputValue] = React.useState(""); - const [errorMessage, setErrorMessage] = React.useState(null); - - const getErrorMessage = (value: string): string | null => { - if (!value.trim()) return null; - - const numValue = parseInt(value, 10); - if (Number.isNaN(numValue)) { - return t(I18nKey.PAYMENT$ERROR_INVALID_NUMBER); - } - if (numValue < 0) { - return t(I18nKey.PAYMENT$ERROR_NEGATIVE_AMOUNT); - } - if (numValue < 10) { - return t(I18nKey.PAYMENT$ERROR_MINIMUM_AMOUNT); - } - if (numValue > 25000) { - return t(I18nKey.PAYMENT$ERROR_MAXIMUM_AMOUNT); - } - if (numValue !== parseFloat(value)) { - return t(I18nKey.PAYMENT$ERROR_MUST_BE_WHOLE_NUMBER); - } - return null; - }; - - const formAction = (formData: FormData) => { - const amount = formData.get("amount")?.toString(); - - if (amount?.trim()) { - if (!amountIsValid(amount)) { - const error = getErrorMessage(amount); - setErrorMessage(error || "Invalid amount"); - return; - } - - const intValue = parseInt(amount, 10); - - addBalance({ amount: intValue }, { onSuccess: onClose }); - - setErrorMessage(null); - } - }; - - const handleAmountInputChange = (value: string) => { - setInputValue(value); - setErrorMessage(null); - }; - - return ( - -
-

{t(I18nKey.ORG$ADD_CREDITS)}

-
- handleAmountInputChange(value)} - className="w-full" - /> - {errorMessage && ( -

- {errorMessage} -

- )} -
- - - -
- ); -} - export const clientLoader = createPermissionGuard("view_billing"); function ManageOrg() { From 38648bddb3afbbef1fc93c5631724118cbc69191 Mon Sep 17 00:00:00 2001 From: Hiep Le <69354317+hieptl@users.noreply.github.com> Date: Fri, 20 Mar 2026 00:13:02 +0700 Subject: [PATCH 19/28] fix(frontend): use correct git path based on sandbox grouping strategy (#13488) --- frontend/__tests__/utils/get-git-path.test.ts | 105 +++++++++++++++--- .../query/use-unified-get-git-changes.ts | 13 ++- .../src/hooks/query/use-unified-git-diff.ts | 21 +++- frontend/src/utils/get-git-path.ts | 20 +++- 4 files changed, 132 insertions(+), 27 deletions(-) diff --git a/frontend/__tests__/utils/get-git-path.test.ts b/frontend/__tests__/utils/get-git-path.test.ts index 2adfc232d4..a1f3512862 100644 --- a/frontend/__tests__/utils/get-git-path.test.ts +++ b/frontend/__tests__/utils/get-git-path.test.ts @@ -4,27 +4,96 @@ import { getGitPath } from "#/utils/get-git-path"; describe("getGitPath", () => { const conversationId = "abc123"; - it("should return /workspace/project/{conversationId} when no repository is selected", () => { - expect(getGitPath(conversationId, null)).toBe(`/workspace/project/${conversationId}`); - expect(getGitPath(conversationId, undefined)).toBe(`/workspace/project/${conversationId}`); + describe("without sandbox grouping (NO_GROUPING)", () => { + it("should return /workspace/project when no repository is selected", () => { + expect(getGitPath(conversationId, null, false)).toBe("/workspace/project"); + expect(getGitPath(conversationId, undefined, false)).toBe( + "/workspace/project", + ); + }); + + it("should handle standard owner/repo format (GitHub)", () => { + expect(getGitPath(conversationId, "OpenHands/OpenHands", false)).toBe( + "/workspace/project/OpenHands", + ); + expect(getGitPath(conversationId, "facebook/react", false)).toBe( + "/workspace/project/react", + ); + }); + + it("should handle nested group paths (GitLab)", () => { + expect( + getGitPath(conversationId, "modernhealth/frontend-guild/pan", false), + ).toBe("/workspace/project/pan"); + expect(getGitPath(conversationId, "group/subgroup/repo", false)).toBe( + "/workspace/project/repo", + ); + expect(getGitPath(conversationId, "a/b/c/d/repo", false)).toBe( + "/workspace/project/repo", + ); + }); + + it("should handle single segment paths", () => { + expect(getGitPath(conversationId, "repo", false)).toBe( + "/workspace/project/repo", + ); + }); + + it("should handle empty string", () => { + expect(getGitPath(conversationId, "", false)).toBe("/workspace/project"); + }); }); - it("should handle standard owner/repo format (GitHub)", () => { - expect(getGitPath(conversationId, "OpenHands/OpenHands")).toBe(`/workspace/project/${conversationId}/OpenHands`); - expect(getGitPath(conversationId, "facebook/react")).toBe(`/workspace/project/${conversationId}/react`); + describe("with sandbox grouping enabled", () => { + it("should return /workspace/project/{conversationId} when no repository is selected", () => { + expect(getGitPath(conversationId, null, true)).toBe( + `/workspace/project/${conversationId}`, + ); + expect(getGitPath(conversationId, undefined, true)).toBe( + `/workspace/project/${conversationId}`, + ); + }); + + it("should handle standard owner/repo format (GitHub)", () => { + expect(getGitPath(conversationId, "OpenHands/OpenHands", true)).toBe( + `/workspace/project/${conversationId}/OpenHands`, + ); + expect(getGitPath(conversationId, "facebook/react", true)).toBe( + `/workspace/project/${conversationId}/react`, + ); + }); + + it("should handle nested group paths (GitLab)", () => { + expect( + getGitPath(conversationId, "modernhealth/frontend-guild/pan", true), + ).toBe(`/workspace/project/${conversationId}/pan`); + expect(getGitPath(conversationId, "group/subgroup/repo", true)).toBe( + `/workspace/project/${conversationId}/repo`, + ); + expect(getGitPath(conversationId, "a/b/c/d/repo", true)).toBe( + `/workspace/project/${conversationId}/repo`, + ); + }); + + it("should handle single segment paths", () => { + expect(getGitPath(conversationId, "repo", true)).toBe( + `/workspace/project/${conversationId}/repo`, + ); + }); + + it("should handle empty string", () => { + expect(getGitPath(conversationId, "", true)).toBe( + `/workspace/project/${conversationId}`, + ); + }); }); - it("should handle nested group paths (GitLab)", () => { - expect(getGitPath(conversationId, "modernhealth/frontend-guild/pan")).toBe(`/workspace/project/${conversationId}/pan`); - expect(getGitPath(conversationId, "group/subgroup/repo")).toBe(`/workspace/project/${conversationId}/repo`); - expect(getGitPath(conversationId, "a/b/c/d/repo")).toBe(`/workspace/project/${conversationId}/repo`); - }); - - it("should handle single segment paths", () => { - expect(getGitPath(conversationId, "repo")).toBe(`/workspace/project/${conversationId}/repo`); - }); - - it("should handle empty string", () => { - expect(getGitPath(conversationId, "")).toBe(`/workspace/project/${conversationId}`); + describe("default behavior (useSandboxGrouping defaults to false)", () => { + it("should default to no sandbox grouping", () => { + expect(getGitPath(conversationId, null)).toBe("/workspace/project"); + expect(getGitPath(conversationId, "owner/repo")).toBe( + "/workspace/project/repo", + ); + }); }); }); diff --git a/frontend/src/hooks/query/use-unified-get-git-changes.ts b/frontend/src/hooks/query/use-unified-get-git-changes.ts index a1de3852f9..616665a07f 100644 --- a/frontend/src/hooks/query/use-unified-get-git-changes.ts +++ b/frontend/src/hooks/query/use-unified-get-git-changes.ts @@ -5,6 +5,7 @@ import V1GitService from "#/api/git-service/v1-git-service.api"; import { useConversationId } from "#/hooks/use-conversation-id"; import { useActiveConversation } from "#/hooks/query/use-active-conversation"; import { useRuntimeIsReady } from "#/hooks/use-runtime-is-ready"; +import { useSettings } from "#/hooks/query/use-settings"; import { getGitPath } from "#/utils/get-git-path"; import type { GitChange } from "#/api/open-hands.types"; @@ -16,6 +17,7 @@ import type { GitChange } from "#/api/open-hands.types"; export const useUnifiedGetGitChanges = () => { const { conversationId } = useConversationId(); const { data: conversation } = useActiveConversation(); + const { data: settings } = useSettings(); const [orderedChanges, setOrderedChanges] = React.useState([]); const previousDataRef = React.useRef(null); const runtimeIsReady = useRuntimeIsReady(); @@ -25,10 +27,15 @@ export const useUnifiedGetGitChanges = () => { const sessionApiKey = conversation?.session_api_key; const selectedRepository = conversation?.selected_repository; - // Calculate git path based on selected repository + // Sandbox grouping is enabled when strategy is not NO_GROUPING + const useSandboxGrouping = + settings?.sandbox_grouping_strategy !== "NO_GROUPING" && + settings?.sandbox_grouping_strategy !== undefined; + + // Calculate git path based on selected repository and sandbox grouping strategy const gitPath = React.useMemo( - () => getGitPath(conversationId, selectedRepository), - [selectedRepository], + () => getGitPath(conversationId, selectedRepository, useSandboxGrouping), + [conversationId, selectedRepository, useSandboxGrouping], ); const result = useQuery({ diff --git a/frontend/src/hooks/query/use-unified-git-diff.ts b/frontend/src/hooks/query/use-unified-git-diff.ts index 26bca16fce..8705a70c76 100644 --- a/frontend/src/hooks/query/use-unified-git-diff.ts +++ b/frontend/src/hooks/query/use-unified-git-diff.ts @@ -4,6 +4,7 @@ import GitService from "#/api/git-service/git-service.api"; import V1GitService from "#/api/git-service/v1-git-service.api"; import { useConversationId } from "#/hooks/use-conversation-id"; import { useActiveConversation } from "#/hooks/query/use-active-conversation"; +import { useSettings } from "#/hooks/query/use-settings"; import { getGitPath } from "#/utils/get-git-path"; import type { GitChangeStatus } from "#/api/open-hands.types"; @@ -21,20 +22,36 @@ type UseUnifiedGitDiffConfig = { export const useUnifiedGitDiff = (config: UseUnifiedGitDiffConfig) => { const { conversationId } = useConversationId(); const { data: conversation } = useActiveConversation(); + const { data: settings } = useSettings(); const isV1Conversation = conversation?.conversation_version === "V1"; const conversationUrl = conversation?.url; const sessionApiKey = conversation?.session_api_key; const selectedRepository = conversation?.selected_repository; + // Sandbox grouping is enabled when strategy is not NO_GROUPING + const useSandboxGrouping = + settings?.sandbox_grouping_strategy !== "NO_GROUPING" && + settings?.sandbox_grouping_strategy !== undefined; + // For V1, we need to convert the relative file path to an absolute path // The diff endpoint expects: /workspace/project/RepoName/relative/path const absoluteFilePath = React.useMemo(() => { if (!isV1Conversation) return config.filePath; - const gitPath = getGitPath(conversationId, selectedRepository); + const gitPath = getGitPath( + conversationId, + selectedRepository, + useSandboxGrouping, + ); return `${gitPath}/${config.filePath}`; - }, [isV1Conversation, selectedRepository, config.filePath]); + }, [ + isV1Conversation, + conversationId, + selectedRepository, + useSandboxGrouping, + config.filePath, + ]); return useQuery({ queryKey: [ diff --git a/frontend/src/utils/get-git-path.ts b/frontend/src/utils/get-git-path.ts index 39292b819f..e55b0bb989 100644 --- a/frontend/src/utils/get-git-path.ts +++ b/frontend/src/utils/get-git-path.ts @@ -1,17 +1,29 @@ /** * Get the git repository path for a conversation - * If a repository is selected, returns /workspace/project/{repo-name} - * Otherwise, returns /workspace/project * + * When sandbox grouping is enabled (strategy != NO_GROUPING), each conversation + * gets its own subdirectory: /workspace/project/{conversationId}[/{repoName}] + * + * When sandbox grouping is disabled (NO_GROUPING), the path is simply: + * /workspace/project[/{repoName}] + * + * @param conversationId The conversation ID * @param selectedRepository The selected repository (e.g., "OpenHands/OpenHands", "owner/repo", or "group/subgroup/repo") + * @param useSandboxGrouping Whether sandbox grouping is enabled (strategy != NO_GROUPING) * @returns The git path to use */ export function getGitPath( conversationId: string, selectedRepository: string | null | undefined, + useSandboxGrouping: boolean = false, ): string { + // Base path depends on sandbox grouping strategy + const basePath = useSandboxGrouping + ? `/workspace/project/${conversationId}` + : "/workspace/project"; + if (!selectedRepository) { - return `/workspace/project/${conversationId}`; + return basePath; } // Extract the repository name from the path @@ -19,5 +31,5 @@ export function getGitPath( const parts = selectedRepository.split("/"); const repoName = parts[parts.length - 1]; - return `/workspace/project/${conversationId}/${repoName}`; + return `${basePath}/${repoName}`; } From 49a98885aba218379f8876d30006828aa576debb Mon Sep 17 00:00:00 2001 From: aivong-openhands Date: Thu, 19 Mar 2026 14:33:23 -0500 Subject: [PATCH 20/28] chore: Update OpenSSL in Debian images for security patches (#13401) Co-authored-by: openhands --- openhands/runtime/utils/runtime_templates/Dockerfile.j2 | 3 +++ 1 file changed, 3 insertions(+) diff --git a/openhands/runtime/utils/runtime_templates/Dockerfile.j2 b/openhands/runtime/utils/runtime_templates/Dockerfile.j2 index a02229995f..78ee532fe7 100644 --- a/openhands/runtime/utils/runtime_templates/Dockerfile.j2 +++ b/openhands/runtime/utils/runtime_templates/Dockerfile.j2 @@ -46,6 +46,9 @@ RUN apt-get update && \ (apt-get install -y --no-install-recommends libgl1 || apt-get install -y --no-install-recommends libgl1-mesa-glx) && \ # Install Docker dependencies apt-get install -y --no-install-recommends apt-transport-https ca-certificates curl gnupg lsb-release && \ + # Security upgrade: patch OpenSSL CVEs (CVE-2025-15467, CVE-2025-69419, CVE-2025-69421, et al.) + (apt-get install -y --no-install-recommends --only-upgrade \ + openssl openssl-provider-legacy libssl3t64 || true) && \ # Security upgrade: patch ImageMagick CVEs (CVE-2026-25897, CVE-2026-25968, CVE-2026-26284, et al.) (apt-get install -y --no-install-recommends --only-upgrade \ imagemagick imagemagick-7-common imagemagick-7.q16 \ From 0137201903c05d748da99a36e0f68b512d24ed6b Mon Sep 17 00:00:00 2001 From: aivong-openhands Date: Thu, 19 Mar 2026 14:36:22 -0500 Subject: [PATCH 21/28] fix: remove vulnerable VSCode extensions in build_from_scratch path (#13399) Co-authored-by: openhands Co-authored-by: Ray Myers --- openhands/runtime/utils/runtime_templates/Dockerfile.j2 | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/openhands/runtime/utils/runtime_templates/Dockerfile.j2 b/openhands/runtime/utils/runtime_templates/Dockerfile.j2 index 78ee532fe7..9bf06c54b2 100644 --- a/openhands/runtime/utils/runtime_templates/Dockerfile.j2 +++ b/openhands/runtime/utils/runtime_templates/Dockerfile.j2 @@ -360,6 +360,14 @@ RUN chmod a+rwx /openhands/code/openhands/__init__.py && \ chown -R openhands:openhands /openhands/code +# ================================================================ +# Install VSCode extensions for build_from_scratch +# (must be after setup_vscode_server and source file copy) +# ================================================================ +{% if build_from_scratch %} +{{ install_vscode_extensions() }} +{% endif %} + # ================================================================ # END: Build from versioned image # ================================================================ From f706a217d0e658c088d22a7192067e506ddd79a5 Mon Sep 17 00:00:00 2001 From: Joe Laverty Date: Thu, 19 Mar 2026 16:24:07 -0400 Subject: [PATCH 22/28] fix: Use commit SHA instead of mutable branch tag for enterprise base (#13498) --- .github/workflows/ghcr-build.yml | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/.github/workflows/ghcr-build.yml b/.github/workflows/ghcr-build.yml index 86ba722cec..bd0718e651 100644 --- a/.github/workflows/ghcr-build.yml +++ b/.github/workflows/ghcr-build.yml @@ -219,11 +219,9 @@ jobs: - name: Determine app image tag shell: bash run: | - # Duplicated with build.sh - sanitized_ref_name=$(echo "$GITHUB_REF_NAME" | sed 's/[^a-zA-Z0-9.-]\+/-/g') - OPENHANDS_BUILD_VERSION=$sanitized_ref_name - sanitized_ref_name=$(echo "$sanitized_ref_name" | tr '[:upper:]' '[:lower:]') # lower case is required in tagging - echo "OPENHANDS_DOCKER_TAG=${sanitized_ref_name}" >> $GITHUB_ENV + # Use the commit SHA to pin the exact app image built by ghcr_build_app, + # rather than a mutable branch tag like "main" which can serve stale cached layers. + echo "OPENHANDS_DOCKER_TAG=${RELEVANT_SHA}" >> $GITHUB_ENV - name: Build and push Docker image uses: useblacksmith/build-push-action@v1 with: From a8f6a353416e391271f17c3be84f044fcb840c6a Mon Sep 17 00:00:00 2001 From: aivong-openhands Date: Thu, 19 Mar 2026 16:21:24 -0500 Subject: [PATCH 23/28] fix: patch GLib CVE-2025-14087 in runtime Docker images (#13403) Co-authored-by: openhands --- openhands/runtime/utils/runtime_templates/Dockerfile.j2 | 3 +++ 1 file changed, 3 insertions(+) diff --git a/openhands/runtime/utils/runtime_templates/Dockerfile.j2 b/openhands/runtime/utils/runtime_templates/Dockerfile.j2 index 9bf06c54b2..69eb841a9f 100644 --- a/openhands/runtime/utils/runtime_templates/Dockerfile.j2 +++ b/openhands/runtime/utils/runtime_templates/Dockerfile.j2 @@ -46,6 +46,9 @@ RUN apt-get update && \ (apt-get install -y --no-install-recommends libgl1 || apt-get install -y --no-install-recommends libgl1-mesa-glx) && \ # Install Docker dependencies apt-get install -y --no-install-recommends apt-transport-https ca-certificates curl gnupg lsb-release && \ + # Security upgrade: patch GLib CVE-2025-14087 (buffer underflow in GVariant parser) + (apt-get install -y --no-install-recommends --only-upgrade \ + libglib2.0-0t64 libglib2.0-bin libglib2.0-dev libglib2.0-dev-bin || true) && \ # Security upgrade: patch OpenSSL CVEs (CVE-2025-15467, CVE-2025-69419, CVE-2025-69421, et al.) (apt-get install -y --no-install-recommends --only-upgrade \ openssl openssl-provider-legacy libssl3t64 || true) && \ From e4515b21eba030cec031cbd48f3e5051327b57b3 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 19 Mar 2026 17:28:15 -0400 Subject: [PATCH 24/28] chore(deps): bump socket.io-parser from 4.2.5 to 4.2.6 in /frontend in the security-all group across 1 directory (#13474) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- frontend/package-lock.json | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/frontend/package-lock.json b/frontend/package-lock.json index cd20c8aa82..811131f3cb 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -15325,10 +15325,9 @@ } }, "node_modules/socket.io-parser": { - "version": "4.2.5", - "resolved": "https://registry.npmjs.org/socket.io-parser/-/socket.io-parser-4.2.5.tgz", - "integrity": "sha512-bPMmpy/5WWKHea5Y/jYAP6k74A+hvmRCQaJuJB6I/ML5JZq/KfNieUVo/3Mh7SAqn7TyFdIo6wqYHInG1MU1bQ==", - "license": "MIT", + "version": "4.2.6", + "resolved": "https://registry.npmjs.org/socket.io-parser/-/socket.io-parser-4.2.6.tgz", + "integrity": "sha512-asJqbVBDsBCJx0pTqw3WfesSY0iRX+2xzWEWzrpcH7L6fLzrhyF8WPI8UaeM4YCuDfpwA/cgsdugMsmtz8EJeg==", "dependencies": { "@socket.io/component-emitter": "~3.1.0", "debug": "~4.4.1" From f75141af3e94528c18536b6d7697fe0ed652f17c Mon Sep 17 00:00:00 2001 From: chuckbutkus Date: Thu, 19 Mar 2026 19:34:12 -0400 Subject: [PATCH 25/28] fix: prevent secrets deletion across organizations when storing secrets (#13500) Co-authored-by: openhands --- enterprise/storage/saas_secrets_store.py | 13 +-- .../tests/unit/test_saas_secrets_store.py | 79 +++++++++++++++++++ 2 files changed, 87 insertions(+), 5 deletions(-) diff --git a/enterprise/storage/saas_secrets_store.py b/enterprise/storage/saas_secrets_store.py index aede6df419..f4fb310556 100644 --- a/enterprise/storage/saas_secrets_store.py +++ b/enterprise/storage/saas_secrets_store.py @@ -59,12 +59,15 @@ class SaasSecretsStore(SecretsStore): async with a_session_maker() as session: # Incoming secrets are always the most updated ones - # Delete all existing records and override with incoming ones - await session.execute( - delete(StoredCustomSecrets).filter( - StoredCustomSecrets.keycloak_user_id == self.user_id - ) + # Delete existing records for this user AND organization only + delete_query = delete(StoredCustomSecrets).filter( + StoredCustomSecrets.keycloak_user_id == self.user_id ) + if org_id is not None: + delete_query = delete_query.filter(StoredCustomSecrets.org_id == org_id) + else: + delete_query = delete_query.filter(StoredCustomSecrets.org_id.is_(None)) + await session.execute(delete_query) # Prepare the new secrets data kwargs = item.model_dump(context={'expose_secrets': True}) diff --git a/enterprise/tests/unit/test_saas_secrets_store.py b/enterprise/tests/unit/test_saas_secrets_store.py index f9a560d11c..740507a973 100644 --- a/enterprise/tests/unit/test_saas_secrets_store.py +++ b/enterprise/tests/unit/test_saas_secrets_store.py @@ -246,3 +246,82 @@ class TestSaasSecretsStore: assert isinstance(store, SaasSecretsStore) assert store.user_id == 'test-user-id' assert store.config == mock_config + + @pytest.mark.asyncio + @patch( + 'storage.saas_secrets_store.UserStore.get_user_by_id', + new_callable=AsyncMock, + ) + async def test_secrets_isolation_between_organizations( + self, mock_get_user, secrets_store, mock_user + ): + """Test that secrets from one organization are not deleted when storing + secrets in another organization. This reproduces a bug where switching + organizations and creating a secret would delete all secrets from the + user's personal workspace.""" + org1_id = UUID('a1111111-1111-1111-1111-111111111111') + org2_id = UUID('b2222222-2222-2222-2222-222222222222') + + # Store secrets in org1 (personal workspace) + mock_user.current_org_id = org1_id + mock_get_user.return_value = mock_user + org1_secrets = Secrets( + custom_secrets=MappingProxyType( + { + 'personal_secret': CustomSecret.from_value( + { + 'secret': 'personal_secret_value', + 'description': 'My personal secret', + } + ), + } + ) + ) + await secrets_store.store(org1_secrets) + + # Verify org1 secrets are stored + loaded_org1 = await secrets_store.load() + assert loaded_org1 is not None + assert 'personal_secret' in loaded_org1.custom_secrets + assert ( + loaded_org1.custom_secrets['personal_secret'].secret.get_secret_value() + == 'personal_secret_value' + ) + + # Switch to org2 and store secrets there + mock_user.current_org_id = org2_id + mock_get_user.return_value = mock_user + org2_secrets = Secrets( + custom_secrets=MappingProxyType( + { + 'org2_secret': CustomSecret.from_value( + {'secret': 'org2_secret_value', 'description': 'Org2 secret'} + ), + } + ) + ) + await secrets_store.store(org2_secrets) + + # Verify org2 secrets are stored + loaded_org2 = await secrets_store.load() + assert loaded_org2 is not None + assert 'org2_secret' in loaded_org2.custom_secrets + assert ( + loaded_org2.custom_secrets['org2_secret'].secret.get_secret_value() + == 'org2_secret_value' + ) + + # Switch back to org1 and verify secrets are still there + mock_user.current_org_id = org1_id + mock_get_user.return_value = mock_user + loaded_org1_again = await secrets_store.load() + assert loaded_org1_again is not None + assert 'personal_secret' in loaded_org1_again.custom_secrets + assert ( + loaded_org1_again.custom_secrets[ + 'personal_secret' + ].secret.get_secret_value() + == 'personal_secret_value' + ) + # Verify org2 secrets are NOT visible in org1 + assert 'org2_secret' not in loaded_org1_again.custom_secrets From 63956c329270e376c37755c0cc326558633db1cc Mon Sep 17 00:00:00 2001 From: Rohit Malhotra Date: Thu, 19 Mar 2026 20:27:10 -0400 Subject: [PATCH 26/28] Fix FastAPI Query parameter validation: lte -> le (#13502) Co-authored-by: openhands --- enterprise/server/routes/orgs.py | 4 +- .../sharing/shared_conversation_router.py | 12 +- .../server/sharing/shared_event_router.py | 12 +- .../app_conversation_router.py | 14 +- openhands/app_server/event/event_router.py | 12 +- .../app_server/sandbox/sandbox_router.py | 10 +- .../app_server/sandbox/sandbox_spec_router.py | 12 +- tests/unit/app_server/test_event_router.py | 200 ++++++++++++++++++ 8 files changed, 243 insertions(+), 33 deletions(-) create mode 100644 tests/unit/app_server/test_event_router.py diff --git a/enterprise/server/routes/orgs.py b/enterprise/server/routes/orgs.py index a39f959864..3a49f23d70 100644 --- a/enterprise/server/routes/orgs.py +++ b/enterprise/server/routes/orgs.py @@ -68,7 +68,7 @@ async def list_user_orgs( ] = None, limit: Annotated[ int, - Query(title='The max number of results in the page', gt=0, lte=100), + Query(title='The max number of results in the page', gt=0, le=100), ] = 100, user_id: str = Depends(get_user_id), ) -> OrgPage: @@ -734,7 +734,7 @@ async def get_org_members( Query( title='The max number of results in the page', gt=0, - lte=100, + le=100, ), ] = 10, email: Annotated[ diff --git a/enterprise/server/sharing/shared_conversation_router.py b/enterprise/server/sharing/shared_conversation_router.py index 26fe047e6d..529dca3914 100644 --- a/enterprise/server/sharing/shared_conversation_router.py +++ b/enterprise/server/sharing/shared_conversation_router.py @@ -4,7 +4,7 @@ from datetime import datetime from typing import Annotated from uuid import UUID -from fastapi import APIRouter, Depends, Query +from fastapi import APIRouter, Depends, HTTPException, Query from server.sharing.shared_conversation_info_service import ( SharedConversationInfoService, ) @@ -60,7 +60,7 @@ async def search_shared_conversations( Query( title='The max number of results in the page', gt=0, - lte=100, + le=100, ), ] = 100, include_sub_conversations: Annotated[ @@ -72,8 +72,6 @@ async def search_shared_conversations( shared_conversation_service: SharedConversationInfoService = shared_conversation_info_service_dependency, ) -> SharedConversationPage: """Search / List shared conversations.""" - assert limit > 0 - assert limit <= 100 return await shared_conversation_service.search_shared_conversation_info( title__contains=title__contains, created_at__gte=created_at__gte, @@ -127,7 +125,11 @@ async def batch_get_shared_conversations( shared_conversation_service: SharedConversationInfoService = shared_conversation_info_service_dependency, ) -> list[SharedConversation | None]: """Get a batch of shared conversations given their ids. Return None for any missing or non-shared.""" - assert len(ids) <= 100 + if len(ids) > 100: + raise HTTPException( + status_code=400, + detail=f'Cannot request more than 100 conversations at once, got {len(ids)}', + ) uuids = [UUID(id_) for id_ in ids] shared_conversation_info = ( await shared_conversation_service.batch_get_shared_conversation_info(uuids) diff --git a/enterprise/server/sharing/shared_event_router.py b/enterprise/server/sharing/shared_event_router.py index 1f42d1d32e..2d8b500126 100644 --- a/enterprise/server/sharing/shared_event_router.py +++ b/enterprise/server/sharing/shared_event_router.py @@ -4,7 +4,7 @@ from datetime import datetime from typing import Annotated from uuid import UUID -from fastapi import APIRouter, Depends, Query +from fastapi import APIRouter, Depends, HTTPException, Query from server.sharing.shared_event_service import ( SharedEventService, SharedEventServiceInjector, @@ -77,13 +77,11 @@ async def search_shared_events( ] = None, limit: Annotated[ int, - Query(title='The max number of results in the page', gt=0, lte=100), + Query(title='The max number of results in the page', gt=0, le=100), ] = 100, shared_event_service: SharedEventService = shared_event_service_dependency, ) -> EventPage: """Search / List events for a shared conversation.""" - assert limit > 0 - assert limit <= 100 return await shared_event_service.search_shared_events( conversation_id=UUID(conversation_id), kind__eq=kind__eq, @@ -134,7 +132,11 @@ async def batch_get_shared_events( shared_event_service: SharedEventService = shared_event_service_dependency, ) -> list[Event | None]: """Get a batch of events for a shared conversation given their ids, returning null for any missing event.""" - assert len(id) <= 100 + if len(id) > 100: + raise HTTPException( + status_code=400, + detail=f'Cannot request more than 100 events at once, got {len(id)}', + ) event_ids = [UUID(id_) for id_ in id] events = await shared_event_service.batch_get_shared_events( UUID(conversation_id), event_ids diff --git a/openhands/app_server/app_conversation/app_conversation_router.py b/openhands/app_server/app_conversation/app_conversation_router.py index 582de93761..02fb97986d 100644 --- a/openhands/app_server/app_conversation/app_conversation_router.py +++ b/openhands/app_server/app_conversation/app_conversation_router.py @@ -234,7 +234,7 @@ async def search_app_conversations( Query( title='The max number of results in the page', gt=0, - lte=100, + le=100, ), ] = 100, include_sub_conversations: Annotated[ @@ -248,8 +248,6 @@ async def search_app_conversations( ), ) -> AppConversationPage: """Search / List sandboxed conversations.""" - assert limit > 0 - assert limit <= 100 return await app_conversation_service.search_app_conversations( title__contains=title__contains, created_at__gte=created_at__gte, @@ -422,7 +420,7 @@ async def search_app_conversation_start_tasks( Query( title='The max number of results in the page', gt=0, - lte=100, + le=100, ), ] = 100, app_conversation_start_task_service: AppConversationStartTaskService = ( @@ -430,8 +428,6 @@ async def search_app_conversation_start_tasks( ), ) -> AppConversationStartTaskPage: """Search / List conversation start tasks.""" - assert limit > 0 - assert limit <= 100 return ( await app_conversation_start_task_service.search_app_conversation_start_tasks( conversation_id__eq=conversation_id__eq, @@ -472,7 +468,11 @@ async def batch_get_app_conversation_start_tasks( ), ) -> list[AppConversationStartTask | None]: """Get a batch of start app conversation tasks given their ids. Return None for any missing.""" - assert len(ids) < 100 + if len(ids) > 100: + raise HTTPException( + status_code=400, + detail=f'Cannot request more than 100 start tasks at once, got {len(ids)}', + ) start_tasks = await app_conversation_start_task_service.batch_get_app_conversation_start_tasks( ids ) diff --git a/openhands/app_server/event/event_router.py b/openhands/app_server/event/event_router.py index 522a53c273..ae525d3e04 100644 --- a/openhands/app_server/event/event_router.py +++ b/openhands/app_server/event/event_router.py @@ -4,7 +4,7 @@ from datetime import datetime from typing import Annotated from uuid import UUID -from fastapi import APIRouter, Query +from fastapi import APIRouter, HTTPException, Query from openhands.agent_server.models import EventPage, EventSortOrder from openhands.app_server.config import depends_event_service @@ -51,13 +51,11 @@ async def search_events( ] = None, limit: Annotated[ int, - Query(title='The max number of results in the page', gt=0, lte=100), + Query(title='The max number of results in the page', gt=0, le=100), ] = 100, event_service: EventService = event_service_dependency, ) -> EventPage: """Search / List events.""" - assert limit > 0 - assert limit <= 100 return await event_service.search_events( conversation_id=UUID(conversation_id), kind__eq=kind__eq, @@ -102,7 +100,11 @@ async def batch_get_events( event_service: EventService = event_service_dependency, ) -> list[Event | None]: """Get a batch of events given their ids, returning null for any missing event.""" + if len(id) > 100: + raise HTTPException( + status_code=400, + detail=f'Cannot request more than 100 events at once, got {len(id)}', + ) event_ids = [UUID(id_) for id_ in id] - assert len(id) <= 100 events = await event_service.batch_get_events(UUID(conversation_id), event_ids) return events diff --git a/openhands/app_server/sandbox/sandbox_router.py b/openhands/app_server/sandbox/sandbox_router.py index 7b2575c3e7..54dedecfbb 100644 --- a/openhands/app_server/sandbox/sandbox_router.py +++ b/openhands/app_server/sandbox/sandbox_router.py @@ -44,13 +44,11 @@ async def search_sandboxes( ] = None, limit: Annotated[ int, - Query(title='The max number of results in the page', gt=0, lte=100), + Query(title='The max number of results in the page', gt=0, le=100), ] = 100, sandbox_service: SandboxService = sandbox_service_dependency, ) -> SandboxPage: """Search / list sandboxes owned by the current user.""" - assert limit > 0 - assert limit <= 100 return await sandbox_service.search_sandboxes(page_id=page_id, limit=limit) @@ -60,7 +58,11 @@ async def batch_get_sandboxes( sandbox_service: SandboxService = sandbox_service_dependency, ) -> list[SandboxInfo | None]: """Get a batch of sandboxes given their ids, returning null for any missing.""" - assert len(id) < 100 + if len(id) > 100: + raise HTTPException( + status_code=400, + detail=f'Cannot request more than 100 sandboxes at once, got {len(id)}', + ) sandboxes = await sandbox_service.batch_get_sandboxes(id) return sandboxes diff --git a/openhands/app_server/sandbox/sandbox_spec_router.py b/openhands/app_server/sandbox/sandbox_spec_router.py index 6da3353f39..1a1f98bb96 100644 --- a/openhands/app_server/sandbox/sandbox_spec_router.py +++ b/openhands/app_server/sandbox/sandbox_spec_router.py @@ -2,7 +2,7 @@ from typing import Annotated -from fastapi import APIRouter, Query +from fastapi import APIRouter, HTTPException, Query from openhands.app_server.config import depends_sandbox_spec_service from openhands.app_server.sandbox.sandbox_spec_models import ( @@ -35,13 +35,11 @@ async def search_sandbox_specs( ] = None, limit: Annotated[ int, - Query(title='The max number of results in the page', gt=0, lte=100), + Query(title='The max number of results in the page', gt=0, le=100), ] = 100, sandbox_spec_service: SandboxSpecService = sandbox_spec_service_dependency, ) -> SandboxSpecInfoPage: """Search / List sandbox specs.""" - assert limit > 0 - assert limit <= 100 return await sandbox_spec_service.search_sandbox_specs(page_id=page_id, limit=limit) @@ -51,6 +49,10 @@ async def batch_get_sandbox_specs( sandbox_spec_service: SandboxSpecService = sandbox_spec_service_dependency, ) -> list[SandboxSpecInfo | None]: """Get a batch of sandbox specs given their ids, returning null for any missing.""" - assert len(id) <= 100 + if len(id) > 100: + raise HTTPException( + status_code=400, + detail=f'Cannot request more than 100 sandbox specs at once, got {len(id)}', + ) sandbox_specs = await sandbox_spec_service.batch_get_sandbox_specs(id) return sandbox_specs diff --git a/tests/unit/app_server/test_event_router.py b/tests/unit/app_server/test_event_router.py new file mode 100644 index 0000000000..a3b94dd0ff --- /dev/null +++ b/tests/unit/app_server/test_event_router.py @@ -0,0 +1,200 @@ +"""Unit tests for the event_router endpoints. + +This module tests the event router endpoints, +focusing on limit validation and error handling. +""" + +from unittest.mock import AsyncMock, MagicMock +from uuid import uuid4 + +import pytest +from fastapi import FastAPI, HTTPException, status +from fastapi.testclient import TestClient + +from openhands.app_server.event.event_router import batch_get_events, router +from openhands.server.dependencies import check_session_api_key + + +def _make_mock_event_service(search_return=None, batch_get_return=None): + """Create a mock EventService for testing.""" + service = MagicMock() + service.search_events = AsyncMock(return_value=search_return) + service.batch_get_events = AsyncMock(return_value=batch_get_return or []) + return service + + +@pytest.fixture +def test_client(): + """Create a test client with the actual event router and mocked dependencies. + + We override check_session_api_key to bypass auth checks. + This allows us to test the actual Query parameter validation in the router. + """ + app = FastAPI() + app.include_router(router) + + # Override the auth dependency to always pass + app.dependency_overrides[check_session_api_key] = lambda: None + + client = TestClient(app, raise_server_exceptions=False) + yield client + + # Clean up + app.dependency_overrides.clear() + + +class TestSearchEventsValidation: + """Test suite for search_events endpoint limit validation via FastAPI.""" + + def test_returns_422_for_limit_exceeding_100(self, test_client): + """Test that limit > 100 returns 422 Unprocessable Entity. + + FastAPI's Query validation (le=100) should reject limit=200. + """ + conversation_id = str(uuid4()) + + response = test_client.get( + f'/conversation/{conversation_id}/events/search', + params={'limit': 200}, + ) + + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + # Verify the error message mentions the constraint + error_detail = response.json()['detail'] + assert any( + 'less than or equal to 100' in str(err).lower() or 'le' in str(err).lower() + for err in error_detail + ) + + def test_returns_422_for_limit_zero(self, test_client): + """Test that limit=0 returns 422 Unprocessable Entity. + + FastAPI's Query validation (gt=0) should reject limit=0. + """ + conversation_id = str(uuid4()) + + response = test_client.get( + f'/conversation/{conversation_id}/events/search', + params={'limit': 0}, + ) + + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + def test_returns_422_for_negative_limit(self, test_client): + """Test that negative limit returns 422 Unprocessable Entity. + + FastAPI's Query validation (gt=0) should reject limit=-1. + """ + conversation_id = str(uuid4()) + + response = test_client.get( + f'/conversation/{conversation_id}/events/search', + params={'limit': -1}, + ) + + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + def test_accepts_valid_limit_100(self, test_client): + """Test that limit=100 is accepted (boundary case). + + Verify that limit=100 passes FastAPI validation and doesn't return 422. + """ + conversation_id = str(uuid4()) + + response = test_client.get( + f'/conversation/{conversation_id}/events/search', + params={'limit': 100}, + ) + + # Should pass validation (not 422) - may fail on other errors like missing service + assert response.status_code != status.HTTP_422_UNPROCESSABLE_ENTITY + + def test_accepts_valid_limit_1(self, test_client): + """Test that limit=1 is accepted (boundary case). + + Verify that limit=1 passes FastAPI validation and doesn't return 422. + """ + conversation_id = str(uuid4()) + + response = test_client.get( + f'/conversation/{conversation_id}/events/search', + params={'limit': 1}, + ) + + # Should pass validation (not 422) - may fail on other errors like missing service + assert response.status_code != status.HTTP_422_UNPROCESSABLE_ENTITY + + +@pytest.mark.asyncio +class TestBatchGetEvents: + """Test suite for batch_get_events endpoint.""" + + async def test_returns_400_for_more_than_100_ids(self): + """Test that requesting more than 100 IDs returns 400 Bad Request. + + Arrange: Create list with 101 IDs + Act: Call batch_get_events + Assert: HTTPException is raised with 400 status + """ + # Arrange + conversation_id = str(uuid4()) + ids = [str(uuid4()) for _ in range(101)] + mock_service = _make_mock_event_service() + + # Act & Assert + with pytest.raises(HTTPException) as exc_info: + await batch_get_events( + conversation_id=conversation_id, + id=ids, + event_service=mock_service, + ) + + assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST + assert 'Cannot request more than 100 events' in exc_info.value.detail + assert '101' in exc_info.value.detail + + async def test_accepts_exactly_100_ids(self): + """Test that exactly 100 IDs is accepted. + + Arrange: Create list with 100 IDs + Act: Call batch_get_events + Assert: No exception is raised and service is called + """ + # Arrange + conversation_id = str(uuid4()) + ids = [str(uuid4()) for _ in range(100)] + mock_return = [None] * 100 + mock_service = _make_mock_event_service(batch_get_return=mock_return) + + # Act + result = await batch_get_events( + conversation_id=conversation_id, + id=ids, + event_service=mock_service, + ) + + # Assert + assert result == mock_return + mock_service.batch_get_events.assert_called_once() + + async def test_accepts_empty_list(self): + """Test that empty list of IDs is accepted. + + Arrange: Create empty list of IDs + Act: Call batch_get_events + Assert: No exception is raised + """ + # Arrange + conversation_id = str(uuid4()) + mock_service = _make_mock_event_service(batch_get_return=[]) + + # Act + result = await batch_get_events( + conversation_id=conversation_id, + id=[], + event_service=mock_service, + ) + + # Assert + assert result == [] + mock_service.batch_get_events.assert_called_once() From a75b576f1cf087650227c8d4df805f9f45e6758e Mon Sep 17 00:00:00 2001 From: Abi Date: Fri, 20 Mar 2026 15:44:15 +0530 Subject: [PATCH 27/28] fix: treat llm_base_url="" as explicit clear in store_llm_settings (#13471) Co-authored-by: Claude Sonnet 4.6 --- openhands/server/routes/settings.py | 10 ++++--- .../routes/test_settings_store_functions.py | 29 +++++++++++++++++-- 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/openhands/server/routes/settings.py b/openhands/server/routes/settings.py index 62944d11ce..4affad1d61 100644 --- a/openhands/server/routes/settings.py +++ b/openhands/server/routes/settings.py @@ -123,10 +123,9 @@ async def store_llm_settings( settings.llm_api_key = existing_settings.llm_api_key if settings.llm_model is None: settings.llm_model = existing_settings.llm_model - # if llm_base_url is missing or empty, try to preserve existing or determine appropriate URL - if not settings.llm_base_url: - if settings.llm_base_url is None and existing_settings.llm_base_url: - # Not provided at all (e.g. MCP config save) - preserve existing + if settings.llm_base_url is None: + # Not provided at all (e.g. MCP config save) - preserve existing or auto-detect + if existing_settings.llm_base_url: settings.llm_base_url = existing_settings.llm_base_url elif is_openhands_model(settings.llm_model): # OpenHands models use the LiteLLM proxy @@ -145,6 +144,9 @@ async def store_llm_settings( logger.error( f'Failed to get api_base from litellm for model {settings.llm_model}: {e}' ) + elif settings.llm_base_url == '': + # Explicitly cleared by the user (basic view save or advanced view clear) + settings.llm_base_url = None # Keep search API key if missing or empty if not settings.search_api_key: settings.search_api_key = existing_settings.search_api_key diff --git a/tests/unit/server/routes/test_settings_store_functions.py b/tests/unit/server/routes/test_settings_store_functions.py index f51a5b506a..48bc79a280 100644 --- a/tests/unit/server/routes/test_settings_store_functions.py +++ b/tests/unit/server/routes/test_settings_store_functions.py @@ -211,9 +211,32 @@ async def test_store_llm_settings_partial_update(): assert result.llm_model == 'gpt-4' # For SecretStr objects, we need to compare the secret value assert result.llm_api_key.get_secret_value() == 'existing-api-key' - # llm_base_url was explicitly cleared (""), so auto-detection runs - # OpenAI models: litellm.get_api_base() returns https://api.openai.com - assert result.llm_base_url == 'https://api.openai.com' + # llm_base_url="" is an explicit clear — must not be repopulated via auto-detection + assert result.llm_base_url is None + + +@pytest.mark.asyncio +async def test_store_llm_settings_advanced_view_clear_removes_base_url(): + """Regression test for #13420: clearing Base URL in Advanced view must persist. + + Before the fix, llm_base_url="" was treated identically to llm_base_url=None, + causing the backend to re-run auto-detection and overwrite the user's intent. + """ + settings = Settings( + llm_model='gpt-4', + llm_base_url='', # User deleted the field in Advanced view + ) + + existing_settings = Settings( + llm_model='gpt-4', + llm_api_key=SecretStr('my-api-key'), + llm_base_url='https://my-custom-proxy.example.com', + ) + + result = await store_llm_settings(settings, existing_settings) + + # The old custom URL must not come back + assert result.llm_base_url is None @pytest.mark.asyncio From fb776ef6509bd9015403f42ebb4c2fe9aca30009 Mon Sep 17 00:00:00 2001 From: Vasco Schiavo <115561717+VascoSch92@users.noreply.github.com> Date: Fri, 20 Mar 2026 12:20:25 +0100 Subject: [PATCH 28/28] feat(frontend): Add copy button to code blocks (#13458) Co-authored-by: openhands --- .../buttons/copyable-content-wrapper.test.tsx | 60 +++++++++++++++++++ .../features/markdown/code.test.tsx | 37 ++++++++++++ .../src/components/features/markdown/code.tsx | 46 +++++++------- .../buttons/copyable-content-wrapper.tsx | 44 ++++++++++++++ 4 files changed, 167 insertions(+), 20 deletions(-) create mode 100644 frontend/__tests__/components/buttons/copyable-content-wrapper.test.tsx create mode 100644 frontend/__tests__/components/features/markdown/code.test.tsx create mode 100644 frontend/src/components/shared/buttons/copyable-content-wrapper.tsx diff --git a/frontend/__tests__/components/buttons/copyable-content-wrapper.test.tsx b/frontend/__tests__/components/buttons/copyable-content-wrapper.test.tsx new file mode 100644 index 0000000000..7c7aaee48b --- /dev/null +++ b/frontend/__tests__/components/buttons/copyable-content-wrapper.test.tsx @@ -0,0 +1,60 @@ +import { render, screen, waitFor } from "@testing-library/react"; +import userEvent from "@testing-library/user-event"; +import { describe, it, expect } from "vitest"; +import { CopyableContentWrapper } from "#/components/shared/buttons/copyable-content-wrapper"; + +describe("CopyableContentWrapper", () => { + it("should hide the copy button by default", () => { + render( + +

content

+
, + ); + + expect(screen.getByTestId("copy-to-clipboard")).not.toBeVisible(); + }); + + it("should show the copy button on hover", async () => { + const user = userEvent.setup(); + render( + +

content

+
, + ); + + await user.hover(screen.getByText("content")); + + expect(screen.getByTestId("copy-to-clipboard")).toBeVisible(); + }); + + it("should copy text to clipboard on click", async () => { + const user = userEvent.setup(); + render( + +

content

+
, + ); + + await user.click(screen.getByTestId("copy-to-clipboard")); + + await waitFor(() => + expect(navigator.clipboard.readText()).resolves.toBe("copy me"), + ); + }); + + it("should show copied state after clicking", async () => { + const user = userEvent.setup(); + render( + +

content

+
, + ); + + await user.click(screen.getByTestId("copy-to-clipboard")); + + expect(screen.getByTestId("copy-to-clipboard")).toHaveAttribute( + "aria-label", + "BUTTON$COPIED", + ); + }); +}); diff --git a/frontend/__tests__/components/features/markdown/code.test.tsx b/frontend/__tests__/components/features/markdown/code.test.tsx new file mode 100644 index 0000000000..c5ba1562af --- /dev/null +++ b/frontend/__tests__/components/features/markdown/code.test.tsx @@ -0,0 +1,37 @@ +import { render, screen, waitFor } from "@testing-library/react"; +import userEvent from "@testing-library/user-event"; +import { describe, it, expect } from "vitest"; +import { code as Code } from "#/components/features/markdown/code"; + +describe("code (markdown)", () => { + it("should render inline code without a copy button", () => { + render(inline snippet); + + expect(screen.getByText("inline snippet")).toBeInTheDocument(); + expect(screen.queryByTestId("copy-to-clipboard")).not.toBeInTheDocument(); + }); + + it("should render a multiline code block with a copy button", () => { + render({"line1\nline2"}); + + expect(screen.getByText("line1 line2")).toBeInTheDocument(); + expect(screen.getByTestId("copy-to-clipboard")).toBeInTheDocument(); + }); + + it("should render a syntax-highlighted block with a copy button", () => { + render({"console.log('hi')"}); + + expect(screen.getByTestId("copy-to-clipboard")).toBeInTheDocument(); + }); + + it("should copy code block content to clipboard", async () => { + const user = userEvent.setup(); + render({"line1\nline2"}); + + await user.click(screen.getByTestId("copy-to-clipboard")); + + await waitFor(() => + expect(navigator.clipboard.readText()).resolves.toBe("line1\nline2"), + ); + }); +}); diff --git a/frontend/src/components/features/markdown/code.tsx b/frontend/src/components/features/markdown/code.tsx index 2a801f6848..ee04ce53b5 100644 --- a/frontend/src/components/features/markdown/code.tsx +++ b/frontend/src/components/features/markdown/code.tsx @@ -2,6 +2,7 @@ import React from "react"; import { ExtraProps } from "react-markdown"; import { Prism as SyntaxHighlighter } from "react-syntax-highlighter"; import { vscDarkPlus } from "react-syntax-highlighter/dist/esm/styles/prism"; +import { CopyableContentWrapper } from "#/components/shared/buttons/copyable-content-wrapper"; // See https://github.com/remarkjs/react-markdown?tab=readme-ov-file#use-custom-components-syntax-highlight @@ -15,6 +16,7 @@ export function code({ React.HTMLAttributes & ExtraProps) { const match = /language-(\w+)/.exec(className || ""); // get the language + const codeString = String(children).replace(/\n$/, ""); if (!match) { const isMultiline = String(children).includes("\n"); @@ -37,29 +39,33 @@ export function code({ } return ( -
-        {String(children).replace(/\n$/, "")}
-      
+ +
+          {codeString}
+        
+
); } return ( - - {String(children).replace(/\n$/, "")} - + + + {codeString} + + ); } diff --git a/frontend/src/components/shared/buttons/copyable-content-wrapper.tsx b/frontend/src/components/shared/buttons/copyable-content-wrapper.tsx new file mode 100644 index 0000000000..fe9a4d837a --- /dev/null +++ b/frontend/src/components/shared/buttons/copyable-content-wrapper.tsx @@ -0,0 +1,44 @@ +import React from "react"; +import { CopyToClipboardButton } from "./copy-to-clipboard-button"; + +export function CopyableContentWrapper({ + text, + children, +}: { + text: string; + children: React.ReactNode; +}) { + const [isHovering, setIsHovering] = React.useState(false); + const [isCopied, setIsCopied] = React.useState(false); + + const handleCopy = async () => { + await navigator.clipboard.writeText(text); + setIsCopied(true); + }; + + React.useEffect(() => { + let timeout: NodeJS.Timeout; + if (isCopied) { + timeout = setTimeout(() => setIsCopied(false), 2000); + } + return () => clearTimeout(timeout); + }, [isCopied]); + + return ( +
setIsHovering(true)} + onMouseLeave={() => setIsHovering(false)} + > +
+ +
+ {children} +
+ ); +}