mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 05:37:20 +08:00
Merge branch 'main' into APP-972/lead-capture-form
This commit is contained in:
8
.github/workflows/ghcr-build.yml
vendored
8
.github/workflows/ghcr-build.yml
vendored
@@ -219,11 +219,9 @@ jobs:
|
||||
- name: Determine app image tag
|
||||
shell: bash
|
||||
run: |
|
||||
# Duplicated with build.sh
|
||||
sanitized_ref_name=$(echo "$GITHUB_REF_NAME" | sed 's/[^a-zA-Z0-9.-]\+/-/g')
|
||||
OPENHANDS_BUILD_VERSION=$sanitized_ref_name
|
||||
sanitized_ref_name=$(echo "$sanitized_ref_name" | tr '[:upper:]' '[:lower:]') # lower case is required in tagging
|
||||
echo "OPENHANDS_DOCKER_TAG=${sanitized_ref_name}" >> $GITHUB_ENV
|
||||
# Use the commit SHA to pin the exact app image built by ghcr_build_app,
|
||||
# rather than a mutable branch tag like "main" which can serve stale cached layers.
|
||||
echo "OPENHANDS_DOCKER_TAG=${RELEVANT_SHA}" >> $GITHUB_ENV
|
||||
- name: Build and push Docker image
|
||||
uses: useblacksmith/build-push-action@v1
|
||||
with:
|
||||
|
||||
12
enterprise/poetry.lock
generated
12
enterprise/poetry.lock
generated
@@ -7597,14 +7597,14 @@ wrappers-encryption = ["cryptography (>=45.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "pyasn1"
|
||||
version = "0.6.2"
|
||||
version = "0.6.3"
|
||||
description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "pyasn1-0.6.2-py3-none-any.whl", hash = "sha256:1eb26d860996a18e9b6ed05e7aae0e9fc21619fcee6af91cca9bad4fbea224bf"},
|
||||
{file = "pyasn1-0.6.2.tar.gz", hash = "sha256:9b59a2b25ba7e4f8197db7686c09fb33e658b98339fadb826e9512629017833b"},
|
||||
{file = "pyasn1-0.6.3-py3-none-any.whl", hash = "sha256:a80184d120f0864a52a073acc6fc642847d0be408e7c7252f31390c0f4eadcde"},
|
||||
{file = "pyasn1-0.6.3.tar.gz", hash = "sha256:697a8ecd6d98891189184ca1fa05d1bb00e2f84b5977c481452050549c8a72cf"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -11587,14 +11587,14 @@ diagrams = ["jinja2", "railroad-diagrams"]
|
||||
|
||||
[[package]]
|
||||
name = "pypdf"
|
||||
version = "6.8.0"
|
||||
version = "6.9.1"
|
||||
description = "A pure-python PDF library capable of splitting, merging, cropping, and transforming PDF files"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "pypdf-6.8.0-py3-none-any.whl", hash = "sha256:2a025080a8dd73f48123c89c57174a5ff3806c71763ee4e49572dc90454943c7"},
|
||||
{file = "pypdf-6.8.0.tar.gz", hash = "sha256:cb7eaeaa4133ce76f762184069a854e03f4d9a08568f0e0623f7ea810407833b"},
|
||||
{file = "pypdf-6.9.1-py3-none-any.whl", hash = "sha256:f35a6a022348fae47e092a908339a8f3dc993510c026bb39a96718fc7185e89f"},
|
||||
{file = "pypdf-6.9.1.tar.gz", hash = "sha256:ae052407d33d34de0c86c5c729be6d51010bf36e03035a8f23ab449bca52377d"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
|
||||
@@ -46,6 +46,7 @@ from server.routes.org_invitations import ( # noqa: E402
|
||||
)
|
||||
from server.routes.orgs import org_router # noqa: E402
|
||||
from server.routes.readiness import readiness_router # noqa: E402
|
||||
from server.routes.service import service_router # noqa: E402
|
||||
from server.routes.user import saas_user_router # noqa: E402
|
||||
from server.routes.user_app_settings import user_app_settings_router # noqa: E402
|
||||
from server.sharing.shared_conversation_router import ( # noqa: E402
|
||||
@@ -112,6 +113,7 @@ if GITLAB_APP_CLIENT_ID:
|
||||
base_app.include_router(gitlab_integration_router)
|
||||
|
||||
base_app.include_router(api_keys_router) # Add routes for API key management
|
||||
base_app.include_router(service_router) # Add routes for internal service API
|
||||
base_app.include_router(org_router) # Add routes for organization management
|
||||
base_app.include_router(
|
||||
verified_models_router
|
||||
|
||||
@@ -35,7 +35,7 @@ Usage:
|
||||
from enum import Enum
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi import Depends, HTTPException, Request, status
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
from storage.role import Role
|
||||
from storage.role_store import RoleStore
|
||||
@@ -214,6 +214,19 @@ def has_permission(user_role: Role, permission: Permission) -> bool:
|
||||
return permission in permissions
|
||||
|
||||
|
||||
async def get_api_key_org_id_from_request(request: Request) -> UUID | None:
|
||||
"""Get the org_id bound to the API key used for authentication.
|
||||
|
||||
Returns None if:
|
||||
- Not authenticated via API key (cookie auth)
|
||||
- API key is a legacy key without org binding
|
||||
"""
|
||||
user_auth = getattr(request.state, 'user_auth', None)
|
||||
if user_auth and hasattr(user_auth, 'get_api_key_org_id'):
|
||||
return user_auth.get_api_key_org_id()
|
||||
return None
|
||||
|
||||
|
||||
def require_permission(permission: Permission):
|
||||
"""
|
||||
Factory function that creates a dependency to require a specific permission.
|
||||
@@ -221,8 +234,9 @@ def require_permission(permission: Permission):
|
||||
This creates a FastAPI dependency that:
|
||||
1. Extracts org_id from the path parameter
|
||||
2. Gets the authenticated user_id
|
||||
3. Checks if the user has the required permission in the organization
|
||||
4. Returns the user_id if authorized, raises HTTPException otherwise
|
||||
3. Validates API key org binding (if using API key auth)
|
||||
4. Checks if the user has the required permission in the organization
|
||||
5. Returns the user_id if authorized, raises HTTPException otherwise
|
||||
|
||||
Usage:
|
||||
@router.get('/{org_id}/settings')
|
||||
@@ -240,6 +254,7 @@ def require_permission(permission: Permission):
|
||||
"""
|
||||
|
||||
async def permission_checker(
|
||||
request: Request,
|
||||
org_id: UUID | None = None,
|
||||
user_id: str | None = Depends(get_user_id),
|
||||
) -> str:
|
||||
@@ -249,6 +264,23 @@ def require_permission(permission: Permission):
|
||||
detail='User not authenticated',
|
||||
)
|
||||
|
||||
# Validate API key organization binding
|
||||
api_key_org_id = await get_api_key_org_id_from_request(request)
|
||||
if api_key_org_id is not None and org_id is not None:
|
||||
if api_key_org_id != org_id:
|
||||
logger.warning(
|
||||
'API key organization mismatch',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'api_key_org_id': str(api_key_org_id),
|
||||
'target_org_id': str(org_id),
|
||||
},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail='API key is not authorized for this organization',
|
||||
)
|
||||
|
||||
user_role = await get_user_org_role(user_id, org_id)
|
||||
|
||||
if not user_role:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from types import MappingProxyType
|
||||
from uuid import UUID
|
||||
|
||||
import jwt
|
||||
from fastapi import Request
|
||||
@@ -59,6 +60,19 @@ class SaasUserAuth(UserAuth):
|
||||
_secrets: Secrets | None = None
|
||||
accepted_tos: bool | None = None
|
||||
auth_type: AuthType = AuthType.COOKIE
|
||||
# API key context fields - populated when authenticated via API key
|
||||
api_key_org_id: UUID | None = None # Org bound to the API key used for auth
|
||||
api_key_id: int | None = None
|
||||
api_key_name: str | None = None
|
||||
|
||||
def get_api_key_org_id(self) -> UUID | None:
|
||||
"""Get the organization ID bound to the API key used for authentication.
|
||||
|
||||
Returns:
|
||||
The org_id if authenticated via API key with org binding, None otherwise
|
||||
(cookie auth or legacy API keys without org binding).
|
||||
"""
|
||||
return self.api_key_org_id
|
||||
|
||||
async def get_user_id(self) -> str | None:
|
||||
return self.user_id
|
||||
@@ -283,14 +297,19 @@ async def saas_user_auth_from_bearer(request: Request) -> SaasUserAuth | None:
|
||||
return None
|
||||
|
||||
api_key_store = ApiKeyStore.get_instance()
|
||||
user_id = await api_key_store.validate_api_key(api_key)
|
||||
if not user_id:
|
||||
validation_result = await api_key_store.validate_api_key(api_key)
|
||||
if not validation_result:
|
||||
return None
|
||||
offline_token = await token_manager.load_offline_token(user_id)
|
||||
offline_token = await token_manager.load_offline_token(
|
||||
validation_result.user_id
|
||||
)
|
||||
saas_user_auth = SaasUserAuth(
|
||||
user_id=user_id,
|
||||
user_id=validation_result.user_id,
|
||||
refresh_token=SecretStr(offline_token),
|
||||
auth_type=AuthType.BEARER,
|
||||
api_key_org_id=validation_result.org_id,
|
||||
api_key_id=validation_result.key_id,
|
||||
api_key_name=validation_result.key_name,
|
||||
)
|
||||
await saas_user_auth.refresh()
|
||||
return saas_user_auth
|
||||
|
||||
@@ -182,6 +182,10 @@ class SetAuthCookieMiddleware:
|
||||
if path.startswith('/api/v1/webhooks/'):
|
||||
return False
|
||||
|
||||
# Service API uses its own authentication (X-Service-API-Key header)
|
||||
if path.startswith('/api/service/'):
|
||||
return False
|
||||
|
||||
is_mcp = path.startswith('/mcp')
|
||||
is_api_route = path.startswith('/api')
|
||||
return is_api_route or is_mcp
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from datetime import UTC, datetime
|
||||
from typing import cast
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from pydantic import BaseModel, field_validator
|
||||
from server.auth.saas_user_auth import SaasUserAuth
|
||||
from storage.api_key import ApiKey
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
@@ -11,7 +13,8 @@ from storage.org_service import OrgService
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth import get_user_id
|
||||
from openhands.server.user_auth import get_user_auth, get_user_id
|
||||
from openhands.server.user_auth.user_auth import AuthType
|
||||
|
||||
|
||||
# Helper functions for BYOR API key management
|
||||
@@ -150,6 +153,16 @@ class MessageResponse(BaseModel):
|
||||
message: str
|
||||
|
||||
|
||||
class CurrentApiKeyResponse(BaseModel):
|
||||
"""Response model for the current API key endpoint."""
|
||||
|
||||
id: int
|
||||
name: str | None
|
||||
org_id: str
|
||||
user_id: str
|
||||
auth_type: str
|
||||
|
||||
|
||||
def api_key_to_response(key: ApiKey) -> ApiKeyResponse:
|
||||
"""Convert an ApiKey model to an ApiKeyResponse."""
|
||||
return ApiKeyResponse(
|
||||
@@ -262,6 +275,46 @@ async def delete_api_key(
|
||||
)
|
||||
|
||||
|
||||
@api_router.get('/current', tags=['Keys'])
|
||||
async def get_current_api_key(
|
||||
request: Request,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> CurrentApiKeyResponse:
|
||||
"""Get information about the currently authenticated API key.
|
||||
|
||||
This endpoint returns metadata about the API key used for the current request,
|
||||
including the org_id associated with the key. This is useful for API key
|
||||
callers who need to know which organization context their key operates in.
|
||||
|
||||
Returns 400 if not authenticated via API key (e.g., using cookie auth).
|
||||
"""
|
||||
user_auth = await get_user_auth(request)
|
||||
|
||||
# Check if authenticated via API key
|
||||
if user_auth.get_auth_type() != AuthType.BEARER:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='This endpoint requires API key authentication. Not available for cookie-based auth.',
|
||||
)
|
||||
|
||||
# In SaaS context, bearer auth always produces SaasUserAuth
|
||||
saas_user_auth = cast(SaasUserAuth, user_auth)
|
||||
|
||||
if saas_user_auth.api_key_org_id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='This API key was created before organization support. Please regenerate your API key to use this endpoint.',
|
||||
)
|
||||
|
||||
return CurrentApiKeyResponse(
|
||||
id=saas_user_auth.api_key_id,
|
||||
name=saas_user_auth.api_key_name,
|
||||
org_id=str(saas_user_auth.api_key_org_id),
|
||||
user_id=user_id,
|
||||
auth_type=saas_user_auth.auth_type.value,
|
||||
)
|
||||
|
||||
|
||||
@api_router.get('/llm/byor', tags=['Keys'])
|
||||
async def get_llm_api_key_for_byor(
|
||||
user_id: str = Depends(get_user_id),
|
||||
|
||||
@@ -68,7 +68,7 @@ async def list_user_orgs(
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(title='The max number of results in the page', gt=0, lte=100),
|
||||
Query(title='The max number of results in the page', gt=0, le=100),
|
||||
] = 100,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> OrgPage:
|
||||
@@ -734,7 +734,7 @@ async def get_org_members(
|
||||
Query(
|
||||
title='The max number of results in the page',
|
||||
gt=0,
|
||||
lte=100,
|
||||
le=100,
|
||||
),
|
||||
] = 10,
|
||||
email: Annotated[
|
||||
|
||||
270
enterprise/server/routes/service.py
Normal file
270
enterprise/server/routes/service.py
Normal file
@@ -0,0 +1,270 @@
|
||||
"""
|
||||
Service API routes for internal service-to-service communication.
|
||||
|
||||
This module provides endpoints for trusted internal services (e.g., automations service)
|
||||
to perform privileged operations like creating API keys on behalf of users.
|
||||
|
||||
Authentication is via a shared secret (X-Service-API-Key header) configured
|
||||
through the AUTOMATIONS_SERVICE_API_KEY environment variable.
|
||||
"""
|
||||
|
||||
import os
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Header, HTTPException, status
|
||||
from pydantic import BaseModel, field_validator
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
# Environment variable for the service API key
|
||||
AUTOMATIONS_SERVICE_API_KEY = os.getenv('AUTOMATIONS_SERVICE_API_KEY', '').strip()
|
||||
|
||||
service_router = APIRouter(prefix='/api/service', tags=['Service'])
|
||||
|
||||
|
||||
class CreateUserApiKeyRequest(BaseModel):
|
||||
"""Request model for creating an API key on behalf of a user."""
|
||||
|
||||
name: str # Required - used to identify the key
|
||||
|
||||
@field_validator('name')
|
||||
@classmethod
|
||||
def validate_name(cls, v: str) -> str:
|
||||
if not v or not v.strip():
|
||||
raise ValueError('name is required and cannot be empty')
|
||||
return v.strip()
|
||||
|
||||
|
||||
class CreateUserApiKeyResponse(BaseModel):
|
||||
"""Response model for created API key."""
|
||||
|
||||
key: str
|
||||
user_id: str
|
||||
org_id: str
|
||||
name: str
|
||||
|
||||
|
||||
class ServiceInfoResponse(BaseModel):
|
||||
"""Response model for service info endpoint."""
|
||||
|
||||
service: str
|
||||
authenticated: bool
|
||||
|
||||
|
||||
async def validate_service_api_key(
|
||||
x_service_api_key: str | None = Header(default=None, alias='X-Service-API-Key'),
|
||||
) -> str:
|
||||
"""
|
||||
Validate the service API key from the request header.
|
||||
|
||||
Args:
|
||||
x_service_api_key: The service API key from the X-Service-API-Key header
|
||||
|
||||
Returns:
|
||||
str: Service identifier for audit logging
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if key is missing or invalid
|
||||
HTTPException: 503 if service auth is not configured
|
||||
"""
|
||||
if not AUTOMATIONS_SERVICE_API_KEY:
|
||||
logger.warning(
|
||||
'Service authentication not configured (AUTOMATIONS_SERVICE_API_KEY not set)'
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail='Service authentication not configured',
|
||||
)
|
||||
|
||||
if not x_service_api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail='X-Service-API-Key header is required',
|
||||
)
|
||||
|
||||
if x_service_api_key != AUTOMATIONS_SERVICE_API_KEY:
|
||||
logger.warning('Invalid service API key attempted')
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail='Invalid service API key',
|
||||
)
|
||||
|
||||
return 'automations-service'
|
||||
|
||||
|
||||
@service_router.get('/health')
|
||||
async def service_health() -> dict:
|
||||
"""Health check endpoint for the service API.
|
||||
|
||||
This endpoint does not require authentication and can be used
|
||||
to verify the service routes are accessible.
|
||||
"""
|
||||
return {
|
||||
'status': 'ok',
|
||||
'service_auth_configured': bool(AUTOMATIONS_SERVICE_API_KEY),
|
||||
}
|
||||
|
||||
|
||||
@service_router.post('/users/{user_id}/orgs/{org_id}/api-keys')
|
||||
async def get_or_create_api_key_for_user(
|
||||
user_id: str,
|
||||
org_id: UUID,
|
||||
request: CreateUserApiKeyRequest,
|
||||
x_service_api_key: str | None = Header(default=None, alias='X-Service-API-Key'),
|
||||
) -> CreateUserApiKeyResponse:
|
||||
"""
|
||||
Get or create an API key for a user on behalf of the automations service.
|
||||
|
||||
If a key with the given name already exists for the user/org and is not expired,
|
||||
returns the existing key. Otherwise, creates a new key.
|
||||
|
||||
The created/returned keys are system keys and are:
|
||||
- Not visible to the user in their API keys list
|
||||
- Not deletable by the user
|
||||
- Never expire
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
org_id: The organization ID
|
||||
request: Request body containing name (required)
|
||||
x_service_api_key: Service API key header for authentication
|
||||
|
||||
Returns:
|
||||
CreateUserApiKeyResponse: The API key and metadata
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if service key is invalid
|
||||
HTTPException: 404 if user not found
|
||||
HTTPException: 403 if user is not a member of the specified org
|
||||
"""
|
||||
# Validate service API key
|
||||
service_id = await validate_service_api_key(x_service_api_key)
|
||||
|
||||
# Verify user exists
|
||||
user = await UserStore.get_user_by_id(user_id)
|
||||
if not user:
|
||||
logger.warning(
|
||||
'Service attempted to create key for non-existent user',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f'User {user_id} not found',
|
||||
)
|
||||
|
||||
# Verify user is a member of the specified org
|
||||
org_member = await OrgMemberStore.get_org_member(org_id, UUID(user_id))
|
||||
if not org_member:
|
||||
logger.warning(
|
||||
'Service attempted to create key for user not in org',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f'User {user_id} is not a member of org {org_id}',
|
||||
)
|
||||
|
||||
# Get or create the system API key
|
||||
api_key_store = ApiKeyStore.get_instance()
|
||||
|
||||
try:
|
||||
api_key = await api_key_store.get_or_create_system_api_key(
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name=request.name,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Failed to get or create system API key',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to get or create API key',
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'Service created API key for user',
|
||||
extra={
|
||||
'service_id': service_id,
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'key_name': request.name,
|
||||
},
|
||||
)
|
||||
|
||||
return CreateUserApiKeyResponse(
|
||||
key=api_key,
|
||||
user_id=user_id,
|
||||
org_id=str(org_id),
|
||||
name=request.name,
|
||||
)
|
||||
|
||||
|
||||
@service_router.delete('/users/{user_id}/orgs/{org_id}/api-keys/{key_name}')
|
||||
async def delete_user_api_key(
|
||||
user_id: str,
|
||||
org_id: UUID,
|
||||
key_name: str,
|
||||
x_service_api_key: str | None = Header(default=None, alias='X-Service-API-Key'),
|
||||
) -> dict:
|
||||
"""
|
||||
Delete a system API key created by the service.
|
||||
|
||||
This endpoint allows the automations service to clean up API keys
|
||||
it previously created for users.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
org_id: The organization ID
|
||||
key_name: The name of the key to delete (without __SYSTEM__: prefix)
|
||||
x_service_api_key: Service API key header for authentication
|
||||
|
||||
Returns:
|
||||
dict: Success message
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if service key is invalid
|
||||
HTTPException: 404 if key not found
|
||||
"""
|
||||
# Validate service API key
|
||||
service_id = await validate_service_api_key(x_service_api_key)
|
||||
|
||||
api_key_store = ApiKeyStore.get_instance()
|
||||
|
||||
# Delete the key by name (wrap with system key prefix since service creates system keys)
|
||||
system_key_name = api_key_store.make_system_key_name(key_name)
|
||||
success = await api_key_store.delete_api_key_by_name(
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name=system_key_name,
|
||||
allow_system=True,
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f'API key with name "{key_name}" not found for user {user_id} in org {org_id}',
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'Service deleted API key for user',
|
||||
extra={
|
||||
'service_id': service_id,
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'key_name': key_name,
|
||||
},
|
||||
)
|
||||
|
||||
return {'message': 'API key deleted successfully'}
|
||||
@@ -4,7 +4,7 @@ from datetime import datetime
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from server.sharing.shared_conversation_info_service import (
|
||||
SharedConversationInfoService,
|
||||
)
|
||||
@@ -60,7 +60,7 @@ async def search_shared_conversations(
|
||||
Query(
|
||||
title='The max number of results in the page',
|
||||
gt=0,
|
||||
lte=100,
|
||||
le=100,
|
||||
),
|
||||
] = 100,
|
||||
include_sub_conversations: Annotated[
|
||||
@@ -72,8 +72,6 @@ async def search_shared_conversations(
|
||||
shared_conversation_service: SharedConversationInfoService = shared_conversation_info_service_dependency,
|
||||
) -> SharedConversationPage:
|
||||
"""Search / List shared conversations."""
|
||||
assert limit > 0
|
||||
assert limit <= 100
|
||||
return await shared_conversation_service.search_shared_conversation_info(
|
||||
title__contains=title__contains,
|
||||
created_at__gte=created_at__gte,
|
||||
@@ -127,7 +125,11 @@ async def batch_get_shared_conversations(
|
||||
shared_conversation_service: SharedConversationInfoService = shared_conversation_info_service_dependency,
|
||||
) -> list[SharedConversation | None]:
|
||||
"""Get a batch of shared conversations given their ids. Return None for any missing or non-shared."""
|
||||
assert len(ids) <= 100
|
||||
if len(ids) > 100:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f'Cannot request more than 100 conversations at once, got {len(ids)}',
|
||||
)
|
||||
uuids = [UUID(id_) for id_ in ids]
|
||||
shared_conversation_info = (
|
||||
await shared_conversation_service.batch_get_shared_conversation_info(uuids)
|
||||
|
||||
@@ -4,7 +4,7 @@ from datetime import datetime
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from server.sharing.shared_event_service import (
|
||||
SharedEventService,
|
||||
SharedEventServiceInjector,
|
||||
@@ -77,13 +77,11 @@ async def search_shared_events(
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(title='The max number of results in the page', gt=0, lte=100),
|
||||
Query(title='The max number of results in the page', gt=0, le=100),
|
||||
] = 100,
|
||||
shared_event_service: SharedEventService = shared_event_service_dependency,
|
||||
) -> EventPage:
|
||||
"""Search / List events for a shared conversation."""
|
||||
assert limit > 0
|
||||
assert limit <= 100
|
||||
return await shared_event_service.search_shared_events(
|
||||
conversation_id=UUID(conversation_id),
|
||||
kind__eq=kind__eq,
|
||||
@@ -134,7 +132,11 @@ async def batch_get_shared_events(
|
||||
shared_event_service: SharedEventService = shared_event_service_dependency,
|
||||
) -> list[Event | None]:
|
||||
"""Get a batch of events for a shared conversation given their ids, returning null for any missing event."""
|
||||
assert len(id) <= 100
|
||||
if len(id) > 100:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f'Cannot request more than 100 events at once, got {len(id)}',
|
||||
)
|
||||
event_ids = [UUID(id_) for id_ in id]
|
||||
events = await shared_event_service.batch_get_shared_events(
|
||||
UUID(conversation_id), event_ids
|
||||
|
||||
@@ -4,6 +4,7 @@ import secrets
|
||||
import string
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from storage.api_key import ApiKey
|
||||
@@ -13,9 +14,22 @@ from storage.user_store import UserStore
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class ApiKeyValidationResult:
|
||||
"""Result of API key validation containing user and organization info."""
|
||||
|
||||
user_id: str
|
||||
org_id: UUID | None # None for legacy API keys without org binding
|
||||
key_id: int
|
||||
key_name: str | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ApiKeyStore:
|
||||
API_KEY_PREFIX = 'sk-oh-'
|
||||
# Prefix for system keys created by internal services (e.g., automations)
|
||||
# Keys with this prefix are hidden from users and cannot be deleted by users
|
||||
SYSTEM_KEY_NAME_PREFIX = '__SYSTEM__:'
|
||||
|
||||
def generate_api_key(self, length: int = 32) -> str:
|
||||
"""Generate a random API key with the sk-oh- prefix."""
|
||||
@@ -23,6 +37,19 @@ class ApiKeyStore:
|
||||
random_part = ''.join(secrets.choice(alphabet) for _ in range(length))
|
||||
return f'{self.API_KEY_PREFIX}{random_part}'
|
||||
|
||||
@classmethod
|
||||
def is_system_key_name(cls, name: str | None) -> bool:
|
||||
"""Check if a key name indicates a system key."""
|
||||
return name is not None and name.startswith(cls.SYSTEM_KEY_NAME_PREFIX)
|
||||
|
||||
@classmethod
|
||||
def make_system_key_name(cls, name: str) -> str:
|
||||
"""Create a system key name with the appropriate prefix.
|
||||
|
||||
Format: __SYSTEM__:<name>
|
||||
"""
|
||||
return f'{cls.SYSTEM_KEY_NAME_PREFIX}{name}'
|
||||
|
||||
async def create_api_key(
|
||||
self, user_id: str, name: str | None = None, expires_at: datetime | None = None
|
||||
) -> str:
|
||||
@@ -60,8 +87,120 @@ class ApiKeyStore:
|
||||
|
||||
return api_key
|
||||
|
||||
async def validate_api_key(self, api_key: str) -> str | None:
|
||||
"""Validate an API key and return the associated user_id if valid."""
|
||||
async def get_or_create_system_api_key(
|
||||
self,
|
||||
user_id: str,
|
||||
org_id: UUID,
|
||||
name: str,
|
||||
) -> str:
|
||||
"""Get or create a system API key for a user on behalf of an internal service.
|
||||
|
||||
If a key with the given name already exists for this user/org and is not expired,
|
||||
returns the existing key. Otherwise, creates a new key (and deletes any expired one).
|
||||
|
||||
System keys are:
|
||||
- Not visible to users in their API keys list (filtered by name prefix)
|
||||
- Not deletable by users (protected by name prefix check)
|
||||
- Associated with a specific org (not the user's current org)
|
||||
- Never expire (no expiration date)
|
||||
|
||||
Args:
|
||||
user_id: The ID of the user to create the key for
|
||||
org_id: The organization ID to associate the key with
|
||||
name: Required name for the key (will be prefixed with __SYSTEM__:)
|
||||
|
||||
Returns:
|
||||
The API key (existing or newly created)
|
||||
"""
|
||||
# Create system key name with prefix
|
||||
system_key_name = self.make_system_key_name(name)
|
||||
|
||||
async with a_session_maker() as session:
|
||||
# Check if key already exists for this user/org/name
|
||||
result = await session.execute(
|
||||
select(ApiKey).filter(
|
||||
ApiKey.user_id == user_id,
|
||||
ApiKey.org_id == org_id,
|
||||
ApiKey.name == system_key_name,
|
||||
)
|
||||
)
|
||||
existing_key = result.scalars().first()
|
||||
|
||||
if existing_key:
|
||||
# Check if expired
|
||||
if existing_key.expires_at:
|
||||
now = datetime.now(UTC)
|
||||
expires_at = existing_key.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=UTC)
|
||||
|
||||
if expires_at < now:
|
||||
# Key is expired, delete it and create new one
|
||||
logger.info(
|
||||
'System API key expired, re-issuing',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'key_name': system_key_name,
|
||||
},
|
||||
)
|
||||
await session.delete(existing_key)
|
||||
await session.commit()
|
||||
else:
|
||||
# Key exists and is not expired, return it
|
||||
logger.debug(
|
||||
'Returning existing system API key',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'key_name': system_key_name,
|
||||
},
|
||||
)
|
||||
return existing_key.key
|
||||
else:
|
||||
# Key exists and has no expiration, return it
|
||||
logger.debug(
|
||||
'Returning existing system API key',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'key_name': system_key_name,
|
||||
},
|
||||
)
|
||||
return existing_key.key
|
||||
|
||||
# Create new key (no expiration)
|
||||
api_key = self.generate_api_key()
|
||||
|
||||
async with a_session_maker() as session:
|
||||
key_record = ApiKey(
|
||||
key=api_key,
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name=system_key_name,
|
||||
expires_at=None, # System keys never expire
|
||||
)
|
||||
session.add(key_record)
|
||||
await session.commit()
|
||||
|
||||
logger.info(
|
||||
'Created system API key',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'key_name': system_key_name,
|
||||
},
|
||||
)
|
||||
|
||||
return api_key
|
||||
|
||||
async def validate_api_key(self, api_key: str) -> ApiKeyValidationResult | None:
|
||||
"""Validate an API key and return the associated user_id and org_id if valid.
|
||||
|
||||
Returns:
|
||||
ApiKeyValidationResult if the key is valid, None otherwise.
|
||||
The org_id may be None for legacy API keys that weren't bound to an organization.
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
|
||||
async with a_session_maker() as session:
|
||||
@@ -89,7 +228,12 @@ class ApiKeyStore:
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
return key_record.user_id
|
||||
return ApiKeyValidationResult(
|
||||
user_id=key_record.user_id,
|
||||
org_id=key_record.org_id,
|
||||
key_id=key_record.id,
|
||||
key_name=key_record.name,
|
||||
)
|
||||
|
||||
async def delete_api_key(self, api_key: str) -> bool:
|
||||
"""Delete an API key by the key value."""
|
||||
@@ -105,8 +249,18 @@ class ApiKeyStore:
|
||||
|
||||
return True
|
||||
|
||||
async def delete_api_key_by_id(self, key_id: int) -> bool:
|
||||
"""Delete an API key by its ID."""
|
||||
async def delete_api_key_by_id(
|
||||
self, key_id: int, allow_system: bool = False
|
||||
) -> bool:
|
||||
"""Delete an API key by its ID.
|
||||
|
||||
Args:
|
||||
key_id: The ID of the key to delete
|
||||
allow_system: If False (default), system keys cannot be deleted
|
||||
|
||||
Returns:
|
||||
True if the key was deleted, False if not found or is a protected system key
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(select(ApiKey).filter(ApiKey.id == key_id))
|
||||
key_record = result.scalars().first()
|
||||
@@ -114,13 +268,26 @@ class ApiKeyStore:
|
||||
if not key_record:
|
||||
return False
|
||||
|
||||
# Protect system keys from deletion unless explicitly allowed
|
||||
if self.is_system_key_name(key_record.name) and not allow_system:
|
||||
logger.warning(
|
||||
'Attempted to delete system API key',
|
||||
extra={'key_id': key_id, 'user_id': key_record.user_id},
|
||||
)
|
||||
return False
|
||||
|
||||
await session.delete(key_record)
|
||||
await session.commit()
|
||||
|
||||
return True
|
||||
|
||||
async def list_api_keys(self, user_id: str) -> list[ApiKey]:
|
||||
"""List all API keys for a user."""
|
||||
"""List all user-visible API keys for a user.
|
||||
|
||||
This excludes:
|
||||
- System keys (name starts with __SYSTEM__:) - created by internal services
|
||||
- MCP_API_KEY - internal MCP key
|
||||
"""
|
||||
user = await UserStore.get_user_by_id(user_id)
|
||||
if user is None:
|
||||
raise ValueError(f'User not found: {user_id}')
|
||||
@@ -129,11 +296,17 @@ class ApiKeyStore:
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(ApiKey).filter(
|
||||
ApiKey.user_id == user_id, ApiKey.org_id == org_id
|
||||
ApiKey.user_id == user_id,
|
||||
ApiKey.org_id == org_id,
|
||||
)
|
||||
)
|
||||
keys = result.scalars().all()
|
||||
return [key for key in keys if key.name != 'MCP_API_KEY']
|
||||
# Filter out system keys and MCP_API_KEY
|
||||
return [
|
||||
key
|
||||
for key in keys
|
||||
if key.name != 'MCP_API_KEY' and not self.is_system_key_name(key.name)
|
||||
]
|
||||
|
||||
async def retrieve_mcp_api_key(self, user_id: str) -> str | None:
|
||||
user = await UserStore.get_user_by_id(user_id)
|
||||
@@ -163,17 +336,44 @@ class ApiKeyStore:
|
||||
key_record = result.scalars().first()
|
||||
return key_record.key if key_record else None
|
||||
|
||||
async def delete_api_key_by_name(self, user_id: str, name: str) -> bool:
|
||||
"""Delete an API key by name for a specific user."""
|
||||
async def delete_api_key_by_name(
|
||||
self,
|
||||
user_id: str,
|
||||
name: str,
|
||||
org_id: UUID | None = None,
|
||||
allow_system: bool = False,
|
||||
) -> bool:
|
||||
"""Delete an API key by name for a specific user.
|
||||
|
||||
Args:
|
||||
user_id: The ID of the user whose key to delete
|
||||
name: The name of the key to delete
|
||||
org_id: Optional organization ID to filter by (required for system keys)
|
||||
allow_system: If False (default), system keys cannot be deleted
|
||||
|
||||
Returns:
|
||||
True if the key was deleted, False if not found or is a protected system key
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(ApiKey).filter(ApiKey.user_id == user_id, ApiKey.name == name)
|
||||
)
|
||||
# Build the query filters
|
||||
filters = [ApiKey.user_id == user_id, ApiKey.name == name]
|
||||
if org_id is not None:
|
||||
filters.append(ApiKey.org_id == org_id)
|
||||
|
||||
result = await session.execute(select(ApiKey).filter(*filters))
|
||||
key_record = result.scalars().first()
|
||||
|
||||
if not key_record:
|
||||
return False
|
||||
|
||||
# Protect system keys from deletion unless explicitly allowed
|
||||
if self.is_system_key_name(key_record.name) and not allow_system:
|
||||
logger.warning(
|
||||
'Attempted to delete system API key',
|
||||
extra={'user_id': user_id, 'key_name': name},
|
||||
)
|
||||
return False
|
||||
|
||||
await session.delete(key_record)
|
||||
await session.commit()
|
||||
|
||||
|
||||
@@ -164,9 +164,33 @@ class LiteLlmManager:
|
||||
)
|
||||
|
||||
if create_user:
|
||||
await LiteLlmManager._create_user(
|
||||
user_created = await LiteLlmManager._create_user(
|
||||
client, keycloak_user_info.get('email'), keycloak_user_id
|
||||
)
|
||||
if not user_created:
|
||||
logger.error(
|
||||
'create_entries_failed_user_creation',
|
||||
extra={
|
||||
'org_id': org_id,
|
||||
'user_id': keycloak_user_id,
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
# Verify user exists before proceeding with key generation
|
||||
user_exists = await LiteLlmManager._user_exists(
|
||||
client, keycloak_user_id
|
||||
)
|
||||
if not user_exists:
|
||||
logger.error(
|
||||
'create_entries_user_not_found_before_key_generation',
|
||||
extra={
|
||||
'org_id': org_id,
|
||||
'user_id': keycloak_user_id,
|
||||
'create_user_flag': create_user,
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
await LiteLlmManager._add_user_to_team(
|
||||
client, keycloak_user_id, org_id, team_budget
|
||||
@@ -565,20 +589,26 @@ class LiteLlmManager:
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
|
||||
json_data: dict[str, Any] = {
|
||||
'team_id': team_id,
|
||||
'team_alias': team_alias,
|
||||
'models': [],
|
||||
'spend': 0,
|
||||
'metadata': {
|
||||
'version': ORG_SETTINGS_VERSION,
|
||||
'model': get_default_litellm_model(),
|
||||
},
|
||||
}
|
||||
|
||||
if max_budget is not None:
|
||||
json_data['max_budget'] = max_budget
|
||||
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/team/new',
|
||||
json={
|
||||
'team_id': team_id,
|
||||
'team_alias': team_alias,
|
||||
'models': [],
|
||||
'max_budget': max_budget, # None disables budget enforcement
|
||||
'spend': 0,
|
||||
'metadata': {
|
||||
'version': ORG_SETTINGS_VERSION,
|
||||
'model': get_default_litellm_model(),
|
||||
},
|
||||
},
|
||||
json=json_data,
|
||||
)
|
||||
|
||||
# Team failed to create in litellm - this is an unforseen error state...
|
||||
if not response.is_success:
|
||||
if (
|
||||
@@ -655,15 +685,48 @@ class LiteLlmManager:
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
async def _user_exists(
|
||||
client: httpx.AsyncClient,
|
||||
user_id: str,
|
||||
) -> bool:
|
||||
"""Check if a user exists in LiteLLM.
|
||||
|
||||
Returns True if the user exists, False otherwise.
|
||||
"""
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
return False
|
||||
try:
|
||||
response = await client.get(
|
||||
f'{LITE_LLM_API_URL}/user/info?user_id={user_id}',
|
||||
)
|
||||
if response.is_success:
|
||||
user_data = response.json()
|
||||
# Check that user_info exists and has the user_id
|
||||
user_info = user_data.get('user_info', {})
|
||||
return user_info.get('user_id') == user_id
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
'litellm_user_exists_check_failed',
|
||||
extra={'user_id': user_id, 'error': str(e)},
|
||||
)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def _create_user(
|
||||
client: httpx.AsyncClient,
|
||||
email: str | None,
|
||||
keycloak_user_id: str,
|
||||
):
|
||||
) -> bool:
|
||||
"""Create a user in LiteLLM.
|
||||
|
||||
Returns True if the user was created or already exists and is verified,
|
||||
False if creation failed and user does not exist.
|
||||
"""
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
return False
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/user/new',
|
||||
json={
|
||||
@@ -716,17 +779,33 @@ class LiteLlmManager:
|
||||
'user_id': keycloak_user_id,
|
||||
},
|
||||
)
|
||||
return
|
||||
# Verify the user actually exists before returning success
|
||||
user_exists = await LiteLlmManager._user_exists(
|
||||
client, keycloak_user_id
|
||||
)
|
||||
if not user_exists:
|
||||
logger.error(
|
||||
'litellm_user_claimed_exists_but_not_found',
|
||||
extra={
|
||||
'user_id': keycloak_user_id,
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
},
|
||||
)
|
||||
return False
|
||||
return True
|
||||
logger.error(
|
||||
'error_creating_litellm_user',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'user_id': [keycloak_user_id],
|
||||
'user_id': keycloak_user_id,
|
||||
'email': None,
|
||||
},
|
||||
)
|
||||
return False
|
||||
response.raise_for_status()
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def _get_user(client: httpx.AsyncClient, user_id: str) -> dict | None:
|
||||
@@ -967,14 +1046,20 @@ class LiteLlmManager:
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
|
||||
json_data: dict[str, Any] = {
|
||||
'team_id': team_id,
|
||||
'member': {'user_id': keycloak_user_id, 'role': 'user'},
|
||||
}
|
||||
|
||||
if max_budget is not None:
|
||||
json_data['max_budget_in_team'] = max_budget
|
||||
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/team/member_add',
|
||||
json={
|
||||
'team_id': team_id,
|
||||
'member': {'user_id': keycloak_user_id, 'role': 'user'},
|
||||
'max_budget_in_team': max_budget, # None disables budget enforcement
|
||||
},
|
||||
json=json_data,
|
||||
)
|
||||
|
||||
# Failed to add user to team - this is an unforseen error state...
|
||||
if not response.is_success:
|
||||
if (
|
||||
@@ -1056,14 +1141,20 @@ class LiteLlmManager:
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
|
||||
json_data: dict[str, Any] = {
|
||||
'team_id': team_id,
|
||||
'user_id': keycloak_user_id,
|
||||
}
|
||||
|
||||
if max_budget is not None:
|
||||
json_data['max_budget_in_team'] = max_budget
|
||||
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/team/member_update',
|
||||
json={
|
||||
'team_id': team_id,
|
||||
'user_id': keycloak_user_id,
|
||||
'max_budget_in_team': max_budget, # None disables budget enforcement
|
||||
},
|
||||
json=json_data,
|
||||
)
|
||||
|
||||
# Failed to update user in team - this is an unforseen error state...
|
||||
if not response.is_success:
|
||||
logger.error(
|
||||
@@ -1450,6 +1541,7 @@ class LiteLlmManager:
|
||||
create_team = staticmethod(with_http_client(_create_team))
|
||||
get_team = staticmethod(with_http_client(_get_team))
|
||||
update_team = staticmethod(with_http_client(_update_team))
|
||||
user_exists = staticmethod(with_http_client(_user_exists))
|
||||
create_user = staticmethod(with_http_client(_create_user))
|
||||
get_user = staticmethod(with_http_client(_get_user))
|
||||
update_user = staticmethod(with_http_client(_update_user))
|
||||
|
||||
@@ -15,25 +15,27 @@ class SaasConversationValidator(ConversationValidator):
|
||||
|
||||
async def _validate_api_key(self, api_key: str) -> str | None:
|
||||
"""
|
||||
Validate an API key and return the user_id and github_user_id if valid.
|
||||
Validate an API key and return the user_id if valid.
|
||||
|
||||
Args:
|
||||
api_key: The API key to validate
|
||||
|
||||
Returns:
|
||||
A tuple of (user_id, github_user_id) if the API key is valid, None otherwise
|
||||
The user_id if the API key is valid, None otherwise
|
||||
"""
|
||||
try:
|
||||
token_manager = TokenManager()
|
||||
|
||||
# Validate the API key and get the user_id
|
||||
api_key_store = ApiKeyStore.get_instance()
|
||||
user_id = await api_key_store.validate_api_key(api_key)
|
||||
validation_result = await api_key_store.validate_api_key(api_key)
|
||||
|
||||
if not user_id:
|
||||
if not validation_result:
|
||||
logger.warning('Invalid API key')
|
||||
return None
|
||||
|
||||
user_id = validation_result.user_id
|
||||
|
||||
# Get the offline token for the user
|
||||
offline_token = await token_manager.load_offline_token(user_id)
|
||||
if not offline_token:
|
||||
|
||||
@@ -59,12 +59,15 @@ class SaasSecretsStore(SecretsStore):
|
||||
|
||||
async with a_session_maker() as session:
|
||||
# Incoming secrets are always the most updated ones
|
||||
# Delete all existing records and override with incoming ones
|
||||
await session.execute(
|
||||
delete(StoredCustomSecrets).filter(
|
||||
StoredCustomSecrets.keycloak_user_id == self.user_id
|
||||
)
|
||||
# Delete existing records for this user AND organization only
|
||||
delete_query = delete(StoredCustomSecrets).filter(
|
||||
StoredCustomSecrets.keycloak_user_id == self.user_id
|
||||
)
|
||||
if org_id is not None:
|
||||
delete_query = delete_query.filter(StoredCustomSecrets.org_id == org_id)
|
||||
else:
|
||||
delete_query = delete_query.filter(StoredCustomSecrets.org_id.is_(None))
|
||||
await session.execute(delete_query)
|
||||
|
||||
# Prepare the new secrets data
|
||||
kwargs = item.model_dump(context={'expose_secrets': True})
|
||||
|
||||
0
enterprise/tests/unit/routes/__init__.py
Normal file
0
enterprise/tests/unit/routes/__init__.py
Normal file
331
enterprise/tests/unit/routes/test_service.py
Normal file
331
enterprise/tests/unit/routes/test_service.py
Normal file
@@ -0,0 +1,331 @@
|
||||
"""Unit tests for service API routes."""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from server.routes.service import (
|
||||
CreateUserApiKeyRequest,
|
||||
delete_user_api_key,
|
||||
get_or_create_api_key_for_user,
|
||||
validate_service_api_key,
|
||||
)
|
||||
|
||||
|
||||
class TestValidateServiceApiKey:
|
||||
"""Test cases for validate_service_api_key."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_service_key(self):
|
||||
"""Test validation with valid service API key."""
|
||||
with patch(
|
||||
'server.routes.service.AUTOMATIONS_SERVICE_API_KEY', 'test-service-key'
|
||||
):
|
||||
result = await validate_service_api_key('test-service-key')
|
||||
assert result == 'automations-service'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_service_key(self):
|
||||
"""Test validation with missing service API key header."""
|
||||
with patch(
|
||||
'server.routes.service.AUTOMATIONS_SERVICE_API_KEY', 'test-service-key'
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await validate_service_api_key(None)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert 'X-Service-API-Key header is required' in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_service_key(self):
|
||||
"""Test validation with invalid service API key."""
|
||||
with patch(
|
||||
'server.routes.service.AUTOMATIONS_SERVICE_API_KEY', 'test-service-key'
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await validate_service_api_key('wrong-key')
|
||||
assert exc_info.value.status_code == 401
|
||||
assert 'Invalid service API key' in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_auth_not_configured(self):
|
||||
"""Test validation when service auth is not configured."""
|
||||
with patch('server.routes.service.AUTOMATIONS_SERVICE_API_KEY', ''):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await validate_service_api_key('any-key')
|
||||
assert exc_info.value.status_code == 503
|
||||
assert 'Service authentication not configured' in exc_info.value.detail
|
||||
|
||||
|
||||
class TestCreateUserApiKeyRequest:
|
||||
"""Test cases for CreateUserApiKeyRequest validation."""
|
||||
|
||||
def test_valid_request(self):
|
||||
"""Test valid request with all fields."""
|
||||
request = CreateUserApiKeyRequest(
|
||||
name='automation',
|
||||
)
|
||||
assert request.name == 'automation'
|
||||
|
||||
def test_name_is_required(self):
|
||||
"""Test that name field is required."""
|
||||
with pytest.raises(ValueError):
|
||||
CreateUserApiKeyRequest(
|
||||
name='', # Empty name should fail
|
||||
)
|
||||
|
||||
def test_name_is_stripped(self):
|
||||
"""Test that name field is stripped of whitespace."""
|
||||
request = CreateUserApiKeyRequest(
|
||||
name=' automation ',
|
||||
)
|
||||
assert request.name == 'automation'
|
||||
|
||||
def test_whitespace_only_name_fails(self):
|
||||
"""Test that whitespace-only name fails validation."""
|
||||
with pytest.raises(ValueError):
|
||||
CreateUserApiKeyRequest(
|
||||
name=' ',
|
||||
)
|
||||
|
||||
|
||||
class TestGetOrCreateApiKeyForUser:
|
||||
"""Test cases for get_or_create_api_key_for_user endpoint."""
|
||||
|
||||
@pytest.fixture
|
||||
def valid_user_id(self):
|
||||
"""Return a valid user ID."""
|
||||
return '5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
|
||||
@pytest.fixture
|
||||
def valid_org_id(self):
|
||||
"""Return a valid org ID."""
|
||||
return uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081')
|
||||
|
||||
@pytest.fixture
|
||||
def valid_request(self):
|
||||
"""Create a valid request object."""
|
||||
return CreateUserApiKeyRequest(
|
||||
name='automation',
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_not_found(self, valid_user_id, valid_org_id, valid_request):
|
||||
"""Test error when user doesn't exist."""
|
||||
with patch('server.routes.service.AUTOMATIONS_SERVICE_API_KEY', 'test-key'):
|
||||
with patch(
|
||||
'server.routes.service.UserStore.get_user_by_id', new_callable=AsyncMock
|
||||
) as mock_get_user:
|
||||
mock_get_user.return_value = None
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_or_create_api_key_for_user(
|
||||
user_id=valid_user_id,
|
||||
org_id=valid_org_id,
|
||||
request=valid_request,
|
||||
x_service_api_key='test-key',
|
||||
)
|
||||
assert exc_info.value.status_code == 404
|
||||
assert 'not found' in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_not_in_org(self, valid_user_id, valid_org_id, valid_request):
|
||||
"""Test error when user is not a member of the org."""
|
||||
mock_user = MagicMock()
|
||||
|
||||
with patch('server.routes.service.AUTOMATIONS_SERVICE_API_KEY', 'test-key'):
|
||||
with patch(
|
||||
'server.routes.service.UserStore.get_user_by_id', new_callable=AsyncMock
|
||||
) as mock_get_user:
|
||||
with patch(
|
||||
'server.routes.service.OrgMemberStore.get_org_member',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_member:
|
||||
mock_get_user.return_value = mock_user
|
||||
mock_get_member.return_value = None
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_or_create_api_key_for_user(
|
||||
user_id=valid_user_id,
|
||||
org_id=valid_org_id,
|
||||
request=valid_request,
|
||||
x_service_api_key='test-key',
|
||||
)
|
||||
assert exc_info.value.status_code == 403
|
||||
assert 'not a member of org' in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_key_creation(
|
||||
self, valid_user_id, valid_org_id, valid_request
|
||||
):
|
||||
"""Test successful API key creation."""
|
||||
mock_user = MagicMock()
|
||||
mock_org_member = MagicMock()
|
||||
mock_api_key_store = MagicMock()
|
||||
mock_api_key_store.get_or_create_system_api_key = AsyncMock(
|
||||
return_value='sk-oh-test-key-12345678901234567890'
|
||||
)
|
||||
|
||||
with patch('server.routes.service.AUTOMATIONS_SERVICE_API_KEY', 'test-key'):
|
||||
with patch(
|
||||
'server.routes.service.UserStore.get_user_by_id', new_callable=AsyncMock
|
||||
) as mock_get_user:
|
||||
with patch(
|
||||
'server.routes.service.OrgMemberStore.get_org_member',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_member:
|
||||
with patch(
|
||||
'server.routes.service.ApiKeyStore.get_instance'
|
||||
) as mock_get_store:
|
||||
mock_get_user.return_value = mock_user
|
||||
mock_get_member.return_value = mock_org_member
|
||||
mock_get_store.return_value = mock_api_key_store
|
||||
|
||||
response = await get_or_create_api_key_for_user(
|
||||
user_id=valid_user_id,
|
||||
org_id=valid_org_id,
|
||||
request=valid_request,
|
||||
x_service_api_key='test-key',
|
||||
)
|
||||
|
||||
assert response.key == 'sk-oh-test-key-12345678901234567890'
|
||||
assert response.user_id == valid_user_id
|
||||
assert response.org_id == str(valid_org_id)
|
||||
assert response.name == 'automation'
|
||||
|
||||
# Verify the store was called with correct arguments
|
||||
mock_api_key_store.get_or_create_system_api_key.assert_called_once_with(
|
||||
user_id=valid_user_id,
|
||||
org_id=valid_org_id,
|
||||
name='automation',
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_exception_handling(
|
||||
self, valid_user_id, valid_org_id, valid_request
|
||||
):
|
||||
"""Test error handling when store raises exception."""
|
||||
mock_user = MagicMock()
|
||||
mock_org_member = MagicMock()
|
||||
mock_api_key_store = MagicMock()
|
||||
mock_api_key_store.get_or_create_system_api_key = AsyncMock(
|
||||
side_effect=Exception('Database error')
|
||||
)
|
||||
|
||||
with patch('server.routes.service.AUTOMATIONS_SERVICE_API_KEY', 'test-key'):
|
||||
with patch(
|
||||
'server.routes.service.UserStore.get_user_by_id', new_callable=AsyncMock
|
||||
) as mock_get_user:
|
||||
with patch(
|
||||
'server.routes.service.OrgMemberStore.get_org_member',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_member:
|
||||
with patch(
|
||||
'server.routes.service.ApiKeyStore.get_instance'
|
||||
) as mock_get_store:
|
||||
mock_get_user.return_value = mock_user
|
||||
mock_get_member.return_value = mock_org_member
|
||||
mock_get_store.return_value = mock_api_key_store
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_or_create_api_key_for_user(
|
||||
user_id=valid_user_id,
|
||||
org_id=valid_org_id,
|
||||
request=valid_request,
|
||||
x_service_api_key='test-key',
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert 'Failed to get or create API key' in exc_info.value.detail
|
||||
|
||||
|
||||
class TestDeleteUserApiKey:
|
||||
"""Test cases for delete_user_api_key endpoint."""
|
||||
|
||||
@pytest.fixture
|
||||
def valid_org_id(self):
|
||||
"""Return a valid org ID."""
|
||||
return uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_delete(self, valid_org_id):
|
||||
"""Test successful deletion of a system API key."""
|
||||
mock_api_key_store = MagicMock()
|
||||
mock_api_key_store.make_system_key_name.return_value = '__SYSTEM__:automation'
|
||||
mock_api_key_store.delete_api_key_by_name = AsyncMock(return_value=True)
|
||||
|
||||
with patch('server.routes.service.AUTOMATIONS_SERVICE_API_KEY', 'test-key'):
|
||||
with patch(
|
||||
'server.routes.service.ApiKeyStore.get_instance'
|
||||
) as mock_get_store:
|
||||
mock_get_store.return_value = mock_api_key_store
|
||||
|
||||
response = await delete_user_api_key(
|
||||
user_id='user-123',
|
||||
org_id=valid_org_id,
|
||||
key_name='automation',
|
||||
x_service_api_key='test-key',
|
||||
)
|
||||
|
||||
assert response == {'message': 'API key deleted successfully'}
|
||||
|
||||
# Verify the store was called with correct arguments
|
||||
mock_api_key_store.make_system_key_name.assert_called_once_with('automation')
|
||||
mock_api_key_store.delete_api_key_by_name.assert_called_once_with(
|
||||
user_id='user-123',
|
||||
org_id=valid_org_id,
|
||||
name='__SYSTEM__:automation',
|
||||
allow_system=True,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_key_not_found(self, valid_org_id):
|
||||
"""Test error when key to delete is not found."""
|
||||
mock_api_key_store = MagicMock()
|
||||
mock_api_key_store.make_system_key_name.return_value = '__SYSTEM__:nonexistent'
|
||||
mock_api_key_store.delete_api_key_by_name = AsyncMock(return_value=False)
|
||||
|
||||
with patch('server.routes.service.AUTOMATIONS_SERVICE_API_KEY', 'test-key'):
|
||||
with patch(
|
||||
'server.routes.service.ApiKeyStore.get_instance'
|
||||
) as mock_get_store:
|
||||
mock_get_store.return_value = mock_api_key_store
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await delete_user_api_key(
|
||||
user_id='user-123',
|
||||
org_id=valid_org_id,
|
||||
key_name='nonexistent',
|
||||
x_service_api_key='test-key',
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
assert 'not found' in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_invalid_service_key(self, valid_org_id):
|
||||
"""Test error when service API key is invalid."""
|
||||
with patch('server.routes.service.AUTOMATIONS_SERVICE_API_KEY', 'test-key'):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await delete_user_api_key(
|
||||
user_id='user-123',
|
||||
org_id=valid_org_id,
|
||||
key_name='automation',
|
||||
x_service_api_key='wrong-key',
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert 'Invalid service API key' in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_missing_service_key(self, valid_org_id):
|
||||
"""Test error when service API key header is missing."""
|
||||
with patch('server.routes.service.AUTOMATIONS_SERVICE_API_KEY', 'test-key'):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await delete_user_api_key(
|
||||
user_id='user-123',
|
||||
org_id=valid_org_id,
|
||||
key_name='automation',
|
||||
x_service_api_key=None,
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert 'X-Service-API-Key header is required' in exc_info.value.detail
|
||||
@@ -1,19 +1,26 @@
|
||||
"""Unit tests for API keys routes, focusing on BYOR key validation and retrieval."""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from pydantic import SecretStr
|
||||
from server.auth.saas_user_auth import SaasUserAuth
|
||||
from server.routes.api_keys import (
|
||||
ByorPermittedResponse,
|
||||
CurrentApiKeyResponse,
|
||||
LlmApiKeyResponse,
|
||||
check_byor_permitted,
|
||||
delete_byor_key_from_litellm,
|
||||
get_current_api_key,
|
||||
get_llm_api_key_for_byor,
|
||||
)
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
|
||||
from openhands.server.user_auth.user_auth import AuthType
|
||||
|
||||
|
||||
class TestVerifyByorKeyInLitellm:
|
||||
"""Test the verify_byor_key_in_litellm function."""
|
||||
@@ -512,3 +519,81 @@ class TestCheckByorPermitted:
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert 'Failed to check BYOR export permission' in exc_info.value.detail
|
||||
|
||||
|
||||
class TestGetCurrentApiKey:
|
||||
"""Test the get_current_api_key endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.api_keys.get_user_auth')
|
||||
async def test_returns_api_key_info_for_bearer_auth(self, mock_get_user_auth):
|
||||
"""Test that API key metadata including org_id is returned for bearer token auth."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
org_id = uuid.uuid4()
|
||||
mock_request = MagicMock()
|
||||
|
||||
user_auth = SaasUserAuth(
|
||||
refresh_token=SecretStr('mock-token'),
|
||||
user_id=user_id,
|
||||
auth_type=AuthType.BEARER,
|
||||
api_key_org_id=org_id,
|
||||
api_key_id=42,
|
||||
api_key_name='My Production Key',
|
||||
)
|
||||
mock_get_user_auth.return_value = user_auth
|
||||
|
||||
# Act
|
||||
result = await get_current_api_key(request=mock_request, user_id=user_id)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, CurrentApiKeyResponse)
|
||||
assert result.org_id == str(org_id)
|
||||
assert result.id == 42
|
||||
assert result.name == 'My Production Key'
|
||||
assert result.user_id == user_id
|
||||
assert result.auth_type == 'bearer'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.api_keys.get_user_auth')
|
||||
async def test_returns_400_for_cookie_auth(self, mock_get_user_auth):
|
||||
"""Test that 400 Bad Request is returned when using cookie authentication."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
mock_request = MagicMock()
|
||||
|
||||
mock_user_auth = MagicMock()
|
||||
mock_user_auth.get_auth_type.return_value = AuthType.COOKIE
|
||||
mock_get_user_auth.return_value = mock_user_auth
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_api_key(request=mock_request, user_id=user_id)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert 'API key authentication' in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.api_keys.get_user_auth')
|
||||
async def test_returns_400_when_api_key_org_id_is_none(self, mock_get_user_auth):
|
||||
"""Test that 400 is returned when API key has no org_id (legacy key)."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
mock_request = MagicMock()
|
||||
|
||||
user_auth = SaasUserAuth(
|
||||
refresh_token=SecretStr('mock-token'),
|
||||
user_id=user_id,
|
||||
auth_type=AuthType.BEARER,
|
||||
api_key_org_id=None, # No org_id - legacy key
|
||||
api_key_id=42,
|
||||
api_key_name='Legacy Key',
|
||||
)
|
||||
mock_get_user_auth.return_value = user_auth
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_api_key(request=mock_request, user_id=user_id)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert 'created before organization support' in exc_info.value.detail
|
||||
|
||||
314
enterprise/tests/unit/storage/test_api_key_store.py
Normal file
314
enterprise/tests/unit/storage/test_api_key_store.py
Normal file
@@ -0,0 +1,314 @@
|
||||
"""Unit tests for ApiKeyStore system key functionality."""
|
||||
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from storage.api_key import ApiKey
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_key_store():
|
||||
"""Create ApiKeyStore instance."""
|
||||
return ApiKeyStore()
|
||||
|
||||
|
||||
class TestApiKeyStoreSystemKeys:
|
||||
"""Test cases for system API key functionality."""
|
||||
|
||||
def test_is_system_key_name_with_prefix(self, api_key_store):
|
||||
"""Test that names with __SYSTEM__: prefix are identified as system keys."""
|
||||
assert api_key_store.is_system_key_name('__SYSTEM__:automation') is True
|
||||
assert api_key_store.is_system_key_name('__SYSTEM__:test-key') is True
|
||||
assert api_key_store.is_system_key_name('__SYSTEM__:') is True
|
||||
|
||||
def test_is_system_key_name_without_prefix(self, api_key_store):
|
||||
"""Test that names without __SYSTEM__: prefix are not system keys."""
|
||||
assert api_key_store.is_system_key_name('my-key') is False
|
||||
assert api_key_store.is_system_key_name('automation') is False
|
||||
assert api_key_store.is_system_key_name('MCP_API_KEY') is False
|
||||
assert api_key_store.is_system_key_name('') is False
|
||||
|
||||
def test_is_system_key_name_none(self, api_key_store):
|
||||
"""Test that None is not a system key."""
|
||||
assert api_key_store.is_system_key_name(None) is False
|
||||
|
||||
def test_make_system_key_name(self, api_key_store):
|
||||
"""Test system key name generation."""
|
||||
assert (
|
||||
api_key_store.make_system_key_name('automation') == '__SYSTEM__:automation'
|
||||
)
|
||||
assert api_key_store.make_system_key_name('test-key') == '__SYSTEM__:test-key'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_or_create_system_api_key_creates_new(
|
||||
self, api_key_store, async_session_maker
|
||||
):
|
||||
"""Test creating a new system API key when none exists."""
|
||||
user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
org_id = uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081')
|
||||
key_name = 'automation'
|
||||
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
api_key = await api_key_store.get_or_create_system_api_key(
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name=key_name,
|
||||
)
|
||||
|
||||
assert api_key.startswith('sk-oh-')
|
||||
assert len(api_key) == len('sk-oh-') + 32
|
||||
|
||||
# Verify the key was created in the database
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(select(ApiKey).filter(ApiKey.key == api_key))
|
||||
key_record = result.scalars().first()
|
||||
assert key_record is not None
|
||||
assert key_record.user_id == user_id
|
||||
assert key_record.org_id == org_id
|
||||
assert key_record.name == '__SYSTEM__:automation'
|
||||
assert key_record.expires_at is None # System keys never expire
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_or_create_system_api_key_returns_existing(
|
||||
self, api_key_store, async_session_maker
|
||||
):
|
||||
"""Test that existing valid system key is returned."""
|
||||
user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
org_id = uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081')
|
||||
key_name = 'automation'
|
||||
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
# Create the first key
|
||||
first_key = await api_key_store.get_or_create_system_api_key(
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name=key_name,
|
||||
)
|
||||
|
||||
# Request again - should return the same key
|
||||
second_key = await api_key_store.get_or_create_system_api_key(
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name=key_name,
|
||||
)
|
||||
|
||||
assert first_key == second_key
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_or_create_system_api_key_different_names(
|
||||
self, api_key_store, async_session_maker
|
||||
):
|
||||
"""Test that different names create different keys."""
|
||||
user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
org_id = uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081')
|
||||
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
key1 = await api_key_store.get_or_create_system_api_key(
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name='automation-1',
|
||||
)
|
||||
|
||||
key2 = await api_key_store.get_or_create_system_api_key(
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name='automation-2',
|
||||
)
|
||||
|
||||
assert key1 != key2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_or_create_system_api_key_reissues_expired(
|
||||
self, api_key_store, async_session_maker
|
||||
):
|
||||
"""Test that expired system key is replaced with a new one."""
|
||||
user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
org_id = uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081')
|
||||
key_name = 'automation'
|
||||
system_key_name = '__SYSTEM__:automation'
|
||||
|
||||
# First, manually create an expired key
|
||||
expired_time = datetime.now(UTC) - timedelta(hours=1)
|
||||
async with async_session_maker() as session:
|
||||
expired_key = ApiKey(
|
||||
key='sk-oh-expired-key-12345678901234567890',
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name=system_key_name,
|
||||
expires_at=expired_time.replace(tzinfo=None),
|
||||
)
|
||||
session.add(expired_key)
|
||||
await session.commit()
|
||||
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
# Request the key - should create a new one
|
||||
new_key = await api_key_store.get_or_create_system_api_key(
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name=key_name,
|
||||
)
|
||||
|
||||
assert new_key != 'sk-oh-expired-key-12345678901234567890'
|
||||
assert new_key.startswith('sk-oh-')
|
||||
|
||||
# Verify old key was deleted and new key exists
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(ApiKey).filter(ApiKey.name == system_key_name)
|
||||
)
|
||||
keys = result.scalars().all()
|
||||
assert len(keys) == 1
|
||||
assert keys[0].key == new_key
|
||||
assert keys[0].expires_at is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_api_keys_excludes_system_keys(
|
||||
self, api_key_store, async_session_maker
|
||||
):
|
||||
"""Test that list_api_keys excludes system keys."""
|
||||
user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
org_id = uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081')
|
||||
|
||||
# Create a user key and a system key
|
||||
async with async_session_maker() as session:
|
||||
user_key = ApiKey(
|
||||
key='sk-oh-user-key-123456789012345678901',
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name='my-user-key',
|
||||
)
|
||||
system_key = ApiKey(
|
||||
key='sk-oh-system-key-12345678901234567890',
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name='__SYSTEM__:automation',
|
||||
)
|
||||
mcp_key = ApiKey(
|
||||
key='sk-oh-mcp-key-1234567890123456789012',
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name='MCP_API_KEY',
|
||||
)
|
||||
session.add(user_key)
|
||||
session.add(system_key)
|
||||
session.add(mcp_key)
|
||||
await session.commit()
|
||||
|
||||
# Mock UserStore.get_user_by_id to return a user with the correct org
|
||||
mock_user = MagicMock()
|
||||
mock_user.current_org_id = org_id
|
||||
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
with patch(
|
||||
'storage.api_key_store.UserStore.get_user_by_id', new_callable=AsyncMock
|
||||
) as mock_get_user:
|
||||
mock_get_user.return_value = mock_user
|
||||
keys = await api_key_store.list_api_keys(user_id)
|
||||
|
||||
# Should only return the user key
|
||||
assert len(keys) == 1
|
||||
assert keys[0].name == 'my-user-key'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_api_key_by_id_protects_system_keys(
|
||||
self, api_key_store, async_session_maker
|
||||
):
|
||||
"""Test that system keys cannot be deleted by users."""
|
||||
user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
org_id = uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081')
|
||||
|
||||
# Create a system key
|
||||
async with async_session_maker() as session:
|
||||
system_key = ApiKey(
|
||||
key='sk-oh-system-key-12345678901234567890',
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name='__SYSTEM__:automation',
|
||||
)
|
||||
session.add(system_key)
|
||||
await session.commit()
|
||||
key_id = system_key.id
|
||||
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
# Attempt to delete without allow_system flag
|
||||
result = await api_key_store.delete_api_key_by_id(
|
||||
key_id, allow_system=False
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
# Verify the key still exists
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(select(ApiKey).filter(ApiKey.id == key_id))
|
||||
key_record = result.scalars().first()
|
||||
assert key_record is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_api_key_by_id_allows_system_with_flag(
|
||||
self, api_key_store, async_session_maker
|
||||
):
|
||||
"""Test that system keys can be deleted with allow_system=True."""
|
||||
user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
org_id = uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081')
|
||||
|
||||
# Create a system key
|
||||
async with async_session_maker() as session:
|
||||
system_key = ApiKey(
|
||||
key='sk-oh-system-key-12345678901234567890',
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name='__SYSTEM__:automation',
|
||||
)
|
||||
session.add(system_key)
|
||||
await session.commit()
|
||||
key_id = system_key.id
|
||||
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
# Delete with allow_system=True
|
||||
result = await api_key_store.delete_api_key_by_id(key_id, allow_system=True)
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify the key was deleted
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(select(ApiKey).filter(ApiKey.id == key_id))
|
||||
key_record = result.scalars().first()
|
||||
assert key_record is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_api_key_by_id_allows_regular_keys(
|
||||
self, api_key_store, async_session_maker
|
||||
):
|
||||
"""Test that regular keys can be deleted normally."""
|
||||
user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
org_id = uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081')
|
||||
|
||||
# Create a regular key
|
||||
async with async_session_maker() as session:
|
||||
regular_key = ApiKey(
|
||||
key='sk-oh-regular-key-1234567890123456789',
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name='my-regular-key',
|
||||
)
|
||||
session.add(regular_key)
|
||||
await session.commit()
|
||||
key_id = regular_key.id
|
||||
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
# Delete without allow_system flag - should work for regular keys
|
||||
result = await api_key_store.delete_api_key_by_id(
|
||||
key_id, allow_system=False
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify the key was deleted
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(select(ApiKey).filter(ApiKey.id == key_id))
|
||||
key_record = result.scalars().first()
|
||||
assert key_record is None
|
||||
@@ -5,7 +5,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from storage.api_key import ApiKey
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
from storage.api_key_store import ApiKeyStore, ApiKeyValidationResult
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -110,8 +110,8 @@ async def test_create_api_key(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_api_key_valid(api_key_store, async_session_maker):
|
||||
"""Test validating a valid API key."""
|
||||
# Setup - create an API key in the database
|
||||
"""Test validating a valid API key returns user_id and org_id."""
|
||||
# Arrange
|
||||
user_id = str(uuid.uuid4())
|
||||
org_id = uuid.uuid4()
|
||||
api_key_value = 'test-api-key'
|
||||
@@ -126,13 +126,19 @@ async def test_validate_api_key_valid(api_key_store, async_session_maker):
|
||||
)
|
||||
session.add(key_record)
|
||||
await session.commit()
|
||||
key_id = key_record.id
|
||||
|
||||
# Execute - patch a_session_maker to use test's async session maker
|
||||
# Act
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
result = await api_key_store.validate_api_key(api_key_value)
|
||||
|
||||
# Verify
|
||||
assert result == user_id
|
||||
# Assert
|
||||
assert isinstance(result, ApiKeyValidationResult)
|
||||
assert result is not None
|
||||
assert result.user_id == user_id
|
||||
assert result.org_id == org_id
|
||||
assert result.key_id == key_id
|
||||
assert result.key_name == 'Test Key'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -197,7 +203,7 @@ async def test_validate_api_key_valid_timezone_naive(
|
||||
api_key_store, async_session_maker
|
||||
):
|
||||
"""Test validating a valid API key with timezone-naive datetime from database."""
|
||||
# Setup - create a valid API key with timezone-naive datetime (future date)
|
||||
# Arrange
|
||||
user_id = str(uuid.uuid4())
|
||||
org_id = uuid.uuid4()
|
||||
api_key_value = 'test-valid-naive-key'
|
||||
@@ -214,12 +220,44 @@ async def test_validate_api_key_valid_timezone_naive(
|
||||
session.add(key_record)
|
||||
await session.commit()
|
||||
|
||||
# Execute - patch a_session_maker to use test's async session maker
|
||||
# Act
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
result = await api_key_store.validate_api_key(api_key_value)
|
||||
|
||||
# Verify
|
||||
assert result == user_id
|
||||
# Assert
|
||||
assert isinstance(result, ApiKeyValidationResult)
|
||||
assert result.user_id == user_id
|
||||
assert result.org_id == org_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_api_key_legacy_without_org_id(
|
||||
api_key_store, async_session_maker
|
||||
):
|
||||
"""Test validating a legacy API key without org_id returns None for org_id."""
|
||||
# Arrange
|
||||
user_id = str(uuid.uuid4())
|
||||
api_key_value = 'test-legacy-key-no-org'
|
||||
|
||||
async with async_session_maker() as session:
|
||||
key_record = ApiKey(
|
||||
key=api_key_value,
|
||||
user_id=user_id,
|
||||
org_id=None, # Legacy key without org binding
|
||||
name='Legacy Key',
|
||||
)
|
||||
session.add(key_record)
|
||||
await session.commit()
|
||||
|
||||
# Act
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
result = await api_key_store.validate_api_key(api_key_value)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, ApiKeyValidationResult)
|
||||
assert result is not None
|
||||
assert result.user_id == user_id
|
||||
assert result.org_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -13,6 +13,7 @@ from server.auth.authorization import (
|
||||
ROLE_PERMISSIONS,
|
||||
Permission,
|
||||
RoleName,
|
||||
get_api_key_org_id_from_request,
|
||||
get_role_permissions,
|
||||
get_user_org_role,
|
||||
has_permission,
|
||||
@@ -444,6 +445,15 @@ class TestGetUserOrgRole:
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _create_mock_request(api_key_org_id=None):
|
||||
"""Helper to create a mock request with optional api_key_org_id."""
|
||||
mock_request = MagicMock()
|
||||
mock_user_auth = MagicMock()
|
||||
mock_user_auth.get_api_key_org_id.return_value = api_key_org_id
|
||||
mock_request.state.user_auth = mock_user_auth
|
||||
return mock_request
|
||||
|
||||
|
||||
class TestRequirePermission:
|
||||
"""Tests for require_permission dependency factory."""
|
||||
|
||||
@@ -456,6 +466,7 @@ class TestRequirePermission:
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
mock_request = _create_mock_request()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'admin'
|
||||
@@ -465,7 +476,9 @@ class TestRequirePermission:
|
||||
AsyncMock(return_value=mock_role),
|
||||
):
|
||||
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
||||
result = await permission_checker(org_id=org_id, user_id=user_id)
|
||||
result = await permission_checker(
|
||||
request=mock_request, org_id=org_id, user_id=user_id
|
||||
)
|
||||
assert result == user_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -476,10 +489,11 @@ class TestRequirePermission:
|
||||
THEN: 401 Unauthorized is raised
|
||||
"""
|
||||
org_id = uuid4()
|
||||
mock_request = _create_mock_request()
|
||||
|
||||
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await permission_checker(org_id=org_id, user_id=None)
|
||||
await permission_checker(request=mock_request, org_id=org_id, user_id=None)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert 'not authenticated' in exc_info.value.detail.lower()
|
||||
@@ -493,6 +507,7 @@ class TestRequirePermission:
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
mock_request = _create_mock_request()
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role',
|
||||
@@ -500,7 +515,9 @@ class TestRequirePermission:
|
||||
):
|
||||
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await permission_checker(org_id=org_id, user_id=user_id)
|
||||
await permission_checker(
|
||||
request=mock_request, org_id=org_id, user_id=user_id
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert 'not a member' in exc_info.value.detail.lower()
|
||||
@@ -514,6 +531,7 @@ class TestRequirePermission:
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
mock_request = _create_mock_request()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'member'
|
||||
@@ -524,7 +542,9 @@ class TestRequirePermission:
|
||||
):
|
||||
permission_checker = require_permission(Permission.DELETE_ORGANIZATION)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await permission_checker(org_id=org_id, user_id=user_id)
|
||||
await permission_checker(
|
||||
request=mock_request, org_id=org_id, user_id=user_id
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert 'delete_organization' in exc_info.value.detail.lower()
|
||||
@@ -538,6 +558,7 @@ class TestRequirePermission:
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
mock_request = _create_mock_request()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'owner'
|
||||
@@ -547,7 +568,9 @@ class TestRequirePermission:
|
||||
AsyncMock(return_value=mock_role),
|
||||
):
|
||||
permission_checker = require_permission(Permission.DELETE_ORGANIZATION)
|
||||
result = await permission_checker(org_id=org_id, user_id=user_id)
|
||||
result = await permission_checker(
|
||||
request=mock_request, org_id=org_id, user_id=user_id
|
||||
)
|
||||
assert result == user_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -559,6 +582,7 @@ class TestRequirePermission:
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
mock_request = _create_mock_request()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'admin'
|
||||
@@ -569,7 +593,9 @@ class TestRequirePermission:
|
||||
):
|
||||
permission_checker = require_permission(Permission.DELETE_ORGANIZATION)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await permission_checker(org_id=org_id, user_id=user_id)
|
||||
await permission_checker(
|
||||
request=mock_request, org_id=org_id, user_id=user_id
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
|
||||
@@ -582,6 +608,7 @@ class TestRequirePermission:
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
mock_request = _create_mock_request()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'member'
|
||||
@@ -595,7 +622,9 @@ class TestRequirePermission:
|
||||
):
|
||||
permission_checker = require_permission(Permission.DELETE_ORGANIZATION)
|
||||
with pytest.raises(HTTPException):
|
||||
await permission_checker(org_id=org_id, user_id=user_id)
|
||||
await permission_checker(
|
||||
request=mock_request, org_id=org_id, user_id=user_id
|
||||
)
|
||||
|
||||
mock_logger.warning.assert_called()
|
||||
call_args = mock_logger.warning.call_args
|
||||
@@ -611,6 +640,7 @@ class TestRequirePermission:
|
||||
THEN: User ID is returned
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
mock_request = _create_mock_request()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'admin'
|
||||
@@ -620,7 +650,9 @@ class TestRequirePermission:
|
||||
AsyncMock(return_value=mock_role),
|
||||
) as mock_get_role:
|
||||
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
||||
result = await permission_checker(org_id=None, user_id=user_id)
|
||||
result = await permission_checker(
|
||||
request=mock_request, org_id=None, user_id=user_id
|
||||
)
|
||||
assert result == user_id
|
||||
mock_get_role.assert_called_once_with(user_id, None)
|
||||
|
||||
@@ -632,6 +664,7 @@ class TestRequirePermission:
|
||||
THEN: HTTPException with 403 status is raised
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
mock_request = _create_mock_request()
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role',
|
||||
@@ -639,7 +672,9 @@ class TestRequirePermission:
|
||||
):
|
||||
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await permission_checker(org_id=None, user_id=user_id)
|
||||
await permission_checker(
|
||||
request=mock_request, org_id=None, user_id=user_id
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert 'not a member' in exc_info.value.detail
|
||||
@@ -662,6 +697,7 @@ class TestPermissionScenarios:
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
mock_request = _create_mock_request()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'member'
|
||||
@@ -671,7 +707,9 @@ class TestPermissionScenarios:
|
||||
AsyncMock(return_value=mock_role),
|
||||
):
|
||||
permission_checker = require_permission(Permission.MANAGE_SECRETS)
|
||||
result = await permission_checker(org_id=org_id, user_id=user_id)
|
||||
result = await permission_checker(
|
||||
request=mock_request, org_id=org_id, user_id=user_id
|
||||
)
|
||||
assert result == user_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -683,6 +721,7 @@ class TestPermissionScenarios:
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
mock_request = _create_mock_request()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'member'
|
||||
@@ -695,7 +734,9 @@ class TestPermissionScenarios:
|
||||
Permission.INVITE_USER_TO_ORGANIZATION
|
||||
)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await permission_checker(org_id=org_id, user_id=user_id)
|
||||
await permission_checker(
|
||||
request=mock_request, org_id=org_id, user_id=user_id
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
|
||||
@@ -708,6 +749,7 @@ class TestPermissionScenarios:
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
mock_request = _create_mock_request()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'admin'
|
||||
@@ -719,7 +761,9 @@ class TestPermissionScenarios:
|
||||
permission_checker = require_permission(
|
||||
Permission.INVITE_USER_TO_ORGANIZATION
|
||||
)
|
||||
result = await permission_checker(org_id=org_id, user_id=user_id)
|
||||
result = await permission_checker(
|
||||
request=mock_request, org_id=org_id, user_id=user_id
|
||||
)
|
||||
assert result == user_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -731,6 +775,7 @@ class TestPermissionScenarios:
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
mock_request = _create_mock_request()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'admin'
|
||||
@@ -741,7 +786,9 @@ class TestPermissionScenarios:
|
||||
):
|
||||
permission_checker = require_permission(Permission.CHANGE_USER_ROLE_OWNER)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await permission_checker(org_id=org_id, user_id=user_id)
|
||||
await permission_checker(
|
||||
request=mock_request, org_id=org_id, user_id=user_id
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
|
||||
@@ -754,6 +801,7 @@ class TestPermissionScenarios:
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
mock_request = _create_mock_request()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'owner'
|
||||
@@ -763,5 +811,200 @@ class TestPermissionScenarios:
|
||||
AsyncMock(return_value=mock_role),
|
||||
):
|
||||
permission_checker = require_permission(Permission.CHANGE_USER_ROLE_OWNER)
|
||||
result = await permission_checker(org_id=org_id, user_id=user_id)
|
||||
result = await permission_checker(
|
||||
request=mock_request, org_id=org_id, user_id=user_id
|
||||
)
|
||||
assert result == user_id
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests for API key organization validation
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestApiKeyOrgValidation:
|
||||
"""Tests for API key organization binding validation in require_permission."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allows_access_when_api_key_org_matches_target_org(self):
|
||||
"""
|
||||
GIVEN: API key with org_id that matches the target org_id in the request
|
||||
WHEN: Permission checker is called
|
||||
THEN: User ID is returned (access allowed)
|
||||
"""
|
||||
# Arrange
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
mock_request = _create_mock_request(api_key_org_id=org_id)
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'admin'
|
||||
|
||||
# Act & Assert
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role',
|
||||
AsyncMock(return_value=mock_role),
|
||||
):
|
||||
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
||||
result = await permission_checker(
|
||||
request=mock_request, org_id=org_id, user_id=user_id
|
||||
)
|
||||
assert result == user_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_denies_access_when_api_key_org_mismatches_target_org(self):
|
||||
"""
|
||||
GIVEN: API key created for Org A, but user tries to access Org B
|
||||
WHEN: Permission checker is called
|
||||
THEN: 403 Forbidden is raised with org mismatch message
|
||||
"""
|
||||
# Arrange
|
||||
user_id = str(uuid4())
|
||||
api_key_org_id = uuid4() # Org A - where API key was created
|
||||
target_org_id = uuid4() # Org B - where user is trying to access
|
||||
mock_request = _create_mock_request(api_key_org_id=api_key_org_id)
|
||||
|
||||
# Act & Assert
|
||||
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await permission_checker(
|
||||
request=mock_request, org_id=target_org_id, user_id=user_id
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert (
|
||||
'API key is not authorized for this organization' in exc_info.value.detail
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allows_access_for_legacy_api_key_without_org_binding(self):
|
||||
"""
|
||||
GIVEN: Legacy API key without org_id binding (org_id is None)
|
||||
WHEN: Permission checker is called
|
||||
THEN: Falls through to normal permission check (backward compatible)
|
||||
"""
|
||||
# Arrange
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
mock_request = _create_mock_request(api_key_org_id=None)
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'admin'
|
||||
|
||||
# Act & Assert
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role',
|
||||
AsyncMock(return_value=mock_role),
|
||||
):
|
||||
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
||||
result = await permission_checker(
|
||||
request=mock_request, org_id=org_id, user_id=user_id
|
||||
)
|
||||
assert result == user_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allows_access_for_cookie_auth_without_api_key_org_id(self):
|
||||
"""
|
||||
GIVEN: Cookie-based authentication (no api_key_org_id in user_auth)
|
||||
WHEN: Permission checker is called
|
||||
THEN: Falls through to normal permission check
|
||||
"""
|
||||
# Arrange
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
mock_request = _create_mock_request(api_key_org_id=None)
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'admin'
|
||||
|
||||
# Act & Assert
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role',
|
||||
AsyncMock(return_value=mock_role),
|
||||
):
|
||||
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
||||
result = await permission_checker(
|
||||
request=mock_request, org_id=org_id, user_id=user_id
|
||||
)
|
||||
assert result == user_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logs_warning_on_api_key_org_mismatch(self):
|
||||
"""
|
||||
GIVEN: API key org_id doesn't match target org_id
|
||||
WHEN: Permission checker is called
|
||||
THEN: Warning is logged with org mismatch details
|
||||
"""
|
||||
# Arrange
|
||||
user_id = str(uuid4())
|
||||
api_key_org_id = uuid4()
|
||||
target_org_id = uuid4()
|
||||
mock_request = _create_mock_request(api_key_org_id=api_key_org_id)
|
||||
|
||||
# Act & Assert
|
||||
with patch('server.auth.authorization.logger') as mock_logger:
|
||||
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
||||
with pytest.raises(HTTPException):
|
||||
await permission_checker(
|
||||
request=mock_request, org_id=target_org_id, user_id=user_id
|
||||
)
|
||||
|
||||
mock_logger.warning.assert_called()
|
||||
call_args = mock_logger.warning.call_args
|
||||
assert call_args[1]['extra']['user_id'] == user_id
|
||||
assert call_args[1]['extra']['api_key_org_id'] == str(api_key_org_id)
|
||||
assert call_args[1]['extra']['target_org_id'] == str(target_org_id)
|
||||
|
||||
|
||||
class TestGetApiKeyOrgIdFromRequest:
|
||||
"""Tests for get_api_key_org_id_from_request helper function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_org_id_when_user_auth_has_api_key_org_id(self):
|
||||
"""
|
||||
GIVEN: Request with user_auth that has api_key_org_id
|
||||
WHEN: get_api_key_org_id_from_request is called
|
||||
THEN: Returns the api_key_org_id
|
||||
"""
|
||||
# Arrange
|
||||
org_id = uuid4()
|
||||
mock_request = _create_mock_request(api_key_org_id=org_id)
|
||||
|
||||
# Act
|
||||
result = await get_api_key_org_id_from_request(mock_request)
|
||||
|
||||
# Assert
|
||||
assert result == org_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_user_auth_has_no_api_key_org_id(self):
|
||||
"""
|
||||
GIVEN: Request with user_auth that has no api_key_org_id (cookie auth)
|
||||
WHEN: get_api_key_org_id_from_request is called
|
||||
THEN: Returns None
|
||||
"""
|
||||
# Arrange
|
||||
mock_request = _create_mock_request(api_key_org_id=None)
|
||||
|
||||
# Act
|
||||
result = await get_api_key_org_id_from_request(mock_request)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_no_user_auth_in_request(self):
|
||||
"""
|
||||
GIVEN: Request without user_auth in state
|
||||
WHEN: get_api_key_org_id_from_request is called
|
||||
THEN: Returns None
|
||||
"""
|
||||
# Arrange
|
||||
mock_request = MagicMock()
|
||||
mock_request.state.user_auth = None
|
||||
|
||||
# Act
|
||||
result = await get_api_key_org_id_from_request(mock_request)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@@ -239,6 +239,16 @@ class TestLiteLlmManager:
|
||||
mock_404_response = MagicMock()
|
||||
mock_404_response.status_code = 404
|
||||
mock_404_response.is_success = False
|
||||
mock_404_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
message='Not Found', request=MagicMock(), response=mock_404_response
|
||||
)
|
||||
|
||||
# Mock user exists check response
|
||||
mock_user_exists_response = MagicMock()
|
||||
mock_user_exists_response.is_success = True
|
||||
mock_user_exists_response.json.return_value = {
|
||||
'user_info': {'user_id': 'test-user-id'}
|
||||
}
|
||||
|
||||
mock_token_manager = MagicMock()
|
||||
mock_token_manager.return_value.get_user_info_from_user_id = AsyncMock(
|
||||
@@ -246,12 +256,8 @@ class TestLiteLlmManager:
|
||||
)
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_404_response
|
||||
mock_client.get.return_value.raise_for_status.side_effect = (
|
||||
httpx.HTTPStatusError(
|
||||
message='Not Found', request=MagicMock(), response=mock_404_response
|
||||
)
|
||||
)
|
||||
# First GET is for _get_team (404), second GET is for _user_exists (success)
|
||||
mock_client.get.side_effect = [mock_404_response, mock_user_exists_response]
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
mock_client_class = MagicMock()
|
||||
@@ -274,8 +280,8 @@ class TestLiteLlmManager:
|
||||
assert result.llm_api_key.get_secret_value() == 'test-api-key'
|
||||
assert result.llm_base_url == 'http://test.com'
|
||||
|
||||
# Verify API calls were made (get_team + 4 posts)
|
||||
assert mock_client.get.call_count == 1 # get_team
|
||||
# Verify API calls were made (get_team + user_exists + 4 posts)
|
||||
assert mock_client.get.call_count == 2 # get_team + user_exists
|
||||
assert (
|
||||
mock_client.post.call_count == 4
|
||||
) # create_team, add_user_to_team, delete_key_by_alias, generate_key
|
||||
@@ -294,13 +300,21 @@ class TestLiteLlmManager:
|
||||
}
|
||||
mock_team_response.raise_for_status = MagicMock()
|
||||
|
||||
# Mock user exists check response
|
||||
mock_user_exists_response = MagicMock()
|
||||
mock_user_exists_response.is_success = True
|
||||
mock_user_exists_response.json.return_value = {
|
||||
'user_info': {'user_id': 'test-user-id'}
|
||||
}
|
||||
|
||||
mock_token_manager = MagicMock()
|
||||
mock_token_manager.return_value.get_user_info_from_user_id = AsyncMock(
|
||||
return_value={'email': 'test@example.com'}
|
||||
)
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_team_response
|
||||
# First GET is for _get_team (success), second GET is for _user_exists (success)
|
||||
mock_client.get.side_effect = [mock_team_response, mock_user_exists_response]
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
mock_client_class = MagicMock()
|
||||
@@ -320,8 +334,8 @@ class TestLiteLlmManager:
|
||||
assert result is not None
|
||||
|
||||
# Verify _get_team was called first
|
||||
mock_client.get.assert_called_once()
|
||||
get_call_url = mock_client.get.call_args[0][0]
|
||||
assert mock_client.get.call_count == 2 # get_team + user_exists
|
||||
get_call_url = mock_client.get.call_args_list[0][0][0]
|
||||
assert 'team/info' in get_call_url
|
||||
assert 'test-org-id' in get_call_url
|
||||
|
||||
@@ -343,19 +357,25 @@ class TestLiteLlmManager:
|
||||
mock_404_response = MagicMock()
|
||||
mock_404_response.status_code = 404
|
||||
mock_404_response.is_success = False
|
||||
mock_404_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
message='Not Found', request=MagicMock(), response=mock_404_response
|
||||
)
|
||||
|
||||
mock_token_manager = MagicMock()
|
||||
mock_token_manager.return_value.get_user_info_from_user_id = AsyncMock(
|
||||
return_value={'email': 'test@example.com'}
|
||||
)
|
||||
|
||||
# Mock user exists check response
|
||||
mock_user_exists_response = MagicMock()
|
||||
mock_user_exists_response.is_success = True
|
||||
mock_user_exists_response.json.return_value = {
|
||||
'user_info': {'user_id': 'test-user-id'}
|
||||
}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_404_response
|
||||
mock_client.get.return_value.raise_for_status.side_effect = (
|
||||
httpx.HTTPStatusError(
|
||||
message='Not Found', request=MagicMock(), response=mock_404_response
|
||||
)
|
||||
)
|
||||
# First GET is for _get_team (404), second GET is for _user_exists (success)
|
||||
mock_client.get.side_effect = [mock_404_response, mock_user_exists_response]
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
mock_client_class = MagicMock()
|
||||
@@ -393,6 +413,16 @@ class TestLiteLlmManager:
|
||||
mock_404_response = MagicMock()
|
||||
mock_404_response.status_code = 404
|
||||
mock_404_response.is_success = False
|
||||
mock_404_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
message='Not Found', request=MagicMock(), response=mock_404_response
|
||||
)
|
||||
|
||||
# Mock user exists check response
|
||||
mock_user_exists_response = MagicMock()
|
||||
mock_user_exists_response.is_success = True
|
||||
mock_user_exists_response.json.return_value = {
|
||||
'user_info': {'user_id': 'test-user-id'}
|
||||
}
|
||||
|
||||
mock_token_manager = MagicMock()
|
||||
mock_token_manager.return_value.get_user_info_from_user_id = AsyncMock(
|
||||
@@ -400,12 +430,8 @@ class TestLiteLlmManager:
|
||||
)
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_404_response
|
||||
mock_client.get.return_value.raise_for_status.side_effect = (
|
||||
httpx.HTTPStatusError(
|
||||
message='Not Found', request=MagicMock(), response=mock_404_response
|
||||
)
|
||||
)
|
||||
# First GET is for _get_team (404), second GET is for _user_exists (success)
|
||||
mock_client.get.side_effect = [mock_404_response, mock_user_exists_response]
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
mock_client_class = MagicMock()
|
||||
@@ -833,15 +859,16 @@ class TestLiteLlmManager:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user_success(self, mock_http_client, mock_response):
|
||||
"""Test successful _create_user operation."""
|
||||
"""Test successful _create_user operation returns True."""
|
||||
mock_http_client.post.return_value = mock_response
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
await LiteLlmManager._create_user(
|
||||
result = await LiteLlmManager._create_user(
|
||||
mock_http_client, 'test@example.com', 'test-user-id'
|
||||
)
|
||||
|
||||
assert result is True
|
||||
mock_http_client.post.assert_called_once()
|
||||
call_args = mock_http_client.post.call_args
|
||||
assert 'http://test.com/user/new' in call_args[0]
|
||||
@@ -850,7 +877,7 @@ class TestLiteLlmManager:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user_duplicate_email(self, mock_http_client, mock_response):
|
||||
"""Test _create_user with duplicate email handling."""
|
||||
"""Test _create_user with duplicate email handling returns True."""
|
||||
# First call fails with duplicate email
|
||||
error_response = MagicMock()
|
||||
error_response.is_success = False
|
||||
@@ -862,23 +889,81 @@ class TestLiteLlmManager:
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
await LiteLlmManager._create_user(
|
||||
result = await LiteLlmManager._create_user(
|
||||
mock_http_client, 'test@example.com', 'test-user-id'
|
||||
)
|
||||
|
||||
assert result is True
|
||||
assert mock_http_client.post.call_count == 2
|
||||
# Second call should have None email
|
||||
second_call_args = mock_http_client.post.call_args_list[1]
|
||||
assert second_call_args[1]['json']['user_email'] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com')
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key')
|
||||
async def test_user_exists_returns_true(self, mock_http_client):
|
||||
"""Test _user_exists returns True when user exists in LiteLLM."""
|
||||
# Arrange
|
||||
user_response = MagicMock()
|
||||
user_response.is_success = True
|
||||
user_response.json.return_value = {
|
||||
'user_info': {'user_id': 'test-user-id', 'email': 'test@example.com'}
|
||||
}
|
||||
mock_http_client.get.return_value = user_response
|
||||
|
||||
# Act
|
||||
result = await LiteLlmManager._user_exists(mock_http_client, 'test-user-id')
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_http_client.get.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com')
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key')
|
||||
async def test_user_exists_returns_false_when_not_found(self, mock_http_client):
|
||||
"""Test _user_exists returns False when user not found."""
|
||||
# Arrange
|
||||
user_response = MagicMock()
|
||||
user_response.is_success = False
|
||||
mock_http_client.get.return_value = user_response
|
||||
|
||||
# Act
|
||||
result = await LiteLlmManager._user_exists(mock_http_client, 'test-user-id')
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com')
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key')
|
||||
async def test_user_exists_returns_false_on_mismatched_user_id(
|
||||
self, mock_http_client
|
||||
):
|
||||
"""Test _user_exists returns False when returned user_id doesn't match."""
|
||||
# Arrange
|
||||
user_response = MagicMock()
|
||||
user_response.is_success = True
|
||||
user_response.json.return_value = {
|
||||
'user_info': {'user_id': 'different-user-id'}
|
||||
}
|
||||
mock_http_client.get.return_value = user_response
|
||||
|
||||
# Act
|
||||
result = await LiteLlmManager._user_exists(mock_http_client, 'test-user-id')
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.lite_llm_manager.logger')
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com')
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key')
|
||||
async def test_create_user_already_exists_with_409_status_code(
|
||||
async def test_create_user_already_exists_and_verified(
|
||||
self, mock_logger, mock_http_client
|
||||
):
|
||||
"""Test _create_user handles 409 Conflict when user already exists."""
|
||||
"""Test _create_user returns True when user already exists and is verified."""
|
||||
# Arrange
|
||||
first_response = MagicMock()
|
||||
first_response.is_success = False
|
||||
@@ -890,14 +975,141 @@ class TestLiteLlmManager:
|
||||
second_response.status_code = 409
|
||||
second_response.text = 'User with id test-user-id already exists'
|
||||
|
||||
user_exists_response = MagicMock()
|
||||
user_exists_response.is_success = True
|
||||
user_exists_response.json.return_value = {
|
||||
'user_info': {'user_id': 'test-user-id'}
|
||||
}
|
||||
|
||||
mock_http_client.post.side_effect = [first_response, second_response]
|
||||
mock_http_client.get.return_value = user_exists_response
|
||||
|
||||
# Act
|
||||
await LiteLlmManager._create_user(
|
||||
result = await LiteLlmManager._create_user(
|
||||
mock_http_client, 'test@example.com', 'test-user-id'
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_logger.warning.assert_any_call(
|
||||
'litellm_user_already_exists',
|
||||
extra={'user_id': 'test-user-id'},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.lite_llm_manager.logger')
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com')
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key')
|
||||
async def test_create_user_already_exists_but_not_found_returns_false(
|
||||
self, mock_logger, mock_http_client
|
||||
):
|
||||
"""Test _create_user returns False when LiteLLM claims user exists but verification fails."""
|
||||
# Arrange
|
||||
first_response = MagicMock()
|
||||
first_response.is_success = False
|
||||
first_response.status_code = 400
|
||||
first_response.text = 'duplicate email'
|
||||
|
||||
second_response = MagicMock()
|
||||
second_response.is_success = False
|
||||
second_response.status_code = 409
|
||||
second_response.text = 'User with id test-user-id already exists'
|
||||
|
||||
user_not_exists_response = MagicMock()
|
||||
user_not_exists_response.is_success = False
|
||||
|
||||
mock_http_client.post.side_effect = [first_response, second_response]
|
||||
mock_http_client.get.return_value = user_not_exists_response
|
||||
|
||||
# Act
|
||||
result = await LiteLlmManager._create_user(
|
||||
mock_http_client, 'test@example.com', 'test-user-id'
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
mock_logger.error.assert_any_call(
|
||||
'litellm_user_claimed_exists_but_not_found',
|
||||
extra={
|
||||
'user_id': 'test-user-id',
|
||||
'status_code': 409,
|
||||
'text': 'User with id test-user-id already exists',
|
||||
},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.lite_llm_manager.logger')
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com')
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key')
|
||||
async def test_create_user_failure_returns_false(
|
||||
self, mock_logger, mock_http_client
|
||||
):
|
||||
"""Test _create_user returns False when creation fails with non-'already exists' error."""
|
||||
# Arrange
|
||||
first_response = MagicMock()
|
||||
first_response.is_success = False
|
||||
first_response.status_code = 400
|
||||
first_response.text = 'duplicate email'
|
||||
|
||||
second_response = MagicMock()
|
||||
second_response.is_success = False
|
||||
second_response.status_code = 500
|
||||
second_response.text = 'Internal server error'
|
||||
|
||||
mock_http_client.post.side_effect = [first_response, second_response]
|
||||
|
||||
# Act
|
||||
result = await LiteLlmManager._create_user(
|
||||
mock_http_client, 'test@example.com', 'test-user-id'
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
mock_logger.error.assert_any_call(
|
||||
'error_creating_litellm_user',
|
||||
extra={
|
||||
'status_code': 500,
|
||||
'text': 'Internal server error',
|
||||
'user_id': 'test-user-id',
|
||||
'email': None,
|
||||
},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.lite_llm_manager.logger')
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com')
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key')
|
||||
async def test_create_user_already_exists_with_409_status_code(
|
||||
self, mock_logger, mock_http_client
|
||||
):
|
||||
"""Test _create_user handles 409 Conflict when user already exists and verifies."""
|
||||
# Arrange
|
||||
first_response = MagicMock()
|
||||
first_response.is_success = False
|
||||
first_response.status_code = 400
|
||||
first_response.text = 'duplicate email'
|
||||
|
||||
second_response = MagicMock()
|
||||
second_response.is_success = False
|
||||
second_response.status_code = 409
|
||||
second_response.text = 'User with id test-user-id already exists'
|
||||
|
||||
user_exists_response = MagicMock()
|
||||
user_exists_response.is_success = True
|
||||
user_exists_response.json.return_value = {
|
||||
'user_info': {'user_id': 'test-user-id'}
|
||||
}
|
||||
|
||||
mock_http_client.post.side_effect = [first_response, second_response]
|
||||
mock_http_client.get.return_value = user_exists_response
|
||||
|
||||
# Act
|
||||
result = await LiteLlmManager._create_user(
|
||||
mock_http_client, 'test@example.com', 'test-user-id'
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_logger.warning.assert_any_call(
|
||||
'litellm_user_already_exists',
|
||||
extra={'user_id': 'test-user-id'},
|
||||
@@ -910,7 +1122,7 @@ class TestLiteLlmManager:
|
||||
async def test_create_user_already_exists_with_400_status_code(
|
||||
self, mock_logger, mock_http_client
|
||||
):
|
||||
"""Test _create_user handles 400 Bad Request when user already exists."""
|
||||
"""Test _create_user handles 400 Bad Request when user already exists and verifies."""
|
||||
# Arrange
|
||||
first_response = MagicMock()
|
||||
first_response.is_success = False
|
||||
@@ -922,14 +1134,22 @@ class TestLiteLlmManager:
|
||||
second_response.status_code = 400
|
||||
second_response.text = 'User already exists'
|
||||
|
||||
user_exists_response = MagicMock()
|
||||
user_exists_response.is_success = True
|
||||
user_exists_response.json.return_value = {
|
||||
'user_info': {'user_id': 'test-user-id'}
|
||||
}
|
||||
|
||||
mock_http_client.post.side_effect = [first_response, second_response]
|
||||
mock_http_client.get.return_value = user_exists_response
|
||||
|
||||
# Act
|
||||
await LiteLlmManager._create_user(
|
||||
result = await LiteLlmManager._create_user(
|
||||
mock_http_client, 'test@example.com', 'test-user-id'
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_logger.warning.assert_any_call(
|
||||
'litellm_user_already_exists',
|
||||
extra={'user_id': 'test-user-id'},
|
||||
@@ -2164,3 +2384,195 @@ class TestVerifyExistingKey:
|
||||
openhands_type=True,
|
||||
)
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestBudgetPayloadHandling:
|
||||
"""Test cases for budget field handling in API payloads.
|
||||
|
||||
These tests verify that when max_budget is None, the budget field is NOT
|
||||
included in the JSON payload (which tells LiteLLM to disable budget
|
||||
enforcement), and when max_budget has a value, it IS included.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_team_excludes_max_budget_when_none(self):
|
||||
"""Test that _create_team does NOT include max_budget when it is None."""
|
||||
mock_client = AsyncMock(spec=httpx.AsyncClient)
|
||||
mock_response = MagicMock()
|
||||
mock_response.is_success = True
|
||||
mock_response.status_code = 200
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-api-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
await LiteLlmManager._create_team(
|
||||
mock_client,
|
||||
team_alias='test-team',
|
||||
team_id='test-team-id',
|
||||
max_budget=None, # None = no budget limit
|
||||
)
|
||||
|
||||
# Verify the call was made
|
||||
mock_client.post.assert_called_once()
|
||||
call_args = mock_client.post.call_args
|
||||
|
||||
# Verify URL
|
||||
assert call_args[0][0] == 'http://test.com/team/new'
|
||||
|
||||
# Verify that max_budget is NOT in the JSON payload
|
||||
json_payload = call_args[1]['json']
|
||||
assert 'max_budget' not in json_payload, (
|
||||
'max_budget should NOT be in payload when None '
|
||||
'(omitting it tells LiteLLM to disable budget enforcement)'
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_team_includes_max_budget_when_set(self):
|
||||
"""Test that _create_team includes max_budget when it has a value."""
|
||||
mock_client = AsyncMock(spec=httpx.AsyncClient)
|
||||
mock_response = MagicMock()
|
||||
mock_response.is_success = True
|
||||
mock_response.status_code = 200
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-api-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
await LiteLlmManager._create_team(
|
||||
mock_client,
|
||||
team_alias='test-team',
|
||||
team_id='test-team-id',
|
||||
max_budget=100.0, # Explicit budget limit
|
||||
)
|
||||
|
||||
# Verify the call was made
|
||||
mock_client.post.assert_called_once()
|
||||
call_args = mock_client.post.call_args
|
||||
|
||||
# Verify that max_budget IS in the JSON payload with the correct value
|
||||
json_payload = call_args[1]['json']
|
||||
assert (
|
||||
'max_budget' in json_payload
|
||||
), 'max_budget should be in payload when set to a value'
|
||||
assert json_payload['max_budget'] == 100.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_user_to_team_excludes_max_budget_when_none(self):
|
||||
"""Test that _add_user_to_team does NOT include max_budget_in_team when None."""
|
||||
mock_client = AsyncMock(spec=httpx.AsyncClient)
|
||||
mock_response = MagicMock()
|
||||
mock_response.is_success = True
|
||||
mock_response.status_code = 200
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-api-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
await LiteLlmManager._add_user_to_team(
|
||||
mock_client,
|
||||
keycloak_user_id='test-user-id',
|
||||
team_id='test-team-id',
|
||||
max_budget=None, # None = no budget limit
|
||||
)
|
||||
|
||||
# Verify the call was made
|
||||
mock_client.post.assert_called_once()
|
||||
call_args = mock_client.post.call_args
|
||||
|
||||
# Verify URL
|
||||
assert call_args[0][0] == 'http://test.com/team/member_add'
|
||||
|
||||
# Verify that max_budget_in_team is NOT in the JSON payload
|
||||
json_payload = call_args[1]['json']
|
||||
assert 'max_budget_in_team' not in json_payload, (
|
||||
'max_budget_in_team should NOT be in payload when None '
|
||||
'(omitting it tells LiteLLM to disable budget enforcement)'
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_user_to_team_includes_max_budget_when_set(self):
|
||||
"""Test that _add_user_to_team includes max_budget_in_team when set."""
|
||||
mock_client = AsyncMock(spec=httpx.AsyncClient)
|
||||
mock_response = MagicMock()
|
||||
mock_response.is_success = True
|
||||
mock_response.status_code = 200
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-api-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
await LiteLlmManager._add_user_to_team(
|
||||
mock_client,
|
||||
keycloak_user_id='test-user-id',
|
||||
team_id='test-team-id',
|
||||
max_budget=50.0, # Explicit budget limit
|
||||
)
|
||||
|
||||
# Verify the call was made
|
||||
mock_client.post.assert_called_once()
|
||||
call_args = mock_client.post.call_args
|
||||
|
||||
# Verify that max_budget_in_team IS in the JSON payload
|
||||
json_payload = call_args[1]['json']
|
||||
assert (
|
||||
'max_budget_in_team' in json_payload
|
||||
), 'max_budget_in_team should be in payload when set to a value'
|
||||
assert json_payload['max_budget_in_team'] == 50.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_in_team_excludes_max_budget_when_none(self):
|
||||
"""Test that _update_user_in_team does NOT include max_budget_in_team when None."""
|
||||
mock_client = AsyncMock(spec=httpx.AsyncClient)
|
||||
mock_response = MagicMock()
|
||||
mock_response.is_success = True
|
||||
mock_response.status_code = 200
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-api-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
await LiteLlmManager._update_user_in_team(
|
||||
mock_client,
|
||||
keycloak_user_id='test-user-id',
|
||||
team_id='test-team-id',
|
||||
max_budget=None, # None = no budget limit
|
||||
)
|
||||
|
||||
# Verify the call was made
|
||||
mock_client.post.assert_called_once()
|
||||
call_args = mock_client.post.call_args
|
||||
|
||||
# Verify URL
|
||||
assert call_args[0][0] == 'http://test.com/team/member_update'
|
||||
|
||||
# Verify that max_budget_in_team is NOT in the JSON payload
|
||||
json_payload = call_args[1]['json']
|
||||
assert 'max_budget_in_team' not in json_payload, (
|
||||
'max_budget_in_team should NOT be in payload when None '
|
||||
'(omitting it tells LiteLLM to disable budget enforcement)'
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_in_team_includes_max_budget_when_set(self):
|
||||
"""Test that _update_user_in_team includes max_budget_in_team when set."""
|
||||
mock_client = AsyncMock(spec=httpx.AsyncClient)
|
||||
mock_response = MagicMock()
|
||||
mock_response.is_success = True
|
||||
mock_response.status_code = 200
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-api-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
await LiteLlmManager._update_user_in_team(
|
||||
mock_client,
|
||||
keycloak_user_id='test-user-id',
|
||||
team_id='test-team-id',
|
||||
max_budget=75.0, # Explicit budget limit
|
||||
)
|
||||
|
||||
# Verify the call was made
|
||||
mock_client.post.assert_called_once()
|
||||
call_args = mock_client.post.call_args
|
||||
|
||||
# Verify that max_budget_in_team IS in the JSON payload
|
||||
json_payload = call_args[1]['json']
|
||||
assert (
|
||||
'max_budget_in_team' in json_payload
|
||||
), 'max_budget_in_team should be in payload when set to a value'
|
||||
assert json_payload['max_budget_in_team'] == 75.0
|
||||
|
||||
@@ -246,3 +246,82 @@ class TestSaasSecretsStore:
|
||||
assert isinstance(store, SaasSecretsStore)
|
||||
assert store.user_id == 'test-user-id'
|
||||
assert store.config == mock_config
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
'storage.saas_secrets_store.UserStore.get_user_by_id',
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
async def test_secrets_isolation_between_organizations(
|
||||
self, mock_get_user, secrets_store, mock_user
|
||||
):
|
||||
"""Test that secrets from one organization are not deleted when storing
|
||||
secrets in another organization. This reproduces a bug where switching
|
||||
organizations and creating a secret would delete all secrets from the
|
||||
user's personal workspace."""
|
||||
org1_id = UUID('a1111111-1111-1111-1111-111111111111')
|
||||
org2_id = UUID('b2222222-2222-2222-2222-222222222222')
|
||||
|
||||
# Store secrets in org1 (personal workspace)
|
||||
mock_user.current_org_id = org1_id
|
||||
mock_get_user.return_value = mock_user
|
||||
org1_secrets = Secrets(
|
||||
custom_secrets=MappingProxyType(
|
||||
{
|
||||
'personal_secret': CustomSecret.from_value(
|
||||
{
|
||||
'secret': 'personal_secret_value',
|
||||
'description': 'My personal secret',
|
||||
}
|
||||
),
|
||||
}
|
||||
)
|
||||
)
|
||||
await secrets_store.store(org1_secrets)
|
||||
|
||||
# Verify org1 secrets are stored
|
||||
loaded_org1 = await secrets_store.load()
|
||||
assert loaded_org1 is not None
|
||||
assert 'personal_secret' in loaded_org1.custom_secrets
|
||||
assert (
|
||||
loaded_org1.custom_secrets['personal_secret'].secret.get_secret_value()
|
||||
== 'personal_secret_value'
|
||||
)
|
||||
|
||||
# Switch to org2 and store secrets there
|
||||
mock_user.current_org_id = org2_id
|
||||
mock_get_user.return_value = mock_user
|
||||
org2_secrets = Secrets(
|
||||
custom_secrets=MappingProxyType(
|
||||
{
|
||||
'org2_secret': CustomSecret.from_value(
|
||||
{'secret': 'org2_secret_value', 'description': 'Org2 secret'}
|
||||
),
|
||||
}
|
||||
)
|
||||
)
|
||||
await secrets_store.store(org2_secrets)
|
||||
|
||||
# Verify org2 secrets are stored
|
||||
loaded_org2 = await secrets_store.load()
|
||||
assert loaded_org2 is not None
|
||||
assert 'org2_secret' in loaded_org2.custom_secrets
|
||||
assert (
|
||||
loaded_org2.custom_secrets['org2_secret'].secret.get_secret_value()
|
||||
== 'org2_secret_value'
|
||||
)
|
||||
|
||||
# Switch back to org1 and verify secrets are still there
|
||||
mock_user.current_org_id = org1_id
|
||||
mock_get_user.return_value = mock_user
|
||||
loaded_org1_again = await secrets_store.load()
|
||||
assert loaded_org1_again is not None
|
||||
assert 'personal_secret' in loaded_org1_again.custom_secrets
|
||||
assert (
|
||||
loaded_org1_again.custom_secrets[
|
||||
'personal_secret'
|
||||
].secret.get_secret_value()
|
||||
== 'personal_secret_value'
|
||||
)
|
||||
# Verify org2 secrets are NOT visible in org1
|
||||
assert 'org2_secret' not in loaded_org1_again.custom_secrets
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import time
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import jwt
|
||||
@@ -18,6 +19,7 @@ from server.auth.saas_user_auth import (
|
||||
saas_user_auth_from_cookie,
|
||||
saas_user_auth_from_signed_token,
|
||||
)
|
||||
from storage.api_key_store import ApiKeyValidationResult
|
||||
from storage.user_authorization import UserAuthorizationType
|
||||
|
||||
from openhands.integrations.provider import ProviderToken, ProviderType
|
||||
@@ -457,7 +459,8 @@ async def test_get_instance_no_auth(mock_request):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_saas_user_auth_from_bearer_success():
|
||||
"""Test successful authentication from bearer token."""
|
||||
"""Test successful authentication from bearer token sets user_id and api_key_org_id."""
|
||||
# Arrange
|
||||
mock_request = MagicMock()
|
||||
mock_request.headers = {'Authorization': 'Bearer test_api_key'}
|
||||
|
||||
@@ -468,12 +471,22 @@ async def test_saas_user_auth_from_bearer_success():
|
||||
algorithm='HS256',
|
||||
)
|
||||
|
||||
mock_org_id = uuid.uuid4()
|
||||
mock_validation_result = ApiKeyValidationResult(
|
||||
user_id='test_user_id',
|
||||
org_id=mock_org_id,
|
||||
key_id=42,
|
||||
key_name='Test Key',
|
||||
)
|
||||
|
||||
with (
|
||||
patch('server.auth.saas_user_auth.ApiKeyStore') as mock_api_key_store_cls,
|
||||
patch('server.auth.saas_user_auth.token_manager') as mock_token_manager,
|
||||
):
|
||||
mock_api_key_store = MagicMock()
|
||||
mock_api_key_store.validate_api_key = AsyncMock(return_value='test_user_id')
|
||||
mock_api_key_store.validate_api_key = AsyncMock(
|
||||
return_value=mock_validation_result
|
||||
)
|
||||
mock_api_key_store_cls.get_instance.return_value = mock_api_key_store
|
||||
|
||||
mock_token_manager.load_offline_token = AsyncMock(return_value=offline_token)
|
||||
@@ -485,6 +498,9 @@ async def test_saas_user_auth_from_bearer_success():
|
||||
|
||||
assert isinstance(result, SaasUserAuth)
|
||||
assert result.user_id == 'test_user_id'
|
||||
assert result.api_key_org_id == mock_org_id
|
||||
assert result.api_key_id == 42
|
||||
assert result.api_key_name == 'Test Key'
|
||||
mock_api_key_store.validate_api_key.assert_called_once_with('test_api_key')
|
||||
mock_token_manager.load_offline_token.assert_called_once_with('test_user_id')
|
||||
mock_token_manager.refresh.assert_called_once_with(offline_token)
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
import { render, screen, waitFor } from "@testing-library/react";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { describe, it, expect } from "vitest";
|
||||
import { CopyableContentWrapper } from "#/components/shared/buttons/copyable-content-wrapper";
|
||||
|
||||
describe("CopyableContentWrapper", () => {
|
||||
it("should hide the copy button by default", () => {
|
||||
render(
|
||||
<CopyableContentWrapper text="hello">
|
||||
<p>content</p>
|
||||
</CopyableContentWrapper>,
|
||||
);
|
||||
|
||||
expect(screen.getByTestId("copy-to-clipboard")).not.toBeVisible();
|
||||
});
|
||||
|
||||
it("should show the copy button on hover", async () => {
|
||||
const user = userEvent.setup();
|
||||
render(
|
||||
<CopyableContentWrapper text="hello">
|
||||
<p>content</p>
|
||||
</CopyableContentWrapper>,
|
||||
);
|
||||
|
||||
await user.hover(screen.getByText("content"));
|
||||
|
||||
expect(screen.getByTestId("copy-to-clipboard")).toBeVisible();
|
||||
});
|
||||
|
||||
it("should copy text to clipboard on click", async () => {
|
||||
const user = userEvent.setup();
|
||||
render(
|
||||
<CopyableContentWrapper text="copy me">
|
||||
<p>content</p>
|
||||
</CopyableContentWrapper>,
|
||||
);
|
||||
|
||||
await user.click(screen.getByTestId("copy-to-clipboard"));
|
||||
|
||||
await waitFor(() =>
|
||||
expect(navigator.clipboard.readText()).resolves.toBe("copy me"),
|
||||
);
|
||||
});
|
||||
|
||||
it("should show copied state after clicking", async () => {
|
||||
const user = userEvent.setup();
|
||||
render(
|
||||
<CopyableContentWrapper text="hello">
|
||||
<p>content</p>
|
||||
</CopyableContentWrapper>,
|
||||
);
|
||||
|
||||
await user.click(screen.getByTestId("copy-to-clipboard"));
|
||||
|
||||
expect(screen.getByTestId("copy-to-clipboard")).toHaveAttribute(
|
||||
"aria-label",
|
||||
"BUTTON$COPIED",
|
||||
);
|
||||
});
|
||||
});
|
||||
@@ -1,11 +1,16 @@
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
import { describe, expect, it, vi, beforeEach } from "vitest";
|
||||
import { render, screen, waitFor } from "@testing-library/react";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import { AnalyticsConsentFormModal } from "#/components/features/analytics/analytics-consent-form-modal";
|
||||
import SettingsService from "#/api/settings-service/settings-service.api";
|
||||
import { useSelectedOrganizationStore } from "#/stores/selected-organization-store";
|
||||
|
||||
describe("AnalyticsConsentFormModal", () => {
|
||||
beforeEach(() => {
|
||||
useSelectedOrganizationStore.setState({ organizationId: "test-org-id" });
|
||||
});
|
||||
|
||||
it("should call saveUserSettings with consent", async () => {
|
||||
const user = userEvent.setup();
|
||||
const onCloseMock = vi.fn();
|
||||
|
||||
@@ -10,9 +10,12 @@ import {
|
||||
import { OpenHandsObservation } from "#/types/core/observations";
|
||||
import ConversationService from "#/api/conversation-service/conversation-service.api";
|
||||
import { Conversation } from "#/api/open-hands.types";
|
||||
import { useSelectedOrganizationStore } from "#/stores/selected-organization-store";
|
||||
|
||||
vi.mock("react-router", () => ({
|
||||
vi.mock("react-router", async (importOriginal) => ({
|
||||
...(await importOriginal<typeof import("react-router")>()),
|
||||
useParams: () => ({ conversationId: "123" }),
|
||||
useRevalidator: () => ({ revalidate: vi.fn() }),
|
||||
}));
|
||||
|
||||
let queryClient: QueryClient;
|
||||
@@ -47,6 +50,7 @@ const renderMessages = ({
|
||||
describe("Messages", () => {
|
||||
beforeEach(() => {
|
||||
queryClient = new QueryClient();
|
||||
useSelectedOrganizationStore.setState({ organizationId: "test-org-id" });
|
||||
});
|
||||
|
||||
const assistantMessage: AssistantMessageAction = {
|
||||
|
||||
@@ -10,6 +10,7 @@ import OptionService from "#/api/option-service/option-service.api";
|
||||
import { GitRepository } from "#/types/git";
|
||||
import { RepoConnector } from "#/components/features/home/repo-connector";
|
||||
import { MOCK_DEFAULT_USER_SETTINGS } from "#/mocks/handlers";
|
||||
import { useSelectedOrganizationStore } from "#/stores/selected-organization-store";
|
||||
|
||||
const renderRepoConnector = () => {
|
||||
const mockRepoSelection = vi.fn();
|
||||
@@ -65,6 +66,7 @@ const MOCK_RESPOSITORIES: GitRepository[] = [
|
||||
];
|
||||
|
||||
beforeEach(() => {
|
||||
useSelectedOrganizationStore.setState({ organizationId: "test-org-id" });
|
||||
const getSettingsSpy = vi.spyOn(SettingsService, "getSettings");
|
||||
getSettingsSpy.mockResolvedValue({
|
||||
...MOCK_DEFAULT_USER_SETTINGS,
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
import { render, screen, waitFor } from "@testing-library/react";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { describe, it, expect } from "vitest";
|
||||
import { code as Code } from "#/components/features/markdown/code";
|
||||
|
||||
describe("code (markdown)", () => {
|
||||
it("should render inline code without a copy button", () => {
|
||||
render(<Code>inline snippet</Code>);
|
||||
|
||||
expect(screen.getByText("inline snippet")).toBeInTheDocument();
|
||||
expect(screen.queryByTestId("copy-to-clipboard")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should render a multiline code block with a copy button", () => {
|
||||
render(<Code>{"line1\nline2"}</Code>);
|
||||
|
||||
expect(screen.getByText("line1 line2")).toBeInTheDocument();
|
||||
expect(screen.getByTestId("copy-to-clipboard")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should render a syntax-highlighted block with a copy button", () => {
|
||||
render(<Code className="language-js">{"console.log('hi')"}</Code>);
|
||||
|
||||
expect(screen.getByTestId("copy-to-clipboard")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should copy code block content to clipboard", async () => {
|
||||
const user = userEvent.setup();
|
||||
render(<Code>{"line1\nline2"}</Code>);
|
||||
|
||||
await user.click(screen.getByTestId("copy-to-clipboard"));
|
||||
|
||||
await waitFor(() =>
|
||||
expect(navigator.clipboard.readText()).resolves.toBe("line1\nline2"),
|
||||
);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,351 @@
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
|
||||
import { screen, waitFor } from "@testing-library/react";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { renderWithProviders } from "test-utils";
|
||||
import { AddCreditsModal } from "#/components/features/org/add-credits-modal";
|
||||
import BillingService from "#/api/billing-service/billing-service.api";
|
||||
|
||||
vi.mock("react-i18next", async (importOriginal) => ({
|
||||
...(await importOriginal<typeof import("react-i18next")>()),
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => key,
|
||||
i18n: {
|
||||
changeLanguage: vi.fn(),
|
||||
},
|
||||
}),
|
||||
}));
|
||||
|
||||
describe("AddCreditsModal", () => {
|
||||
const onCloseMock = vi.fn();
|
||||
|
||||
const renderModal = () => {
|
||||
const user = userEvent.setup();
|
||||
renderWithProviders(<AddCreditsModal onClose={onCloseMock} />);
|
||||
return { user };
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe("Rendering", () => {
|
||||
it("should render the form with correct elements", () => {
|
||||
renderModal();
|
||||
|
||||
expect(screen.getByTestId("add-credits-form")).toBeInTheDocument();
|
||||
expect(screen.getByTestId("amount-input")).toBeInTheDocument();
|
||||
expect(screen.getByRole("button", { name: /ORG\$NEXT/i })).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should display the title", () => {
|
||||
renderModal();
|
||||
|
||||
expect(screen.getByText("ORG$ADD_CREDITS")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Button State Management", () => {
|
||||
it("should enable submit button initially when modal opens", () => {
|
||||
renderModal();
|
||||
|
||||
const nextButton = screen.getByRole("button", { name: /ORG\$NEXT/i });
|
||||
expect(nextButton).not.toBeDisabled();
|
||||
});
|
||||
|
||||
it("should enable submit button when input contains invalid value", async () => {
|
||||
const { user } = renderModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /ORG\$NEXT/i });
|
||||
|
||||
await user.type(amountInput, "-50");
|
||||
|
||||
expect(nextButton).not.toBeDisabled();
|
||||
});
|
||||
|
||||
it("should enable submit button when input contains valid value", async () => {
|
||||
const { user } = renderModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /ORG\$NEXT/i });
|
||||
|
||||
await user.type(amountInput, "100");
|
||||
|
||||
expect(nextButton).not.toBeDisabled();
|
||||
});
|
||||
|
||||
it("should enable submit button after validation error is shown", async () => {
|
||||
const { user } = renderModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /ORG\$NEXT/i });
|
||||
|
||||
await user.type(amountInput, "9");
|
||||
await user.click(nextButton);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("amount-error")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
expect(nextButton).not.toBeDisabled();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Input Attributes & Placeholder", () => {
|
||||
it("should have min attribute set to 10", () => {
|
||||
renderModal();
|
||||
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
expect(amountInput).toHaveAttribute("min", "10");
|
||||
});
|
||||
|
||||
it("should have max attribute set to 25000", () => {
|
||||
renderModal();
|
||||
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
expect(amountInput).toHaveAttribute("max", "25000");
|
||||
});
|
||||
|
||||
it("should have step attribute set to 1", () => {
|
||||
renderModal();
|
||||
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
expect(amountInput).toHaveAttribute("step", "1");
|
||||
});
|
||||
});
|
||||
|
||||
describe("Error Message Display", () => {
|
||||
it("should not display error message initially when modal opens", () => {
|
||||
renderModal();
|
||||
|
||||
const errorMessage = screen.queryByTestId("amount-error");
|
||||
expect(errorMessage).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should display error message after submitting amount above maximum", async () => {
|
||||
const { user } = renderModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /ORG\$NEXT/i });
|
||||
|
||||
await user.type(amountInput, "25001");
|
||||
await user.click(nextButton);
|
||||
|
||||
await waitFor(() => {
|
||||
const errorMessage = screen.getByTestId("amount-error");
|
||||
expect(errorMessage).toHaveTextContent("PAYMENT$ERROR_MAXIMUM_AMOUNT");
|
||||
});
|
||||
});
|
||||
|
||||
it("should display error message after submitting decimal value", async () => {
|
||||
const { user } = renderModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /ORG\$NEXT/i });
|
||||
|
||||
await user.type(amountInput, "50.5");
|
||||
await user.click(nextButton);
|
||||
|
||||
await waitFor(() => {
|
||||
const errorMessage = screen.getByTestId("amount-error");
|
||||
expect(errorMessage).toHaveTextContent("PAYMENT$ERROR_MUST_BE_WHOLE_NUMBER");
|
||||
});
|
||||
});
|
||||
|
||||
it("should display error message after submitting amount below minimum", async () => {
|
||||
const { user } = renderModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /ORG\$NEXT/i });
|
||||
|
||||
await user.type(amountInput, "9");
|
||||
await user.click(nextButton);
|
||||
|
||||
await waitFor(() => {
|
||||
const errorMessage = screen.getByTestId("amount-error");
|
||||
expect(errorMessage).toHaveTextContent("PAYMENT$ERROR_MINIMUM_AMOUNT");
|
||||
});
|
||||
});
|
||||
|
||||
it("should display error message after submitting negative amount", async () => {
|
||||
const { user } = renderModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /ORG\$NEXT/i });
|
||||
|
||||
await user.type(amountInput, "-50");
|
||||
await user.click(nextButton);
|
||||
|
||||
await waitFor(() => {
|
||||
const errorMessage = screen.getByTestId("amount-error");
|
||||
expect(errorMessage).toHaveTextContent("PAYMENT$ERROR_NEGATIVE_AMOUNT");
|
||||
});
|
||||
});
|
||||
|
||||
it("should replace error message when submitting different invalid value", async () => {
|
||||
const { user } = renderModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /ORG\$NEXT/i });
|
||||
|
||||
await user.type(amountInput, "9");
|
||||
await user.click(nextButton);
|
||||
|
||||
await waitFor(() => {
|
||||
const errorMessage = screen.getByTestId("amount-error");
|
||||
expect(errorMessage).toHaveTextContent("PAYMENT$ERROR_MINIMUM_AMOUNT");
|
||||
});
|
||||
|
||||
await user.clear(amountInput);
|
||||
await user.type(amountInput, "25001");
|
||||
await user.click(nextButton);
|
||||
|
||||
await waitFor(() => {
|
||||
const errorMessage = screen.getByTestId("amount-error");
|
||||
expect(errorMessage).toHaveTextContent("PAYMENT$ERROR_MAXIMUM_AMOUNT");
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("Form Submission Behavior", () => {
|
||||
it("should prevent submission when amount is invalid", async () => {
|
||||
const createCheckoutSessionSpy = vi.spyOn(
|
||||
BillingService,
|
||||
"createCheckoutSession",
|
||||
);
|
||||
const { user } = renderModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /ORG\$NEXT/i });
|
||||
|
||||
await user.type(amountInput, "9");
|
||||
await user.click(nextButton);
|
||||
|
||||
expect(createCheckoutSessionSpy).not.toHaveBeenCalled();
|
||||
await waitFor(() => {
|
||||
const errorMessage = screen.getByTestId("amount-error");
|
||||
expect(errorMessage).toHaveTextContent("PAYMENT$ERROR_MINIMUM_AMOUNT");
|
||||
});
|
||||
});
|
||||
|
||||
it("should call createCheckoutSession with correct amount when valid", async () => {
|
||||
const createCheckoutSessionSpy = vi.spyOn(
|
||||
BillingService,
|
||||
"createCheckoutSession",
|
||||
);
|
||||
const { user } = renderModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /ORG\$NEXT/i });
|
||||
|
||||
await user.type(amountInput, "1000");
|
||||
await user.click(nextButton);
|
||||
|
||||
expect(createCheckoutSessionSpy).toHaveBeenCalledWith(1000);
|
||||
const errorMessage = screen.queryByTestId("amount-error");
|
||||
expect(errorMessage).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not call createCheckoutSession when validation fails", async () => {
|
||||
const createCheckoutSessionSpy = vi.spyOn(
|
||||
BillingService,
|
||||
"createCheckoutSession",
|
||||
);
|
||||
const { user } = renderModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /ORG\$NEXT/i });
|
||||
|
||||
await user.type(amountInput, "-50");
|
||||
await user.click(nextButton);
|
||||
|
||||
expect(createCheckoutSessionSpy).not.toHaveBeenCalled();
|
||||
await waitFor(() => {
|
||||
const errorMessage = screen.getByTestId("amount-error");
|
||||
expect(errorMessage).toHaveTextContent("PAYMENT$ERROR_NEGATIVE_AMOUNT");
|
||||
});
|
||||
});
|
||||
|
||||
it("should close modal on successful submission", async () => {
|
||||
vi.spyOn(BillingService, "createCheckoutSession").mockResolvedValue(
|
||||
"https://checkout.stripe.com/test-session",
|
||||
);
|
||||
|
||||
const { user } = renderModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /ORG\$NEXT/i });
|
||||
|
||||
await user.type(amountInput, "1000");
|
||||
await user.click(nextButton);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(onCloseMock).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
it("should allow API call when validation passes and clear any previous errors", async () => {
|
||||
const createCheckoutSessionSpy = vi.spyOn(
|
||||
BillingService,
|
||||
"createCheckoutSession",
|
||||
);
|
||||
|
||||
const { user } = renderModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /ORG\$NEXT/i });
|
||||
|
||||
// First submit invalid value
|
||||
await user.type(amountInput, "9");
|
||||
await user.click(nextButton);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("amount-error")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// Then submit valid value
|
||||
await user.clear(amountInput);
|
||||
await user.type(amountInput, "100");
|
||||
await user.click(nextButton);
|
||||
|
||||
expect(createCheckoutSessionSpy).toHaveBeenCalledWith(100);
|
||||
const errorMessage = screen.queryByTestId("amount-error");
|
||||
expect(errorMessage).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Edge Cases", () => {
|
||||
it("should handle zero value correctly", async () => {
|
||||
const { user } = renderModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /ORG\$NEXT/i });
|
||||
|
||||
await user.type(amountInput, "0");
|
||||
await user.click(nextButton);
|
||||
|
||||
await waitFor(() => {
|
||||
const errorMessage = screen.getByTestId("amount-error");
|
||||
expect(errorMessage).toHaveTextContent("PAYMENT$ERROR_MINIMUM_AMOUNT");
|
||||
});
|
||||
});
|
||||
|
||||
it("should handle whitespace-only input correctly", async () => {
|
||||
const createCheckoutSessionSpy = vi.spyOn(
|
||||
BillingService,
|
||||
"createCheckoutSession",
|
||||
);
|
||||
const { user } = renderModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /ORG\$NEXT/i });
|
||||
|
||||
// Number inputs typically don't accept spaces, but test the behavior
|
||||
await user.type(amountInput, " ");
|
||||
await user.click(nextButton);
|
||||
|
||||
// Should not call API (empty/invalid input)
|
||||
expect(createCheckoutSessionSpy).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Modal Interaction", () => {
|
||||
it("should call onClose when cancel button is clicked", async () => {
|
||||
const { user } = renderModal();
|
||||
|
||||
const cancelButton = screen.getByRole("button", { name: /close/i });
|
||||
await user.click(cancelButton);
|
||||
|
||||
expect(onCloseMock).toHaveBeenCalledOnce();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,7 +1,8 @@
|
||||
import { render, screen } from "@testing-library/react";
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
import { describe, expect, it, vi, beforeEach } from "vitest";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import { ApiKeysManager } from "#/components/features/settings/api-keys-manager";
|
||||
import { useSelectedOrganizationStore } from "#/stores/selected-organization-store";
|
||||
|
||||
// Mock the react-i18next
|
||||
vi.mock("react-i18next", async () => {
|
||||
@@ -37,6 +38,10 @@ vi.mock("#/hooks/query/use-api-keys", () => ({
|
||||
}));
|
||||
|
||||
describe("ApiKeysManager", () => {
|
||||
beforeEach(() => {
|
||||
useSelectedOrganizationStore.setState({ organizationId: "test-org-id" });
|
||||
});
|
||||
|
||||
const renderComponent = () => {
|
||||
const queryClient = new QueryClient();
|
||||
return render(
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import {
|
||||
renderWithProviders,
|
||||
createAxiosNotFoundErrorObject,
|
||||
@@ -10,6 +10,7 @@ import SettingsService from "#/api/settings-service/settings-service.api";
|
||||
import OptionService from "#/api/option-service/option-service.api";
|
||||
import { MOCK_DEFAULT_USER_SETTINGS } from "#/mocks/handlers";
|
||||
import { WebClientConfig } from "#/api/option-service/option.types";
|
||||
import { useSelectedOrganizationStore } from "#/stores/selected-organization-store";
|
||||
|
||||
// Helper to create mock config with sensible defaults
|
||||
const createMockConfig = (
|
||||
@@ -76,6 +77,10 @@ describe("Sidebar", () => {
|
||||
const getSettingsSpy = vi.spyOn(SettingsService, "getSettings");
|
||||
const getConfigSpy = vi.spyOn(OptionService, "getConfig");
|
||||
|
||||
beforeEach(() => {
|
||||
useSelectedOrganizationStore.setState({ organizationId: "test-org-id" });
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
@@ -156,11 +156,19 @@ describe("UserContextMenu", () => {
|
||||
useSelectedOrganizationStore.setState({ organizationId: null });
|
||||
});
|
||||
|
||||
it("should render the default context items for a user", () => {
|
||||
it("should render the default context items for a user", async () => {
|
||||
vi.spyOn(OptionService, "getConfig").mockResolvedValue(
|
||||
createMockWebClientConfig({ app_mode: "saas" }),
|
||||
);
|
||||
|
||||
renderUserContextMenu({ type: "member", onClose: vi.fn, onOpenInviteModal: vi.fn });
|
||||
|
||||
screen.getByTestId("org-selector");
|
||||
screen.getByText("ACCOUNT_SETTINGS$LOGOUT");
|
||||
|
||||
// Wait for config to load so logout button appears
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("ACCOUNT_SETTINGS$LOGOUT")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
expect(
|
||||
screen.queryByText("ORG$INVITE_ORG_MEMBERS"),
|
||||
@@ -304,6 +312,20 @@ describe("UserContextMenu", () => {
|
||||
screen.queryByText("Organization Members"),
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not display logout button in OSS mode", async () => {
|
||||
renderUserContextMenu({ type: "member", onClose: vi.fn, onOpenInviteModal: vi.fn });
|
||||
|
||||
// Wait for the config to load
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("SETTINGS$NAV_LLM")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// Verify logout button is NOT rendered in OSS mode
|
||||
expect(
|
||||
screen.queryByText("ACCOUNT_SETTINGS$LOGOUT"),
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("HIDE_LLM_SETTINGS feature flag", () => {
|
||||
@@ -382,10 +404,15 @@ describe("UserContextMenu", () => {
|
||||
});
|
||||
|
||||
it("should call the logout handler when Logout is clicked", async () => {
|
||||
vi.spyOn(OptionService, "getConfig").mockResolvedValue(
|
||||
createMockWebClientConfig({ app_mode: "saas" }),
|
||||
);
|
||||
|
||||
const logoutSpy = vi.spyOn(AuthService, "logout");
|
||||
renderUserContextMenu({ type: "member", onClose: vi.fn, onOpenInviteModal: vi.fn });
|
||||
|
||||
const logoutButton = screen.getByText("ACCOUNT_SETTINGS$LOGOUT");
|
||||
// Wait for config to load so logout button appears
|
||||
const logoutButton = await screen.findByText("ACCOUNT_SETTINGS$LOGOUT");
|
||||
await userEvent.click(logoutButton);
|
||||
|
||||
expect(logoutSpy).toHaveBeenCalledOnce();
|
||||
@@ -488,6 +515,10 @@ describe("UserContextMenu", () => {
|
||||
});
|
||||
|
||||
it("should call the onClose handler after each action", async () => {
|
||||
vi.spyOn(OptionService, "getConfig").mockResolvedValue(
|
||||
createMockWebClientConfig({ app_mode: "saas" }),
|
||||
);
|
||||
|
||||
// Mock a team org so org management buttons are visible
|
||||
vi.spyOn(organizationService, "getOrganizations").mockResolvedValue({
|
||||
items: [MOCK_TEAM_ORG_ACME],
|
||||
@@ -497,7 +528,8 @@ describe("UserContextMenu", () => {
|
||||
const onCloseMock = vi.fn();
|
||||
renderUserContextMenu({ type: "owner", onClose: onCloseMock, onOpenInviteModal: vi.fn });
|
||||
|
||||
const logoutButton = screen.getByText("ACCOUNT_SETTINGS$LOGOUT");
|
||||
// Wait for config to load so logout button appears
|
||||
const logoutButton = await screen.findByText("ACCOUNT_SETTINGS$LOGOUT");
|
||||
await userEvent.click(logoutButton);
|
||||
expect(onCloseMock).toHaveBeenCalledTimes(1);
|
||||
|
||||
|
||||
@@ -1,26 +1,25 @@
|
||||
import { screen } from "@testing-library/react";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { afterEach, beforeAll, describe, expect, it, vi } from "vitest";
|
||||
import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { MemoryRouter } from "react-router";
|
||||
import { InteractiveChatBox } from "#/components/features/chat/interactive-chat-box";
|
||||
import { renderWithProviders } from "../../test-utils";
|
||||
import { AgentState } from "#/types/agent-state";
|
||||
import { useAgentState } from "#/hooks/use-agent-state";
|
||||
import { useConversationStore } from "#/stores/conversation-store";
|
||||
import { useSelectedOrganizationStore } from "#/stores/selected-organization-store";
|
||||
|
||||
vi.mock("#/hooks/use-agent-state", () => ({
|
||||
useAgentState: vi.fn(),
|
||||
}));
|
||||
|
||||
// Mock React Router hooks
|
||||
vi.mock("react-router", async () => {
|
||||
const actual = await vi.importActual("react-router");
|
||||
return {
|
||||
...actual,
|
||||
useNavigate: () => vi.fn(),
|
||||
useParams: () => ({ conversationId: "test-conversation-id" }),
|
||||
};
|
||||
});
|
||||
vi.mock("react-router", async (importOriginal) => ({
|
||||
...(await importOriginal<typeof import("react-router")>()),
|
||||
useNavigate: () => vi.fn(),
|
||||
useParams: () => ({ conversationId: "test-conversation-id" }),
|
||||
useRevalidator: () => ({ revalidate: vi.fn() }),
|
||||
}));
|
||||
|
||||
// Mock the useActiveConversation hook
|
||||
vi.mock("#/hooks/query/use-active-conversation", () => ({
|
||||
@@ -52,6 +51,10 @@ vi.mock("#/hooks/use-conversation-name-context-menu", () => ({
|
||||
describe("InteractiveChatBox", () => {
|
||||
const onSubmitMock = vi.fn();
|
||||
|
||||
beforeEach(() => {
|
||||
useSelectedOrganizationStore.setState({ organizationId: "test-org-id" });
|
||||
});
|
||||
|
||||
const mockStores = (agentState: AgentState = AgentState.INIT) => {
|
||||
vi.mocked(useAgentState).mockReturnValue({
|
||||
curAgentState: agentState,
|
||||
@@ -213,6 +216,36 @@ describe("InteractiveChatBox", () => {
|
||||
expect(onSubmitMock).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("should lock the text input field when disabled prop is true (isNewConversationPending)", () => {
|
||||
mockStores(AgentState.INIT);
|
||||
|
||||
renderInteractiveChatBox({
|
||||
onSubmit: onSubmitMock,
|
||||
disabled: true,
|
||||
});
|
||||
|
||||
const chatInput = screen.getByTestId("chat-input");
|
||||
// When disabled=true, the text field should not be editable
|
||||
expect(chatInput).toHaveAttribute("contenteditable", "false");
|
||||
// Should show visual disabled state
|
||||
expect(chatInput.className).toContain("cursor-not-allowed");
|
||||
expect(chatInput.className).toContain("opacity-50");
|
||||
});
|
||||
|
||||
it("should keep the text input field editable when disabled prop is false", () => {
|
||||
mockStores(AgentState.INIT);
|
||||
|
||||
renderInteractiveChatBox({
|
||||
onSubmit: onSubmitMock,
|
||||
disabled: false,
|
||||
});
|
||||
|
||||
const chatInput = screen.getByTestId("chat-input");
|
||||
expect(chatInput).toHaveAttribute("contenteditable", "true");
|
||||
expect(chatInput.className).not.toContain("cursor-not-allowed");
|
||||
expect(chatInput.className).not.toContain("opacity-50");
|
||||
});
|
||||
|
||||
it("should handle image upload and message submission correctly", async () => {
|
||||
const user = userEvent.setup();
|
||||
const onSubmit = vi.fn();
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
import { render, screen, waitFor } from "@testing-library/react";
|
||||
import { render, screen, waitFor, fireEvent, act } from "@testing-library/react";
|
||||
import { describe, expect, it, vi, afterEach, beforeEach, test } from "vitest";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { QueryClientProvider, QueryClient } from "@tanstack/react-query";
|
||||
import { MemoryRouter } from "react-router";
|
||||
import { MemoryRouter, createRoutesStub } from "react-router";
|
||||
import { ReactElement } from "react";
|
||||
import { http, HttpResponse } from "msw";
|
||||
import { UserActions } from "#/components/features/sidebar/user-actions";
|
||||
import { organizationService } from "#/api/organization-service/organization-service.api";
|
||||
import { MOCK_PERSONAL_ORG, MOCK_TEAM_ORG_ACME } from "#/mocks/org-handlers";
|
||||
import { useSelectedOrganizationStore } from "#/stores/selected-organization-store";
|
||||
import { server } from "#/mocks/node";
|
||||
import { createMockWebClientConfig } from "#/mocks/settings-handlers";
|
||||
import { renderWithProviders } from "../../test-utils";
|
||||
|
||||
vi.mock("react-router", async (importActual) => ({
|
||||
@@ -59,6 +62,20 @@ const renderUserActions = (props = { hasAvatar: true }) => {
|
||||
);
|
||||
};
|
||||
|
||||
// RouterStub and render helper for menu close delay tests
|
||||
const RouterStubForMenuCloseDelay = createRoutesStub([
|
||||
{
|
||||
path: "/",
|
||||
Component: () => (
|
||||
<UserActions user={{ avatar_url: "https://example.com/avatar.png" }} />
|
||||
),
|
||||
},
|
||||
]);
|
||||
|
||||
const renderUserActionsForMenuCloseDelay = () => {
|
||||
return renderWithProviders(<RouterStubForMenuCloseDelay initialEntries={["/"]} />);
|
||||
};
|
||||
|
||||
// Create mocks for all the hooks we need
|
||||
const useIsAuthedMock = vi
|
||||
.fn()
|
||||
@@ -347,7 +364,7 @@ describe("UserActions", () => {
|
||||
expect(contextMenu).toBeVisible();
|
||||
});
|
||||
|
||||
it("should have pointer-events-none on hover bridge pseudo-element to allow menu item clicks", async () => {
|
||||
it("should use state-based visibility for hover behavior instead of CSS pseudo-element", async () => {
|
||||
renderUserActions();
|
||||
|
||||
const userActions = screen.getByTestId("user-actions");
|
||||
@@ -356,19 +373,17 @@ describe("UserActions", () => {
|
||||
const contextMenu = screen.getByTestId("user-context-menu");
|
||||
const hoverBridgeContainer = contextMenu.parentElement;
|
||||
|
||||
// The hover bridge uses a ::before pseudo-element for diagonal mouse movement
|
||||
// This pseudo-element MUST have pointer-events-none to allow clicks through to menu items
|
||||
// The class should include "before:pointer-events-none" to prevent the hover bridge from blocking clicks
|
||||
expect(hoverBridgeContainer?.className).toContain(
|
||||
"before:pointer-events-none",
|
||||
);
|
||||
// The component uses state-based visibility with a 500ms delay for diagonal mouse movement
|
||||
// When visible, the container should have opacity-100 and pointer-events-auto
|
||||
expect(hoverBridgeContainer?.className).toContain("opacity-100");
|
||||
expect(hoverBridgeContainer?.className).toContain("pointer-events-auto");
|
||||
});
|
||||
|
||||
describe("Org selector dropdown state reset when context menu hides", () => {
|
||||
// These tests verify that the org selector dropdown resets its internal
|
||||
// state (search text, open/closed) when the context menu hides and
|
||||
// reappears. Without this, stale state persists because the context
|
||||
// menu is hidden via CSS (opacity/pointer-events) rather than unmounted.
|
||||
// reappears. The component uses a 500ms delay before hiding (to support
|
||||
// diagonal mouse movement).
|
||||
|
||||
beforeEach(() => {
|
||||
vi.spyOn(organizationService, "getOrganizations").mockResolvedValue({
|
||||
@@ -400,8 +415,22 @@ describe("UserActions", () => {
|
||||
await user.type(input, "search text");
|
||||
expect(input).toHaveValue("search text");
|
||||
|
||||
// Unhover to hide context menu, then hover again
|
||||
// Unhover to trigger hide timeout, then wait for the 500ms delay to complete
|
||||
await user.unhover(userActions);
|
||||
|
||||
// Wait for the 500ms hide delay to complete and menu to actually hide
|
||||
await waitFor(
|
||||
() => {
|
||||
// The menu resets when it actually hides (after 500ms delay)
|
||||
// After hiding, hovering again should show a fresh menu
|
||||
},
|
||||
{ timeout: 600 },
|
||||
);
|
||||
|
||||
// Wait a bit more for the timeout to fire
|
||||
await new Promise((resolve) => setTimeout(resolve, 550));
|
||||
|
||||
// Now hover again to show the menu
|
||||
await user.hover(userActions);
|
||||
|
||||
// Org selector should be reset — showing selected org name, not search text
|
||||
@@ -434,8 +463,13 @@ describe("UserActions", () => {
|
||||
await user.type(input, "Acme");
|
||||
expect(input).toHaveValue("Acme");
|
||||
|
||||
// Unhover to hide context menu, then hover again
|
||||
// Unhover to trigger hide timeout
|
||||
await user.unhover(userActions);
|
||||
|
||||
// Wait for the 500ms hide delay to complete
|
||||
await new Promise((resolve) => setTimeout(resolve, 550));
|
||||
|
||||
// Now hover again to show the menu
|
||||
await user.hover(userActions);
|
||||
|
||||
// Wait for fresh component with org data
|
||||
@@ -454,4 +488,83 @@ describe("UserActions", () => {
|
||||
expect(screen.queryAllByRole("option")).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe("menu close delay", () => {
|
||||
beforeEach(() => {
|
||||
vi.useFakeTimers();
|
||||
useSelectedOrganizationStore.setState({ organizationId: "1" });
|
||||
|
||||
// Mock config to return SaaS mode so useShouldShowUserFeatures returns true
|
||||
server.use(
|
||||
http.get("/api/v1/web-client/config", () =>
|
||||
HttpResponse.json(createMockWebClientConfig({ app_mode: "saas" })),
|
||||
),
|
||||
);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.useRealTimers();
|
||||
server.resetHandlers();
|
||||
});
|
||||
|
||||
it("should keep menu visible when mouse leaves and re-enters within 500ms", async () => {
|
||||
// Arrange - render and wait for queries to settle
|
||||
renderUserActionsForMenuCloseDelay();
|
||||
await act(async () => {
|
||||
await vi.runAllTimersAsync();
|
||||
});
|
||||
|
||||
const userActions = screen.getByTestId("user-actions");
|
||||
|
||||
// Act - open menu
|
||||
await act(async () => {
|
||||
fireEvent.mouseEnter(userActions);
|
||||
});
|
||||
|
||||
// Assert - menu is visible
|
||||
expect(screen.getByTestId("user-context-menu")).toBeInTheDocument();
|
||||
|
||||
// Act - leave and re-enter within 500ms
|
||||
await act(async () => {
|
||||
fireEvent.mouseLeave(userActions);
|
||||
await vi.advanceTimersByTimeAsync(200);
|
||||
fireEvent.mouseEnter(userActions);
|
||||
});
|
||||
|
||||
// Assert - menu should still be visible after waiting (pending close was cancelled)
|
||||
await act(async () => {
|
||||
await vi.advanceTimersByTimeAsync(500);
|
||||
});
|
||||
expect(screen.getByTestId("user-context-menu")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not close menu before 500ms delay when mouse leaves", async () => {
|
||||
// Arrange - render and wait for queries to settle
|
||||
renderUserActionsForMenuCloseDelay();
|
||||
await act(async () => {
|
||||
await vi.runAllTimersAsync();
|
||||
});
|
||||
|
||||
const userActions = screen.getByTestId("user-actions");
|
||||
|
||||
// Act - open menu
|
||||
await act(async () => {
|
||||
fireEvent.mouseEnter(userActions);
|
||||
});
|
||||
|
||||
// Assert - menu is visible
|
||||
expect(screen.getByTestId("user-context-menu")).toBeInTheDocument();
|
||||
|
||||
// Act - leave without re-entering, but check before timeout expires
|
||||
await act(async () => {
|
||||
fireEvent.mouseLeave(userActions);
|
||||
await vi.advanceTimersByTimeAsync(400); // Before the 500ms delay
|
||||
});
|
||||
|
||||
// Assert - menu should still be visible (delay hasn't expired yet)
|
||||
// Note: The menu is always in DOM but with opacity-0 when closed.
|
||||
// This test verifies the state hasn't changed yet (delay is working).
|
||||
expect(screen.getByTestId("user-context-menu")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -7,6 +7,7 @@ import {
|
||||
WsClientProvider,
|
||||
useWsClient,
|
||||
} from "#/context/ws-client-provider";
|
||||
import { useSelectedOrganizationStore } from "#/stores/selected-organization-store";
|
||||
|
||||
describe("Propagate error message", () => {
|
||||
it("should do nothing when no message was passed from server", () => {
|
||||
@@ -56,6 +57,7 @@ function TestComponent() {
|
||||
describe("WsClientProvider", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
useSelectedOrganizationStore.setState({ organizationId: "test-org-id" });
|
||||
vi.mock("#/hooks/query/use-active-conversation", () => ({
|
||||
useActiveConversation: () => {
|
||||
return { data: {
|
||||
|
||||
@@ -40,6 +40,7 @@ import {
|
||||
import { conversationWebSocketTestSetup } from "./helpers/msw-websocket-setup";
|
||||
import { useEventStore } from "#/stores/use-event-store";
|
||||
import { isV1Event } from "#/types/v1/type-guards";
|
||||
import { useSelectedOrganizationStore } from "#/stores/selected-organization-store";
|
||||
|
||||
// Mock useUserConversation to return V1 conversation data
|
||||
vi.mock("#/hooks/query/use-user-conversation", () => ({
|
||||
@@ -62,6 +63,10 @@ beforeAll(() => {
|
||||
mswServer.listen({ onUnhandledRequest: "bypass" });
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
useSelectedOrganizationStore.setState({ organizationId: "test-org-id" });
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
mswServer.resetHandlers();
|
||||
// Clean up any React components
|
||||
|
||||
@@ -0,0 +1,299 @@
|
||||
import { renderHook, waitFor } from "@testing-library/react";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import { describe, expect, it, vi, beforeEach } from "vitest";
|
||||
import V1ConversationService from "#/api/conversation-service/v1-conversation-service.api";
|
||||
import { useNewConversationCommand } from "#/hooks/mutation/use-new-conversation-command";
|
||||
|
||||
const mockNavigate = vi.fn();
|
||||
|
||||
vi.mock("react-router", () => ({
|
||||
useNavigate: () => mockNavigate,
|
||||
useParams: () => ({ conversationId: "conv-123" }),
|
||||
}));
|
||||
|
||||
vi.mock("react-i18next", () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => key,
|
||||
}),
|
||||
}));
|
||||
|
||||
const { mockToast } = vi.hoisted(() => {
|
||||
const mockToast = Object.assign(vi.fn(), {
|
||||
loading: vi.fn(),
|
||||
dismiss: vi.fn(),
|
||||
});
|
||||
return { mockToast };
|
||||
});
|
||||
|
||||
vi.mock("react-hot-toast", () => ({
|
||||
default: mockToast,
|
||||
}));
|
||||
|
||||
vi.mock("#/utils/custom-toast-handlers", () => ({
|
||||
displaySuccessToast: vi.fn(),
|
||||
displayErrorToast: vi.fn(),
|
||||
TOAST_OPTIONS: { position: "top-right" },
|
||||
}));
|
||||
|
||||
const mockConversation = {
|
||||
conversation_id: "conv-123",
|
||||
sandbox_id: "sandbox-456",
|
||||
title: "Test Conversation",
|
||||
selected_repository: null,
|
||||
selected_branch: null,
|
||||
git_provider: null,
|
||||
last_updated_at: new Date().toISOString(),
|
||||
created_at: new Date().toISOString(),
|
||||
status: "RUNNING" as const,
|
||||
runtime_status: null,
|
||||
url: null,
|
||||
session_api_key: null,
|
||||
conversation_version: "V1" as const,
|
||||
};
|
||||
|
||||
vi.mock("#/hooks/query/use-active-conversation", () => ({
|
||||
useActiveConversation: () => ({
|
||||
data: mockConversation,
|
||||
}),
|
||||
}));
|
||||
|
||||
function makeStartTask(overrides: Record<string, unknown> = {}) {
|
||||
return {
|
||||
id: "task-789",
|
||||
created_by_user_id: null,
|
||||
status: "READY",
|
||||
detail: null,
|
||||
app_conversation_id: "new-conv-999",
|
||||
sandbox_id: "sandbox-456",
|
||||
agent_server_url: "http://agent-server.local",
|
||||
request: {
|
||||
sandbox_id: null,
|
||||
initial_message: null,
|
||||
processors: [],
|
||||
llm_model: null,
|
||||
selected_repository: null,
|
||||
selected_branch: null,
|
||||
git_provider: null,
|
||||
suggested_task: null,
|
||||
title: null,
|
||||
trigger: null,
|
||||
pr_number: [],
|
||||
parent_conversation_id: null,
|
||||
agent_type: "default",
|
||||
},
|
||||
created_at: new Date().toISOString(),
|
||||
updated_at: new Date().toISOString(),
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
describe("useNewConversationCommand", () => {
|
||||
let queryClient: QueryClient;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
queryClient = new QueryClient({
|
||||
defaultOptions: { mutations: { retry: false } },
|
||||
});
|
||||
// Mock batchGetAppConversations to return V1 data with llm_model
|
||||
vi.spyOn(
|
||||
V1ConversationService,
|
||||
"batchGetAppConversations",
|
||||
).mockResolvedValue([
|
||||
{
|
||||
id: "conv-123",
|
||||
title: "Test Conversation",
|
||||
sandbox_id: "sandbox-456",
|
||||
sandbox_status: "RUNNING",
|
||||
execution_status: "IDLE",
|
||||
conversation_url: null,
|
||||
session_api_key: null,
|
||||
selected_repository: null,
|
||||
selected_branch: null,
|
||||
git_provider: null,
|
||||
trigger: null,
|
||||
pr_number: [],
|
||||
llm_model: "gpt-4o",
|
||||
metrics: null,
|
||||
created_at: new Date().toISOString(),
|
||||
updated_at: new Date().toISOString(),
|
||||
sub_conversation_ids: [],
|
||||
public: false,
|
||||
} as never,
|
||||
]);
|
||||
});
|
||||
|
||||
const wrapper = ({ children }: { children: React.ReactNode }) => (
|
||||
<QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
|
||||
);
|
||||
|
||||
it("calls createConversation with sandbox_id and navigates on success", async () => {
|
||||
const readyTask = makeStartTask();
|
||||
const createSpy = vi
|
||||
.spyOn(V1ConversationService, "createConversation")
|
||||
.mockResolvedValue(readyTask as never);
|
||||
const getStartTaskSpy = vi
|
||||
.spyOn(V1ConversationService, "getStartTask")
|
||||
.mockResolvedValue(readyTask as never);
|
||||
|
||||
const { result } = renderHook(() => useNewConversationCommand(), { wrapper });
|
||||
|
||||
await result.current.mutateAsync();
|
||||
|
||||
await waitFor(() => {
|
||||
expect(createSpy).toHaveBeenCalledWith(
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
"sandbox-456",
|
||||
"gpt-4o",
|
||||
);
|
||||
expect(getStartTaskSpy).toHaveBeenCalledWith("task-789");
|
||||
expect(mockNavigate).toHaveBeenCalledWith(
|
||||
"/conversations/new-conv-999",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it("polls getStartTask until status is READY", async () => {
|
||||
vi.useFakeTimers({ shouldAdvanceTime: true });
|
||||
|
||||
const workingTask = makeStartTask({
|
||||
status: "WORKING",
|
||||
app_conversation_id: null,
|
||||
});
|
||||
const readyTask = makeStartTask({ status: "READY" });
|
||||
|
||||
vi.spyOn(V1ConversationService, "createConversation").mockResolvedValue(
|
||||
workingTask as never,
|
||||
);
|
||||
const getStartTaskSpy = vi
|
||||
.spyOn(V1ConversationService, "getStartTask")
|
||||
.mockResolvedValueOnce(workingTask as never)
|
||||
.mockResolvedValueOnce(readyTask as never);
|
||||
|
||||
const { result } = renderHook(() => useNewConversationCommand(), { wrapper });
|
||||
|
||||
const mutatePromise = result.current.mutateAsync();
|
||||
|
||||
await vi.advanceTimersByTimeAsync(2000);
|
||||
await mutatePromise;
|
||||
|
||||
await waitFor(() => {
|
||||
expect(getStartTaskSpy).toHaveBeenCalledTimes(2);
|
||||
expect(mockNavigate).toHaveBeenCalledWith(
|
||||
"/conversations/new-conv-999",
|
||||
);
|
||||
});
|
||||
|
||||
vi.useRealTimers();
|
||||
});
|
||||
|
||||
it("throws when task status is ERROR", async () => {
|
||||
const errorTask = makeStartTask({
|
||||
status: "ERROR",
|
||||
detail: "Sandbox crashed",
|
||||
app_conversation_id: null,
|
||||
});
|
||||
|
||||
vi.spyOn(V1ConversationService, "createConversation").mockResolvedValue(
|
||||
errorTask as never,
|
||||
);
|
||||
vi.spyOn(V1ConversationService, "getStartTask").mockResolvedValue(
|
||||
errorTask as never,
|
||||
);
|
||||
|
||||
const { result } = renderHook(() => useNewConversationCommand(), { wrapper });
|
||||
|
||||
await expect(result.current.mutateAsync()).rejects.toThrow(
|
||||
"Sandbox crashed",
|
||||
);
|
||||
});
|
||||
|
||||
it("invalidates conversation list queries on success", async () => {
|
||||
const readyTask = makeStartTask();
|
||||
|
||||
vi.spyOn(V1ConversationService, "createConversation").mockResolvedValue(
|
||||
readyTask as never,
|
||||
);
|
||||
vi.spyOn(V1ConversationService, "getStartTask").mockResolvedValue(
|
||||
readyTask as never,
|
||||
);
|
||||
|
||||
const invalidateSpy = vi.spyOn(queryClient, "invalidateQueries");
|
||||
|
||||
const { result } = renderHook(() => useNewConversationCommand(), { wrapper });
|
||||
|
||||
await result.current.mutateAsync();
|
||||
|
||||
await waitFor(() => {
|
||||
expect(invalidateSpy).toHaveBeenCalledWith({
|
||||
queryKey: ["user", "conversations"],
|
||||
});
|
||||
expect(invalidateSpy).toHaveBeenCalledWith({
|
||||
queryKey: ["v1-batch-get-app-conversations"],
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
it("creates a standalone conversation (not a sub-conversation) so it appears in the list", async () => {
|
||||
const readyTask = makeStartTask();
|
||||
const createSpy = vi
|
||||
.spyOn(V1ConversationService, "createConversation")
|
||||
.mockResolvedValue(readyTask as never);
|
||||
vi.spyOn(V1ConversationService, "getStartTask").mockResolvedValue(
|
||||
readyTask as never,
|
||||
);
|
||||
|
||||
const { result } = renderHook(() => useNewConversationCommand(), { wrapper });
|
||||
|
||||
await result.current.mutateAsync();
|
||||
|
||||
await waitFor(() => {
|
||||
// parent_conversation_id should be undefined so the new conversation
|
||||
// is NOT a sub-conversation and will appear in the conversation list.
|
||||
expect(createSpy).toHaveBeenCalledWith(
|
||||
undefined, // selectedRepository (null from mock)
|
||||
undefined, // git_provider (null from mock)
|
||||
undefined, // initialUserMsg
|
||||
undefined, // selected_branch (null from mock)
|
||||
undefined, // conversationInstructions
|
||||
undefined, // suggestedTask
|
||||
undefined, // trigger
|
||||
undefined, // parent_conversation_id is NOT set
|
||||
undefined, // agent_type
|
||||
"sandbox-456", // sandbox_id IS set to reuse the sandbox
|
||||
"gpt-4o", // llm_model IS inherited from the original conversation
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it("shows a loading toast immediately and dismisses it on success", async () => {
|
||||
const readyTask = makeStartTask();
|
||||
|
||||
vi.spyOn(V1ConversationService, "createConversation").mockResolvedValue(
|
||||
readyTask as never,
|
||||
);
|
||||
vi.spyOn(V1ConversationService, "getStartTask").mockResolvedValue(
|
||||
readyTask as never,
|
||||
);
|
||||
|
||||
const { result } = renderHook(() => useNewConversationCommand(), { wrapper });
|
||||
|
||||
await result.current.mutateAsync();
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockToast.loading).toHaveBeenCalledWith(
|
||||
"CONVERSATION$CLEARING",
|
||||
expect.objectContaining({ id: "clear-conversation" }),
|
||||
);
|
||||
expect(mockToast.dismiss).toHaveBeenCalledWith("clear-conversation");
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,10 +1,15 @@
|
||||
import { renderHook, waitFor } from "@testing-library/react";
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
import { describe, expect, it, vi, beforeEach } from "vitest";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import SettingsService from "#/api/settings-service/settings-service.api";
|
||||
import { useSaveSettings } from "#/hooks/mutation/use-save-settings";
|
||||
import { useSelectedOrganizationStore } from "#/stores/selected-organization-store";
|
||||
|
||||
describe("useSaveSettings", () => {
|
||||
beforeEach(() => {
|
||||
useSelectedOrganizationStore.setState({ organizationId: "test-org-id" });
|
||||
});
|
||||
|
||||
it("should send an empty string for llm_api_key if an empty string is passed, otherwise undefined", async () => {
|
||||
const saveSettingsSpy = vi.spyOn(SettingsService, "saveSettings");
|
||||
const { result } = renderHook(() => useSaveSettings(), {
|
||||
|
||||
@@ -0,0 +1,225 @@
|
||||
import { renderHook, waitFor } from "@testing-library/react";
|
||||
import { describe, expect, it, vi, beforeEach } from "vitest";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import React from "react";
|
||||
import { useSettings } from "#/hooks/query/use-settings";
|
||||
import { useGetSecrets } from "#/hooks/query/use-get-secrets";
|
||||
import { useApiKeys } from "#/hooks/query/use-api-keys";
|
||||
import SettingsService from "#/api/settings-service/settings-service.api";
|
||||
import { SecretsService } from "#/api/secrets-service";
|
||||
import ApiKeysClient from "#/api/api-keys";
|
||||
import { MOCK_DEFAULT_USER_SETTINGS } from "#/mocks/handlers";
|
||||
import { useSelectedOrganizationStore } from "#/stores/selected-organization-store";
|
||||
|
||||
vi.mock("#/hooks/query/use-config", () => ({
|
||||
useConfig: () => ({
|
||||
data: { app_mode: "saas" },
|
||||
}),
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/query/use-is-authed", () => ({
|
||||
useIsAuthed: () => ({
|
||||
data: true,
|
||||
}),
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/use-is-on-intermediate-page", () => ({
|
||||
useIsOnIntermediatePage: () => false,
|
||||
}));
|
||||
|
||||
describe("Organization-scoped query hooks", () => {
|
||||
let queryClient: QueryClient;
|
||||
|
||||
const createWrapper = () => {
|
||||
return ({ children }: { children: React.ReactNode }) => (
|
||||
<QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
|
||||
);
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: {
|
||||
retry: false,
|
||||
},
|
||||
},
|
||||
});
|
||||
useSelectedOrganizationStore.setState({ organizationId: "org-1" });
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
describe("useSettings", () => {
|
||||
it("should include organizationId in query key for proper cache isolation", async () => {
|
||||
const getSettingsSpy = vi.spyOn(SettingsService, "getSettings");
|
||||
getSettingsSpy.mockResolvedValue(MOCK_DEFAULT_USER_SETTINGS);
|
||||
|
||||
const { result } = renderHook(() => useSettings(), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
await waitFor(() => expect(result.current.isFetched).toBe(true));
|
||||
|
||||
// Verify the query was cached with the org-specific key
|
||||
const cachedData = queryClient.getQueryData(["settings", "org-1"]);
|
||||
expect(cachedData).toBeDefined();
|
||||
|
||||
// Verify no data is cached under the old key without org ID
|
||||
const oldKeyData = queryClient.getQueryData(["settings"]);
|
||||
expect(oldKeyData).toBeUndefined();
|
||||
});
|
||||
|
||||
it("should refetch when organization changes", async () => {
|
||||
const getSettingsSpy = vi.spyOn(SettingsService, "getSettings");
|
||||
getSettingsSpy.mockResolvedValue({
|
||||
...MOCK_DEFAULT_USER_SETTINGS,
|
||||
language: "en",
|
||||
});
|
||||
|
||||
// First render with org-1
|
||||
const { result, rerender } = renderHook(() => useSettings(), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
await waitFor(() => expect(result.current.isFetched).toBe(true));
|
||||
expect(getSettingsSpy).toHaveBeenCalledTimes(1);
|
||||
|
||||
// Change organization
|
||||
useSelectedOrganizationStore.setState({ organizationId: "org-2" });
|
||||
getSettingsSpy.mockResolvedValue({
|
||||
...MOCK_DEFAULT_USER_SETTINGS,
|
||||
language: "es",
|
||||
});
|
||||
|
||||
// Rerender to pick up the new org ID
|
||||
rerender();
|
||||
|
||||
await waitFor(() => {
|
||||
// Should have fetched again for the new org
|
||||
expect(getSettingsSpy).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
// Verify both org caches exist independently
|
||||
const org1Data = queryClient.getQueryData(["settings", "org-1"]);
|
||||
const org2Data = queryClient.getQueryData(["settings", "org-2"]);
|
||||
expect(org1Data).toBeDefined();
|
||||
expect(org2Data).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe("useGetSecrets", () => {
|
||||
it("should include organizationId in query key for proper cache isolation", async () => {
|
||||
const getSecretsSpy = vi.spyOn(SecretsService, "getSecrets");
|
||||
getSecretsSpy.mockResolvedValue([]);
|
||||
|
||||
const { result } = renderHook(() => useGetSecrets(), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
await waitFor(() => expect(result.current.isFetched).toBe(true));
|
||||
|
||||
// Verify the query was cached with the org-specific key
|
||||
const cachedData = queryClient.getQueryData(["secrets", "org-1"]);
|
||||
expect(cachedData).toBeDefined();
|
||||
|
||||
// Verify no data is cached under the old key without org ID
|
||||
const oldKeyData = queryClient.getQueryData(["secrets"]);
|
||||
expect(oldKeyData).toBeUndefined();
|
||||
});
|
||||
|
||||
it("should fetch different data when organization changes", async () => {
|
||||
const getSecretsSpy = vi.spyOn(SecretsService, "getSecrets");
|
||||
|
||||
// Mock different secrets for different orgs
|
||||
getSecretsSpy.mockResolvedValueOnce([
|
||||
{ name: "SECRET_ORG_1", description: "Org 1 secret" },
|
||||
]);
|
||||
|
||||
const { result, rerender } = renderHook(() => useGetSecrets(), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
await waitFor(() => expect(result.current.isFetched).toBe(true));
|
||||
expect(result.current.data).toHaveLength(1);
|
||||
expect(result.current.data?.[0].name).toBe("SECRET_ORG_1");
|
||||
|
||||
// Change organization
|
||||
useSelectedOrganizationStore.setState({ organizationId: "org-2" });
|
||||
getSecretsSpy.mockResolvedValueOnce([
|
||||
{ name: "SECRET_ORG_2", description: "Org 2 secret" },
|
||||
]);
|
||||
|
||||
rerender();
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.data?.[0]?.name).toBe("SECRET_ORG_2");
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("useApiKeys", () => {
|
||||
it("should include organizationId in query key for proper cache isolation", async () => {
|
||||
const getApiKeysSpy = vi.spyOn(ApiKeysClient, "getApiKeys");
|
||||
getApiKeysSpy.mockResolvedValue([]);
|
||||
|
||||
const { result } = renderHook(() => useApiKeys(), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
await waitFor(() => expect(result.current.isFetched).toBe(true));
|
||||
|
||||
// Verify the query was cached with the org-specific key
|
||||
const cachedData = queryClient.getQueryData(["api-keys", "org-1"]);
|
||||
expect(cachedData).toBeDefined();
|
||||
|
||||
// Verify no data is cached under the old key without org ID
|
||||
const oldKeyData = queryClient.getQueryData(["api-keys"]);
|
||||
expect(oldKeyData).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Cache isolation between organizations", () => {
|
||||
it("should maintain separate caches for each organization", async () => {
|
||||
const getSettingsSpy = vi.spyOn(SettingsService, "getSettings");
|
||||
|
||||
// Simulate fetching for org-1
|
||||
getSettingsSpy.mockResolvedValueOnce({
|
||||
...MOCK_DEFAULT_USER_SETTINGS,
|
||||
language: "en",
|
||||
});
|
||||
|
||||
useSelectedOrganizationStore.setState({ organizationId: "org-1" });
|
||||
const { rerender } = renderHook(() => useSettings(), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(queryClient.getQueryData(["settings", "org-1"])).toBeDefined();
|
||||
});
|
||||
|
||||
// Switch to org-2
|
||||
getSettingsSpy.mockResolvedValueOnce({
|
||||
...MOCK_DEFAULT_USER_SETTINGS,
|
||||
language: "fr",
|
||||
});
|
||||
|
||||
useSelectedOrganizationStore.setState({ organizationId: "org-2" });
|
||||
rerender();
|
||||
|
||||
await waitFor(() => {
|
||||
expect(queryClient.getQueryData(["settings", "org-2"])).toBeDefined();
|
||||
});
|
||||
|
||||
// Switch back to org-1 - should use cached data, not refetch
|
||||
useSelectedOrganizationStore.setState({ organizationId: "org-1" });
|
||||
rerender();
|
||||
|
||||
// org-1 data should still be in cache
|
||||
const org1Cache = queryClient.getQueryData(["settings", "org-1"]) as any;
|
||||
expect(org1Cache?.language).toBe("en");
|
||||
|
||||
// org-2 data should also still be in cache
|
||||
const org2Cache = queryClient.getQueryData(["settings", "org-2"]) as any;
|
||||
expect(org2Cache?.language).toBe("fr");
|
||||
});
|
||||
});
|
||||
});
|
||||
64
frontend/__tests__/hooks/use-runtime-is-ready.test.tsx
Normal file
64
frontend/__tests__/hooks/use-runtime-is-ready.test.tsx
Normal file
@@ -0,0 +1,64 @@
|
||||
import { renderHook } from "@testing-library/react";
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import type { Conversation } from "#/api/open-hands.types";
|
||||
import { useRuntimeIsReady } from "#/hooks/use-runtime-is-ready";
|
||||
import { useAgentState } from "#/hooks/use-agent-state";
|
||||
import { useActiveConversation } from "#/hooks/query/use-active-conversation";
|
||||
import { AgentState } from "#/types/agent-state";
|
||||
|
||||
vi.mock("#/hooks/use-agent-state");
|
||||
vi.mock("#/hooks/query/use-active-conversation");
|
||||
|
||||
function asMockReturnValue<T>(value: Partial<T>): T {
|
||||
return value as T;
|
||||
}
|
||||
|
||||
function makeConversation(): Conversation {
|
||||
return {
|
||||
conversation_id: "conv-123",
|
||||
title: "Test Conversation",
|
||||
selected_repository: null,
|
||||
selected_branch: null,
|
||||
git_provider: null,
|
||||
last_updated_at: new Date().toISOString(),
|
||||
created_at: new Date().toISOString(),
|
||||
status: "RUNNING",
|
||||
runtime_status: null,
|
||||
url: null,
|
||||
session_api_key: null,
|
||||
};
|
||||
}
|
||||
|
||||
describe("useRuntimeIsReady", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
|
||||
vi.mocked(useActiveConversation).mockReturnValue(
|
||||
asMockReturnValue<ReturnType<typeof useActiveConversation>>({
|
||||
data: makeConversation(),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it("treats agent errors as not ready by default", () => {
|
||||
vi.mocked(useAgentState).mockReturnValue({
|
||||
curAgentState: AgentState.ERROR,
|
||||
});
|
||||
|
||||
const { result } = renderHook(() => useRuntimeIsReady());
|
||||
|
||||
expect(result.current).toBe(false);
|
||||
});
|
||||
|
||||
it("allows runtime-backed tabs to stay ready when the agent errors", () => {
|
||||
vi.mocked(useAgentState).mockReturnValue({
|
||||
curAgentState: AgentState.ERROR,
|
||||
});
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useRuntimeIsReady({ allowAgentError: true }),
|
||||
);
|
||||
|
||||
expect(result.current).toBe(true);
|
||||
});
|
||||
});
|
||||
@@ -205,7 +205,9 @@ describe("useWebSocket", () => {
|
||||
expect(result.current.isConnected).toBe(true);
|
||||
});
|
||||
|
||||
expect(onCloseSpy).not.toHaveBeenCalled();
|
||||
// Reset spy after connection is established to ignore any spurious
|
||||
// close events fired by the MSW mock during the handshake.
|
||||
onCloseSpy.mockClear();
|
||||
|
||||
// Unmount to trigger close
|
||||
unmount();
|
||||
|
||||
@@ -5,9 +5,11 @@ import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import AcceptTOS from "#/routes/accept-tos";
|
||||
import * as CaptureConsent from "#/utils/handle-capture-consent";
|
||||
import { openHands } from "#/api/open-hands-axios";
|
||||
import { useSelectedOrganizationStore } from "#/stores/selected-organization-store";
|
||||
|
||||
// Mock the react-router hooks
|
||||
vi.mock("react-router", () => ({
|
||||
vi.mock("react-router", async (importOriginal) => ({
|
||||
...(await importOriginal<typeof import("react-router")>()),
|
||||
useNavigate: () => vi.fn(),
|
||||
useSearchParams: () => [
|
||||
{
|
||||
@@ -19,6 +21,7 @@ vi.mock("react-router", () => ({
|
||||
},
|
||||
},
|
||||
],
|
||||
useRevalidator: () => ({ revalidate: vi.fn() }),
|
||||
}));
|
||||
|
||||
// Mock the axios instance
|
||||
@@ -54,6 +57,7 @@ const createWrapper = () => {
|
||||
|
||||
describe("AcceptTOS", () => {
|
||||
beforeEach(() => {
|
||||
useSelectedOrganizationStore.setState({ organizationId: "test-org-id" });
|
||||
vi.stubGlobal("location", { href: "" });
|
||||
});
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { render, screen, waitFor } from "@testing-library/react";
|
||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import AppSettingsScreen, { clientLoader } from "#/routes/app-settings";
|
||||
@@ -8,6 +8,11 @@ import { MOCK_DEFAULT_USER_SETTINGS } from "#/mocks/handlers";
|
||||
import { AvailableLanguages } from "#/i18n";
|
||||
import * as CaptureConsent from "#/utils/handle-capture-consent";
|
||||
import * as ToastHandlers from "#/utils/custom-toast-handlers";
|
||||
import { useSelectedOrganizationStore } from "#/stores/selected-organization-store";
|
||||
|
||||
beforeEach(() => {
|
||||
useSelectedOrganizationStore.setState({ organizationId: "test-org-id" });
|
||||
});
|
||||
|
||||
const renderAppSettingsScreen = () =>
|
||||
render(<AppSettingsScreen />, {
|
||||
|
||||
@@ -32,6 +32,7 @@ describe("Changes Tab", () => {
|
||||
vi.mocked(useUnifiedGetGitChanges).mockReturnValue({
|
||||
data: [],
|
||||
isLoading: false,
|
||||
isFetching: false,
|
||||
isSuccess: true,
|
||||
isError: false,
|
||||
error: null,
|
||||
@@ -50,6 +51,7 @@ describe("Changes Tab", () => {
|
||||
vi.mocked(useUnifiedGetGitChanges).mockReturnValue({
|
||||
data: [{ path: "src/file.ts", status: "M" }],
|
||||
isLoading: false,
|
||||
isFetching: false,
|
||||
isSuccess: true,
|
||||
isError: false,
|
||||
error: null,
|
||||
|
||||
@@ -283,305 +283,6 @@ describe("Manage Org Route", () => {
|
||||
expect(createCheckoutSessionSpy).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
describe("AddCreditsModal", () => {
|
||||
const openAddCreditsModal = async () => {
|
||||
const user = userEvent.setup();
|
||||
renderManageOrg();
|
||||
await screen.findByTestId("manage-org-screen");
|
||||
|
||||
await selectOrganization({ orgIndex: 0 }); // user is owner in org 1
|
||||
|
||||
const addCreditsButton = await waitFor(() => screen.getByText(/add/i));
|
||||
await user.click(addCreditsButton);
|
||||
|
||||
const addCreditsForm = screen.getByTestId("add-credits-form");
|
||||
expect(addCreditsForm).toBeInTheDocument();
|
||||
|
||||
return { user, addCreditsForm };
|
||||
};
|
||||
|
||||
describe("Button State Management", () => {
|
||||
it("should enable submit button initially when modal opens", async () => {
|
||||
await openAddCreditsModal();
|
||||
|
||||
const nextButton = screen.getByRole("button", { name: /next/i });
|
||||
expect(nextButton).not.toBeDisabled();
|
||||
});
|
||||
|
||||
it("should enable submit button when input contains invalid value", async () => {
|
||||
const { user } = await openAddCreditsModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /next/i });
|
||||
|
||||
await user.type(amountInput, "-50");
|
||||
|
||||
expect(nextButton).not.toBeDisabled();
|
||||
});
|
||||
|
||||
it("should enable submit button when input contains valid value", async () => {
|
||||
const { user } = await openAddCreditsModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /next/i });
|
||||
|
||||
await user.type(amountInput, "100");
|
||||
|
||||
expect(nextButton).not.toBeDisabled();
|
||||
});
|
||||
|
||||
it("should enable submit button after validation error is shown", async () => {
|
||||
const { user } = await openAddCreditsModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /next/i });
|
||||
|
||||
await user.type(amountInput, "9");
|
||||
await user.click(nextButton);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("amount-error")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
expect(nextButton).not.toBeDisabled();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Input Attributes & Placeholder", () => {
|
||||
it("should have min attribute set to 10", async () => {
|
||||
await openAddCreditsModal();
|
||||
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
expect(amountInput).toHaveAttribute("min", "10");
|
||||
});
|
||||
|
||||
it("should have max attribute set to 25000", async () => {
|
||||
await openAddCreditsModal();
|
||||
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
expect(amountInput).toHaveAttribute("max", "25000");
|
||||
});
|
||||
|
||||
it("should have step attribute set to 1", async () => {
|
||||
await openAddCreditsModal();
|
||||
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
expect(amountInput).toHaveAttribute("step", "1");
|
||||
});
|
||||
});
|
||||
|
||||
describe("Error Message Display", () => {
|
||||
it("should not display error message initially when modal opens", async () => {
|
||||
await openAddCreditsModal();
|
||||
|
||||
const errorMessage = screen.queryByTestId("amount-error");
|
||||
expect(errorMessage).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should display error message after submitting amount above maximum", async () => {
|
||||
const { user } = await openAddCreditsModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /next/i });
|
||||
|
||||
await user.type(amountInput, "25001");
|
||||
await user.click(nextButton);
|
||||
|
||||
await waitFor(() => {
|
||||
const errorMessage = screen.getByTestId("amount-error");
|
||||
expect(errorMessage).toHaveTextContent(
|
||||
"PAYMENT$ERROR_MAXIMUM_AMOUNT",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it("should display error message after submitting decimal value", async () => {
|
||||
const { user } = await openAddCreditsModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /next/i });
|
||||
|
||||
await user.type(amountInput, "50.5");
|
||||
await user.click(nextButton);
|
||||
|
||||
await waitFor(() => {
|
||||
const errorMessage = screen.getByTestId("amount-error");
|
||||
expect(errorMessage).toHaveTextContent(
|
||||
"PAYMENT$ERROR_MUST_BE_WHOLE_NUMBER",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it("should replace error message when submitting different invalid value", async () => {
|
||||
const { user } = await openAddCreditsModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /next/i });
|
||||
|
||||
await user.type(amountInput, "9");
|
||||
await user.click(nextButton);
|
||||
|
||||
await waitFor(() => {
|
||||
const errorMessage = screen.getByTestId("amount-error");
|
||||
expect(errorMessage).toHaveTextContent(
|
||||
"PAYMENT$ERROR_MINIMUM_AMOUNT",
|
||||
);
|
||||
});
|
||||
|
||||
await user.clear(amountInput);
|
||||
await user.type(amountInput, "25001");
|
||||
await user.click(nextButton);
|
||||
|
||||
await waitFor(() => {
|
||||
const errorMessage = screen.getByTestId("amount-error");
|
||||
expect(errorMessage).toHaveTextContent(
|
||||
"PAYMENT$ERROR_MAXIMUM_AMOUNT",
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("Form Submission Behavior", () => {
|
||||
it("should prevent submission when amount is invalid", async () => {
|
||||
const createCheckoutSessionSpy = vi.spyOn(
|
||||
BillingService,
|
||||
"createCheckoutSession",
|
||||
);
|
||||
const { user } = await openAddCreditsModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /next/i });
|
||||
|
||||
await user.type(amountInput, "9");
|
||||
await user.click(nextButton);
|
||||
|
||||
expect(createCheckoutSessionSpy).not.toHaveBeenCalled();
|
||||
await waitFor(() => {
|
||||
const errorMessage = screen.getByTestId("amount-error");
|
||||
expect(errorMessage).toHaveTextContent(
|
||||
"PAYMENT$ERROR_MINIMUM_AMOUNT",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it("should call createCheckoutSession with correct amount when valid", async () => {
|
||||
const createCheckoutSessionSpy = vi.spyOn(
|
||||
BillingService,
|
||||
"createCheckoutSession",
|
||||
);
|
||||
const { user } = await openAddCreditsModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /next/i });
|
||||
|
||||
await user.type(amountInput, "1000");
|
||||
await user.click(nextButton);
|
||||
|
||||
expect(createCheckoutSessionSpy).toHaveBeenCalledWith(1000);
|
||||
const errorMessage = screen.queryByTestId("amount-error");
|
||||
expect(errorMessage).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not call createCheckoutSession when validation fails", async () => {
|
||||
const createCheckoutSessionSpy = vi.spyOn(
|
||||
BillingService,
|
||||
"createCheckoutSession",
|
||||
);
|
||||
const { user } = await openAddCreditsModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /next/i });
|
||||
|
||||
await user.type(amountInput, "-50");
|
||||
await user.click(nextButton);
|
||||
|
||||
// Verify mutation was not called
|
||||
expect(createCheckoutSessionSpy).not.toHaveBeenCalled();
|
||||
await waitFor(() => {
|
||||
const errorMessage = screen.getByTestId("amount-error");
|
||||
expect(errorMessage).toHaveTextContent(
|
||||
"PAYMENT$ERROR_NEGATIVE_AMOUNT",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it("should close modal on successful submission", async () => {
|
||||
const createCheckoutSessionSpy = vi
|
||||
.spyOn(BillingService, "createCheckoutSession")
|
||||
.mockResolvedValue("https://checkout.stripe.com/test-session");
|
||||
|
||||
const { user } = await openAddCreditsModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /next/i });
|
||||
|
||||
await user.type(amountInput, "1000");
|
||||
await user.click(nextButton);
|
||||
|
||||
expect(createCheckoutSessionSpy).toHaveBeenCalledWith(1000);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.queryByTestId("add-credits-form"),
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should allow API call when validation passes and clear any previous errors", async () => {
|
||||
const createCheckoutSessionSpy = vi.spyOn(
|
||||
BillingService,
|
||||
"createCheckoutSession",
|
||||
);
|
||||
|
||||
const { user } = await openAddCreditsModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /next/i });
|
||||
|
||||
// First submit invalid value
|
||||
await user.type(amountInput, "9");
|
||||
await user.click(nextButton);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("amount-error")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// Then submit valid value
|
||||
await user.clear(amountInput);
|
||||
await user.type(amountInput, "100");
|
||||
await user.click(nextButton);
|
||||
|
||||
expect(createCheckoutSessionSpy).toHaveBeenCalledWith(100);
|
||||
const errorMessage = screen.queryByTestId("amount-error");
|
||||
expect(errorMessage).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Edge Cases", () => {
|
||||
it("should handle zero value correctly", async () => {
|
||||
const { user } = await openAddCreditsModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /next/i });
|
||||
|
||||
await user.type(amountInput, "0");
|
||||
await user.click(nextButton);
|
||||
|
||||
await waitFor(() => {
|
||||
const errorMessage = screen.getByTestId("amount-error");
|
||||
expect(errorMessage).toHaveTextContent(
|
||||
"PAYMENT$ERROR_MINIMUM_AMOUNT",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it("should handle whitespace-only input correctly", async () => {
|
||||
const createCheckoutSessionSpy = vi.spyOn(
|
||||
BillingService,
|
||||
"createCheckoutSession",
|
||||
);
|
||||
const { user } = await openAddCreditsModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /next/i });
|
||||
|
||||
// Number inputs typically don't accept spaces, but test the behavior
|
||||
await user.type(amountInput, " ");
|
||||
await user.click(nextButton);
|
||||
|
||||
// Should not call API (empty/invalid input)
|
||||
expect(createCheckoutSessionSpy).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
it("should show add credits option for ADMIN role", async () => {
|
||||
renderManageOrg();
|
||||
await screen.findByTestId("manage-org-screen");
|
||||
|
||||
65
frontend/__tests__/routes/vscode-tab.test.tsx
Normal file
65
frontend/__tests__/routes/vscode-tab.test.tsx
Normal file
@@ -0,0 +1,65 @@
|
||||
import { screen } from "@testing-library/react";
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { renderWithProviders } from "test-utils";
|
||||
import VSCodeTab from "#/routes/vscode-tab";
|
||||
import { useUnifiedVSCodeUrl } from "#/hooks/query/use-unified-vscode-url";
|
||||
import { useAgentState } from "#/hooks/use-agent-state";
|
||||
import { AgentState } from "#/types/agent-state";
|
||||
|
||||
vi.mock("#/hooks/query/use-unified-vscode-url");
|
||||
vi.mock("#/hooks/use-agent-state");
|
||||
vi.mock("#/utils/feature-flags", () => ({
|
||||
VSCODE_IN_NEW_TAB: () => false,
|
||||
}));
|
||||
|
||||
function mockVSCodeUrlHook(
|
||||
value: Partial<ReturnType<typeof useUnifiedVSCodeUrl>>,
|
||||
) {
|
||||
vi.mocked(useUnifiedVSCodeUrl).mockReturnValue({
|
||||
data: { url: "http://localhost:3000/vscode", error: null },
|
||||
error: null,
|
||||
isLoading: false,
|
||||
isError: false,
|
||||
isSuccess: true,
|
||||
status: "success",
|
||||
refetch: vi.fn(),
|
||||
...value,
|
||||
} as ReturnType<typeof useUnifiedVSCodeUrl>);
|
||||
}
|
||||
|
||||
describe("VSCodeTab", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it("keeps VSCode accessible when the agent is in an error state", () => {
|
||||
vi.mocked(useAgentState).mockReturnValue({
|
||||
curAgentState: AgentState.ERROR,
|
||||
});
|
||||
mockVSCodeUrlHook({});
|
||||
|
||||
renderWithProviders(<VSCodeTab />);
|
||||
|
||||
expect(
|
||||
screen.queryByText("DIFF_VIEWER$WAITING_FOR_RUNTIME"),
|
||||
).not.toBeInTheDocument();
|
||||
expect(screen.getByTitle("VSCODE$TITLE")).toHaveAttribute(
|
||||
"src",
|
||||
"http://localhost:3000/vscode",
|
||||
);
|
||||
});
|
||||
|
||||
it("still waits while the runtime is starting", () => {
|
||||
vi.mocked(useAgentState).mockReturnValue({
|
||||
curAgentState: AgentState.LOADING,
|
||||
});
|
||||
mockVSCodeUrlHook({});
|
||||
|
||||
renderWithProviders(<VSCodeTab />);
|
||||
|
||||
expect(
|
||||
screen.getByText("DIFF_VIEWER$WAITING_FOR_RUNTIME"),
|
||||
).toBeInTheDocument();
|
||||
expect(screen.queryByTitle("VSCODE$TITLE")).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
@@ -1,8 +1,13 @@
|
||||
import { render, screen } from "@testing-library/react";
|
||||
import { test, expect, describe, vi } from "vitest";
|
||||
import { test, expect, describe, vi, beforeEach } from "vitest";
|
||||
import { MemoryRouter } from "react-router";
|
||||
import { InteractiveChatBox } from "#/components/features/chat/interactive-chat-box";
|
||||
import { renderWithProviders } from "../../test-utils";
|
||||
import { useSelectedOrganizationStore } from "#/stores/selected-organization-store";
|
||||
|
||||
beforeEach(() => {
|
||||
useSelectedOrganizationStore.setState({ organizationId: "test-org-id" });
|
||||
});
|
||||
|
||||
// Mock the translation function
|
||||
vi.mock("react-i18next", async () => {
|
||||
@@ -29,14 +34,12 @@ vi.mock("#/hooks/query/use-active-conversation", () => ({
|
||||
}));
|
||||
|
||||
// Mock React Router hooks
|
||||
vi.mock("react-router", async () => {
|
||||
const actual = await vi.importActual("react-router");
|
||||
return {
|
||||
...actual,
|
||||
useNavigate: () => vi.fn(),
|
||||
useParams: () => ({ conversationId: "test-conversation-id" }),
|
||||
};
|
||||
});
|
||||
vi.mock("react-router", async (importOriginal) => ({
|
||||
...(await importOriginal<typeof import("react-router")>()),
|
||||
useNavigate: () => vi.fn(),
|
||||
useParams: () => ({ conversationId: "test-conversation-id" }),
|
||||
useRevalidator: () => ({ revalidate: vi.fn() }),
|
||||
}));
|
||||
|
||||
// Mock other hooks that might be used by the component
|
||||
vi.mock("#/hooks/use-user-providers", () => ({
|
||||
|
||||
@@ -4,27 +4,96 @@ import { getGitPath } from "#/utils/get-git-path";
|
||||
describe("getGitPath", () => {
|
||||
const conversationId = "abc123";
|
||||
|
||||
it("should return /workspace/project/{conversationId} when no repository is selected", () => {
|
||||
expect(getGitPath(conversationId, null)).toBe(`/workspace/project/${conversationId}`);
|
||||
expect(getGitPath(conversationId, undefined)).toBe(`/workspace/project/${conversationId}`);
|
||||
describe("without sandbox grouping (NO_GROUPING)", () => {
|
||||
it("should return /workspace/project when no repository is selected", () => {
|
||||
expect(getGitPath(conversationId, null, false)).toBe("/workspace/project");
|
||||
expect(getGitPath(conversationId, undefined, false)).toBe(
|
||||
"/workspace/project",
|
||||
);
|
||||
});
|
||||
|
||||
it("should handle standard owner/repo format (GitHub)", () => {
|
||||
expect(getGitPath(conversationId, "OpenHands/OpenHands", false)).toBe(
|
||||
"/workspace/project/OpenHands",
|
||||
);
|
||||
expect(getGitPath(conversationId, "facebook/react", false)).toBe(
|
||||
"/workspace/project/react",
|
||||
);
|
||||
});
|
||||
|
||||
it("should handle nested group paths (GitLab)", () => {
|
||||
expect(
|
||||
getGitPath(conversationId, "modernhealth/frontend-guild/pan", false),
|
||||
).toBe("/workspace/project/pan");
|
||||
expect(getGitPath(conversationId, "group/subgroup/repo", false)).toBe(
|
||||
"/workspace/project/repo",
|
||||
);
|
||||
expect(getGitPath(conversationId, "a/b/c/d/repo", false)).toBe(
|
||||
"/workspace/project/repo",
|
||||
);
|
||||
});
|
||||
|
||||
it("should handle single segment paths", () => {
|
||||
expect(getGitPath(conversationId, "repo", false)).toBe(
|
||||
"/workspace/project/repo",
|
||||
);
|
||||
});
|
||||
|
||||
it("should handle empty string", () => {
|
||||
expect(getGitPath(conversationId, "", false)).toBe("/workspace/project");
|
||||
});
|
||||
});
|
||||
|
||||
it("should handle standard owner/repo format (GitHub)", () => {
|
||||
expect(getGitPath(conversationId, "OpenHands/OpenHands")).toBe(`/workspace/project/${conversationId}/OpenHands`);
|
||||
expect(getGitPath(conversationId, "facebook/react")).toBe(`/workspace/project/${conversationId}/react`);
|
||||
describe("with sandbox grouping enabled", () => {
|
||||
it("should return /workspace/project/{conversationId} when no repository is selected", () => {
|
||||
expect(getGitPath(conversationId, null, true)).toBe(
|
||||
`/workspace/project/${conversationId}`,
|
||||
);
|
||||
expect(getGitPath(conversationId, undefined, true)).toBe(
|
||||
`/workspace/project/${conversationId}`,
|
||||
);
|
||||
});
|
||||
|
||||
it("should handle standard owner/repo format (GitHub)", () => {
|
||||
expect(getGitPath(conversationId, "OpenHands/OpenHands", true)).toBe(
|
||||
`/workspace/project/${conversationId}/OpenHands`,
|
||||
);
|
||||
expect(getGitPath(conversationId, "facebook/react", true)).toBe(
|
||||
`/workspace/project/${conversationId}/react`,
|
||||
);
|
||||
});
|
||||
|
||||
it("should handle nested group paths (GitLab)", () => {
|
||||
expect(
|
||||
getGitPath(conversationId, "modernhealth/frontend-guild/pan", true),
|
||||
).toBe(`/workspace/project/${conversationId}/pan`);
|
||||
expect(getGitPath(conversationId, "group/subgroup/repo", true)).toBe(
|
||||
`/workspace/project/${conversationId}/repo`,
|
||||
);
|
||||
expect(getGitPath(conversationId, "a/b/c/d/repo", true)).toBe(
|
||||
`/workspace/project/${conversationId}/repo`,
|
||||
);
|
||||
});
|
||||
|
||||
it("should handle single segment paths", () => {
|
||||
expect(getGitPath(conversationId, "repo", true)).toBe(
|
||||
`/workspace/project/${conversationId}/repo`,
|
||||
);
|
||||
});
|
||||
|
||||
it("should handle empty string", () => {
|
||||
expect(getGitPath(conversationId, "", true)).toBe(
|
||||
`/workspace/project/${conversationId}`,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it("should handle nested group paths (GitLab)", () => {
|
||||
expect(getGitPath(conversationId, "modernhealth/frontend-guild/pan")).toBe(`/workspace/project/${conversationId}/pan`);
|
||||
expect(getGitPath(conversationId, "group/subgroup/repo")).toBe(`/workspace/project/${conversationId}/repo`);
|
||||
expect(getGitPath(conversationId, "a/b/c/d/repo")).toBe(`/workspace/project/${conversationId}/repo`);
|
||||
});
|
||||
|
||||
it("should handle single segment paths", () => {
|
||||
expect(getGitPath(conversationId, "repo")).toBe(`/workspace/project/${conversationId}/repo`);
|
||||
});
|
||||
|
||||
it("should handle empty string", () => {
|
||||
expect(getGitPath(conversationId, "")).toBe(`/workspace/project/${conversationId}`);
|
||||
describe("default behavior (useSandboxGrouping defaults to false)", () => {
|
||||
it("should default to no sandbox grouping", () => {
|
||||
expect(getGitPath(conversationId, null)).toBe("/workspace/project");
|
||||
expect(getGitPath(conversationId, "owner/repo")).toBe(
|
||||
"/workspace/project/repo",
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -11,7 +11,7 @@ import { I18nKey } from "#/i18n/declaration";
|
||||
// Mock translations
|
||||
const t = (key: string) => {
|
||||
const translations: { [key: string]: string } = {
|
||||
COMMON$WAITING_FOR_SANDBOX: "Waiting For Sandbox",
|
||||
COMMON$WAITING_FOR_SANDBOX: "Waiting for sandbox",
|
||||
COMMON$STOPPING: "Stopping",
|
||||
COMMON$STARTING: "Starting",
|
||||
COMMON$SERVER_STOPPED: "Server stopped",
|
||||
@@ -69,7 +69,7 @@ describe("getStatusText", () => {
|
||||
t,
|
||||
});
|
||||
|
||||
expect(result).toBe(t(I18nKey.COMMON$WAITING_FOR_SANDBOX));
|
||||
expect(result).toBe("Waiting for sandbox");
|
||||
});
|
||||
|
||||
it("returns task detail when task status is ERROR and detail exists", () => {
|
||||
|
||||
7
frontend/package-lock.json
generated
7
frontend/package-lock.json
generated
@@ -15325,10 +15325,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/socket.io-parser": {
|
||||
"version": "4.2.5",
|
||||
"resolved": "https://registry.npmjs.org/socket.io-parser/-/socket.io-parser-4.2.5.tgz",
|
||||
"integrity": "sha512-bPMmpy/5WWKHea5Y/jYAP6k74A+hvmRCQaJuJB6I/ML5JZq/KfNieUVo/3Mh7SAqn7TyFdIo6wqYHInG1MU1bQ==",
|
||||
"license": "MIT",
|
||||
"version": "4.2.6",
|
||||
"resolved": "https://registry.npmjs.org/socket.io-parser/-/socket.io-parser-4.2.6.tgz",
|
||||
"integrity": "sha512-asJqbVBDsBCJx0pTqw3WfesSY0iRX+2xzWEWzrpcH7L6fLzrhyF8WPI8UaeM4YCuDfpwA/cgsdugMsmtz8EJeg==",
|
||||
"dependencies": {
|
||||
"@socket.io/component-emitter": "~3.1.0",
|
||||
"debug": "~4.4.1"
|
||||
|
||||
@@ -68,6 +68,8 @@ class V1ConversationService {
|
||||
trigger?: ConversationTrigger,
|
||||
parent_conversation_id?: string,
|
||||
agent_type?: "default" | "plan",
|
||||
sandbox_id?: string,
|
||||
llm_model?: string,
|
||||
): Promise<V1AppConversationStartTask> {
|
||||
const body: V1AppConversationStartRequest = {
|
||||
selected_repository: selectedRepository,
|
||||
@@ -78,6 +80,8 @@ class V1ConversationService {
|
||||
trigger,
|
||||
parent_conversation_id: parent_conversation_id || null,
|
||||
agent_type,
|
||||
sandbox_id: sandbox_id || null,
|
||||
llm_model: llm_model || null,
|
||||
};
|
||||
|
||||
// suggested_task implies the backend will construct the initial_message
|
||||
|
||||
@@ -38,6 +38,8 @@ import { useTaskPolling } from "#/hooks/query/use-task-polling";
|
||||
import { useConversationWebSocket } from "#/contexts/conversation-websocket-context";
|
||||
import ChatStatusIndicator from "./chat-status-indicator";
|
||||
import { getStatusColor, getStatusText } from "#/utils/utils";
|
||||
import { useNewConversationCommand } from "#/hooks/mutation/use-new-conversation-command";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
|
||||
function getEntryPoint(
|
||||
hasRepository: boolean | null,
|
||||
@@ -80,6 +82,10 @@ export function ChatInterface() {
|
||||
setHitBottom,
|
||||
} = useScrollToBottom(scrollRef);
|
||||
const { data: config } = useConfig();
|
||||
const {
|
||||
mutate: newConversationCommand,
|
||||
isPending: isNewConversationPending,
|
||||
} = useNewConversationCommand();
|
||||
|
||||
const { curAgentState } = useAgentState();
|
||||
const { handleBuildPlanClick } = useHandleBuildPlanClick();
|
||||
@@ -146,6 +152,27 @@ export function ChatInterface() {
|
||||
originalImages: File[],
|
||||
originalFiles: File[],
|
||||
) => {
|
||||
// Handle /new command for V1 conversations
|
||||
if (content.trim() === "/new") {
|
||||
if (!isV1Conversation) {
|
||||
displayErrorToast(t(I18nKey.CONVERSATION$CLEAR_V1_ONLY));
|
||||
return;
|
||||
}
|
||||
if (!params.conversationId) {
|
||||
displayErrorToast(t(I18nKey.CONVERSATION$CLEAR_NO_ID));
|
||||
return;
|
||||
}
|
||||
if (totalEvents === 0) {
|
||||
displayErrorToast(t(I18nKey.CONVERSATION$CLEAR_EMPTY));
|
||||
return;
|
||||
}
|
||||
if (isNewConversationPending) {
|
||||
return;
|
||||
}
|
||||
newConversationCommand();
|
||||
return;
|
||||
}
|
||||
|
||||
// Create mutable copies of the arrays
|
||||
const images = [...originalImages];
|
||||
const files = [...originalFiles];
|
||||
@@ -338,7 +365,10 @@ export function ChatInterface() {
|
||||
/>
|
||||
)}
|
||||
|
||||
<InteractiveChatBox onSubmit={handleSendMessage} />
|
||||
<InteractiveChatBox
|
||||
onSubmit={handleSendMessage}
|
||||
disabled={isNewConversationPending}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{config?.app_mode !== "saas" && !isV1Conversation && (
|
||||
|
||||
@@ -12,6 +12,7 @@ interface ChatInputContainerProps {
|
||||
chatContainerRef: React.RefObject<HTMLDivElement | null>;
|
||||
isDragOver: boolean;
|
||||
disabled: boolean;
|
||||
isNewConversationPending?: boolean;
|
||||
showButton: boolean;
|
||||
buttonClassName: string;
|
||||
chatInputRef: React.RefObject<HTMLDivElement | null>;
|
||||
@@ -36,6 +37,7 @@ export function ChatInputContainer({
|
||||
chatContainerRef,
|
||||
isDragOver,
|
||||
disabled,
|
||||
isNewConversationPending = false,
|
||||
showButton,
|
||||
buttonClassName,
|
||||
chatInputRef,
|
||||
@@ -89,6 +91,7 @@ export function ChatInputContainer({
|
||||
<ChatInputRow
|
||||
chatInputRef={chatInputRef}
|
||||
disabled={disabled}
|
||||
isNewConversationPending={isNewConversationPending}
|
||||
showButton={showButton}
|
||||
buttonClassName={buttonClassName}
|
||||
handleFileIconClick={handleFileIconClick}
|
||||
|
||||
@@ -2,9 +2,11 @@ import React from "react";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { useConversationStore } from "#/stores/conversation-store";
|
||||
import { cn } from "#/utils/utils";
|
||||
|
||||
interface ChatInputFieldProps {
|
||||
chatInputRef: React.RefObject<HTMLDivElement | null>;
|
||||
disabled?: boolean;
|
||||
onInput: () => void;
|
||||
onPaste: (e: React.ClipboardEvent) => void;
|
||||
onKeyDown: (e: React.KeyboardEvent) => void;
|
||||
@@ -14,6 +16,7 @@ interface ChatInputFieldProps {
|
||||
|
||||
export function ChatInputField({
|
||||
chatInputRef,
|
||||
disabled = false,
|
||||
onInput,
|
||||
onPaste,
|
||||
onKeyDown,
|
||||
@@ -36,8 +39,11 @@ export function ChatInputField({
|
||||
<div className="basis-0 flex flex-col font-normal grow justify-center leading-[0] min-h-px min-w-px overflow-ellipsis overflow-hidden relative shrink-0 text-[#d0d9fa] text-[16px] text-left">
|
||||
<div
|
||||
ref={chatInputRef}
|
||||
className="chat-input bg-transparent text-white text-[16px] font-normal leading-[20px] outline-none resize-none custom-scrollbar min-h-[20px] max-h-[400px] [text-overflow:inherit] [text-wrap-mode:inherit] [white-space-collapse:inherit] block whitespace-pre-wrap"
|
||||
contentEditable
|
||||
className={cn(
|
||||
"chat-input bg-transparent text-white text-[16px] font-normal leading-[20px] outline-none resize-none custom-scrollbar min-h-[20px] max-h-[400px] [text-overflow:inherit] [text-wrap-mode:inherit] [white-space-collapse:inherit] block whitespace-pre-wrap",
|
||||
disabled && "cursor-not-allowed opacity-50",
|
||||
)}
|
||||
contentEditable={!disabled}
|
||||
data-placeholder={
|
||||
isPlanMode
|
||||
? t(I18nKey.COMMON$LET_S_WORK_ON_A_PLAN)
|
||||
|
||||
@@ -7,6 +7,7 @@ import { ChatInputField } from "./chat-input-field";
|
||||
interface ChatInputRowProps {
|
||||
chatInputRef: React.RefObject<HTMLDivElement | null>;
|
||||
disabled: boolean;
|
||||
isNewConversationPending?: boolean;
|
||||
showButton: boolean;
|
||||
buttonClassName: string;
|
||||
handleFileIconClick: (isDisabled: boolean) => void;
|
||||
@@ -21,6 +22,7 @@ interface ChatInputRowProps {
|
||||
export function ChatInputRow({
|
||||
chatInputRef,
|
||||
disabled,
|
||||
isNewConversationPending = false,
|
||||
showButton,
|
||||
buttonClassName,
|
||||
handleFileIconClick,
|
||||
@@ -41,6 +43,7 @@ export function ChatInputRow({
|
||||
|
||||
<ChatInputField
|
||||
chatInputRef={chatInputRef}
|
||||
disabled={isNewConversationPending}
|
||||
onInput={onInput}
|
||||
onPaste={onPaste}
|
||||
onKeyDown={onKeyDown}
|
||||
|
||||
@@ -13,6 +13,7 @@ import { useConversationStore } from "#/stores/conversation-store";
|
||||
|
||||
export interface CustomChatInputProps {
|
||||
disabled?: boolean;
|
||||
isNewConversationPending?: boolean;
|
||||
showButton?: boolean;
|
||||
conversationStatus?: ConversationStatus | null;
|
||||
onSubmit: (message: string) => void;
|
||||
@@ -25,6 +26,7 @@ export interface CustomChatInputProps {
|
||||
|
||||
export function CustomChatInput({
|
||||
disabled = false,
|
||||
isNewConversationPending = false,
|
||||
showButton = true,
|
||||
conversationStatus = null,
|
||||
onSubmit,
|
||||
@@ -147,6 +149,7 @@ export function CustomChatInput({
|
||||
chatContainerRef={chatContainerRef}
|
||||
isDragOver={isDragOver}
|
||||
disabled={isDisabled}
|
||||
isNewConversationPending={isNewConversationPending}
|
||||
showButton={showButton}
|
||||
buttonClassName={buttonClassName}
|
||||
chatInputRef={chatInputRef}
|
||||
|
||||
@@ -13,9 +13,13 @@ import { isTaskPolling } from "#/utils/utils";
|
||||
|
||||
interface InteractiveChatBoxProps {
|
||||
onSubmit: (message: string, images: File[], files: File[]) => void;
|
||||
disabled?: boolean;
|
||||
}
|
||||
|
||||
export function InteractiveChatBox({ onSubmit }: InteractiveChatBoxProps) {
|
||||
export function InteractiveChatBox({
|
||||
onSubmit,
|
||||
disabled = false,
|
||||
}: InteractiveChatBoxProps) {
|
||||
const {
|
||||
images,
|
||||
files,
|
||||
@@ -145,6 +149,7 @@ export function InteractiveChatBox({ onSubmit }: InteractiveChatBoxProps) {
|
||||
// Allow users to submit messages during LOADING state - they will be
|
||||
// queued server-side and delivered when the conversation becomes ready
|
||||
const isDisabled =
|
||||
disabled ||
|
||||
curAgentState === AgentState.AWAITING_USER_CONFIRMATION ||
|
||||
isTaskPolling(subConversationTaskStatus);
|
||||
|
||||
@@ -152,6 +157,7 @@ export function InteractiveChatBox({ onSubmit }: InteractiveChatBoxProps) {
|
||||
<div data-testid="interactive-chat-box">
|
||||
<CustomChatInput
|
||||
disabled={isDisabled}
|
||||
isNewConversationPending={disabled}
|
||||
onSubmit={handleSubmit}
|
||||
onFilesPaste={handleUpload}
|
||||
conversationStatus={conversation?.status || null}
|
||||
|
||||
@@ -20,7 +20,7 @@ export function ConversationTabTitle({
|
||||
conversationKey,
|
||||
}: ConversationTabTitleProps) {
|
||||
const { t } = useTranslation();
|
||||
const { refetch } = useUnifiedGetGitChanges();
|
||||
const { refetch, isFetching } = useUnifiedGetGitChanges();
|
||||
const { handleBuildPlanClick } = useHandleBuildPlanClick();
|
||||
const { curAgentState } = useAgentState();
|
||||
const { planContent } = useConversationStore();
|
||||
@@ -41,10 +41,16 @@ export function ConversationTabTitle({
|
||||
{conversationKey === "editor" && (
|
||||
<button
|
||||
type="button"
|
||||
className="flex w-[26px] py-1 justify-center items-center gap-[10px] rounded-[7px] hover:bg-[#474A54] cursor-pointer"
|
||||
className="flex w-[26px] py-1 justify-center items-center gap-[10px] rounded-[7px] hover:enabled:bg-[#474A54] cursor-pointer disabled:opacity-50 disabled:cursor-not-allowed"
|
||||
onClick={handleRefresh}
|
||||
disabled={isFetching}
|
||||
>
|
||||
<RefreshIcon width={12.75} height={15} color="#ffffff" />
|
||||
<RefreshIcon
|
||||
width={12.75}
|
||||
height={15}
|
||||
color="#ffffff"
|
||||
className={isFetching ? "animate-spin" : ""}
|
||||
/>
|
||||
</button>
|
||||
)}
|
||||
{conversationKey === "planner" && (
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
import { FaExternalLinkAlt } from "react-icons/fa";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { RUNTIME_INACTIVE_STATES } from "#/types/agent-state";
|
||||
import { useAgentState } from "#/hooks/use-agent-state";
|
||||
import { useUnifiedVSCodeUrl } from "#/hooks/query/use-unified-vscode-url";
|
||||
import { RUNTIME_STARTING_STATES } from "#/types/agent-state";
|
||||
|
||||
export function VSCodeTooltipContent() {
|
||||
const { curAgentState } = useAgentState();
|
||||
const { t } = useTranslation();
|
||||
const { data, refetch } = useUnifiedVSCodeUrl();
|
||||
const isRuntimeStarting = RUNTIME_STARTING_STATES.includes(curAgentState);
|
||||
|
||||
const handleVSCodeClick = async (e: React.MouseEvent) => {
|
||||
e.preventDefault();
|
||||
@@ -29,7 +30,7 @@ export function VSCodeTooltipContent() {
|
||||
return (
|
||||
<div className="flex items-center gap-2">
|
||||
<span>{t(I18nKey.COMMON$CODE)}</span>
|
||||
{!RUNTIME_INACTIVE_STATES.includes(curAgentState) ? (
|
||||
{!isRuntimeStarting ? (
|
||||
<FaExternalLinkAlt
|
||||
className="w-3 h-3 text-inherit cursor-pointer"
|
||||
onClick={handleVSCodeClick}
|
||||
|
||||
@@ -2,6 +2,7 @@ import React from "react";
|
||||
import { ExtraProps } from "react-markdown";
|
||||
import { Prism as SyntaxHighlighter } from "react-syntax-highlighter";
|
||||
import { vscDarkPlus } from "react-syntax-highlighter/dist/esm/styles/prism";
|
||||
import { CopyableContentWrapper } from "#/components/shared/buttons/copyable-content-wrapper";
|
||||
|
||||
// See https://github.com/remarkjs/react-markdown?tab=readme-ov-file#use-custom-components-syntax-highlight
|
||||
|
||||
@@ -15,6 +16,7 @@ export function code({
|
||||
React.HTMLAttributes<HTMLElement> &
|
||||
ExtraProps) {
|
||||
const match = /language-(\w+)/.exec(className || ""); // get the language
|
||||
const codeString = String(children).replace(/\n$/, "");
|
||||
|
||||
if (!match) {
|
||||
const isMultiline = String(children).includes("\n");
|
||||
@@ -37,29 +39,33 @@ export function code({
|
||||
}
|
||||
|
||||
return (
|
||||
<pre
|
||||
style={{
|
||||
backgroundColor: "#2a3038",
|
||||
padding: "1em",
|
||||
borderRadius: "4px",
|
||||
color: "#e6edf3",
|
||||
border: "1px solid #30363d",
|
||||
overflow: "auto",
|
||||
}}
|
||||
>
|
||||
<code className={className}>{String(children).replace(/\n$/, "")}</code>
|
||||
</pre>
|
||||
<CopyableContentWrapper text={codeString}>
|
||||
<pre
|
||||
style={{
|
||||
backgroundColor: "#2a3038",
|
||||
padding: "1em",
|
||||
borderRadius: "4px",
|
||||
color: "#e6edf3",
|
||||
border: "1px solid #30363d",
|
||||
overflow: "auto",
|
||||
}}
|
||||
>
|
||||
<code className={className}>{codeString}</code>
|
||||
</pre>
|
||||
</CopyableContentWrapper>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<SyntaxHighlighter
|
||||
className="rounded-lg"
|
||||
style={vscDarkPlus}
|
||||
language={match?.[1]}
|
||||
PreTag="div"
|
||||
>
|
||||
{String(children).replace(/\n$/, "")}
|
||||
</SyntaxHighlighter>
|
||||
<CopyableContentWrapper text={codeString}>
|
||||
<SyntaxHighlighter
|
||||
className="rounded-lg"
|
||||
style={vscDarkPlus}
|
||||
language={match?.[1]}
|
||||
PreTag="div"
|
||||
>
|
||||
{codeString}
|
||||
</SyntaxHighlighter>
|
||||
</CopyableContentWrapper>
|
||||
);
|
||||
}
|
||||
|
||||
103
frontend/src/components/features/org/add-credits-modal.tsx
Normal file
103
frontend/src/components/features/org/add-credits-modal.tsx
Normal file
@@ -0,0 +1,103 @@
|
||||
import React from "react";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { useCreateStripeCheckoutSession } from "#/hooks/mutation/stripe/use-create-stripe-checkout-session";
|
||||
import { ModalBackdrop } from "#/components/shared/modals/modal-backdrop";
|
||||
import { ModalButtonGroup } from "#/components/shared/modals/modal-button-group";
|
||||
import { SettingsInput } from "#/components/features/settings/settings-input";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { amountIsValid } from "#/utils/amount-is-valid";
|
||||
|
||||
interface AddCreditsModalProps {
|
||||
onClose: () => void;
|
||||
}
|
||||
|
||||
export function AddCreditsModal({ onClose }: AddCreditsModalProps) {
|
||||
const { t } = useTranslation();
|
||||
const { mutate: addBalance } = useCreateStripeCheckoutSession();
|
||||
|
||||
const [inputValue, setInputValue] = React.useState("");
|
||||
const [errorMessage, setErrorMessage] = React.useState<string | null>(null);
|
||||
|
||||
const getErrorMessage = (value: string): string | null => {
|
||||
if (!value.trim()) return null;
|
||||
|
||||
const numValue = parseInt(value, 10);
|
||||
if (Number.isNaN(numValue)) {
|
||||
return t(I18nKey.PAYMENT$ERROR_INVALID_NUMBER);
|
||||
}
|
||||
if (numValue < 0) {
|
||||
return t(I18nKey.PAYMENT$ERROR_NEGATIVE_AMOUNT);
|
||||
}
|
||||
if (numValue < 10) {
|
||||
return t(I18nKey.PAYMENT$ERROR_MINIMUM_AMOUNT);
|
||||
}
|
||||
if (numValue > 25000) {
|
||||
return t(I18nKey.PAYMENT$ERROR_MAXIMUM_AMOUNT);
|
||||
}
|
||||
if (numValue !== parseFloat(value)) {
|
||||
return t(I18nKey.PAYMENT$ERROR_MUST_BE_WHOLE_NUMBER);
|
||||
}
|
||||
return null;
|
||||
};
|
||||
|
||||
const formAction = (formData: FormData) => {
|
||||
const amount = formData.get("amount")?.toString();
|
||||
|
||||
if (amount?.trim()) {
|
||||
if (!amountIsValid(amount)) {
|
||||
const error = getErrorMessage(amount);
|
||||
setErrorMessage(error || "Invalid amount");
|
||||
return;
|
||||
}
|
||||
|
||||
const intValue = parseInt(amount, 10);
|
||||
|
||||
addBalance({ amount: intValue }, { onSuccess: onClose });
|
||||
|
||||
setErrorMessage(null);
|
||||
}
|
||||
};
|
||||
|
||||
const handleAmountInputChange = (value: string) => {
|
||||
setInputValue(value);
|
||||
setErrorMessage(null);
|
||||
};
|
||||
|
||||
return (
|
||||
<ModalBackdrop onClose={onClose}>
|
||||
<form
|
||||
data-testid="add-credits-form"
|
||||
action={formAction}
|
||||
noValidate
|
||||
className="w-sm rounded-xl bg-base-secondary flex flex-col p-6 gap-4 border border-tertiary"
|
||||
>
|
||||
<h3 className="text-xl font-bold">{t(I18nKey.ORG$ADD_CREDITS)}</h3>
|
||||
<div className="flex flex-col gap-2">
|
||||
<SettingsInput
|
||||
testId="amount-input"
|
||||
name="amount"
|
||||
label={t(I18nKey.PAYMENT$SPECIFY_AMOUNT_USD)}
|
||||
type="number"
|
||||
min={10}
|
||||
max={25000}
|
||||
step={1}
|
||||
value={inputValue}
|
||||
onChange={(value) => handleAmountInputChange(value)}
|
||||
className="w-full"
|
||||
/>
|
||||
{errorMessage && (
|
||||
<p className="text-red-500 text-sm mt-1" data-testid="amount-error">
|
||||
{errorMessage}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<ModalButtonGroup
|
||||
primaryText={t(I18nKey.ORG$NEXT)}
|
||||
onSecondaryClick={onClose}
|
||||
primaryType="submit"
|
||||
/>
|
||||
</form>
|
||||
</ModalBackdrop>
|
||||
);
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import { BrandButton } from "../brand-button";
|
||||
import { useGetSecrets } from "#/hooks/query/use-get-secrets";
|
||||
import { GetSecretsResponse } from "#/api/secrets-service.types";
|
||||
import { OptionalTag } from "../optional-tag";
|
||||
import { useSelectedOrganizationId } from "#/context/use-selected-organization";
|
||||
|
||||
interface SecretFormProps {
|
||||
mode: "add" | "edit";
|
||||
@@ -24,6 +25,7 @@ export function SecretForm({
|
||||
}: SecretFormProps) {
|
||||
const queryClient = useQueryClient();
|
||||
const { t } = useTranslation();
|
||||
const { organizationId } = useSelectedOrganizationId();
|
||||
|
||||
const { data: secrets } = useGetSecrets();
|
||||
const { mutate: createSecret } = useCreateSecret();
|
||||
@@ -49,7 +51,9 @@ export function SecretForm({
|
||||
{
|
||||
onSettled: onCancel,
|
||||
onSuccess: async () => {
|
||||
await queryClient.invalidateQueries({ queryKey: ["secrets"] });
|
||||
await queryClient.invalidateQueries({
|
||||
queryKey: ["secrets", organizationId],
|
||||
});
|
||||
},
|
||||
},
|
||||
);
|
||||
@@ -61,7 +65,7 @@ export function SecretForm({
|
||||
description?: string,
|
||||
) => {
|
||||
queryClient.setQueryData<GetSecretsResponse["custom_secrets"]>(
|
||||
["secrets"],
|
||||
["secrets", organizationId],
|
||||
(oldSecrets) => {
|
||||
if (!oldSecrets) return [];
|
||||
return oldSecrets.map((secret) => {
|
||||
@@ -79,7 +83,7 @@ export function SecretForm({
|
||||
};
|
||||
|
||||
const revertOptimisticUpdate = () => {
|
||||
queryClient.invalidateQueries({ queryKey: ["secrets"] });
|
||||
queryClient.invalidateQueries({ queryKey: ["secrets", organizationId] });
|
||||
};
|
||||
|
||||
const handleEditSecret = (
|
||||
|
||||
@@ -22,20 +22,43 @@ export function UserActions({ user, isLoading }: UserActionsProps) {
|
||||
const [menuResetCount, setMenuResetCount] = React.useState(0);
|
||||
const [inviteMemberModalIsOpen, setInviteMemberModalIsOpen] =
|
||||
React.useState(false);
|
||||
const hideTimeoutRef = React.useRef<number | null>(null);
|
||||
|
||||
// Use the shared hook to determine if user actions should be shown
|
||||
const shouldShowUserActions = useShouldShowUserFeatures();
|
||||
|
||||
// Clean up timeout on unmount
|
||||
React.useEffect(
|
||||
() => () => {
|
||||
if (hideTimeoutRef.current) {
|
||||
clearTimeout(hideTimeoutRef.current);
|
||||
}
|
||||
},
|
||||
[],
|
||||
);
|
||||
|
||||
const showAccountMenu = () => {
|
||||
// Cancel any pending hide to allow diagonal mouse movement to menu
|
||||
if (hideTimeoutRef.current) {
|
||||
clearTimeout(hideTimeoutRef.current);
|
||||
hideTimeoutRef.current = null;
|
||||
}
|
||||
setAccountContextMenuIsVisible(true);
|
||||
};
|
||||
|
||||
const hideAccountMenu = () => {
|
||||
setAccountContextMenuIsVisible(false);
|
||||
setMenuResetCount((c) => c + 1);
|
||||
// Delay hiding to allow diagonal mouse movement to menu
|
||||
hideTimeoutRef.current = window.setTimeout(() => {
|
||||
setAccountContextMenuIsVisible(false);
|
||||
setMenuResetCount((c) => c + 1);
|
||||
}, 500);
|
||||
};
|
||||
|
||||
const closeAccountMenu = () => {
|
||||
if (hideTimeoutRef.current) {
|
||||
clearTimeout(hideTimeoutRef.current);
|
||||
hideTimeoutRef.current = null;
|
||||
}
|
||||
if (accountContextMenuIsVisible) {
|
||||
setAccountContextMenuIsVisible(false);
|
||||
setMenuResetCount((c) => c + 1);
|
||||
@@ -61,9 +84,6 @@ export function UserActions({ user, isLoading }: UserActionsProps) {
|
||||
className={cn(
|
||||
"opacity-0 pointer-events-none group-hover:opacity-100 group-hover:pointer-events-auto",
|
||||
accountContextMenuIsVisible && "opacity-100 pointer-events-auto",
|
||||
// Invisible hover bridge: extends hover zone to create a "safe corridor"
|
||||
// for diagonal mouse movement to the menu (only active when menu is visible)
|
||||
"group-hover:before:content-[''] group-hover:before:block group-hover:before:absolute group-hover:before:inset-[-320px] group-hover:before:z-50 before:pointer-events-none",
|
||||
)}
|
||||
>
|
||||
<UserContextMenu
|
||||
|
||||
@@ -156,13 +156,16 @@ export function UserContextMenu({
|
||||
{t(I18nKey.SIDEBAR$DOCS)}
|
||||
</a>
|
||||
|
||||
<ContextMenuListItem
|
||||
onClick={handleLogout}
|
||||
className={contextMenuListItemClassName}
|
||||
>
|
||||
<IoLogOutOutline className="text-white" size={16} />
|
||||
{t(I18nKey.ACCOUNT_SETTINGS$LOGOUT)}
|
||||
</ContextMenuListItem>
|
||||
{/* Only show logout in saas mode - oss mode has no session to invalidate */}
|
||||
{isSaasMode && (
|
||||
<ContextMenuListItem
|
||||
onClick={handleLogout}
|
||||
className={contextMenuListItemClassName}
|
||||
>
|
||||
<IoLogOutOutline className="text-white" size={16} />
|
||||
{t(I18nKey.ACCOUNT_SETTINGS$LOGOUT)}
|
||||
</ContextMenuListItem>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
import React from "react";
|
||||
import { CopyToClipboardButton } from "./copy-to-clipboard-button";
|
||||
|
||||
export function CopyableContentWrapper({
|
||||
text,
|
||||
children,
|
||||
}: {
|
||||
text: string;
|
||||
children: React.ReactNode;
|
||||
}) {
|
||||
const [isHovering, setIsHovering] = React.useState(false);
|
||||
const [isCopied, setIsCopied] = React.useState(false);
|
||||
|
||||
const handleCopy = async () => {
|
||||
await navigator.clipboard.writeText(text);
|
||||
setIsCopied(true);
|
||||
};
|
||||
|
||||
React.useEffect(() => {
|
||||
let timeout: NodeJS.Timeout;
|
||||
if (isCopied) {
|
||||
timeout = setTimeout(() => setIsCopied(false), 2000);
|
||||
}
|
||||
return () => clearTimeout(timeout);
|
||||
}, [isCopied]);
|
||||
|
||||
return (
|
||||
<div
|
||||
className="relative"
|
||||
onMouseEnter={() => setIsHovering(true)}
|
||||
onMouseLeave={() => setIsHovering(false)}
|
||||
>
|
||||
<div className="absolute top-2 right-2 z-10">
|
||||
<CopyToClipboardButton
|
||||
isHidden={!isHovering}
|
||||
isDisabled={isCopied}
|
||||
onClick={handleCopy}
|
||||
mode={isCopied ? "copied" : "copy"}
|
||||
/>
|
||||
</div>
|
||||
{children}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -2,10 +2,12 @@ import { useMutation, useQueryClient } from "@tanstack/react-query";
|
||||
import { SecretsService } from "#/api/secrets-service";
|
||||
import { Provider, ProviderToken } from "#/types/settings";
|
||||
import { useTracking } from "#/hooks/use-tracking";
|
||||
import { useSelectedOrganizationId } from "#/context/use-selected-organization";
|
||||
|
||||
export const useAddGitProviders = () => {
|
||||
const queryClient = useQueryClient();
|
||||
const { trackGitProviderConnected } = useTracking();
|
||||
const { organizationId } = useSelectedOrganizationId();
|
||||
|
||||
return useMutation({
|
||||
mutationFn: ({
|
||||
@@ -25,7 +27,9 @@ export const useAddGitProviders = () => {
|
||||
});
|
||||
}
|
||||
|
||||
await queryClient.invalidateQueries({ queryKey: ["settings"] });
|
||||
await queryClient.invalidateQueries({
|
||||
queryKey: ["settings", organizationId],
|
||||
});
|
||||
},
|
||||
meta: {
|
||||
disableToast: true,
|
||||
|
||||
@@ -2,6 +2,7 @@ import { useMutation, useQueryClient } from "@tanstack/react-query";
|
||||
import { useSettings } from "#/hooks/query/use-settings";
|
||||
import SettingsService from "#/api/settings-service/settings-service.api";
|
||||
import { MCPSSEServer, MCPStdioServer, MCPSHTTPServer } from "#/types/settings";
|
||||
import { useSelectedOrganizationId } from "#/context/use-selected-organization";
|
||||
|
||||
type MCPServerType = "sse" | "stdio" | "shttp";
|
||||
|
||||
@@ -19,6 +20,7 @@ interface MCPServerConfig {
|
||||
export function useAddMcpServer() {
|
||||
const queryClient = useQueryClient();
|
||||
const { data: settings } = useSettings();
|
||||
const { organizationId } = useSelectedOrganizationId();
|
||||
|
||||
return useMutation({
|
||||
mutationFn: async (server: MCPServerConfig): Promise<void> => {
|
||||
@@ -64,7 +66,9 @@ export function useAddMcpServer() {
|
||||
},
|
||||
onSuccess: () => {
|
||||
// Invalidate the settings query to trigger a refetch
|
||||
queryClient.invalidateQueries({ queryKey: ["settings"] });
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: ["settings", organizationId],
|
||||
});
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1,16 +1,20 @@
|
||||
import { useMutation, useQueryClient } from "@tanstack/react-query";
|
||||
import ApiKeysClient, { CreateApiKeyResponse } from "#/api/api-keys";
|
||||
import { API_KEYS_QUERY_KEY } from "#/hooks/query/use-api-keys";
|
||||
import { useSelectedOrganizationId } from "#/context/use-selected-organization";
|
||||
|
||||
export function useCreateApiKey() {
|
||||
const queryClient = useQueryClient();
|
||||
const { organizationId } = useSelectedOrganizationId();
|
||||
|
||||
return useMutation({
|
||||
mutationFn: async (name: string): Promise<CreateApiKeyResponse> =>
|
||||
ApiKeysClient.createApiKey(name),
|
||||
onSuccess: () => {
|
||||
// Invalidate the API keys query to trigger a refetch
|
||||
queryClient.invalidateQueries({ queryKey: [API_KEYS_QUERY_KEY] });
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: [API_KEYS_QUERY_KEY, organizationId],
|
||||
});
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import { useMutation, useQueryClient } from "@tanstack/react-query";
|
||||
import ApiKeysClient from "#/api/api-keys";
|
||||
import { API_KEYS_QUERY_KEY } from "#/hooks/query/use-api-keys";
|
||||
import { useSelectedOrganizationId } from "#/context/use-selected-organization";
|
||||
|
||||
export function useDeleteApiKey() {
|
||||
const queryClient = useQueryClient();
|
||||
const { organizationId } = useSelectedOrganizationId();
|
||||
|
||||
return useMutation({
|
||||
mutationFn: async (id: string): Promise<void> => {
|
||||
@@ -11,7 +13,9 @@ export function useDeleteApiKey() {
|
||||
},
|
||||
onSuccess: () => {
|
||||
// Invalidate the API keys query to trigger a refetch
|
||||
queryClient.invalidateQueries({ queryKey: [API_KEYS_QUERY_KEY] });
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: [API_KEYS_QUERY_KEY, organizationId],
|
||||
});
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
@@ -2,10 +2,12 @@ import { useMutation, useQueryClient } from "@tanstack/react-query";
|
||||
import { useSettings } from "#/hooks/query/use-settings";
|
||||
import SettingsService from "#/api/settings-service/settings-service.api";
|
||||
import { MCPConfig } from "#/types/settings";
|
||||
import { useSelectedOrganizationId } from "#/context/use-selected-organization";
|
||||
|
||||
export function useDeleteMcpServer() {
|
||||
const queryClient = useQueryClient();
|
||||
const { data: settings } = useSettings();
|
||||
const { organizationId } = useSelectedOrganizationId();
|
||||
|
||||
return useMutation({
|
||||
mutationFn: async (serverId: string): Promise<void> => {
|
||||
@@ -32,7 +34,9 @@ export function useDeleteMcpServer() {
|
||||
},
|
||||
onSuccess: () => {
|
||||
// Invalidate the settings query to trigger a refetch
|
||||
queryClient.invalidateQueries({ queryKey: ["settings"] });
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: ["settings", organizationId],
|
||||
});
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
115
frontend/src/hooks/mutation/use-new-conversation-command.ts
Normal file
115
frontend/src/hooks/mutation/use-new-conversation-command.ts
Normal file
@@ -0,0 +1,115 @@
|
||||
import { useMutation, useQueryClient } from "@tanstack/react-query";
|
||||
import { useNavigate } from "react-router";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import toast from "react-hot-toast";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import V1ConversationService from "#/api/conversation-service/v1-conversation-service.api";
|
||||
import {
|
||||
displayErrorToast,
|
||||
displaySuccessToast,
|
||||
TOAST_OPTIONS,
|
||||
} from "#/utils/custom-toast-handlers";
|
||||
import { useActiveConversation } from "#/hooks/query/use-active-conversation";
|
||||
|
||||
export const useNewConversationCommand = () => {
|
||||
const queryClient = useQueryClient();
|
||||
const navigate = useNavigate();
|
||||
const { t } = useTranslation();
|
||||
const { data: conversation } = useActiveConversation();
|
||||
|
||||
const mutation = useMutation({
|
||||
mutationFn: async () => {
|
||||
if (!conversation?.conversation_id || !conversation.sandbox_id) {
|
||||
throw new Error("No active conversation or sandbox");
|
||||
}
|
||||
|
||||
// Fetch V1 conversation data to get llm_model (not available in legacy type)
|
||||
const v1Conversations =
|
||||
await V1ConversationService.batchGetAppConversations([
|
||||
conversation.conversation_id,
|
||||
]);
|
||||
const llmModel = v1Conversations?.[0]?.llm_model;
|
||||
|
||||
// Start a new conversation reusing the existing sandbox directly.
|
||||
// We pass sandbox_id instead of parent_conversation_id so that the
|
||||
// new conversation is NOT marked as a sub-conversation and will
|
||||
// appear in the conversation list.
|
||||
const startTask = await V1ConversationService.createConversation(
|
||||
conversation.selected_repository ?? undefined, // selectedRepository
|
||||
conversation.git_provider ?? undefined, // git_provider
|
||||
undefined, // initialUserMsg
|
||||
conversation.selected_branch ?? undefined, // selected_branch
|
||||
undefined, // conversationInstructions
|
||||
undefined, // suggestedTask
|
||||
undefined, // trigger
|
||||
undefined, // parent_conversation_id
|
||||
undefined, // agent_type
|
||||
conversation.sandbox_id ?? undefined, // sandbox_id - reuse the same sandbox
|
||||
llmModel ?? undefined, // llm_model - preserve the LLM model
|
||||
);
|
||||
|
||||
// Poll for the task to complete and get the new conversation ID
|
||||
let task = await V1ConversationService.getStartTask(startTask.id);
|
||||
const maxAttempts = 60; // 60 seconds timeout
|
||||
let attempts = 0;
|
||||
|
||||
/* eslint-disable no-await-in-loop */
|
||||
while (
|
||||
task &&
|
||||
!["READY", "ERROR"].includes(task.status) &&
|
||||
attempts < maxAttempts
|
||||
) {
|
||||
// eslint-disable-next-line no-await-in-loop
|
||||
await new Promise((resolve) => {
|
||||
setTimeout(resolve, 1000);
|
||||
});
|
||||
task = await V1ConversationService.getStartTask(startTask.id);
|
||||
attempts += 1;
|
||||
}
|
||||
|
||||
if (!task || task.status !== "READY" || !task.app_conversation_id) {
|
||||
throw new Error(
|
||||
task?.detail || "Failed to create new conversation in sandbox",
|
||||
);
|
||||
}
|
||||
|
||||
return {
|
||||
newConversationId: task.app_conversation_id,
|
||||
oldConversationId: conversation.conversation_id,
|
||||
};
|
||||
},
|
||||
onMutate: () => {
|
||||
toast.loading(t(I18nKey.CONVERSATION$CLEARING), {
|
||||
...TOAST_OPTIONS,
|
||||
id: "clear-conversation",
|
||||
});
|
||||
},
|
||||
onSuccess: (data) => {
|
||||
toast.dismiss("clear-conversation");
|
||||
displaySuccessToast(t(I18nKey.CONVERSATION$CLEAR_SUCCESS));
|
||||
navigate(`/conversations/${data.newConversationId}`);
|
||||
|
||||
// Refresh the sidebar to show the new conversation.
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: ["user", "conversations"],
|
||||
});
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: ["v1-batch-get-app-conversations"],
|
||||
});
|
||||
},
|
||||
onError: (error) => {
|
||||
toast.dismiss("clear-conversation");
|
||||
let clearError = t(I18nKey.CONVERSATION$CLEAR_UNKNOWN_ERROR);
|
||||
if (error instanceof Error) {
|
||||
clearError = error.message;
|
||||
} else if (typeof error === "string") {
|
||||
clearError = error;
|
||||
}
|
||||
displayErrorToast(
|
||||
t(I18nKey.CONVERSATION$CLEAR_FAILED, { error: clearError }),
|
||||
);
|
||||
},
|
||||
});
|
||||
|
||||
return mutation;
|
||||
};
|
||||
@@ -4,6 +4,7 @@ import { DEFAULT_SETTINGS } from "#/services/settings";
|
||||
import SettingsService from "#/api/settings-service/settings-service.api";
|
||||
import { Settings } from "#/types/settings";
|
||||
import { useSettings } from "../query/use-settings";
|
||||
import { useSelectedOrganizationId } from "#/context/use-selected-organization";
|
||||
|
||||
const saveSettingsMutationFn = async (settings: Partial<Settings>) => {
|
||||
const settingsToSave: Partial<Settings> = {
|
||||
@@ -30,6 +31,7 @@ export const useSaveSettings = () => {
|
||||
const posthog = usePostHog();
|
||||
const queryClient = useQueryClient();
|
||||
const { data: currentSettings } = useSettings();
|
||||
const { organizationId } = useSelectedOrganizationId();
|
||||
|
||||
return useMutation({
|
||||
mutationFn: async (settings: Partial<Settings>) => {
|
||||
@@ -56,7 +58,9 @@ export const useSaveSettings = () => {
|
||||
await saveSettingsMutationFn(newSettings);
|
||||
},
|
||||
onSuccess: async () => {
|
||||
await queryClient.invalidateQueries({ queryKey: ["settings"] });
|
||||
await queryClient.invalidateQueries({
|
||||
queryKey: ["settings", organizationId],
|
||||
});
|
||||
},
|
||||
meta: {
|
||||
disableToast: true,
|
||||
|
||||
@@ -17,10 +17,9 @@ export const useSwitchOrganization = () => {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: ["organizations", orgId, "me"],
|
||||
});
|
||||
// Update local state
|
||||
// Update local state - this triggers automatic refetch for all org-scoped queries
|
||||
// since their query keys include organizationId (e.g., ["settings", orgId], ["secrets", orgId])
|
||||
setOrganizationId(orgId);
|
||||
// Invalidate settings for the new org context
|
||||
queryClient.invalidateQueries({ queryKey: ["settings"] });
|
||||
// Invalidate conversations to fetch data for the new org context
|
||||
queryClient.invalidateQueries({ queryKey: ["user", "conversations"] });
|
||||
// Remove all individual conversation queries to clear any stale/null data
|
||||
|
||||
@@ -2,6 +2,7 @@ import { useMutation, useQueryClient } from "@tanstack/react-query";
|
||||
import { useSettings } from "#/hooks/query/use-settings";
|
||||
import SettingsService from "#/api/settings-service/settings-service.api";
|
||||
import { MCPSSEServer, MCPStdioServer, MCPSHTTPServer } from "#/types/settings";
|
||||
import { useSelectedOrganizationId } from "#/context/use-selected-organization";
|
||||
|
||||
type MCPServerType = "sse" | "stdio" | "shttp";
|
||||
|
||||
@@ -19,6 +20,7 @@ interface MCPServerConfig {
|
||||
export function useUpdateMcpServer() {
|
||||
const queryClient = useQueryClient();
|
||||
const { data: settings } = useSettings();
|
||||
const { organizationId } = useSelectedOrganizationId();
|
||||
|
||||
return useMutation({
|
||||
mutationFn: async ({
|
||||
@@ -66,7 +68,9 @@ export function useUpdateMcpServer() {
|
||||
},
|
||||
onSuccess: () => {
|
||||
// Invalidate the settings query to trigger a refetch
|
||||
queryClient.invalidateQueries({ queryKey: ["settings"] });
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: ["settings", organizationId],
|
||||
});
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
import { useQuery } from "@tanstack/react-query";
|
||||
import ApiKeysClient from "#/api/api-keys";
|
||||
import { useConfig } from "./use-config";
|
||||
import { useSelectedOrganizationId } from "#/context/use-selected-organization";
|
||||
|
||||
export const API_KEYS_QUERY_KEY = "api-keys";
|
||||
|
||||
export function useApiKeys() {
|
||||
const { data: config } = useConfig();
|
||||
const { organizationId } = useSelectedOrganizationId();
|
||||
|
||||
return useQuery({
|
||||
queryKey: [API_KEYS_QUERY_KEY],
|
||||
enabled: config?.app_mode === "saas",
|
||||
queryKey: [API_KEYS_QUERY_KEY, organizationId],
|
||||
enabled: config?.app_mode === "saas" && !!organizationId,
|
||||
queryFn: async () => {
|
||||
const keys = await ApiKeysClient.getApiKeys();
|
||||
return Array.isArray(keys) ? keys : [];
|
||||
|
||||
@@ -2,16 +2,18 @@ import { useQuery } from "@tanstack/react-query";
|
||||
import { SecretsService } from "#/api/secrets-service";
|
||||
import { useConfig } from "./use-config";
|
||||
import { useIsAuthed } from "#/hooks/query/use-is-authed";
|
||||
import { useSelectedOrganizationId } from "#/context/use-selected-organization";
|
||||
|
||||
export const useGetSecrets = () => {
|
||||
const { data: config } = useConfig();
|
||||
const { data: isAuthed } = useIsAuthed();
|
||||
const { organizationId } = useSelectedOrganizationId();
|
||||
|
||||
const isOss = config?.app_mode === "oss";
|
||||
|
||||
return useQuery({
|
||||
queryKey: ["secrets"],
|
||||
queryKey: ["secrets", organizationId],
|
||||
queryFn: SecretsService.getSecrets,
|
||||
enabled: isOss || isAuthed, // Enable regardless of providers
|
||||
enabled: isOss || (isAuthed && !!organizationId),
|
||||
});
|
||||
};
|
||||
|
||||
@@ -35,13 +35,14 @@ export const useGitUser = () => {
|
||||
}
|
||||
}, [user.data]);
|
||||
|
||||
// If we get a 401 here, it means that the integration tokens need to be
|
||||
// In saas mode, a 401 means that the integration tokens need to be
|
||||
// refreshed. Since this happens at login, we log out.
|
||||
// In oss mode, skip auto-logout since there's no token refresh mechanism
|
||||
React.useEffect(() => {
|
||||
if (user?.error?.response?.status === 401) {
|
||||
if (user?.error?.response?.status === 401 && config?.app_mode === "saas") {
|
||||
logout.mutate();
|
||||
}
|
||||
}, [user.status]);
|
||||
}, [user.status, config?.app_mode]);
|
||||
|
||||
return user;
|
||||
};
|
||||
|
||||
@@ -4,6 +4,8 @@ import { DEFAULT_SETTINGS } from "#/services/settings";
|
||||
import { useIsOnIntermediatePage } from "#/hooks/use-is-on-intermediate-page";
|
||||
import { Settings } from "#/types/settings";
|
||||
import { useIsAuthed } from "./use-is-authed";
|
||||
import { useSelectedOrganizationId } from "#/context/use-selected-organization";
|
||||
import { useConfig } from "./use-config";
|
||||
|
||||
const getSettingsQueryFn = async (): Promise<Settings> => {
|
||||
const settings = await SettingsService.getSettings();
|
||||
@@ -27,9 +29,13 @@ const getSettingsQueryFn = async (): Promise<Settings> => {
|
||||
export const useSettings = () => {
|
||||
const isOnIntermediatePage = useIsOnIntermediatePage();
|
||||
const { data: userIsAuthenticated } = useIsAuthed();
|
||||
const { organizationId } = useSelectedOrganizationId();
|
||||
const { data: config } = useConfig();
|
||||
|
||||
const isOss = config?.app_mode === "oss";
|
||||
|
||||
const query = useQuery({
|
||||
queryKey: ["settings"],
|
||||
queryKey: ["settings", organizationId],
|
||||
queryFn: getSettingsQueryFn,
|
||||
// Only retry if the error is not a 404 because we
|
||||
// would want to show the modal immediately if the
|
||||
@@ -38,7 +44,10 @@ export const useSettings = () => {
|
||||
refetchOnWindowFocus: false,
|
||||
staleTime: 1000 * 60 * 5, // 5 minutes
|
||||
gcTime: 1000 * 60 * 15, // 15 minutes
|
||||
enabled: !isOnIntermediatePage && !!userIsAuthenticated,
|
||||
enabled:
|
||||
!isOnIntermediatePage &&
|
||||
!!userIsAuthenticated &&
|
||||
(isOss || !!organizationId),
|
||||
meta: {
|
||||
disableToast: true,
|
||||
},
|
||||
|
||||
@@ -5,6 +5,7 @@ import V1GitService from "#/api/git-service/v1-git-service.api";
|
||||
import { useConversationId } from "#/hooks/use-conversation-id";
|
||||
import { useActiveConversation } from "#/hooks/query/use-active-conversation";
|
||||
import { useRuntimeIsReady } from "#/hooks/use-runtime-is-ready";
|
||||
import { useSettings } from "#/hooks/query/use-settings";
|
||||
import { getGitPath } from "#/utils/get-git-path";
|
||||
import type { GitChange } from "#/api/open-hands.types";
|
||||
|
||||
@@ -16,6 +17,7 @@ import type { GitChange } from "#/api/open-hands.types";
|
||||
export const useUnifiedGetGitChanges = () => {
|
||||
const { conversationId } = useConversationId();
|
||||
const { data: conversation } = useActiveConversation();
|
||||
const { data: settings } = useSettings();
|
||||
const [orderedChanges, setOrderedChanges] = React.useState<GitChange[]>([]);
|
||||
const previousDataRef = React.useRef<GitChange[] | null>(null);
|
||||
const runtimeIsReady = useRuntimeIsReady();
|
||||
@@ -25,10 +27,15 @@ export const useUnifiedGetGitChanges = () => {
|
||||
const sessionApiKey = conversation?.session_api_key;
|
||||
const selectedRepository = conversation?.selected_repository;
|
||||
|
||||
// Calculate git path based on selected repository
|
||||
// Sandbox grouping is enabled when strategy is not NO_GROUPING
|
||||
const useSandboxGrouping =
|
||||
settings?.sandbox_grouping_strategy !== "NO_GROUPING" &&
|
||||
settings?.sandbox_grouping_strategy !== undefined;
|
||||
|
||||
// Calculate git path based on selected repository and sandbox grouping strategy
|
||||
const gitPath = React.useMemo(
|
||||
() => getGitPath(conversationId, selectedRepository),
|
||||
[selectedRepository],
|
||||
() => getGitPath(conversationId, selectedRepository, useSandboxGrouping),
|
||||
[conversationId, selectedRepository, useSandboxGrouping],
|
||||
);
|
||||
|
||||
const result = useQuery({
|
||||
@@ -57,6 +64,7 @@ export const useUnifiedGetGitChanges = () => {
|
||||
retry: false,
|
||||
staleTime: 1000 * 60 * 5, // 5 minutes
|
||||
gcTime: 1000 * 60 * 15, // 15 minutes
|
||||
refetchOnMount: "always", // Always refetch when mounting (e.g. navigating between conversations that share a sandbox)
|
||||
enabled: runtimeIsReady && !!conversationId,
|
||||
meta: {
|
||||
disableToast: true,
|
||||
@@ -100,6 +108,7 @@ export const useUnifiedGetGitChanges = () => {
|
||||
return {
|
||||
data: orderedChanges,
|
||||
isLoading: result.isLoading,
|
||||
isFetching: result.isFetching,
|
||||
isSuccess: result.isSuccess,
|
||||
isError: result.isError,
|
||||
error: result.error,
|
||||
|
||||
@@ -4,6 +4,7 @@ import GitService from "#/api/git-service/git-service.api";
|
||||
import V1GitService from "#/api/git-service/v1-git-service.api";
|
||||
import { useConversationId } from "#/hooks/use-conversation-id";
|
||||
import { useActiveConversation } from "#/hooks/query/use-active-conversation";
|
||||
import { useSettings } from "#/hooks/query/use-settings";
|
||||
import { getGitPath } from "#/utils/get-git-path";
|
||||
import type { GitChangeStatus } from "#/api/open-hands.types";
|
||||
|
||||
@@ -21,20 +22,36 @@ type UseUnifiedGitDiffConfig = {
|
||||
export const useUnifiedGitDiff = (config: UseUnifiedGitDiffConfig) => {
|
||||
const { conversationId } = useConversationId();
|
||||
const { data: conversation } = useActiveConversation();
|
||||
const { data: settings } = useSettings();
|
||||
|
||||
const isV1Conversation = conversation?.conversation_version === "V1";
|
||||
const conversationUrl = conversation?.url;
|
||||
const sessionApiKey = conversation?.session_api_key;
|
||||
const selectedRepository = conversation?.selected_repository;
|
||||
|
||||
// Sandbox grouping is enabled when strategy is not NO_GROUPING
|
||||
const useSandboxGrouping =
|
||||
settings?.sandbox_grouping_strategy !== "NO_GROUPING" &&
|
||||
settings?.sandbox_grouping_strategy !== undefined;
|
||||
|
||||
// For V1, we need to convert the relative file path to an absolute path
|
||||
// The diff endpoint expects: /workspace/project/RepoName/relative/path
|
||||
const absoluteFilePath = React.useMemo(() => {
|
||||
if (!isV1Conversation) return config.filePath;
|
||||
|
||||
const gitPath = getGitPath(conversationId, selectedRepository);
|
||||
const gitPath = getGitPath(
|
||||
conversationId,
|
||||
selectedRepository,
|
||||
useSandboxGrouping,
|
||||
);
|
||||
return `${gitPath}/${config.filePath}`;
|
||||
}, [isV1Conversation, selectedRepository, config.filePath]);
|
||||
}, [
|
||||
isV1Conversation,
|
||||
conversationId,
|
||||
selectedRepository,
|
||||
useSandboxGrouping,
|
||||
config.filePath,
|
||||
]);
|
||||
|
||||
return useQuery({
|
||||
queryKey: [
|
||||
|
||||
@@ -23,7 +23,7 @@ export const useUnifiedVSCodeUrl = () => {
|
||||
const { t } = useTranslation();
|
||||
const { conversationId } = useConversationId();
|
||||
const { data: conversation } = useActiveConversation();
|
||||
const runtimeIsReady = useRuntimeIsReady();
|
||||
const runtimeIsReady = useRuntimeIsReady({ allowAgentError: true });
|
||||
|
||||
const isV1Conversation = conversation?.conversation_version === "V1";
|
||||
|
||||
|
||||
@@ -1,18 +1,30 @@
|
||||
import { RUNTIME_INACTIVE_STATES } from "#/types/agent-state";
|
||||
import { useActiveConversation } from "./query/use-active-conversation";
|
||||
import { useAgentState } from "#/hooks/use-agent-state";
|
||||
import {
|
||||
RUNTIME_INACTIVE_STATES,
|
||||
RUNTIME_STARTING_STATES,
|
||||
} from "#/types/agent-state";
|
||||
import { useActiveConversation } from "./query/use-active-conversation";
|
||||
|
||||
interface UseRuntimeIsReadyOptions {
|
||||
allowAgentError?: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Hook to determine if the runtime is ready for operations
|
||||
*
|
||||
* @returns boolean indicating if the runtime is ready
|
||||
*/
|
||||
export const useRuntimeIsReady = (): boolean => {
|
||||
export const useRuntimeIsReady = ({
|
||||
allowAgentError = false,
|
||||
}: UseRuntimeIsReadyOptions = {}): boolean => {
|
||||
const { data: conversation } = useActiveConversation();
|
||||
const { curAgentState } = useAgentState();
|
||||
const inactiveStates = allowAgentError
|
||||
? RUNTIME_STARTING_STATES
|
||||
: RUNTIME_INACTIVE_STATES;
|
||||
|
||||
return (
|
||||
conversation?.status === "RUNNING" &&
|
||||
!RUNTIME_INACTIVE_STATES.includes(curAgentState)
|
||||
!inactiveStates.includes(curAgentState)
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1084,6 +1084,14 @@ export enum I18nKey {
|
||||
CONVERSATION$NO_HISTORY_AVAILABLE = "CONVERSATION$NO_HISTORY_AVAILABLE",
|
||||
CONVERSATION$SHARED_CONVERSATION = "CONVERSATION$SHARED_CONVERSATION",
|
||||
CONVERSATION$LINK_COPIED = "CONVERSATION$LINK_COPIED",
|
||||
ONBOARDING$STEP1_TITLE = "ONBOARDING$STEP1_TITLE",
|
||||
ONBOARDING$STEP1_SUBTITLE = "ONBOARDING$STEP1_SUBTITLE",
|
||||
ONBOARDING$SOFTWARE_ENGINEER = "ONBOARDING$SOFTWARE_ENGINEER",
|
||||
ONBOARDING$ENGINEERING_MANAGER = "ONBOARDING$ENGINEERING_MANAGER",
|
||||
ONBOARDING$CTO_FOUNDER = "ONBOARDING$CTO_FOUNDER",
|
||||
ONBOARDING$PRODUCT_OPERATIONS = "ONBOARDING$PRODUCT_OPERATIONS",
|
||||
ONBOARDING$STUDENT_HOBBYIST = "ONBOARDING$STUDENT_HOBBYIST",
|
||||
ONBOARDING$OTHER = "ONBOARDING$OTHER",
|
||||
HOOKS_MODAL$TITLE = "HOOKS_MODAL$TITLE",
|
||||
HOOKS_MODAL$WARNING = "HOOKS_MODAL$WARNING",
|
||||
HOOKS_MODAL$MATCHER = "HOOKS_MODAL$MATCHER",
|
||||
@@ -1143,6 +1151,14 @@ export enum I18nKey {
|
||||
ONBOARDING$NEXT_BUTTON = "ONBOARDING$NEXT_BUTTON",
|
||||
ONBOARDING$BACK_BUTTON = "ONBOARDING$BACK_BUTTON",
|
||||
ONBOARDING$FINISH_BUTTON = "ONBOARDING$FINISH_BUTTON",
|
||||
CONVERSATION$CLEAR_V1_ONLY = "CONVERSATION$CLEAR_V1_ONLY",
|
||||
CONVERSATION$CLEAR_EMPTY = "CONVERSATION$CLEAR_EMPTY",
|
||||
CONVERSATION$CLEAR_NO_ID = "CONVERSATION$CLEAR_NO_ID",
|
||||
CONVERSATION$CLEAR_NO_NEW_ID = "CONVERSATION$CLEAR_NO_NEW_ID",
|
||||
CONVERSATION$CLEAR_UNKNOWN_ERROR = "CONVERSATION$CLEAR_UNKNOWN_ERROR",
|
||||
CONVERSATION$CLEAR_FAILED = "CONVERSATION$CLEAR_FAILED",
|
||||
CONVERSATION$CLEAR_SUCCESS = "CONVERSATION$CLEAR_SUCCESS",
|
||||
CONVERSATION$CLEARING = "CONVERSATION$CLEARING",
|
||||
CTA$ENTERPRISE = "CTA$ENTERPRISE",
|
||||
CTA$ENTERPRISE_DEPLOY = "CTA$ENTERPRISE_DEPLOY",
|
||||
CTA$FEATURE_ON_PREMISES = "CTA$FEATURE_ON_PREMISES",
|
||||
|
||||
@@ -19433,6 +19433,142 @@
|
||||
"uk": "Завершити",
|
||||
"ca": "Finalitza"
|
||||
},
|
||||
"CONVERSATION$CLEAR_V1_ONLY": {
|
||||
"en": "The /new command is only available for V1 conversations",
|
||||
"ja": "/newコマンドはV1会話でのみ使用できます",
|
||||
"zh-CN": "/new 命令仅适用于 V1 对话",
|
||||
"zh-TW": "/new 指令僅適用於 V1 對話",
|
||||
"ko-KR": "/new 명령은 V1 대화에서만 사용할 수 있습니다",
|
||||
"no": "/new-kommandoen er kun tilgjengelig for V1-samtaler",
|
||||
"it": "Il comando /new è disponibile solo per le conversazioni V1",
|
||||
"pt": "O comando /new está disponível apenas para conversas V1",
|
||||
"es": "El comando /new solo está disponible para conversaciones V1",
|
||||
"ar": "أمر /new متاح فقط لمحادثات V1",
|
||||
"fr": "La commande /new n'est disponible que pour les conversations V1",
|
||||
"tr": "/new komutu yalnızca V1 konuşmalarında kullanılabilir",
|
||||
"de": "Der /new-Befehl ist nur für V1-Konversationen verfügbar",
|
||||
"uk": "Команда /new доступна лише для розмов V1",
|
||||
"ca": "L'ordre /new només està disponible per a converses V1"
|
||||
},
|
||||
"CONVERSATION$CLEAR_EMPTY": {
|
||||
"en": "Nothing to clear. This conversation has no messages yet.",
|
||||
"ja": "クリアするものがありません。この会話にはまだメッセージがありません。",
|
||||
"zh-CN": "没有可清除的内容。此对话尚无消息。",
|
||||
"zh-TW": "沒有可清除的內容。此對話尚無訊息。",
|
||||
"ko-KR": "지울 내용이 없습니다. 이 대화에는 아직 메시지가 없습니다.",
|
||||
"no": "Ingenting å tømme. Denne samtalen har ingen meldinger ennå.",
|
||||
"it": "Niente da cancellare. Questa conversazione non ha ancora messaggi.",
|
||||
"pt": "Nada para limpar. Esta conversa ainda não tem mensagens.",
|
||||
"es": "Nada que borrar. Esta conversación aún no tiene mensajes.",
|
||||
"ar": "لا يوجد شيء لمسحه. لا تحتوي هذه المحادثة على رسائل بعد.",
|
||||
"fr": "Rien à effacer. Cette conversation n'a pas encore de messages.",
|
||||
"tr": "Temizlenecek bir şey yok. Bu konuşmada henüz mesaj yok.",
|
||||
"de": "Nichts zu löschen. Diese Konversation hat noch keine Nachrichten.",
|
||||
"uk": "Нічого очищувати. Ця розмова ще не має повідомлень.",
|
||||
"ca": "No hi ha res a esborrar. Aquesta conversa encara no té missatges."
|
||||
},
|
||||
"CONVERSATION$CLEAR_NO_ID": {
|
||||
"en": "No conversation ID found",
|
||||
"ja": "会話IDが見つかりません",
|
||||
"zh-CN": "未找到对话 ID",
|
||||
"zh-TW": "找不到對話 ID",
|
||||
"ko-KR": "대화 ID를 찾을 수 없습니다",
|
||||
"no": "Ingen samtale-ID funnet",
|
||||
"it": "Nessun ID conversazione trovato",
|
||||
"pt": "Nenhum ID de conversa encontrado",
|
||||
"es": "No se encontró el ID de conversación",
|
||||
"ar": "لم يتم العثور على معرف المحادثة",
|
||||
"fr": "Aucun identifiant de conversation trouvé",
|
||||
"tr": "Konuşma kimliği bulunamadı",
|
||||
"de": "Keine Konversations-ID gefunden",
|
||||
"uk": "Ідентифікатор розмови не знайдено",
|
||||
"ca": "No s'ha trobat l'identificador de la conversa"
|
||||
},
|
||||
"CONVERSATION$CLEAR_NO_NEW_ID": {
|
||||
"en": "Server did not return a new conversation ID",
|
||||
"ja": "サーバーが新しい会話IDを返しませんでした",
|
||||
"zh-CN": "服务器未返回新的对话 ID",
|
||||
"zh-TW": "伺服器未返回新的對話 ID",
|
||||
"ko-KR": "서버가 새 대화 ID를 반환하지 않았습니다",
|
||||
"no": "Serveren returnerte ikke en ny samtale-ID",
|
||||
"it": "Il server non ha restituito un nuovo ID conversazione",
|
||||
"pt": "O servidor não retornou um novo ID de conversa",
|
||||
"es": "El servidor no devolvió un nuevo ID de conversación",
|
||||
"ar": "لم يقم الخادم بإرجاع معرف محادثة جديد",
|
||||
"fr": "Le serveur n'a pas renvoyé un nouvel identifiant de conversation",
|
||||
"tr": "Sunucu yeni bir konuşma kimliği döndürmedi",
|
||||
"de": "Der Server hat keine neue Konversations-ID zurückgegeben",
|
||||
"uk": "Сервер не повернув новий ідентифікатор розмови",
|
||||
"ca": "El servidor no ha retornat un nou identificador de conversa"
|
||||
},
|
||||
"CONVERSATION$CLEAR_UNKNOWN_ERROR": {
|
||||
"en": "Unknown error",
|
||||
"ja": "不明なエラー",
|
||||
"zh-CN": "未知错误",
|
||||
"zh-TW": "未知錯誤",
|
||||
"ko-KR": "알 수 없는 오류",
|
||||
"no": "Ukjent feil",
|
||||
"it": "Errore sconosciuto",
|
||||
"pt": "Erro desconhecido",
|
||||
"es": "Error desconocido",
|
||||
"ar": "خطأ غير معروف",
|
||||
"fr": "Erreur inconnue",
|
||||
"tr": "Bilinmeyen hata",
|
||||
"de": "Unbekannter Fehler",
|
||||
"uk": "Невідома помилка",
|
||||
"ca": "Error desconegut"
|
||||
},
|
||||
"CONVERSATION$CLEAR_FAILED": {
|
||||
"en": "Failed to start new conversation: {{error}}",
|
||||
"ja": "新しい会話の開始に失敗しました: {{error}}",
|
||||
"zh-CN": "启动新对话失败: {{error}}",
|
||||
"zh-TW": "啟動新對話失敗: {{error}}",
|
||||
"ko-KR": "새 대화 시작 실패: {{error}}",
|
||||
"no": "Kunne ikke starte ny samtale: {{error}}",
|
||||
"it": "Impossibile avviare una nuova conversazione: {{error}}",
|
||||
"pt": "Falha ao iniciar nova conversa: {{error}}",
|
||||
"es": "Error al iniciar nueva conversación: {{error}}",
|
||||
"ar": "فشل في بدء محادثة جديدة: {{error}}",
|
||||
"fr": "Échec du démarrage d'une nouvelle conversation : {{error}}",
|
||||
"tr": "Yeni konuşma başlatılamadı: {{error}}",
|
||||
"de": "Neue Konversation konnte nicht gestartet werden: {{error}}",
|
||||
"uk": "Не вдалося розпочати нову розмову: {{error}}",
|
||||
"ca": "No s'ha pogut iniciar una nova conversa: {{error}}"
|
||||
},
|
||||
"CONVERSATION$CLEAR_SUCCESS": {
|
||||
"en": "Starting a new conversation in the same sandbox. These conversations share the same runtime.",
|
||||
"ja": "同じサンドボックスで新しい会話を開始します。これらの会話は同じランタイムを共有します。",
|
||||
"zh-CN": "正在同一沙箱中开始新对话。这些对话共享同一运行时。",
|
||||
"zh-TW": "正在同一沙盒中開始新對話。這些對話共享同一執行環境。",
|
||||
"ko-KR": "같은 샌드박스에서 새 대화를 시작합니다. 이 대화들은 같은 런타임을 공유합니다.",
|
||||
"no": "Starter ny samtale i samme sandbox. Disse samtalene deler samme kjøretid.",
|
||||
"it": "Avvio nuova conversazione nello stesso sandbox. Queste conversazioni condividono lo stesso runtime.",
|
||||
"pt": "Iniciando nova conversa no mesmo sandbox. Essas conversas compartilham o mesmo runtime.",
|
||||
"es": "Iniciando nueva conversación en el mismo sandbox. Estas conversaciones comparten el mismo runtime.",
|
||||
"ar": "بدء محادثة جديدة في نفس صندوق الحماية. هذه المحادثات تشارك نفس وقت التشغيل.",
|
||||
"fr": "Démarrage d'une nouvelle conversation dans le même bac à sable. Ces conversations partagent le même environnement d'exécution.",
|
||||
"tr": "Aynı korumalı alanda yeni konuşma başlatılıyor. Bu konuşmalar aynı çalışma ortamını paylaşır.",
|
||||
"de": "Starte neue Konversation in derselben Sandbox. Diese Konversationen teilen dieselbe Laufzeitumgebung.",
|
||||
"uk": "Починаю нову розмову в тому самому захищеному середовищі. Ці розмови використовують одне середовище виконання.",
|
||||
"ca": "S'està iniciant una nova conversa al mateix entorn aïllat. Aquestes converses comparteixen el mateix entorn d'execució."
|
||||
},
|
||||
"CONVERSATION$CLEARING": {
|
||||
"en": "Creating new conversation...",
|
||||
"ja": "新しい会話を作成中...",
|
||||
"zh-CN": "正在创建新对话...",
|
||||
"zh-TW": "正在建立新對話...",
|
||||
"ko-KR": "새 대화를 만드는 중...",
|
||||
"no": "Oppretter ny samtale...",
|
||||
"it": "Creazione nuova conversazione...",
|
||||
"pt": "Criando nova conversa...",
|
||||
"es": "Creando nueva conversación...",
|
||||
"ar": "جارٍ إنشاء محادثة جديدة...",
|
||||
"fr": "Création d'une nouvelle conversation...",
|
||||
"tr": "Yeni konuşma oluşturuluyor...",
|
||||
"de": "Neue Konversation wird erstellt...",
|
||||
"uk": "Створення нової розмови...",
|
||||
"ca": "S'està creant una nova conversa..."
|
||||
},
|
||||
"CTA$ENTERPRISE": {
|
||||
"en": "Enterprise",
|
||||
"ja": "エンタープライズ",
|
||||
|
||||
@@ -1,14 +1,9 @@
|
||||
import React from "react";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { useCreateStripeCheckoutSession } from "#/hooks/mutation/stripe/use-create-stripe-checkout-session";
|
||||
import { useOrganization } from "#/hooks/query/use-organization";
|
||||
import { ModalBackdrop } from "#/components/shared/modals/modal-backdrop";
|
||||
import { ModalButtonGroup } from "#/components/shared/modals/modal-button-group";
|
||||
import { SettingsInput } from "#/components/features/settings/settings-input";
|
||||
import { useMe } from "#/hooks/query/use-me";
|
||||
import { useConfig } from "#/hooks/query/use-config";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { amountIsValid } from "#/utils/amount-is-valid";
|
||||
import { CreditsChip } from "#/ui/credits-chip";
|
||||
import { InteractiveChip } from "#/ui/interactive-chip";
|
||||
import { usePermission } from "#/hooks/organizations/use-permissions";
|
||||
@@ -16,104 +11,10 @@ import { createPermissionGuard } from "#/utils/org/permission-guard";
|
||||
import { isBillingHidden } from "#/utils/org/billing-visibility";
|
||||
import { DeleteOrgConfirmationModal } from "#/components/features/org/delete-org-confirmation-modal";
|
||||
import { ChangeOrgNameModal } from "#/components/features/org/change-org-name-modal";
|
||||
import { AddCreditsModal } from "#/components/features/org/add-credits-modal";
|
||||
import { useBalance } from "#/hooks/query/use-balance";
|
||||
import { cn } from "#/utils/utils";
|
||||
|
||||
interface AddCreditsModalProps {
|
||||
onClose: () => void;
|
||||
}
|
||||
|
||||
function AddCreditsModal({ onClose }: AddCreditsModalProps) {
|
||||
const { t } = useTranslation();
|
||||
const { mutate: addBalance } = useCreateStripeCheckoutSession();
|
||||
|
||||
const [inputValue, setInputValue] = React.useState("");
|
||||
const [errorMessage, setErrorMessage] = React.useState<string | null>(null);
|
||||
|
||||
const getErrorMessage = (value: string): string | null => {
|
||||
if (!value.trim()) return null;
|
||||
|
||||
const numValue = parseInt(value, 10);
|
||||
if (Number.isNaN(numValue)) {
|
||||
return t(I18nKey.PAYMENT$ERROR_INVALID_NUMBER);
|
||||
}
|
||||
if (numValue < 0) {
|
||||
return t(I18nKey.PAYMENT$ERROR_NEGATIVE_AMOUNT);
|
||||
}
|
||||
if (numValue < 10) {
|
||||
return t(I18nKey.PAYMENT$ERROR_MINIMUM_AMOUNT);
|
||||
}
|
||||
if (numValue > 25000) {
|
||||
return t(I18nKey.PAYMENT$ERROR_MAXIMUM_AMOUNT);
|
||||
}
|
||||
if (numValue !== parseFloat(value)) {
|
||||
return t(I18nKey.PAYMENT$ERROR_MUST_BE_WHOLE_NUMBER);
|
||||
}
|
||||
return null;
|
||||
};
|
||||
|
||||
const formAction = (formData: FormData) => {
|
||||
const amount = formData.get("amount")?.toString();
|
||||
|
||||
if (amount?.trim()) {
|
||||
if (!amountIsValid(amount)) {
|
||||
const error = getErrorMessage(amount);
|
||||
setErrorMessage(error || "Invalid amount");
|
||||
return;
|
||||
}
|
||||
|
||||
const intValue = parseInt(amount, 10);
|
||||
|
||||
addBalance({ amount: intValue }, { onSuccess: onClose });
|
||||
|
||||
setErrorMessage(null);
|
||||
}
|
||||
};
|
||||
|
||||
const handleAmountInputChange = (value: string) => {
|
||||
setInputValue(value);
|
||||
setErrorMessage(null);
|
||||
};
|
||||
|
||||
return (
|
||||
<ModalBackdrop onClose={onClose}>
|
||||
<form
|
||||
data-testid="add-credits-form"
|
||||
action={formAction}
|
||||
noValidate
|
||||
className="w-sm rounded-xl bg-base-secondary flex flex-col p-6 gap-4 border border-tertiary"
|
||||
>
|
||||
<h3 className="text-xl font-bold">{t(I18nKey.ORG$ADD_CREDITS)}</h3>
|
||||
<div className="flex flex-col gap-2">
|
||||
<SettingsInput
|
||||
testId="amount-input"
|
||||
name="amount"
|
||||
label={t(I18nKey.PAYMENT$SPECIFY_AMOUNT_USD)}
|
||||
type="number"
|
||||
min={10}
|
||||
max={25000}
|
||||
step={1}
|
||||
value={inputValue}
|
||||
onChange={(value) => handleAmountInputChange(value)}
|
||||
className="w-full"
|
||||
/>
|
||||
{errorMessage && (
|
||||
<p className="text-red-500 text-sm mt-1" data-testid="amount-error">
|
||||
{errorMessage}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<ModalButtonGroup
|
||||
primaryText={t(I18nKey.ORG$NEXT)}
|
||||
onSecondaryClick={onClose}
|
||||
primaryType="submit"
|
||||
/>
|
||||
</form>
|
||||
</ModalBackdrop>
|
||||
);
|
||||
}
|
||||
|
||||
export const clientLoader = createPermissionGuard("view_billing");
|
||||
|
||||
function ManageOrg() {
|
||||
|
||||
@@ -13,12 +13,14 @@ import { ConfirmationModal } from "#/components/shared/modals/confirmation-modal
|
||||
import { GetSecretsResponse } from "#/api/secrets-service.types";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { createPermissionGuard } from "#/utils/org/permission-guard";
|
||||
import { useSelectedOrganizationId } from "#/context/use-selected-organization";
|
||||
|
||||
export const clientLoader = createPermissionGuard("manage_secrets");
|
||||
|
||||
function SecretsSettingsScreen() {
|
||||
const queryClient = useQueryClient();
|
||||
const { t } = useTranslation();
|
||||
const { organizationId } = useSelectedOrganizationId();
|
||||
|
||||
const { data: secrets, isLoading: isLoadingSecrets } = useGetSecrets();
|
||||
const { mutate: deleteSecret } = useDeleteSecret();
|
||||
@@ -34,7 +36,7 @@ function SecretsSettingsScreen() {
|
||||
|
||||
const deleteSecretOptimistically = (secret: string) => {
|
||||
queryClient.setQueryData<GetSecretsResponse["custom_secrets"]>(
|
||||
["secrets"],
|
||||
["secrets", organizationId],
|
||||
(oldSecrets) => {
|
||||
if (!oldSecrets) return [];
|
||||
return oldSecrets.filter((s) => s.name !== secret);
|
||||
@@ -43,7 +45,7 @@ function SecretsSettingsScreen() {
|
||||
};
|
||||
|
||||
const revertOptimisticUpdate = () => {
|
||||
queryClient.invalidateQueries({ queryKey: ["secrets"] });
|
||||
queryClient.invalidateQueries({ queryKey: ["secrets", organizationId] });
|
||||
};
|
||||
|
||||
const handleDeleteSecret = (secret: string) => {
|
||||
|
||||
@@ -30,7 +30,6 @@ const SAAS_ONLY_PATHS = [
|
||||
export const clientLoader = async ({ request }: Route.ClientLoaderArgs) => {
|
||||
const url = new URL(request.url);
|
||||
const { pathname } = url;
|
||||
console.log("clientLoader", { pathname });
|
||||
|
||||
// Step 1: Get config first (needed for all checks, no user data required)
|
||||
let config = queryClient.getQueryData<WebClientConfig>(["web-client-config"]);
|
||||
@@ -51,7 +50,6 @@ export const clientLoader = async ({ request }: Route.ClientLoaderArgs) => {
|
||||
// This handles hide_llm_settings, hide_users_page, hide_billing_page, hide_integrations_page
|
||||
if (isSettingsPageHidden(pathname, featureFlags)) {
|
||||
const fallbackPath = getFirstAvailablePath(isSaas, featureFlags);
|
||||
console.log("fallbackPath", fallbackPath);
|
||||
if (fallbackPath && fallbackPath !== pathname) {
|
||||
return redirect(fallbackPath);
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import { useSettings } from "#/hooks/query/use-settings";
|
||||
import { openHands } from "#/api/open-hands-axios";
|
||||
import { displaySuccessToast } from "#/utils/custom-toast-handlers";
|
||||
import { useEmailVerification } from "#/hooks/use-email-verification";
|
||||
import { useSelectedOrganizationId } from "#/context/use-selected-organization";
|
||||
|
||||
// Email validation regex pattern
|
||||
const EMAIL_REGEX = /^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$/;
|
||||
@@ -113,6 +114,7 @@ function VerificationAlert() {
|
||||
function UserSettingsScreen() {
|
||||
const { t } = useTranslation();
|
||||
const { data: settings, isLoading, refetch } = useSettings();
|
||||
const { organizationId } = useSelectedOrganizationId();
|
||||
const [email, setEmail] = useState("");
|
||||
const [originalEmail, setOriginalEmail] = useState("");
|
||||
const [isSaving, setIsSaving] = useState(false);
|
||||
@@ -144,7 +146,9 @@ function UserSettingsScreen() {
|
||||
// Display toast notification instead of setting state
|
||||
displaySuccessToast(t("SETTINGS$EMAIL_VERIFIED_SUCCESSFULLY"));
|
||||
setTimeout(() => {
|
||||
queryClient.invalidateQueries({ queryKey: ["settings"] });
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: ["settings", organizationId],
|
||||
});
|
||||
}, 2000);
|
||||
}
|
||||
|
||||
@@ -162,7 +166,7 @@ function UserSettingsScreen() {
|
||||
pollingIntervalRef.current = null;
|
||||
}
|
||||
};
|
||||
}, [settings?.email_verified, refetch, queryClient, t]);
|
||||
}, [settings?.email_verified, refetch, queryClient, t, organizationId]);
|
||||
|
||||
const handleEmailChange = (e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
const newEmail = e.target.value;
|
||||
@@ -178,7 +182,9 @@ function UserSettingsScreen() {
|
||||
setOriginalEmail(email);
|
||||
// Display toast notification instead of setting state
|
||||
displaySuccessToast(t("SETTINGS$EMAIL_SAVED_SUCCESSFULLY"));
|
||||
queryClient.invalidateQueries({ queryKey: ["settings"] });
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: ["settings", organizationId],
|
||||
});
|
||||
} catch (error) {
|
||||
// eslint-disable-next-line no-console
|
||||
console.error(t("SETTINGS$FAILED_TO_SAVE_EMAIL"), error);
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
import React, { useState, useEffect } from "react";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { RUNTIME_INACTIVE_STATES } from "#/types/agent-state";
|
||||
import { useUnifiedVSCodeUrl } from "#/hooks/query/use-unified-vscode-url";
|
||||
import { useAgentState } from "#/hooks/use-agent-state";
|
||||
import { RUNTIME_STARTING_STATES } from "#/types/agent-state";
|
||||
import { VSCODE_IN_NEW_TAB } from "#/utils/feature-flags";
|
||||
import { WaitingForRuntimeMessage } from "#/components/features/chat/waiting-for-runtime-message";
|
||||
import { useAgentState } from "#/hooks/use-agent-state";
|
||||
|
||||
function VSCodeTab() {
|
||||
const { t } = useTranslation();
|
||||
const { data, isLoading, error } = useUnifiedVSCodeUrl();
|
||||
const { curAgentState } = useAgentState();
|
||||
const isRuntimeInactive = RUNTIME_INACTIVE_STATES.includes(curAgentState);
|
||||
const isRuntimeStarting = RUNTIME_STARTING_STATES.includes(curAgentState);
|
||||
const iframeRef = React.useRef<HTMLIFrameElement>(null);
|
||||
const [isCrossProtocol, setIsCrossProtocol] = useState(false);
|
||||
const [iframeError, setIframeError] = useState<string | null>(null);
|
||||
@@ -39,7 +39,7 @@ function VSCodeTab() {
|
||||
}
|
||||
};
|
||||
|
||||
if (isRuntimeInactive) {
|
||||
if (isRuntimeStarting) {
|
||||
return <WaitingForRuntimeMessage />;
|
||||
}
|
||||
|
||||
|
||||
@@ -14,9 +14,10 @@ export enum AgentState {
|
||||
USER_REJECTED = "user_rejected",
|
||||
}
|
||||
|
||||
export const RUNTIME_STARTING_STATES = [AgentState.INIT, AgentState.LOADING];
|
||||
|
||||
export const RUNTIME_INACTIVE_STATES = [
|
||||
AgentState.INIT,
|
||||
AgentState.LOADING,
|
||||
...RUNTIME_STARTING_STATES,
|
||||
// Removed AgentState.STOPPED to allow tabs to remain visible when agent is stopped
|
||||
AgentState.ERROR,
|
||||
];
|
||||
|
||||
@@ -1,17 +1,29 @@
|
||||
/**
|
||||
* Get the git repository path for a conversation
|
||||
* If a repository is selected, returns /workspace/project/{repo-name}
|
||||
* Otherwise, returns /workspace/project
|
||||
*
|
||||
* When sandbox grouping is enabled (strategy != NO_GROUPING), each conversation
|
||||
* gets its own subdirectory: /workspace/project/{conversationId}[/{repoName}]
|
||||
*
|
||||
* When sandbox grouping is disabled (NO_GROUPING), the path is simply:
|
||||
* /workspace/project[/{repoName}]
|
||||
*
|
||||
* @param conversationId The conversation ID
|
||||
* @param selectedRepository The selected repository (e.g., "OpenHands/OpenHands", "owner/repo", or "group/subgroup/repo")
|
||||
* @param useSandboxGrouping Whether sandbox grouping is enabled (strategy != NO_GROUPING)
|
||||
* @returns The git path to use
|
||||
*/
|
||||
export function getGitPath(
|
||||
conversationId: string,
|
||||
selectedRepository: string | null | undefined,
|
||||
useSandboxGrouping: boolean = false,
|
||||
): string {
|
||||
// Base path depends on sandbox grouping strategy
|
||||
const basePath = useSandboxGrouping
|
||||
? `/workspace/project/${conversationId}`
|
||||
: "/workspace/project";
|
||||
|
||||
if (!selectedRepository) {
|
||||
return `/workspace/project/${conversationId}`;
|
||||
return basePath;
|
||||
}
|
||||
|
||||
// Extract the repository name from the path
|
||||
@@ -19,5 +31,5 @@ export function getGitPath(
|
||||
const parts = selectedRepository.split("/");
|
||||
const repoName = parts[parts.length - 1];
|
||||
|
||||
return `/workspace/project/${conversationId}/${repoName}`;
|
||||
return `${basePath}/${repoName}`;
|
||||
}
|
||||
|
||||
@@ -838,7 +838,7 @@ interface GetStatusTextArgs {
|
||||
* isStartingStatus: false,
|
||||
* isStopStatus: false,
|
||||
* curAgentState: AgentState.RUNNING
|
||||
* }) // Returns "Waiting For Sandbox"
|
||||
* }) // Returns "Waiting for sandbox"
|
||||
*/
|
||||
export function getStatusText({
|
||||
isPausing = false,
|
||||
@@ -866,13 +866,13 @@ export function getStatusText({
|
||||
return t(I18nKey.CONVERSATION$READY);
|
||||
}
|
||||
|
||||
// Format status text: "WAITING_FOR_SANDBOX" -> "Waiting for sandbox"
|
||||
// Format status text with sentence case: "WAITING_FOR_SANDBOX" -> "Waiting for sandbox"
|
||||
return (
|
||||
taskDetail ||
|
||||
taskStatus
|
||||
.toLowerCase()
|
||||
.replace(/_/g, " ")
|
||||
.replace(/\b\w/g, (c) => c.toUpperCase())
|
||||
.replace(/^\w/, (c) => c.toUpperCase())
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -9,6 +9,19 @@ export function extractBaseHost(
|
||||
if (conversationUrl && !conversationUrl.startsWith("/")) {
|
||||
try {
|
||||
const url = new URL(conversationUrl);
|
||||
// If the conversation URL points to localhost but we're accessing from external,
|
||||
// use the browser's hostname with the conversation URL's port
|
||||
const urlHostname = url.hostname;
|
||||
const browserHostname =
|
||||
window.location.hostname ?? window.location.host?.split(":")[0];
|
||||
if (
|
||||
browserHostname &&
|
||||
(urlHostname === "localhost" || urlHostname === "127.0.0.1") &&
|
||||
browserHostname !== "localhost" &&
|
||||
browserHostname !== "127.0.0.1"
|
||||
) {
|
||||
return `${browserHostname}:${url.port}`;
|
||||
}
|
||||
return url.host; // e.g., "localhost:3000"
|
||||
} catch {
|
||||
return window.location.host;
|
||||
|
||||
@@ -36,6 +36,15 @@ vi.mock("#/hooks/use-is-on-intermediate-page", () => ({
|
||||
useIsOnIntermediatePage: () => false,
|
||||
}));
|
||||
|
||||
// Mock useRevalidator from react-router to allow direct store manipulation
|
||||
// in tests instead of mocking useSelectedOrganizationId hook
|
||||
vi.mock("react-router", async (importOriginal) => ({
|
||||
...(await importOriginal<typeof import("react-router")>()),
|
||||
useRevalidator: () => ({
|
||||
revalidate: vi.fn(),
|
||||
}),
|
||||
}));
|
||||
|
||||
// Import the Zustand mock to enable automatic store resets
|
||||
vi.mock("zustand");
|
||||
|
||||
|
||||
@@ -84,6 +84,14 @@ class AppConversationInfoService(ABC):
|
||||
List of sub-conversation IDs
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def count_conversations_by_sandbox_id(self, sandbox_id: str) -> int:
|
||||
"""Count V1 conversations that reference the given sandbox.
|
||||
|
||||
Used to decide whether a sandbox can be safely deleted when a
|
||||
conversation is removed (only delete if count is 0).
|
||||
"""
|
||||
|
||||
# Mutators
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -115,7 +115,7 @@ async def _get_agent_server_context(
|
||||
app_conversation_service: AppConversationService,
|
||||
sandbox_service: SandboxService,
|
||||
sandbox_spec_service: SandboxSpecService,
|
||||
) -> AgentServerContext | JSONResponse:
|
||||
) -> AgentServerContext | JSONResponse | None:
|
||||
"""Get the agent server context for a conversation.
|
||||
|
||||
This helper retrieves all necessary information to communicate with the
|
||||
@@ -129,7 +129,8 @@ async def _get_agent_server_context(
|
||||
sandbox_spec_service: Service for sandbox spec operations
|
||||
|
||||
Returns:
|
||||
AgentServerContext if successful, or JSONResponse with error details.
|
||||
AgentServerContext if successful, JSONResponse(404) if conversation
|
||||
not found, or None if sandbox is not running (e.g. closed conversation).
|
||||
"""
|
||||
# Get the conversation info
|
||||
conversation = await app_conversation_service.get_app_conversation(conversation_id)
|
||||
@@ -141,12 +142,19 @@ async def _get_agent_server_context(
|
||||
|
||||
# Get the sandbox info
|
||||
sandbox = await sandbox_service.get_sandbox(conversation.sandbox_id)
|
||||
if not sandbox or sandbox.status != SandboxStatus.RUNNING:
|
||||
if not sandbox:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
content={
|
||||
'error': f'Sandbox not found or not running for conversation {conversation_id}'
|
||||
},
|
||||
content={'error': f'Sandbox not found for conversation {conversation_id}'},
|
||||
)
|
||||
# Return None for paused sandboxes (closed conversation)
|
||||
if sandbox.status == SandboxStatus.PAUSED:
|
||||
return None
|
||||
# Return 404 for other non-running states (STARTING, ERROR, MISSING)
|
||||
if sandbox.status != SandboxStatus.RUNNING:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
content={'error': f'Sandbox not ready for conversation {conversation_id}'},
|
||||
)
|
||||
|
||||
# Get the sandbox spec to find the working directory
|
||||
@@ -226,7 +234,7 @@ async def search_app_conversations(
|
||||
Query(
|
||||
title='The max number of results in the page',
|
||||
gt=0,
|
||||
lte=100,
|
||||
le=100,
|
||||
),
|
||||
] = 100,
|
||||
include_sub_conversations: Annotated[
|
||||
@@ -240,8 +248,6 @@ async def search_app_conversations(
|
||||
),
|
||||
) -> AppConversationPage:
|
||||
"""Search / List sandboxed conversations."""
|
||||
assert limit > 0
|
||||
assert limit <= 100
|
||||
return await app_conversation_service.search_app_conversations(
|
||||
title__contains=title__contains,
|
||||
created_at__gte=created_at__gte,
|
||||
@@ -414,7 +420,7 @@ async def search_app_conversation_start_tasks(
|
||||
Query(
|
||||
title='The max number of results in the page',
|
||||
gt=0,
|
||||
lte=100,
|
||||
le=100,
|
||||
),
|
||||
] = 100,
|
||||
app_conversation_start_task_service: AppConversationStartTaskService = (
|
||||
@@ -422,8 +428,6 @@ async def search_app_conversation_start_tasks(
|
||||
),
|
||||
) -> AppConversationStartTaskPage:
|
||||
"""Search / List conversation start tasks."""
|
||||
assert limit > 0
|
||||
assert limit <= 100
|
||||
return (
|
||||
await app_conversation_start_task_service.search_app_conversation_start_tasks(
|
||||
conversation_id__eq=conversation_id__eq,
|
||||
@@ -464,7 +468,11 @@ async def batch_get_app_conversation_start_tasks(
|
||||
),
|
||||
) -> list[AppConversationStartTask | None]:
|
||||
"""Get a batch of start app conversation tasks given their ids. Return None for any missing."""
|
||||
assert len(ids) < 100
|
||||
if len(ids) > 100:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f'Cannot request more than 100 start tasks at once, got {len(ids)}',
|
||||
)
|
||||
start_tasks = await app_conversation_start_task_service.batch_get_app_conversation_start_tasks(
|
||||
ids
|
||||
)
|
||||
@@ -587,6 +595,7 @@ async def get_conversation_skills(
|
||||
|
||||
Returns:
|
||||
JSONResponse: A JSON response containing the list of skills.
|
||||
Returns an empty list if the sandbox is not running.
|
||||
"""
|
||||
try:
|
||||
# Get agent server context (conversation, sandbox, sandbox_spec, agent_server_url)
|
||||
@@ -598,6 +607,8 @@ async def get_conversation_skills(
|
||||
)
|
||||
if isinstance(ctx, JSONResponse):
|
||||
return ctx
|
||||
if ctx is None:
|
||||
return JSONResponse(status_code=status.HTTP_200_OK, content={'skills': []})
|
||||
|
||||
# Load skills from all sources
|
||||
logger.info(f'Loading skills for conversation {conversation_id}')
|
||||
@@ -685,6 +696,7 @@ async def get_conversation_hooks(
|
||||
|
||||
Returns:
|
||||
JSONResponse: A JSON response containing the list of hook event types.
|
||||
Returns an empty list if the sandbox is not running.
|
||||
"""
|
||||
try:
|
||||
# Get agent server context (conversation, sandbox, sandbox_spec, agent_server_url)
|
||||
@@ -696,6 +708,8 @@ async def get_conversation_hooks(
|
||||
)
|
||||
if isinstance(ctx, JSONResponse):
|
||||
return ctx
|
||||
if ctx is None:
|
||||
return JSONResponse(status_code=status.HTTP_200_OK, content={'hooks': []})
|
||||
|
||||
from openhands.app_server.app_conversation.hook_loader import (
|
||||
fetch_hooks_from_agent_server,
|
||||
|
||||
@@ -77,8 +77,20 @@ class AppConversationService(ABC):
|
||||
id, starting a conversation, attaching a callback, and then running the
|
||||
conversation.
|
||||
|
||||
Yields an instance of AppConversationStartTask as updates occur, which can be used to determine
|
||||
the progress of the task.
|
||||
This method returns an async iterator that yields the same
|
||||
AppConversationStartTask repeatedly as status updates occur. Callers
|
||||
should iterate until the task reaches a terminal status::
|
||||
|
||||
async for task in service.start_app_conversation(request):
|
||||
if task.status in (
|
||||
AppConversationStartTaskStatus.READY,
|
||||
AppConversationStartTaskStatus.ERROR,
|
||||
):
|
||||
break
|
||||
|
||||
Status progression: WORKING → WAITING_FOR_SANDBOX → PREPARING_REPOSITORY
|
||||
→ RUNNING_SETUP_SCRIPT → SETTING_UP_GIT_HOOKS → SETTING_UP_SKILLS
|
||||
→ STARTING_CONVERSATION → READY (or ERROR at any point).
|
||||
"""
|
||||
# This is an abstract method - concrete implementations should provide real values
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
@@ -111,15 +123,21 @@ class AppConversationService(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def delete_app_conversation(self, conversation_id: UUID) -> bool:
|
||||
async def delete_app_conversation(
|
||||
self, conversation_id: UUID, skip_agent_server_delete: bool = False
|
||||
) -> bool:
|
||||
"""Delete a V1 conversation and all its associated data.
|
||||
|
||||
Args:
|
||||
conversation_id: The UUID of the conversation to delete.
|
||||
skip_agent_server_delete: If True, skip the agent server DELETE call.
|
||||
This should be set when the sandbox is shared with other
|
||||
conversations (e.g. created via /new) to avoid destabilizing
|
||||
the shared runtime.
|
||||
|
||||
This method should:
|
||||
1. Delete the conversation from the database
|
||||
2. Call the agent server to delete the conversation
|
||||
2. Call the agent server to delete the conversation (unless skipped)
|
||||
3. Clean up any related data
|
||||
|
||||
Returns True if the conversation was deleted successfully, False otherwise.
|
||||
|
||||
@@ -1740,13 +1740,19 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
conversations = await self._build_app_conversations([info])
|
||||
return conversations[0]
|
||||
|
||||
async def delete_app_conversation(self, conversation_id: UUID) -> bool:
|
||||
async def delete_app_conversation(
|
||||
self, conversation_id: UUID, skip_agent_server_delete: bool = False
|
||||
) -> bool:
|
||||
"""Delete a V1 conversation and all its associated data.
|
||||
|
||||
This method will also cascade delete all sub-conversations of the parent.
|
||||
|
||||
Args:
|
||||
conversation_id: The UUID of the conversation to delete.
|
||||
skip_agent_server_delete: If True, skip the agent server DELETE call.
|
||||
This should be set when the sandbox is shared with other
|
||||
conversations (e.g. created via /new) to avoid destabilizing
|
||||
the shared runtime.
|
||||
"""
|
||||
# Check if we have the required SQL implementation for transactional deletion
|
||||
if not isinstance(
|
||||
@@ -1772,8 +1778,9 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
await self._delete_sub_conversations(conversation_id)
|
||||
|
||||
# Now delete the parent conversation
|
||||
# Delete from agent server if sandbox is running
|
||||
await self._delete_from_agent_server(app_conversation)
|
||||
# Delete from agent server if sandbox is running (skip if sandbox is shared)
|
||||
if not skip_agent_server_delete:
|
||||
await self._delete_from_agent_server(app_conversation)
|
||||
|
||||
# Delete from database using the conversation info from app_conversation
|
||||
# AppConversation extends AppConversationInfo, so we can use it directly
|
||||
|
||||
@@ -278,6 +278,14 @@ class SQLAppConversationInfoService(AppConversationInfoService):
|
||||
rows = result_set.scalars().all()
|
||||
return [UUID(row.conversation_id) for row in rows]
|
||||
|
||||
async def count_conversations_by_sandbox_id(self, sandbox_id: str) -> int:
|
||||
query = await self._secure_select()
|
||||
query = query.where(StoredConversationMetadata.sandbox_id == sandbox_id)
|
||||
count_query = select(func.count()).select_from(query.subquery())
|
||||
result = await self.db_session.execute(count_query)
|
||||
count = result.scalar()
|
||||
return count or 0
|
||||
|
||||
async def get_app_conversation_info(
|
||||
self, conversation_id: UUID
|
||||
) -> AppConversationInfo | None:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user