mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Merge main into openhands/fix-saas-page-title to resolve conflicts
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
commit
a526d2e03d
47
.github/workflows/fe-e2e-tests.yml
vendored
Normal file
47
.github/workflows/fe-e2e-tests.yml
vendored
Normal file
@ -0,0 +1,47 @@
|
||||
# Workflow that runs frontend e2e tests with Playwright
|
||||
name: Run Frontend E2E Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
paths:
|
||||
- "frontend/**"
|
||||
- ".github/workflows/fe-e2e-tests.yml"
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ (github.head_ref && github.ref) || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
fe-e2e-test:
|
||||
name: FE E2E Tests
|
||||
runs-on: blacksmith-4vcpu-ubuntu-2204
|
||||
strategy:
|
||||
matrix:
|
||||
node-version: [22]
|
||||
fail-fast: true
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Set up Node.js
|
||||
uses: useblacksmith/setup-node@v5
|
||||
with:
|
||||
node-version: ${{ matrix.node-version }}
|
||||
- name: Install dependencies
|
||||
working-directory: ./frontend
|
||||
run: npm ci
|
||||
- name: Install Playwright browsers
|
||||
working-directory: ./frontend
|
||||
run: npx playwright install --with-deps chromium
|
||||
- name: Run Playwright tests
|
||||
working-directory: ./frontend
|
||||
run: npx playwright test --project=chromium
|
||||
- name: Upload Playwright report
|
||||
uses: actions/upload-artifact@v4
|
||||
if: always()
|
||||
with:
|
||||
name: playwright-report
|
||||
path: frontend/playwright-report/
|
||||
retention-days: 30
|
||||
@ -63,7 +63,7 @@ Frontend:
|
||||
- We use TanStack Query (fka React Query) for data fetching and cache management
|
||||
- Data Access Layer: API client methods are located in `frontend/src/api` and should never be called directly from UI components - they must always be wrapped with TanStack Query
|
||||
- Custom hooks are located in `frontend/src/hooks/query/` and `frontend/src/hooks/mutation/`
|
||||
- Query hooks should follow the pattern use[Resource] (e.g., `useConversationMicroagents`)
|
||||
- Query hooks should follow the pattern use[Resource] (e.g., `useConversationSkills`)
|
||||
- Mutation hooks should follow the pattern use[Action] (e.g., `useDeleteConversation`)
|
||||
- Architecture rule: UI components → TanStack Query hooks → Data Access Layer (`frontend/src/api`) → API endpoints
|
||||
|
||||
|
||||
@ -161,7 +161,7 @@ poetry run pytest ./tests/unit/test_*.py
|
||||
To reduce build time (e.g., if no changes were made to the client-runtime component), you can use an existing Docker
|
||||
container image by setting the SANDBOX_RUNTIME_CONTAINER_IMAGE environment variable to the desired Docker image.
|
||||
|
||||
Example: `export SANDBOX_RUNTIME_CONTAINER_IMAGE=ghcr.io/openhands/runtime:0.62-nikolaik`
|
||||
Example: `export SANDBOX_RUNTIME_CONTAINER_IMAGE=ghcr.io/openhands/runtime:1.0-nikolaik`
|
||||
|
||||
## Develop inside Docker container
|
||||
|
||||
|
||||
@ -8,7 +8,7 @@
|
||||
|
||||
<div align="center">
|
||||
<a href="https://github.com/OpenHands/OpenHands/blob/main/LICENSE"><img src="https://img.shields.io/badge/LICENSE-MIT-20B2AA?style=for-the-badge" alt="MIT License"></a>
|
||||
<a href="https://docs.google.com/spreadsheets/d/1wOUdFCMyY6Nt0AIqF705KN4JKOWgeI4wUGUP60krXXs/edit?gid=811504672#gid=811504672"><img src="https://img.shields.io/badge/SWEBench-72.8-00cc00?logoColor=FFE165&style=for-the-badge" alt="Benchmark Score"></a>
|
||||
<a href="https://docs.google.com/spreadsheets/d/1wOUdFCMyY6Nt0AIqF705KN4JKOWgeI4wUGUP60krXXs/edit?gid=811504672#gid=811504672"><img src="https://img.shields.io/badge/SWEBench-77.6-00cc00?logoColor=FFE165&style=for-the-badge" alt="Benchmark Score"></a>
|
||||
<br/>
|
||||
<a href="https://docs.openhands.dev/sdk"><img src="https://img.shields.io/badge/Documentation-000?logo=googledocs&logoColor=FFE165&style=for-the-badge" alt="Check out the documentation"></a>
|
||||
<a href="https://arxiv.org/abs/2511.03690"><img src="https://img.shields.io/badge/Paper-000?logoColor=FFE165&logo=arxiv&style=for-the-badge" alt="Tech Report"></a>
|
||||
|
||||
@ -12,7 +12,7 @@ services:
|
||||
- SANDBOX_API_HOSTNAME=host.docker.internal
|
||||
- DOCKER_HOST_ADDR=host.docker.internal
|
||||
#
|
||||
- SANDBOX_RUNTIME_CONTAINER_IMAGE=${SANDBOX_RUNTIME_CONTAINER_IMAGE:-ghcr.io/openhands/runtime:0.62-nikolaik}
|
||||
- SANDBOX_RUNTIME_CONTAINER_IMAGE=${SANDBOX_RUNTIME_CONTAINER_IMAGE:-ghcr.io/openhands/runtime:1.0-nikolaik}
|
||||
- SANDBOX_USER_ID=${SANDBOX_USER_ID:-1234}
|
||||
- WORKSPACE_MOUNT_PATH=${WORKSPACE_BASE:-$PWD/workspace}
|
||||
ports:
|
||||
|
||||
@ -7,7 +7,7 @@ services:
|
||||
image: openhands:latest
|
||||
container_name: openhands-app-${DATE:-}
|
||||
environment:
|
||||
- SANDBOX_RUNTIME_CONTAINER_IMAGE=${SANDBOX_RUNTIME_CONTAINER_IMAGE:-docker.openhands.dev/openhands/runtime:0.62-nikolaik}
|
||||
- SANDBOX_RUNTIME_CONTAINER_IMAGE=${SANDBOX_RUNTIME_CONTAINER_IMAGE:-docker.openhands.dev/openhands/runtime:1.0-nikolaik}
|
||||
#- SANDBOX_USER_ID=${SANDBOX_USER_ID:-1234} # enable this only if you want a specific non-root sandbox user but you will have to manually adjust permissions of ~/.openhands for this user
|
||||
- WORKSPACE_MOUNT_PATH=${WORKSPACE_BASE:-$PWD/workspace}
|
||||
ports:
|
||||
|
||||
@ -721,6 +721,7 @@
|
||||
"https://$WEB_HOST/oauth/keycloak/callback",
|
||||
"https://$WEB_HOST/oauth/keycloak/offline/callback",
|
||||
"https://$WEB_HOST/slack/keycloak-callback",
|
||||
"https://$WEB_HOST/oauth/device/keycloak-callback",
|
||||
"https://$WEB_HOST/api/email/verified",
|
||||
"/realms/$KEYCLOAK_REALM_NAME/$KEYCLOAK_CLIENT_ID/*"
|
||||
],
|
||||
|
||||
@ -13,6 +13,7 @@ from integrations.resolver_context import ResolverUserContext
|
||||
from integrations.types import ResolverViewInterface, UserData
|
||||
from integrations.utils import (
|
||||
ENABLE_PROACTIVE_CONVERSATION_STARTERS,
|
||||
ENABLE_V1_GITHUB_RESOLVER,
|
||||
HOST,
|
||||
HOST_URL,
|
||||
get_oh_labels,
|
||||
@ -95,7 +96,15 @@ async def get_user_v1_enabled_setting(user_id: str) -> bool:
|
||||
|
||||
Returns:
|
||||
True if V1 conversations are enabled for this user, False otherwise
|
||||
|
||||
Note:
|
||||
This function checks both the global environment variable kill switch AND
|
||||
the user's individual setting. Both must be true for the function to return true.
|
||||
"""
|
||||
# Check the global environment variable first
|
||||
if not ENABLE_V1_GITHUB_RESOLVER:
|
||||
return False
|
||||
|
||||
config = get_config()
|
||||
settings_store = SaasSettingsStore(
|
||||
user_id=user_id, session_maker=session_maker, config=config
|
||||
@ -178,6 +187,19 @@ class GithubIssue(ResolverViewInterface):
|
||||
|
||||
async def initialize_new_conversation(self) -> ConversationMetadata:
|
||||
# FIXME: Handle if initialize_conversation returns None
|
||||
|
||||
v1_enabled = await get_user_v1_enabled_setting(self.user_info.keycloak_user_id)
|
||||
logger.info(
|
||||
f'[GitHub V1]: User flag found for {self.user_info.keycloak_user_id} is {v1_enabled}'
|
||||
)
|
||||
if v1_enabled:
|
||||
# Create dummy conversationm metadata
|
||||
# Don't save to conversation store
|
||||
# V1 conversations are stored in a separate table
|
||||
return ConversationMetadata(
|
||||
conversation_id=uuid4().hex, selected_repository=self.full_repo_name
|
||||
)
|
||||
|
||||
conversation_metadata: ConversationMetadata = await initialize_conversation( # type: ignore[assignment]
|
||||
user_id=self.user_info.keycloak_user_id,
|
||||
conversation_id=None,
|
||||
@ -223,7 +245,7 @@ class GithubIssue(ResolverViewInterface):
|
||||
conversation_metadata: ConversationMetadata,
|
||||
):
|
||||
"""Create conversation using the legacy V0 system."""
|
||||
logger.info('[GitHub V1]: Creating V0 conversation')
|
||||
logger.info('[GitHub]: Creating V0 conversation')
|
||||
custom_secrets = await self._get_user_secrets()
|
||||
|
||||
user_instructions, conversation_instructions = await self._get_instructions(
|
||||
@ -369,7 +391,18 @@ class GithubPRComment(GithubIssueComment):
|
||||
return user_instructions, conversation_instructions
|
||||
|
||||
async def initialize_new_conversation(self) -> ConversationMetadata:
|
||||
# FIXME: Handle if initialize_conversation returns None
|
||||
v1_enabled = await get_user_v1_enabled_setting(self.user_info.keycloak_user_id)
|
||||
logger.info(
|
||||
f'[GitHub V1]: User flag found for {self.user_info.keycloak_user_id} is {v1_enabled}'
|
||||
)
|
||||
if v1_enabled:
|
||||
# Create dummy conversationm metadata
|
||||
# Don't save to conversation store
|
||||
# V1 conversations are stored in a separate table
|
||||
return ConversationMetadata(
|
||||
conversation_id=uuid4().hex, selected_repository=self.full_repo_name
|
||||
)
|
||||
|
||||
conversation_metadata: ConversationMetadata = await initialize_conversation( # type: ignore[assignment]
|
||||
user_id=self.user_info.keycloak_user_id,
|
||||
conversation_id=None,
|
||||
|
||||
@ -2,6 +2,7 @@ from openhands.app_server.user.user_context import UserContext
|
||||
from openhands.app_server.user.user_models import UserInfo
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.sdk.secret import SecretSource, StaticSecret
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
|
||||
|
||||
@ -44,11 +45,18 @@ class ResolverUserContext(UserContext):
|
||||
async def get_provider_tokens(self) -> PROVIDER_TOKEN_TYPE | None:
|
||||
return await self.saas_user_auth.get_provider_tokens()
|
||||
|
||||
async def get_secrets(self) -> dict[str, str]:
|
||||
async def get_secrets(self) -> dict[str, SecretSource]:
|
||||
"""Get secrets for the user, including custom secrets."""
|
||||
secrets = await self.saas_user_auth.get_secrets()
|
||||
if secrets:
|
||||
return dict(secrets.custom_secrets)
|
||||
# Convert custom secrets to StaticSecret objects for SDK compatibility
|
||||
# secrets.custom_secrets is of type Mapping[str, CustomSecret]
|
||||
converted_secrets = {}
|
||||
for key, custom_secret in secrets.custom_secrets.items():
|
||||
# Extract the secret value from CustomSecret and convert to StaticSecret
|
||||
secret_value = custom_secret.secret.get_secret_value()
|
||||
converted_secrets[key] = StaticSecret(value=secret_value)
|
||||
return converted_secrets
|
||||
return {}
|
||||
|
||||
async def get_mcp_api_key(self) -> str | None:
|
||||
|
||||
@ -51,6 +51,11 @@ ENABLE_SOLVABILITY_ANALYSIS = (
|
||||
os.getenv('ENABLE_SOLVABILITY_ANALYSIS', 'false').lower() == 'true'
|
||||
)
|
||||
|
||||
# Toggle for V1 GitHub resolver feature
|
||||
ENABLE_V1_GITHUB_RESOLVER = (
|
||||
os.getenv('ENABLE_V1_GITHUB_RESOLVER', 'false').lower() == 'true'
|
||||
)
|
||||
|
||||
|
||||
OPENHANDS_RESOLVER_TEMPLATES_DIR = 'openhands/integrations/templates/resolver/'
|
||||
jinja_env = Environment(loader=FileSystemLoader(OPENHANDS_RESOLVER_TEMPLATES_DIR))
|
||||
|
||||
@ -0,0 +1,49 @@
|
||||
"""Create device_codes table for OAuth 2.0 Device Flow
|
||||
|
||||
Revision ID: 084
|
||||
Revises: 083
|
||||
Create Date: 2024-12-10 12:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '084'
|
||||
down_revision = '083'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
"""Create device_codes table for OAuth 2.0 Device Flow."""
|
||||
op.create_table(
|
||||
'device_codes',
|
||||
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column('device_code', sa.String(length=128), nullable=False),
|
||||
sa.Column('user_code', sa.String(length=16), nullable=False),
|
||||
sa.Column('status', sa.String(length=32), nullable=False),
|
||||
sa.Column('keycloak_user_id', sa.String(length=255), nullable=True),
|
||||
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('authorized_at', sa.DateTime(timezone=True), nullable=True),
|
||||
# Rate limiting fields for RFC 8628 section 3.5 compliance
|
||||
sa.Column('last_poll_time', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('current_interval', sa.Integer(), nullable=False, default=5),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
)
|
||||
|
||||
# Create indexes for efficient lookups
|
||||
op.create_index(
|
||||
'ix_device_codes_device_code', 'device_codes', ['device_code'], unique=True
|
||||
)
|
||||
op.create_index(
|
||||
'ix_device_codes_user_code', 'device_codes', ['user_code'], unique=True
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
"""Drop device_codes table."""
|
||||
op.drop_index('ix_device_codes_user_code', table_name='device_codes')
|
||||
op.drop_index('ix_device_codes_device_code', table_name='device_codes')
|
||||
op.drop_table('device_codes')
|
||||
39
enterprise/poetry.lock
generated
39
enterprise/poetry.lock
generated
@ -4624,14 +4624,14 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "lmnr"
|
||||
version = "0.7.20"
|
||||
version = "0.7.24"
|
||||
description = "Python SDK for Laminar"
|
||||
optional = false
|
||||
python-versions = "<4,>=3.10"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "lmnr-0.7.20-py3-none-any.whl", hash = "sha256:5f9fa7444e6f96c25e097f66484ff29e632bdd1de0e9346948bf5595f4a8af38"},
|
||||
{file = "lmnr-0.7.20.tar.gz", hash = "sha256:1f484cd618db2d71af65f90a0b8b36d20d80dc91a5138b811575c8677bf7c4fd"},
|
||||
{file = "lmnr-0.7.24-py3-none-any.whl", hash = "sha256:ad780d4a62ece897048811f3368639c240a9329ab31027da8c96545137a3a08a"},
|
||||
{file = "lmnr-0.7.24.tar.gz", hash = "sha256:aa6973f46fc4ba95c9061c1feceb58afc02eb43c9376c21e32545371ff6123d7"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -4654,14 +4654,15 @@ tqdm = ">=4.0"
|
||||
|
||||
[package.extras]
|
||||
alephalpha = ["opentelemetry-instrumentation-alephalpha (>=0.47.1)"]
|
||||
all = ["opentelemetry-instrumentation-alephalpha (>=0.47.1)", "opentelemetry-instrumentation-bedrock (>=0.47.1)", "opentelemetry-instrumentation-chromadb (>=0.47.1)", "opentelemetry-instrumentation-cohere (>=0.47.1)", "opentelemetry-instrumentation-crewai (>=0.47.1)", "opentelemetry-instrumentation-haystack (>=0.47.1)", "opentelemetry-instrumentation-lancedb (>=0.47.1)", "opentelemetry-instrumentation-langchain (>=0.47.1)", "opentelemetry-instrumentation-llamaindex (>=0.47.1)", "opentelemetry-instrumentation-marqo (>=0.47.1)", "opentelemetry-instrumentation-mcp (>=0.47.1)", "opentelemetry-instrumentation-milvus (>=0.47.1)", "opentelemetry-instrumentation-mistralai (>=0.47.1)", "opentelemetry-instrumentation-ollama (>=0.47.1)", "opentelemetry-instrumentation-pinecone (>=0.47.1)", "opentelemetry-instrumentation-qdrant (>=0.47.1)", "opentelemetry-instrumentation-replicate (>=0.47.1)", "opentelemetry-instrumentation-sagemaker (>=0.47.1)", "opentelemetry-instrumentation-together (>=0.47.1)", "opentelemetry-instrumentation-transformers (>=0.47.1)", "opentelemetry-instrumentation-vertexai (>=0.47.1)", "opentelemetry-instrumentation-watsonx (>=0.47.1)", "opentelemetry-instrumentation-weaviate (>=0.47.1)"]
|
||||
all = ["opentelemetry-instrumentation-alephalpha (>=0.47.1)", "opentelemetry-instrumentation-bedrock (>=0.47.1)", "opentelemetry-instrumentation-chromadb (>=0.47.1)", "opentelemetry-instrumentation-cohere (>=0.47.1)", "opentelemetry-instrumentation-crewai (>=0.47.1)", "opentelemetry-instrumentation-haystack (>=0.47.1)", "opentelemetry-instrumentation-lancedb (>=0.47.1)", "opentelemetry-instrumentation-langchain (>=0.47.1,<0.48.0)", "opentelemetry-instrumentation-llamaindex (>=0.47.1)", "opentelemetry-instrumentation-marqo (>=0.47.1)", "opentelemetry-instrumentation-mcp (>=0.47.1)", "opentelemetry-instrumentation-milvus (>=0.47.1)", "opentelemetry-instrumentation-mistralai (>=0.47.1)", "opentelemetry-instrumentation-ollama (>=0.47.1)", "opentelemetry-instrumentation-pinecone (>=0.47.1)", "opentelemetry-instrumentation-qdrant (>=0.47.1)", "opentelemetry-instrumentation-replicate (>=0.47.1)", "opentelemetry-instrumentation-sagemaker (>=0.47.1)", "opentelemetry-instrumentation-together (>=0.47.1)", "opentelemetry-instrumentation-transformers (>=0.47.1)", "opentelemetry-instrumentation-vertexai (>=0.47.1)", "opentelemetry-instrumentation-watsonx (>=0.47.1)", "opentelemetry-instrumentation-weaviate (>=0.47.1)"]
|
||||
bedrock = ["opentelemetry-instrumentation-bedrock (>=0.47.1)"]
|
||||
chromadb = ["opentelemetry-instrumentation-chromadb (>=0.47.1)"]
|
||||
claude-agent-sdk = ["lmnr-claude-code-proxy (>=0.1.0a5)"]
|
||||
cohere = ["opentelemetry-instrumentation-cohere (>=0.47.1)"]
|
||||
crewai = ["opentelemetry-instrumentation-crewai (>=0.47.1)"]
|
||||
haystack = ["opentelemetry-instrumentation-haystack (>=0.47.1)"]
|
||||
lancedb = ["opentelemetry-instrumentation-lancedb (>=0.47.1)"]
|
||||
langchain = ["opentelemetry-instrumentation-langchain (>=0.47.1)"]
|
||||
langchain = ["opentelemetry-instrumentation-langchain (>=0.47.1,<0.48.0)"]
|
||||
llamaindex = ["opentelemetry-instrumentation-llamaindex (>=0.47.1)"]
|
||||
marqo = ["opentelemetry-instrumentation-marqo (>=0.47.1)"]
|
||||
mcp = ["opentelemetry-instrumentation-mcp (>=0.47.1)"]
|
||||
@ -5835,14 +5836,14 @@ llama = ["llama-index (>=0.12.29,<0.13.0)", "llama-index-core (>=0.12.29,<0.13.0
|
||||
|
||||
[[package]]
|
||||
name = "openhands-agent-server"
|
||||
version = "1.4.1"
|
||||
version = "1.6.0"
|
||||
description = "OpenHands Agent Server - REST/WebSocket interface for OpenHands AI Agent"
|
||||
optional = false
|
||||
python-versions = ">=3.12"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "openhands_agent_server-1.4.1-py3-none-any.whl", hash = "sha256:1e621d15215a48e2398e23c58a791347f06c215c2344053aeb26b562c34a44ee"},
|
||||
{file = "openhands_agent_server-1.4.1.tar.gz", hash = "sha256:03010a5c8d63bbd5b088458eb75308ef16559018140d75a3644ae5bbc3531bbf"},
|
||||
{file = "openhands_agent_server-1.6.0-py3-none-any.whl", hash = "sha256:e6ae865ac3e7a96b234e10a0faad23f6210e025bbf7721cb66bc7a71d160848c"},
|
||||
{file = "openhands_agent_server-1.6.0.tar.gz", hash = "sha256:44ce7694ae2d4bb0666d318ef13e6618bd4dc73022c60354839fe6130e67d02a"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -5859,7 +5860,7 @@ wsproto = ">=1.2.0"
|
||||
|
||||
[[package]]
|
||||
name = "openhands-ai"
|
||||
version = "0.0.0-post.5625+0a98f165e"
|
||||
version = "0.0.0-post.5687+7853b41ad"
|
||||
description = "OpenHands: Code Less, Make More"
|
||||
optional = false
|
||||
python-versions = "^3.12,<3.14"
|
||||
@ -5901,9 +5902,9 @@ memory-profiler = "^0.61.0"
|
||||
numpy = "*"
|
||||
openai = "2.8.0"
|
||||
openhands-aci = "0.3.2"
|
||||
openhands-agent-server = "1.4.1"
|
||||
openhands-sdk = "1.4.1"
|
||||
openhands-tools = "1.4.1"
|
||||
openhands-agent-server = "1.6.0"
|
||||
openhands-sdk = "1.6.0"
|
||||
openhands-tools = "1.6.0"
|
||||
opentelemetry-api = "^1.33.1"
|
||||
opentelemetry-exporter-otlp-proto-grpc = "^1.33.1"
|
||||
pathspec = "^0.12.1"
|
||||
@ -5959,14 +5960,14 @@ url = ".."
|
||||
|
||||
[[package]]
|
||||
name = "openhands-sdk"
|
||||
version = "1.4.1"
|
||||
version = "1.6.0"
|
||||
description = "OpenHands SDK - Core functionality for building AI agents"
|
||||
optional = false
|
||||
python-versions = ">=3.12"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "openhands_sdk-1.4.1-py3-none-any.whl", hash = "sha256:70e453eab7f9ab6b705198c2615fdd844b21e14b29d78afaf62724f4a440bcdc"},
|
||||
{file = "openhands_sdk-1.4.1.tar.gz", hash = "sha256:37365de25ed57cf8cc2a8003ab4d7a1fe2a40b49c8e8da84a3f1ea2b522eddf2"},
|
||||
{file = "openhands_sdk-1.6.0-py3-none-any.whl", hash = "sha256:94d2f87fb35406373da6728ae2d88584137f9e9b67fa0e940444c72f2e44e7d3"},
|
||||
{file = "openhands_sdk-1.6.0.tar.gz", hash = "sha256:f45742350e3874a7f5b08befc4a9d5adc7e4454f7ab5f8391c519eee3116090f"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -5974,7 +5975,7 @@ deprecation = ">=2.1.0"
|
||||
fastmcp = ">=2.11.3"
|
||||
httpx = ">=0.27.0"
|
||||
litellm = ">=1.80.7"
|
||||
lmnr = ">=0.7.20"
|
||||
lmnr = ">=0.7.24"
|
||||
pydantic = ">=2.11.7"
|
||||
python-frontmatter = ">=1.1.0"
|
||||
python-json-logger = ">=3.3.0"
|
||||
@ -5986,14 +5987,14 @@ boto3 = ["boto3 (>=1.35.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "openhands-tools"
|
||||
version = "1.4.1"
|
||||
version = "1.6.0"
|
||||
description = "OpenHands Tools - Runtime tools for AI agents"
|
||||
optional = false
|
||||
python-versions = ">=3.12"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "openhands_tools-1.4.1-py3-none-any.whl", hash = "sha256:8f40189a08bf80eb4a33219ee9ccc528f9c6c4f2d5c9ab807b06c3f3fe21a612"},
|
||||
{file = "openhands_tools-1.4.1.tar.gz", hash = "sha256:4c0caf87f520a207d9035191c77b7b5c53eeec996350a24ffaf7f740a6566b22"},
|
||||
{file = "openhands_tools-1.6.0-py3-none-any.whl", hash = "sha256:176556d44186536751b23fe052d3505492cc2afb8d52db20fb7a2cc0169cd57a"},
|
||||
{file = "openhands_tools-1.6.0.tar.gz", hash = "sha256:d07ba31050fd4a7891a4c48388aa53ce9f703e17064ddbd59146d6c77e5980b3"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
||||
@ -34,6 +34,7 @@ from server.routes.integration.jira_dc import jira_dc_integration_router # noqa
|
||||
from server.routes.integration.linear import linear_integration_router # noqa: E402
|
||||
from server.routes.integration.slack import slack_router # noqa: E402
|
||||
from server.routes.mcp_patch import patch_mcp_server # noqa: E402
|
||||
from server.routes.oauth_device import oauth_device_router # noqa: E402
|
||||
from server.routes.readiness import readiness_router # noqa: E402
|
||||
from server.routes.user import saas_user_router # noqa: E402
|
||||
|
||||
@ -60,6 +61,7 @@ base_app.mount('/internal/metrics', metrics_app())
|
||||
base_app.include_router(readiness_router) # Add routes for readiness checks
|
||||
base_app.include_router(api_router) # Add additional route for github auth
|
||||
base_app.include_router(oauth_router) # Add additional route for oauth callback
|
||||
base_app.include_router(oauth_device_router) # Add OAuth 2.0 Device Flow routes
|
||||
base_app.include_router(saas_user_router) # Add additional route SAAS user calls
|
||||
base_app.include_router(
|
||||
billing_router
|
||||
|
||||
@ -1,331 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import socketio
|
||||
from server.clustered_conversation_manager import ClusteredConversationManager
|
||||
from server.saas_nested_conversation_manager import SaasNestedConversationManager
|
||||
|
||||
from openhands.core.config import LLMConfig, OpenHandsConfig
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.server.config.server_config import ServerConfig
|
||||
from openhands.server.conversation_manager.conversation_manager import (
|
||||
ConversationManager,
|
||||
)
|
||||
from openhands.server.data_models.agent_loop_info import AgentLoopInfo
|
||||
from openhands.server.monitoring import MonitoringListener
|
||||
from openhands.server.session.conversation import ServerConversation
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
from openhands.storage.files import FileStore
|
||||
from openhands.utils.async_utils import wait_all
|
||||
|
||||
_LEGACY_ENTRY_TIMEOUT_SECONDS = 3600
|
||||
|
||||
|
||||
@dataclass
|
||||
class LegacyCacheEntry:
|
||||
"""Cache entry for legacy mode status."""
|
||||
|
||||
is_legacy: bool
|
||||
timestamp: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class LegacyConversationManager(ConversationManager):
|
||||
"""
|
||||
Conversation manager for use while migrating - since existing conversations are not nested!
|
||||
Separate class from SaasNestedConversationManager so it can be easliy removed in a few weeks.
|
||||
(As of 2025-07-23)
|
||||
"""
|
||||
|
||||
sio: socketio.AsyncServer
|
||||
config: OpenHandsConfig
|
||||
server_config: ServerConfig
|
||||
file_store: FileStore
|
||||
conversation_manager: SaasNestedConversationManager
|
||||
legacy_conversation_manager: ClusteredConversationManager
|
||||
_legacy_cache: dict[str, LegacyCacheEntry] = field(default_factory=dict)
|
||||
|
||||
async def __aenter__(self):
|
||||
await wait_all(
|
||||
[
|
||||
self.conversation_manager.__aenter__(),
|
||||
self.legacy_conversation_manager.__aenter__(),
|
||||
]
|
||||
)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||
await wait_all(
|
||||
[
|
||||
self.conversation_manager.__aexit__(exc_type, exc_value, traceback),
|
||||
self.legacy_conversation_manager.__aexit__(
|
||||
exc_type, exc_value, traceback
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
async def request_llm_completion(
|
||||
self,
|
||||
sid: str,
|
||||
service_id: str,
|
||||
llm_config: LLMConfig,
|
||||
messages: list[dict[str, str]],
|
||||
) -> str:
|
||||
session = self.get_agent_session(sid)
|
||||
llm_registry = session.llm_registry
|
||||
return llm_registry.request_extraneous_completion(
|
||||
service_id, llm_config, messages
|
||||
)
|
||||
|
||||
async def attach_to_conversation(
|
||||
self, sid: str, user_id: str | None = None
|
||||
) -> ServerConversation | None:
|
||||
if await self.should_start_in_legacy_mode(sid):
|
||||
return await self.legacy_conversation_manager.attach_to_conversation(
|
||||
sid, user_id
|
||||
)
|
||||
return await self.conversation_manager.attach_to_conversation(sid, user_id)
|
||||
|
||||
async def detach_from_conversation(self, conversation: ServerConversation):
|
||||
if await self.should_start_in_legacy_mode(conversation.sid):
|
||||
return await self.legacy_conversation_manager.detach_from_conversation(
|
||||
conversation
|
||||
)
|
||||
return await self.conversation_manager.detach_from_conversation(conversation)
|
||||
|
||||
async def join_conversation(
|
||||
self,
|
||||
sid: str,
|
||||
connection_id: str,
|
||||
settings: Settings,
|
||||
user_id: str | None,
|
||||
) -> AgentLoopInfo:
|
||||
if await self.should_start_in_legacy_mode(sid):
|
||||
return await self.legacy_conversation_manager.join_conversation(
|
||||
sid, connection_id, settings, user_id
|
||||
)
|
||||
return await self.conversation_manager.join_conversation(
|
||||
sid, connection_id, settings, user_id
|
||||
)
|
||||
|
||||
def get_agent_session(self, sid: str):
|
||||
session = self.legacy_conversation_manager.get_agent_session(sid)
|
||||
if session is None:
|
||||
session = self.conversation_manager.get_agent_session(sid)
|
||||
return session
|
||||
|
||||
async def get_running_agent_loops(
|
||||
self, user_id: str | None = None, filter_to_sids: set[str] | None = None
|
||||
) -> set[str]:
|
||||
if filter_to_sids and len(filter_to_sids) == 1:
|
||||
sid = next(iter(filter_to_sids))
|
||||
if await self.should_start_in_legacy_mode(sid):
|
||||
return await self.legacy_conversation_manager.get_running_agent_loops(
|
||||
user_id, filter_to_sids
|
||||
)
|
||||
return await self.conversation_manager.get_running_agent_loops(
|
||||
user_id, filter_to_sids
|
||||
)
|
||||
|
||||
# Get all running agent loops from both managers
|
||||
agent_loops, legacy_agent_loops = await wait_all(
|
||||
[
|
||||
self.conversation_manager.get_running_agent_loops(
|
||||
user_id, filter_to_sids
|
||||
),
|
||||
self.legacy_conversation_manager.get_running_agent_loops(
|
||||
user_id, filter_to_sids
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# Combine the results
|
||||
result = set()
|
||||
for sid in legacy_agent_loops:
|
||||
if await self.should_start_in_legacy_mode(sid):
|
||||
result.add(sid)
|
||||
|
||||
for sid in agent_loops:
|
||||
if not await self.should_start_in_legacy_mode(sid):
|
||||
result.add(sid)
|
||||
|
||||
return result
|
||||
|
||||
async def is_agent_loop_running(self, sid: str) -> bool:
|
||||
return bool(await self.get_running_agent_loops(filter_to_sids={sid}))
|
||||
|
||||
async def get_connections(
|
||||
self, user_id: str | None = None, filter_to_sids: set[str] | None = None
|
||||
) -> dict[str, str]:
|
||||
if filter_to_sids and len(filter_to_sids) == 1:
|
||||
sid = next(iter(filter_to_sids))
|
||||
if await self.should_start_in_legacy_mode(sid):
|
||||
return await self.legacy_conversation_manager.get_connections(
|
||||
user_id, filter_to_sids
|
||||
)
|
||||
return await self.conversation_manager.get_connections(
|
||||
user_id, filter_to_sids
|
||||
)
|
||||
agent_loops, legacy_agent_loops = await wait_all(
|
||||
[
|
||||
self.conversation_manager.get_connections(user_id, filter_to_sids),
|
||||
self.legacy_conversation_manager.get_connections(
|
||||
user_id, filter_to_sids
|
||||
),
|
||||
]
|
||||
)
|
||||
legacy_agent_loops.update(agent_loops)
|
||||
return legacy_agent_loops
|
||||
|
||||
async def maybe_start_agent_loop(
|
||||
self,
|
||||
sid: str,
|
||||
settings: Settings,
|
||||
user_id: str, # type: ignore[override]
|
||||
initial_user_msg: MessageAction | None = None,
|
||||
replay_json: str | None = None,
|
||||
) -> AgentLoopInfo:
|
||||
if await self.should_start_in_legacy_mode(sid):
|
||||
return await self.legacy_conversation_manager.maybe_start_agent_loop(
|
||||
sid, settings, user_id, initial_user_msg, replay_json
|
||||
)
|
||||
return await self.conversation_manager.maybe_start_agent_loop(
|
||||
sid, settings, user_id, initial_user_msg, replay_json
|
||||
)
|
||||
|
||||
async def send_to_event_stream(self, connection_id: str, data: dict):
|
||||
return await self.legacy_conversation_manager.send_to_event_stream(
|
||||
connection_id, data
|
||||
)
|
||||
|
||||
async def send_event_to_conversation(self, sid: str, data: dict):
|
||||
if await self.should_start_in_legacy_mode(sid):
|
||||
await self.legacy_conversation_manager.send_event_to_conversation(sid, data)
|
||||
await self.conversation_manager.send_event_to_conversation(sid, data)
|
||||
|
||||
async def disconnect_from_session(self, connection_id: str):
|
||||
return await self.legacy_conversation_manager.disconnect_from_session(
|
||||
connection_id
|
||||
)
|
||||
|
||||
async def close_session(self, sid: str):
|
||||
if await self.should_start_in_legacy_mode(sid):
|
||||
await self.legacy_conversation_manager.close_session(sid)
|
||||
await self.conversation_manager.close_session(sid)
|
||||
|
||||
async def get_agent_loop_info(
|
||||
self, user_id: str | None = None, filter_to_sids: set[str] | None = None
|
||||
) -> list[AgentLoopInfo]:
|
||||
if filter_to_sids and len(filter_to_sids) == 1:
|
||||
sid = next(iter(filter_to_sids))
|
||||
if await self.should_start_in_legacy_mode(sid):
|
||||
return await self.legacy_conversation_manager.get_agent_loop_info(
|
||||
user_id, filter_to_sids
|
||||
)
|
||||
return await self.conversation_manager.get_agent_loop_info(
|
||||
user_id, filter_to_sids
|
||||
)
|
||||
agent_loops, legacy_agent_loops = await wait_all(
|
||||
[
|
||||
self.conversation_manager.get_agent_loop_info(user_id, filter_to_sids),
|
||||
self.legacy_conversation_manager.get_agent_loop_info(
|
||||
user_id, filter_to_sids
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# Combine results
|
||||
result = []
|
||||
legacy_sids = set()
|
||||
|
||||
# Add legacy agent loops
|
||||
for agent_loop in legacy_agent_loops:
|
||||
if await self.should_start_in_legacy_mode(agent_loop.conversation_id):
|
||||
result.append(agent_loop)
|
||||
legacy_sids.add(agent_loop.conversation_id)
|
||||
|
||||
# Add non-legacy agent loops
|
||||
for agent_loop in agent_loops:
|
||||
if (
|
||||
agent_loop.conversation_id not in legacy_sids
|
||||
and not await self.should_start_in_legacy_mode(
|
||||
agent_loop.conversation_id
|
||||
)
|
||||
):
|
||||
result.append(agent_loop)
|
||||
|
||||
return result
|
||||
|
||||
def _cleanup_expired_cache_entries(self):
|
||||
"""Remove expired entries from the local cache."""
|
||||
current_time = time.time()
|
||||
expired_keys = [
|
||||
key
|
||||
for key, entry in self._legacy_cache.items()
|
||||
if current_time - entry.timestamp > _LEGACY_ENTRY_TIMEOUT_SECONDS
|
||||
]
|
||||
for key in expired_keys:
|
||||
del self._legacy_cache[key]
|
||||
|
||||
async def should_start_in_legacy_mode(self, conversation_id: str) -> bool:
|
||||
"""
|
||||
Check if a conversation should run in legacy mode by directly checking the runtime.
|
||||
The /list method does not include stopped conversations even though the PVC for these
|
||||
may not yet have been deleted, so we need to check /sessions/{session_id} directly.
|
||||
"""
|
||||
# Clean up expired entries periodically
|
||||
self._cleanup_expired_cache_entries()
|
||||
|
||||
# First check the local cache
|
||||
if conversation_id in self._legacy_cache:
|
||||
cached_entry = self._legacy_cache[conversation_id]
|
||||
# Check if the cached value is still valid
|
||||
if time.time() - cached_entry.timestamp <= _LEGACY_ENTRY_TIMEOUT_SECONDS:
|
||||
return cached_entry.is_legacy
|
||||
|
||||
# If not in cache or expired, check the runtime directly
|
||||
runtime = await self.conversation_manager._get_runtime(conversation_id)
|
||||
is_legacy = self.is_legacy_runtime(runtime)
|
||||
|
||||
# Cache the result with current timestamp
|
||||
self._legacy_cache[conversation_id] = LegacyCacheEntry(is_legacy, time.time())
|
||||
|
||||
return is_legacy
|
||||
|
||||
def is_legacy_runtime(self, runtime: dict | None) -> bool:
|
||||
"""
|
||||
Determine if a runtime is a legacy runtime based on its command.
|
||||
|
||||
Args:
|
||||
runtime: The runtime dictionary or None if not found
|
||||
|
||||
Returns:
|
||||
bool: True if this is a legacy runtime, False otherwise
|
||||
"""
|
||||
if runtime is None:
|
||||
return False
|
||||
return 'openhands.server' not in runtime['command']
|
||||
|
||||
@classmethod
|
||||
def get_instance(
|
||||
cls,
|
||||
sio: socketio.AsyncServer,
|
||||
config: OpenHandsConfig,
|
||||
file_store: FileStore,
|
||||
server_config: ServerConfig,
|
||||
monitoring_listener: MonitoringListener,
|
||||
) -> ConversationManager:
|
||||
return LegacyConversationManager(
|
||||
sio=sio,
|
||||
config=config,
|
||||
server_config=server_config,
|
||||
file_store=file_store,
|
||||
conversation_manager=SaasNestedConversationManager.get_instance(
|
||||
sio, config, file_store, server_config, monitoring_listener
|
||||
),
|
||||
legacy_conversation_manager=ClusteredConversationManager.get_instance(
|
||||
sio, config, file_store, server_config, monitoring_listener
|
||||
),
|
||||
)
|
||||
@ -152,17 +152,22 @@ class SetAuthCookieMiddleware:
|
||||
return False
|
||||
path = request.url.path
|
||||
|
||||
is_api_that_should_attach = path.startswith('/api') and path not in (
|
||||
ignore_paths = (
|
||||
'/api/options/config',
|
||||
'/api/keycloak/callback',
|
||||
'/api/billing/success',
|
||||
'/api/billing/cancel',
|
||||
'/api/billing/customer-setup-success',
|
||||
'/api/billing/stripe-webhook',
|
||||
'/oauth/device/authorize',
|
||||
'/oauth/device/token',
|
||||
)
|
||||
if path in ignore_paths:
|
||||
return False
|
||||
|
||||
is_mcp = path.startswith('/mcp')
|
||||
return is_api_that_should_attach or is_mcp
|
||||
is_api_route = path.startswith('/api')
|
||||
return is_api_route or is_mcp
|
||||
|
||||
async def _logout(self, request: Request):
|
||||
# Log out of keycloak - this prevents issues where you did not log in with the idp you believe you used
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import hmac
|
||||
import os
|
||||
@ -58,7 +59,8 @@ async def github_events(
|
||||
)
|
||||
|
||||
try:
|
||||
payload = await request.body()
|
||||
# Add timeout to prevent hanging on slow/stalled clients
|
||||
payload = await asyncio.wait_for(request.body(), timeout=15.0)
|
||||
verify_github_signature(payload, x_hub_signature_256)
|
||||
|
||||
payload_data = await request.json()
|
||||
@ -78,6 +80,12 @@ async def github_events(
|
||||
status_code=200,
|
||||
content={'message': 'GitHub events endpoint reached successfully.'},
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning('GitHub webhook request timed out waiting for request body')
|
||||
return JSONResponse(
|
||||
status_code=408,
|
||||
content={'error': 'Request timeout - client took too long to send data.'},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f'Error processing GitHub event: {e}')
|
||||
return JSONResponse(status_code=400, content={'error': 'Invalid payload.'})
|
||||
|
||||
324
enterprise/server/routes/oauth_device.py
Normal file
324
enterprise/server/routes/oauth_device.py
Normal file
@ -0,0 +1,324 @@
|
||||
"""OAuth 2.0 Device Flow endpoints for CLI authentication."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Form, HTTPException, Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
from storage.database import session_maker
|
||||
from storage.device_code_store import DeviceCodeStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth import get_user_id
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
DEVICE_CODE_EXPIRES_IN = 600 # 10 minutes
|
||||
DEVICE_TOKEN_POLL_INTERVAL = 5 # seconds
|
||||
|
||||
API_KEY_NAME = 'Device Link Access Key'
|
||||
KEY_EXPIRATION_TIME = timedelta(days=1) # Key expires in 24 hours
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class DeviceAuthorizationResponse(BaseModel):
|
||||
device_code: str
|
||||
user_code: str
|
||||
verification_uri: str
|
||||
verification_uri_complete: str
|
||||
expires_in: int
|
||||
interval: int
|
||||
|
||||
|
||||
class DeviceTokenResponse(BaseModel):
|
||||
access_token: str # This will be the user's API key
|
||||
token_type: str = 'Bearer'
|
||||
expires_in: Optional[int] = None # API keys may not have expiration
|
||||
|
||||
|
||||
class DeviceTokenErrorResponse(BaseModel):
|
||||
error: str
|
||||
error_description: Optional[str] = None
|
||||
interval: Optional[int] = None # Required for slow_down error
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Router + stores
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
oauth_device_router = APIRouter(prefix='/oauth/device')
|
||||
device_code_store = DeviceCodeStore(session_maker)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _oauth_error(
|
||||
status_code: int,
|
||||
error: str,
|
||||
description: str,
|
||||
interval: Optional[int] = None,
|
||||
) -> JSONResponse:
|
||||
"""Return a JSON OAuth-style error response."""
|
||||
return JSONResponse(
|
||||
status_code=status_code,
|
||||
content=DeviceTokenErrorResponse(
|
||||
error=error,
|
||||
error_description=description,
|
||||
interval=interval,
|
||||
).model_dump(),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@oauth_device_router.post('/authorize', response_model=DeviceAuthorizationResponse)
|
||||
async def device_authorization(
|
||||
http_request: Request,
|
||||
) -> DeviceAuthorizationResponse:
|
||||
"""Start device flow by generating device and user codes."""
|
||||
try:
|
||||
device_code_entry = device_code_store.create_device_code(
|
||||
expires_in=DEVICE_CODE_EXPIRES_IN,
|
||||
)
|
||||
|
||||
base_url = str(http_request.base_url).rstrip('/')
|
||||
verification_uri = f'{base_url}/oauth/device/verify'
|
||||
verification_uri_complete = (
|
||||
f'{verification_uri}?user_code={device_code_entry.user_code}'
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'Device authorization initiated',
|
||||
extra={'user_code': device_code_entry.user_code},
|
||||
)
|
||||
|
||||
return DeviceAuthorizationResponse(
|
||||
device_code=device_code_entry.device_code,
|
||||
user_code=device_code_entry.user_code,
|
||||
verification_uri=verification_uri,
|
||||
verification_uri_complete=verification_uri_complete,
|
||||
expires_in=DEVICE_CODE_EXPIRES_IN,
|
||||
interval=device_code_entry.current_interval,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception('Error in device authorization: %s', str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Internal server error',
|
||||
) from e
|
||||
|
||||
|
||||
@oauth_device_router.post('/token')
|
||||
async def device_token(device_code: str = Form(...)):
|
||||
"""Poll for a token until the user authorizes or the code expires."""
|
||||
try:
|
||||
device_code_entry = device_code_store.get_by_device_code(device_code)
|
||||
|
||||
if not device_code_entry:
|
||||
return _oauth_error(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
'invalid_grant',
|
||||
'Invalid device code',
|
||||
)
|
||||
|
||||
# Check rate limiting (RFC 8628 section 3.5)
|
||||
is_too_fast, current_interval = device_code_entry.check_rate_limit()
|
||||
if is_too_fast:
|
||||
# Update poll time and increase interval
|
||||
device_code_store.update_poll_time(device_code, increase_interval=True)
|
||||
logger.warning(
|
||||
'Client polling too fast, returning slow_down error',
|
||||
extra={
|
||||
'device_code': device_code[:8] + '...', # Log partial for privacy
|
||||
'new_interval': current_interval,
|
||||
},
|
||||
)
|
||||
return _oauth_error(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
'slow_down',
|
||||
f'Polling too frequently. Wait at least {current_interval} seconds between requests.',
|
||||
interval=current_interval,
|
||||
)
|
||||
|
||||
# Update poll time for successful rate limit check
|
||||
device_code_store.update_poll_time(device_code, increase_interval=False)
|
||||
|
||||
if device_code_entry.is_expired():
|
||||
return _oauth_error(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
'expired_token',
|
||||
'Device code has expired',
|
||||
)
|
||||
|
||||
if device_code_entry.status == 'denied':
|
||||
return _oauth_error(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
'access_denied',
|
||||
'User denied the authorization request',
|
||||
)
|
||||
|
||||
if device_code_entry.status == 'pending':
|
||||
return _oauth_error(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
'authorization_pending',
|
||||
'User has not yet completed authorization',
|
||||
)
|
||||
|
||||
if device_code_entry.status == 'authorized':
|
||||
# Retrieve the specific API key for this device using the user_code
|
||||
api_key_store = ApiKeyStore.get_instance()
|
||||
device_key_name = f'{API_KEY_NAME} ({device_code_entry.user_code})'
|
||||
device_api_key = api_key_store.retrieve_api_key_by_name(
|
||||
device_code_entry.keycloak_user_id, device_key_name
|
||||
)
|
||||
|
||||
if not device_api_key:
|
||||
logger.error(
|
||||
'No device API key found for authorized device',
|
||||
extra={
|
||||
'user_id': device_code_entry.keycloak_user_id,
|
||||
'user_code': device_code_entry.user_code,
|
||||
},
|
||||
)
|
||||
return _oauth_error(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
'server_error',
|
||||
'API key not found',
|
||||
)
|
||||
|
||||
# Return the API key as access_token
|
||||
return DeviceTokenResponse(
|
||||
access_token=device_api_key,
|
||||
)
|
||||
|
||||
# Fallback for unexpected status values
|
||||
logger.error(
|
||||
'Unknown device code status',
|
||||
extra={'status': device_code_entry.status},
|
||||
)
|
||||
return _oauth_error(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
'server_error',
|
||||
'Unknown device code status',
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception('Error in device token: %s', str(e))
|
||||
return _oauth_error(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
'server_error',
|
||||
'Internal server error',
|
||||
)
|
||||
|
||||
|
||||
@oauth_device_router.post('/verify-authenticated')
|
||||
async def device_verification_authenticated(
|
||||
user_code: str = Form(...),
|
||||
user_id: str = Depends(get_user_id),
|
||||
):
|
||||
"""Process device verification for authenticated users (called by frontend)."""
|
||||
try:
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail='Authentication required',
|
||||
)
|
||||
|
||||
# Validate device code
|
||||
device_code_entry = device_code_store.get_by_user_code(user_code)
|
||||
if not device_code_entry:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='The device code is invalid or has expired.',
|
||||
)
|
||||
|
||||
if not device_code_entry.is_pending():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='This device code has already been processed.',
|
||||
)
|
||||
|
||||
# First, authorize the device code
|
||||
success = device_code_store.authorize_device_code(
|
||||
user_code=user_code,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
if not success:
|
||||
logger.error(
|
||||
'Failed to authorize device code',
|
||||
extra={'user_code': user_code, 'user_id': user_id},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to authorize the device. Please try again.',
|
||||
)
|
||||
|
||||
# Only create API key AFTER successful authorization
|
||||
api_key_store = ApiKeyStore.get_instance()
|
||||
try:
|
||||
# Create a unique API key for this device using user_code in the name
|
||||
device_key_name = f'{API_KEY_NAME} ({user_code})'
|
||||
api_key_store.create_api_key(
|
||||
user_id,
|
||||
name=device_key_name,
|
||||
expires_at=datetime.now(UTC) + KEY_EXPIRATION_TIME,
|
||||
)
|
||||
logger.info(
|
||||
'Created new device API key for user after successful authorization',
|
||||
extra={'user_id': user_id, 'user_code': user_code},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Failed to create device API key after authorization: %s', str(e)
|
||||
)
|
||||
|
||||
# Clean up: revert the device authorization since API key creation failed
|
||||
# This prevents the device from being in an authorized state without an API key
|
||||
try:
|
||||
device_code_store.deny_device_code(user_code)
|
||||
logger.info(
|
||||
'Reverted device authorization due to API key creation failure',
|
||||
extra={'user_code': user_code, 'user_id': user_id},
|
||||
)
|
||||
except Exception as cleanup_error:
|
||||
logger.exception(
|
||||
'Failed to revert device authorization during cleanup: %s',
|
||||
str(cleanup_error),
|
||||
)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to create API key for device access.',
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'Device code authorized with API key successfully',
|
||||
extra={'user_code': user_code, 'user_id': user_id},
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content={'message': 'Device authorized successfully!'},
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception('Error in device verification: %s', str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='An unexpected error occurred. Please try again.',
|
||||
)
|
||||
@ -31,6 +31,7 @@ from openhands.events.event_store import EventStore
|
||||
from openhands.events.serialization.event import event_to_dict
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderHandler
|
||||
from openhands.runtime.impl.remote.remote_runtime import RemoteRuntime
|
||||
from openhands.runtime.plugins.vscode import VSCodeRequirement
|
||||
from openhands.runtime.runtime_status import RuntimeStatus
|
||||
from openhands.server.config.server_config import ServerConfig
|
||||
from openhands.server.constants import ROOM_KEY
|
||||
@ -71,10 +72,13 @@ RUNTIME_CONVERSATION_URL = RUNTIME_URL_PATTERN + (
|
||||
)
|
||||
|
||||
RUNTIME_USERNAME = os.getenv('RUNTIME_USERNAME')
|
||||
|
||||
SU_TO_USER = os.getenv('SU_TO_USER', 'false')
|
||||
truthy = {'1', 'true', 't', 'yes', 'y', 'on'}
|
||||
SU_TO_USER = str(SU_TO_USER.lower() in truthy).lower()
|
||||
|
||||
DISABLE_VSCODE_PLUGIN = os.getenv('DISABLE_VSCODE_PLUGIN', 'false').lower() == 'true'
|
||||
|
||||
# Time in seconds before a Redis entry is considered expired if not refreshed
|
||||
_REDIS_ENTRY_TIMEOUT_SECONDS = 300
|
||||
|
||||
@ -799,6 +803,7 @@ class SaasNestedConversationManager(ConversationManager):
|
||||
env_vars['INIT_GIT_IN_EMPTY_WORKSPACE'] = '1'
|
||||
env_vars['ENABLE_V1'] = '0'
|
||||
env_vars['SU_TO_USER'] = SU_TO_USER
|
||||
env_vars['DISABLE_VSCODE_PLUGIN'] = str(DISABLE_VSCODE_PLUGIN).lower()
|
||||
|
||||
# We need this for LLM traces tracking to identify the source of the LLM calls
|
||||
env_vars['WEB_HOST'] = WEB_HOST
|
||||
@ -814,11 +819,18 @@ class SaasNestedConversationManager(ConversationManager):
|
||||
if self._runtime_container_image:
|
||||
config.sandbox.runtime_container_image = self._runtime_container_image
|
||||
|
||||
plugins = [
|
||||
plugin
|
||||
for plugin in agent.sandbox_plugins
|
||||
if not (DISABLE_VSCODE_PLUGIN and isinstance(plugin, VSCodeRequirement))
|
||||
]
|
||||
logger.info(f'Loaded plugins for runtime {sid}: {plugins}')
|
||||
|
||||
runtime = RemoteRuntime(
|
||||
config=config,
|
||||
event_stream=None, # type: ignore[arg-type]
|
||||
sid=sid,
|
||||
plugins=agent.sandbox_plugins,
|
||||
plugins=plugins,
|
||||
# env_vars=env_vars,
|
||||
# status_callback: Callable[..., None] | None = None,
|
||||
attach_to_existing=False,
|
||||
|
||||
@ -17,10 +17,13 @@ from openhands.core.logger import openhands_logger as logger
|
||||
class ApiKeyStore:
|
||||
session_maker: sessionmaker
|
||||
|
||||
API_KEY_PREFIX = 'sk-oh-'
|
||||
|
||||
def generate_api_key(self, length: int = 32) -> str:
|
||||
"""Generate a random API key."""
|
||||
"""Generate a random API key with the sk-oh- prefix."""
|
||||
alphabet = string.ascii_letters + string.digits
|
||||
return ''.join(secrets.choice(alphabet) for _ in range(length))
|
||||
random_part = ''.join(secrets.choice(alphabet) for _ in range(length))
|
||||
return f'{self.API_KEY_PREFIX}{random_part}'
|
||||
|
||||
def create_api_key(
|
||||
self, user_id: str, name: str | None = None, expires_at: datetime | None = None
|
||||
@ -57,9 +60,15 @@ class ApiKeyStore:
|
||||
return None
|
||||
|
||||
# Check if the key has expired
|
||||
if key_record.expires_at and key_record.expires_at < now:
|
||||
logger.info(f'API key has expired: {key_record.id}')
|
||||
return None
|
||||
if key_record.expires_at:
|
||||
# Handle timezone-naive datetime from database by assuming it's UTC
|
||||
expires_at = key_record.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=UTC)
|
||||
|
||||
if expires_at < now:
|
||||
logger.info(f'API key has expired: {key_record.id}')
|
||||
return None
|
||||
|
||||
# Update last_used_at timestamp
|
||||
session.execute(
|
||||
@ -125,6 +134,33 @@ class ApiKeyStore:
|
||||
|
||||
return None
|
||||
|
||||
def retrieve_api_key_by_name(self, user_id: str, name: str) -> str | None:
|
||||
"""Retrieve an API key by name for a specific user."""
|
||||
with self.session_maker() as session:
|
||||
key_record = (
|
||||
session.query(ApiKey)
|
||||
.filter(ApiKey.user_id == user_id, ApiKey.name == name)
|
||||
.first()
|
||||
)
|
||||
return key_record.key if key_record else None
|
||||
|
||||
def delete_api_key_by_name(self, user_id: str, name: str) -> bool:
|
||||
"""Delete an API key by name for a specific user."""
|
||||
with self.session_maker() as session:
|
||||
key_record = (
|
||||
session.query(ApiKey)
|
||||
.filter(ApiKey.user_id == user_id, ApiKey.name == name)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not key_record:
|
||||
return False
|
||||
|
||||
session.delete(key_record)
|
||||
session.commit()
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> ApiKeyStore:
|
||||
"""Get an instance of the ApiKeyStore."""
|
||||
|
||||
109
enterprise/storage/device_code.py
Normal file
109
enterprise/storage/device_code.py
Normal file
@ -0,0 +1,109 @@
|
||||
"""Device code storage model for OAuth 2.0 Device Flow."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
|
||||
from sqlalchemy import Column, DateTime, Integer, String
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
class DeviceCodeStatus(Enum):
|
||||
"""Status of a device code authorization request."""
|
||||
|
||||
PENDING = 'pending'
|
||||
AUTHORIZED = 'authorized'
|
||||
EXPIRED = 'expired'
|
||||
DENIED = 'denied'
|
||||
|
||||
|
||||
class DeviceCode(Base):
|
||||
"""Device code for OAuth 2.0 Device Flow.
|
||||
|
||||
This stores the device codes issued during the device authorization flow,
|
||||
along with their status and associated user information once authorized.
|
||||
"""
|
||||
|
||||
__tablename__ = 'device_codes'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
device_code = Column(String(128), unique=True, nullable=False, index=True)
|
||||
user_code = Column(String(16), unique=True, nullable=False, index=True)
|
||||
status = Column(String(32), nullable=False, default=DeviceCodeStatus.PENDING.value)
|
||||
|
||||
# Keycloak user ID who authorized the device (set during verification)
|
||||
keycloak_user_id = Column(String(255), nullable=True)
|
||||
|
||||
# Timestamps
|
||||
expires_at = Column(DateTime(timezone=True), nullable=False)
|
||||
authorized_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Rate limiting fields for RFC 8628 section 3.5 compliance
|
||||
last_poll_time = Column(DateTime(timezone=True), nullable=True)
|
||||
current_interval = Column(Integer, nullable=False, default=5)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<DeviceCode(user_code='{self.user_code}', status='{self.status}')>"
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if the device code has expired."""
|
||||
now = datetime.now(timezone.utc)
|
||||
return now > self.expires_at
|
||||
|
||||
def is_pending(self) -> bool:
|
||||
"""Check if the device code is still pending authorization."""
|
||||
return self.status == DeviceCodeStatus.PENDING.value and not self.is_expired()
|
||||
|
||||
def is_authorized(self) -> bool:
|
||||
"""Check if the device code has been authorized."""
|
||||
return self.status == DeviceCodeStatus.AUTHORIZED.value
|
||||
|
||||
def authorize(self, user_id: str) -> None:
|
||||
"""Mark the device code as authorized."""
|
||||
self.status = DeviceCodeStatus.AUTHORIZED.value
|
||||
self.keycloak_user_id = user_id # Set the Keycloak user ID during authorization
|
||||
self.authorized_at = datetime.now(timezone.utc)
|
||||
|
||||
def deny(self) -> None:
|
||||
"""Mark the device code as denied."""
|
||||
self.status = DeviceCodeStatus.DENIED.value
|
||||
|
||||
def expire(self) -> None:
|
||||
"""Mark the device code as expired."""
|
||||
self.status = DeviceCodeStatus.EXPIRED.value
|
||||
|
||||
def check_rate_limit(self) -> tuple[bool, int]:
|
||||
"""Check if the client is polling too fast.
|
||||
|
||||
Returns:
|
||||
tuple: (is_too_fast, current_interval)
|
||||
- is_too_fast: True if client should receive slow_down error
|
||||
- current_interval: Current polling interval to use
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# If this is the first poll, allow it
|
||||
if self.last_poll_time is None:
|
||||
return False, self.current_interval
|
||||
|
||||
# Calculate time since last poll
|
||||
time_since_last_poll = (now - self.last_poll_time).total_seconds()
|
||||
|
||||
# Check if polling too fast
|
||||
if time_since_last_poll < self.current_interval:
|
||||
# Increase interval for slow_down (RFC 8628 section 3.5)
|
||||
new_interval = min(self.current_interval + 5, 60) # Cap at 60 seconds
|
||||
return True, new_interval
|
||||
|
||||
return False, self.current_interval
|
||||
|
||||
def update_poll_time(self, increase_interval: bool = False) -> None:
|
||||
"""Update the last poll time and optionally increase the interval.
|
||||
|
||||
Args:
|
||||
increase_interval: If True, increase the current interval for slow_down
|
||||
"""
|
||||
self.last_poll_time = datetime.now(timezone.utc)
|
||||
|
||||
if increase_interval:
|
||||
# Increase interval by 5 seconds, cap at 60 seconds (RFC 8628)
|
||||
self.current_interval = min(self.current_interval + 5, 60)
|
||||
167
enterprise/storage/device_code_store.py
Normal file
167
enterprise/storage/device_code_store.py
Normal file
@ -0,0 +1,167 @@
|
||||
"""Device code store for OAuth 2.0 Device Flow."""
|
||||
|
||||
import secrets
|
||||
import string
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from storage.device_code import DeviceCode
|
||||
|
||||
|
||||
class DeviceCodeStore:
|
||||
"""Store for managing OAuth 2.0 device codes."""
|
||||
|
||||
def __init__(self, session_maker):
|
||||
self.session_maker = session_maker
|
||||
|
||||
def generate_user_code(self) -> str:
|
||||
"""Generate a human-readable user code (8 characters, uppercase letters and digits)."""
|
||||
# Use a mix of uppercase letters and digits, avoiding confusing characters
|
||||
alphabet = 'ABCDEFGHJKLMNPQRSTUVWXYZ23456789' # No I, O, 0, 1
|
||||
return ''.join(secrets.choice(alphabet) for _ in range(8))
|
||||
|
||||
def generate_device_code(self) -> str:
|
||||
"""Generate a secure device code (128 characters)."""
|
||||
alphabet = string.ascii_letters + string.digits
|
||||
return ''.join(secrets.choice(alphabet) for _ in range(128))
|
||||
|
||||
def create_device_code(
|
||||
self,
|
||||
expires_in: int = 600, # 10 minutes default
|
||||
max_attempts: int = 10,
|
||||
) -> DeviceCode:
|
||||
"""Create a new device code entry.
|
||||
|
||||
Uses database constraints to ensure uniqueness, avoiding TOCTOU race conditions.
|
||||
Retries on constraint violations until unique codes are generated.
|
||||
|
||||
Args:
|
||||
expires_in: Expiration time in seconds
|
||||
max_attempts: Maximum number of attempts to generate unique codes
|
||||
|
||||
Returns:
|
||||
The created DeviceCode instance
|
||||
|
||||
Raises:
|
||||
RuntimeError: If unable to generate unique codes after max_attempts
|
||||
"""
|
||||
for attempt in range(max_attempts):
|
||||
user_code = self.generate_user_code()
|
||||
device_code = self.generate_device_code()
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in)
|
||||
|
||||
device_code_entry = DeviceCode(
|
||||
device_code=device_code,
|
||||
user_code=user_code,
|
||||
keycloak_user_id=None, # Will be set during authorization
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
try:
|
||||
with self.session_maker() as session:
|
||||
session.add(device_code_entry)
|
||||
session.commit()
|
||||
session.refresh(device_code_entry)
|
||||
session.expunge(device_code_entry) # Detach from session cleanly
|
||||
return device_code_entry
|
||||
except IntegrityError:
|
||||
# Constraint violation - codes already exist, retry with new codes
|
||||
continue
|
||||
|
||||
raise RuntimeError(
|
||||
f'Failed to generate unique device codes after {max_attempts} attempts'
|
||||
)
|
||||
|
||||
def get_by_device_code(self, device_code: str) -> DeviceCode | None:
|
||||
"""Get device code entry by device code."""
|
||||
with self.session_maker() as session:
|
||||
result = (
|
||||
session.query(DeviceCode).filter_by(device_code=device_code).first()
|
||||
)
|
||||
if result:
|
||||
session.expunge(result) # Detach from session cleanly
|
||||
return result
|
||||
|
||||
def get_by_user_code(self, user_code: str) -> DeviceCode | None:
|
||||
"""Get device code entry by user code."""
|
||||
with self.session_maker() as session:
|
||||
result = session.query(DeviceCode).filter_by(user_code=user_code).first()
|
||||
if result:
|
||||
session.expunge(result) # Detach from session cleanly
|
||||
return result
|
||||
|
||||
def authorize_device_code(self, user_code: str, user_id: str) -> bool:
|
||||
"""Authorize a device code.
|
||||
|
||||
Args:
|
||||
user_code: The user code to authorize
|
||||
user_id: The user ID from Keycloak
|
||||
|
||||
Returns:
|
||||
True if authorization was successful, False otherwise
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
device_code_entry = (
|
||||
session.query(DeviceCode).filter_by(user_code=user_code).first()
|
||||
)
|
||||
|
||||
if not device_code_entry:
|
||||
return False
|
||||
|
||||
if not device_code_entry.is_pending():
|
||||
return False
|
||||
|
||||
device_code_entry.authorize(user_id)
|
||||
session.commit()
|
||||
|
||||
return True
|
||||
|
||||
def deny_device_code(self, user_code: str) -> bool:
|
||||
"""Deny a device code authorization.
|
||||
|
||||
Args:
|
||||
user_code: The user code to deny
|
||||
|
||||
Returns:
|
||||
True if denial was successful, False otherwise
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
device_code_entry = (
|
||||
session.query(DeviceCode).filter_by(user_code=user_code).first()
|
||||
)
|
||||
|
||||
if not device_code_entry:
|
||||
return False
|
||||
|
||||
if not device_code_entry.is_pending():
|
||||
return False
|
||||
|
||||
device_code_entry.deny()
|
||||
session.commit()
|
||||
|
||||
return True
|
||||
|
||||
def update_poll_time(
|
||||
self, device_code: str, increase_interval: bool = False
|
||||
) -> bool:
|
||||
"""Update the poll time for a device code and optionally increase interval.
|
||||
|
||||
Args:
|
||||
device_code: The device code to update
|
||||
increase_interval: If True, increase the polling interval for slow_down
|
||||
|
||||
Returns:
|
||||
True if update was successful, False otherwise
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
device_code_entry = (
|
||||
session.query(DeviceCode).filter_by(device_code=device_code).first()
|
||||
)
|
||||
|
||||
if not device_code_entry:
|
||||
return False
|
||||
|
||||
device_code_entry.update_poll_time(increase_interval)
|
||||
session.commit()
|
||||
|
||||
return True
|
||||
@ -94,6 +94,7 @@ class SaasSettingsStore(SettingsStore):
|
||||
}
|
||||
self._decrypt_kwargs(kwargs)
|
||||
settings = Settings(**kwargs)
|
||||
|
||||
return settings
|
||||
|
||||
async def store(self, item: Settings):
|
||||
|
||||
@ -12,6 +12,7 @@ from storage.base import Base
|
||||
# Anything not loaded here may not have a table created for it.
|
||||
from storage.billing_session import BillingSession
|
||||
from storage.conversation_work import ConversationWork
|
||||
from storage.device_code import DeviceCode # noqa: F401
|
||||
from storage.feedback import Feedback
|
||||
from storage.github_app_installation import GithubAppInstallation
|
||||
from storage.maintenance_task import MaintenanceTask, MaintenanceTaskStatus
|
||||
|
||||
133
enterprise/tests/unit/integrations/test_resolver_context.py
Normal file
133
enterprise/tests/unit/integrations/test_resolver_context.py
Normal file
@ -0,0 +1,133 @@
|
||||
"""Test for ResolverUserContext get_secrets conversion logic.
|
||||
|
||||
This test focuses on testing the actual ResolverUserContext implementation.
|
||||
"""
|
||||
|
||||
from types import MappingProxyType
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from enterprise.integrations.resolver_context import ResolverUserContext
|
||||
|
||||
# Import the real classes we want to test
|
||||
from openhands.integrations.provider import CustomSecret
|
||||
|
||||
# Import the SDK types we need for testing
|
||||
from openhands.sdk.secret import SecretSource, StaticSecret
|
||||
from openhands.storage.data_models.secrets import Secrets
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_saas_user_auth():
|
||||
"""Mock SaasUserAuth for testing."""
|
||||
return AsyncMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def resolver_context(mock_saas_user_auth):
|
||||
"""Create a ResolverUserContext instance for testing."""
|
||||
return ResolverUserContext(saas_user_auth=mock_saas_user_auth)
|
||||
|
||||
|
||||
def create_custom_secret(value: str, description: str = 'Test secret') -> CustomSecret:
|
||||
"""Helper to create CustomSecret instances."""
|
||||
return CustomSecret(secret=SecretStr(value), description=description)
|
||||
|
||||
|
||||
def create_secrets(custom_secrets_dict: dict[str, CustomSecret]) -> Secrets:
|
||||
"""Helper to create Secrets instances."""
|
||||
return Secrets(custom_secrets=MappingProxyType(custom_secrets_dict))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_secrets_converts_custom_to_static(
|
||||
resolver_context, mock_saas_user_auth
|
||||
):
|
||||
"""Test that get_secrets correctly converts CustomSecret objects to StaticSecret objects."""
|
||||
# Arrange
|
||||
secrets = create_secrets(
|
||||
{
|
||||
'TEST_SECRET_1': create_custom_secret('secret_value_1'),
|
||||
'TEST_SECRET_2': create_custom_secret('secret_value_2'),
|
||||
}
|
||||
)
|
||||
mock_saas_user_auth.get_secrets.return_value = secrets
|
||||
|
||||
# Act
|
||||
result = await resolver_context.get_secrets()
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
assert all(isinstance(secret, StaticSecret) for secret in result.values())
|
||||
assert result['TEST_SECRET_1'].value.get_secret_value() == 'secret_value_1'
|
||||
assert result['TEST_SECRET_2'].value.get_secret_value() == 'secret_value_2'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_secrets_with_special_characters(
|
||||
resolver_context, mock_saas_user_auth
|
||||
):
|
||||
"""Test that secret values with special characters are preserved during conversion."""
|
||||
# Arrange
|
||||
special_value = 'very_secret_password_123!@#$%^&*()'
|
||||
secrets = create_secrets({'SPECIAL_SECRET': create_custom_secret(special_value)})
|
||||
mock_saas_user_auth.get_secrets.return_value = secrets
|
||||
|
||||
# Act
|
||||
result = await resolver_context.get_secrets()
|
||||
|
||||
# Assert
|
||||
assert len(result) == 1
|
||||
assert isinstance(result['SPECIAL_SECRET'], StaticSecret)
|
||||
assert result['SPECIAL_SECRET'].value.get_secret_value() == special_value
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'secrets_input,expected_result',
|
||||
[
|
||||
(None, {}), # No secrets available
|
||||
(create_secrets({}), {}), # Empty custom secrets
|
||||
],
|
||||
)
|
||||
async def test_get_secrets_empty_cases(
|
||||
resolver_context, mock_saas_user_auth, secrets_input, expected_result
|
||||
):
|
||||
"""Test that get_secrets handles empty cases correctly."""
|
||||
# Arrange
|
||||
mock_saas_user_auth.get_secrets.return_value = secrets_input
|
||||
|
||||
# Act
|
||||
result = await resolver_context.get_secrets()
|
||||
|
||||
# Assert
|
||||
assert result == expected_result
|
||||
|
||||
|
||||
def test_static_secret_is_valid_secret_source():
|
||||
"""Test that StaticSecret is a valid SecretSource for SDK validation."""
|
||||
# Arrange & Act
|
||||
static_secret = StaticSecret(value='test_secret_123')
|
||||
|
||||
# Assert
|
||||
assert isinstance(static_secret, StaticSecret)
|
||||
assert isinstance(static_secret, SecretSource)
|
||||
assert static_secret.value.get_secret_value() == 'test_secret_123'
|
||||
|
||||
|
||||
def test_custom_to_static_conversion():
|
||||
"""Test the complete conversion flow from CustomSecret to StaticSecret."""
|
||||
# Arrange
|
||||
secret_value = 'conversion_test_secret'
|
||||
custom_secret = create_custom_secret(secret_value, 'Conversion test')
|
||||
|
||||
# Act - simulate the conversion logic from the actual method
|
||||
extracted_value = custom_secret.secret.get_secret_value()
|
||||
static_secret = StaticSecret(value=extracted_value)
|
||||
|
||||
# Assert
|
||||
assert isinstance(static_secret, StaticSecret)
|
||||
assert isinstance(static_secret, SecretSource)
|
||||
assert static_secret.value.get_secret_value() == secret_value
|
||||
610
enterprise/tests/unit/server/routes/test_oauth_device.py
Normal file
610
enterprise/tests/unit/server/routes/test_oauth_device.py
Normal file
@ -0,0 +1,610 @@
|
||||
"""Unit tests for OAuth2 Device Flow endpoints."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from server.routes.oauth_device import (
|
||||
device_authorization,
|
||||
device_token,
|
||||
device_verification_authenticated,
|
||||
)
|
||||
from storage.device_code import DeviceCode
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_device_code_store():
|
||||
"""Mock device code store."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_api_key_store():
|
||||
"""Mock API key store."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_token_manager():
|
||||
"""Mock token manager."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request():
|
||||
"""Mock FastAPI request."""
|
||||
request = MagicMock(spec=Request)
|
||||
request.base_url = 'https://test.example.com/'
|
||||
return request
|
||||
|
||||
|
||||
class TestDeviceAuthorization:
|
||||
"""Test device authorization endpoint."""
|
||||
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_device_authorization_success(self, mock_store, mock_request):
|
||||
"""Test successful device authorization."""
|
||||
mock_device = DeviceCode(
|
||||
device_code='test-device-code-123',
|
||||
user_code='ABC12345',
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||
current_interval=5, # Default interval
|
||||
)
|
||||
mock_store.create_device_code.return_value = mock_device
|
||||
|
||||
result = await device_authorization(mock_request)
|
||||
|
||||
assert result.device_code == 'test-device-code-123'
|
||||
assert result.user_code == 'ABC12345'
|
||||
assert result.expires_in == 600
|
||||
assert result.interval == 5 # Should match device's current_interval
|
||||
assert 'verify' in result.verification_uri
|
||||
assert 'ABC12345' in result.verification_uri_complete
|
||||
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_device_authorization_with_increased_interval(
|
||||
self, mock_store, mock_request
|
||||
):
|
||||
"""Test device authorization returns increased interval from rate limiting."""
|
||||
mock_device = DeviceCode(
|
||||
device_code='test-device-code-456',
|
||||
user_code='XYZ98765',
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||
current_interval=15, # Increased interval from previous rate limiting
|
||||
)
|
||||
mock_store.create_device_code.return_value = mock_device
|
||||
|
||||
result = await device_authorization(mock_request)
|
||||
|
||||
assert result.device_code == 'test-device-code-456'
|
||||
assert result.user_code == 'XYZ98765'
|
||||
assert result.expires_in == 600
|
||||
assert result.interval == 15 # Should match device's increased current_interval
|
||||
assert 'verify' in result.verification_uri
|
||||
assert 'XYZ98765' in result.verification_uri_complete
|
||||
|
||||
|
||||
class TestDeviceToken:
|
||||
"""Test device token endpoint."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'device_exists,status,expected_error',
|
||||
[
|
||||
(False, None, 'invalid_grant'),
|
||||
(True, 'expired', 'expired_token'),
|
||||
(True, 'denied', 'access_denied'),
|
||||
(True, 'pending', 'authorization_pending'),
|
||||
],
|
||||
)
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_device_token_error_cases(
|
||||
self, mock_store, device_exists, status, expected_error
|
||||
):
|
||||
"""Test various error cases for device token endpoint."""
|
||||
device_code = 'test-device-code'
|
||||
|
||||
if device_exists:
|
||||
mock_device = MagicMock()
|
||||
mock_device.is_expired.return_value = status == 'expired'
|
||||
mock_device.status = status
|
||||
# Mock rate limiting - return False (not too fast) and default interval
|
||||
mock_device.check_rate_limit.return_value = (False, 5)
|
||||
mock_store.get_by_device_code.return_value = mock_device
|
||||
mock_store.update_poll_time.return_value = True
|
||||
else:
|
||||
mock_store.get_by_device_code.return_value = None
|
||||
|
||||
result = await device_token(device_code=device_code)
|
||||
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == 400
|
||||
# Check error in response content
|
||||
content = result.body.decode()
|
||||
assert expected_error in content
|
||||
|
||||
@patch('server.routes.oauth_device.ApiKeyStore')
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_device_token_success(self, mock_store, mock_api_key_class):
|
||||
"""Test successful device token retrieval."""
|
||||
device_code = 'test-device-code'
|
||||
|
||||
# Mock authorized device
|
||||
mock_device = MagicMock()
|
||||
mock_device.is_expired.return_value = False
|
||||
mock_device.status = 'authorized'
|
||||
mock_device.keycloak_user_id = 'user-123'
|
||||
mock_device.user_code = (
|
||||
'ABC12345' # Add user_code for device-specific API key lookup
|
||||
)
|
||||
# Mock rate limiting - return False (not too fast) and default interval
|
||||
mock_device.check_rate_limit.return_value = (False, 5)
|
||||
mock_store.get_by_device_code.return_value = mock_device
|
||||
mock_store.update_poll_time.return_value = True
|
||||
|
||||
# Mock API key retrieval
|
||||
mock_api_key_store = MagicMock()
|
||||
mock_api_key_store.retrieve_api_key_by_name.return_value = 'test-api-key'
|
||||
mock_api_key_class.get_instance.return_value = mock_api_key_store
|
||||
|
||||
result = await device_token(device_code=device_code)
|
||||
|
||||
# Check that result is a DeviceTokenResponse
|
||||
assert result.access_token == 'test-api-key'
|
||||
assert result.token_type == 'Bearer'
|
||||
|
||||
# Verify that the correct device-specific API key name was used
|
||||
mock_api_key_store.retrieve_api_key_by_name.assert_called_once_with(
|
||||
'user-123', 'Device Link Access Key (ABC12345)'
|
||||
)
|
||||
|
||||
|
||||
class TestDeviceVerificationAuthenticated:
|
||||
"""Test device verification authenticated endpoint."""
|
||||
|
||||
async def test_verification_unauthenticated_user(self):
|
||||
"""Test verification with unauthenticated user."""
|
||||
with pytest.raises(HTTPException):
|
||||
await device_verification_authenticated(user_code='ABC12345', user_id=None)
|
||||
|
||||
@patch('server.routes.oauth_device.ApiKeyStore')
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_verification_invalid_device_code(
|
||||
self, mock_store, mock_api_key_class
|
||||
):
|
||||
"""Test verification with invalid device code."""
|
||||
mock_store.get_by_user_code.return_value = None
|
||||
|
||||
with pytest.raises(HTTPException):
|
||||
await device_verification_authenticated(
|
||||
user_code='INVALID', user_id='user-123'
|
||||
)
|
||||
|
||||
@patch('server.routes.oauth_device.ApiKeyStore')
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_verification_already_processed(self, mock_store, mock_api_key_class):
|
||||
"""Test verification with already processed device code."""
|
||||
mock_device = MagicMock()
|
||||
mock_device.is_pending.return_value = False
|
||||
mock_store.get_by_user_code.return_value = mock_device
|
||||
|
||||
with pytest.raises(HTTPException):
|
||||
await device_verification_authenticated(
|
||||
user_code='ABC12345', user_id='user-123'
|
||||
)
|
||||
|
||||
@patch('server.routes.oauth_device.ApiKeyStore')
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_verification_success(self, mock_store, mock_api_key_class):
|
||||
"""Test successful device verification."""
|
||||
# Mock device code
|
||||
mock_device = MagicMock()
|
||||
mock_device.is_pending.return_value = True
|
||||
mock_store.get_by_user_code.return_value = mock_device
|
||||
mock_store.authorize_device_code.return_value = True
|
||||
|
||||
# Mock API key store
|
||||
mock_api_key_store = MagicMock()
|
||||
mock_api_key_class.get_instance.return_value = mock_api_key_store
|
||||
|
||||
result = await device_verification_authenticated(
|
||||
user_code='ABC12345', user_id='user-123'
|
||||
)
|
||||
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == 200
|
||||
# Should NOT delete existing API keys (multiple devices allowed)
|
||||
mock_api_key_store.delete_api_key_by_name.assert_not_called()
|
||||
# Should create a new API key with device-specific name
|
||||
mock_api_key_store.create_api_key.assert_called_once()
|
||||
call_args = mock_api_key_store.create_api_key.call_args
|
||||
assert call_args[1]['name'] == 'Device Link Access Key (ABC12345)'
|
||||
mock_store.authorize_device_code.assert_called_once_with(
|
||||
user_code='ABC12345', user_id='user-123'
|
||||
)
|
||||
|
||||
@patch('server.routes.oauth_device.ApiKeyStore')
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_multiple_device_authentication(self, mock_store, mock_api_key_class):
|
||||
"""Test that multiple devices can authenticate simultaneously."""
|
||||
# Mock API key store
|
||||
mock_api_key_store = MagicMock()
|
||||
mock_api_key_class.get_instance.return_value = mock_api_key_store
|
||||
|
||||
# Simulate two different devices
|
||||
device1_code = 'ABC12345'
|
||||
device2_code = 'XYZ67890'
|
||||
user_id = 'user-123'
|
||||
|
||||
# Mock device codes
|
||||
mock_device1 = MagicMock()
|
||||
mock_device1.is_pending.return_value = True
|
||||
mock_device2 = MagicMock()
|
||||
mock_device2.is_pending.return_value = True
|
||||
|
||||
# Configure mock store to return appropriate device for each user_code
|
||||
def get_by_user_code_side_effect(user_code):
|
||||
if user_code == device1_code:
|
||||
return mock_device1
|
||||
elif user_code == device2_code:
|
||||
return mock_device2
|
||||
return None
|
||||
|
||||
mock_store.get_by_user_code.side_effect = get_by_user_code_side_effect
|
||||
mock_store.authorize_device_code.return_value = True
|
||||
|
||||
# Authenticate first device
|
||||
result1 = await device_verification_authenticated(
|
||||
user_code=device1_code, user_id=user_id
|
||||
)
|
||||
|
||||
# Authenticate second device
|
||||
result2 = await device_verification_authenticated(
|
||||
user_code=device2_code, user_id=user_id
|
||||
)
|
||||
|
||||
# Both should succeed
|
||||
assert isinstance(result1, JSONResponse)
|
||||
assert result1.status_code == 200
|
||||
assert isinstance(result2, JSONResponse)
|
||||
assert result2.status_code == 200
|
||||
|
||||
# Should create two separate API keys with different names
|
||||
assert mock_api_key_store.create_api_key.call_count == 2
|
||||
|
||||
# Check that each device got a unique API key name
|
||||
call_args_list = mock_api_key_store.create_api_key.call_args_list
|
||||
device1_name = call_args_list[0][1]['name']
|
||||
device2_name = call_args_list[1][1]['name']
|
||||
|
||||
assert device1_name == f'Device Link Access Key ({device1_code})'
|
||||
assert device2_name == f'Device Link Access Key ({device2_code})'
|
||||
assert device1_name != device2_name # Ensure they're different
|
||||
|
||||
# Should NOT delete any existing API keys
|
||||
mock_api_key_store.delete_api_key_by_name.assert_not_called()
|
||||
|
||||
|
||||
class TestDeviceTokenRateLimiting:
|
||||
"""Test rate limiting for device token polling (RFC 8628 section 3.5)."""
|
||||
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_first_poll_allowed(self, mock_store):
|
||||
"""Test that the first poll is always allowed."""
|
||||
# Create a device code with no previous poll time
|
||||
mock_device = DeviceCode(
|
||||
device_code='test_device_code',
|
||||
user_code='ABC123',
|
||||
status='pending',
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||
last_poll_time=None, # First poll
|
||||
current_interval=5,
|
||||
)
|
||||
mock_store.get_by_device_code.return_value = mock_device
|
||||
mock_store.update_poll_time.return_value = True
|
||||
|
||||
device_code = 'test_device_code'
|
||||
result = await device_token(device_code=device_code)
|
||||
|
||||
# Should return authorization_pending, not slow_down
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == 400
|
||||
content = result.body.decode()
|
||||
assert 'authorization_pending' in content
|
||||
assert 'slow_down' not in content
|
||||
|
||||
# Should update poll time without increasing interval
|
||||
mock_store.update_poll_time.assert_called_with(
|
||||
'test_device_code', increase_interval=False
|
||||
)
|
||||
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_normal_polling_allowed(self, mock_store):
|
||||
"""Test that normal polling (respecting interval) is allowed."""
|
||||
# Create a device code with last poll time 6 seconds ago (interval is 5)
|
||||
last_poll = datetime.now(UTC) - timedelta(seconds=6)
|
||||
mock_device = DeviceCode(
|
||||
device_code='test_device_code',
|
||||
user_code='ABC123',
|
||||
status='pending',
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||
last_poll_time=last_poll,
|
||||
current_interval=5,
|
||||
)
|
||||
mock_store.get_by_device_code.return_value = mock_device
|
||||
mock_store.update_poll_time.return_value = True
|
||||
|
||||
device_code = 'test_device_code'
|
||||
result = await device_token(device_code=device_code)
|
||||
|
||||
# Should return authorization_pending, not slow_down
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == 400
|
||||
content = result.body.decode()
|
||||
assert 'authorization_pending' in content
|
||||
assert 'slow_down' not in content
|
||||
|
||||
# Should update poll time without increasing interval
|
||||
mock_store.update_poll_time.assert_called_with(
|
||||
'test_device_code', increase_interval=False
|
||||
)
|
||||
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_fast_polling_returns_slow_down(self, mock_store):
|
||||
"""Test that polling too fast returns slow_down error."""
|
||||
# Create a device code with last poll time 2 seconds ago (interval is 5)
|
||||
last_poll = datetime.now(UTC) - timedelta(seconds=2)
|
||||
mock_device = DeviceCode(
|
||||
device_code='test_device_code',
|
||||
user_code='ABC123',
|
||||
status='pending',
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||
last_poll_time=last_poll,
|
||||
current_interval=5,
|
||||
)
|
||||
mock_store.get_by_device_code.return_value = mock_device
|
||||
mock_store.update_poll_time.return_value = True
|
||||
|
||||
device_code = 'test_device_code'
|
||||
result = await device_token(device_code=device_code)
|
||||
|
||||
# Should return slow_down error
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == 400
|
||||
content = result.body.decode()
|
||||
assert 'slow_down' in content
|
||||
assert 'interval' in content
|
||||
assert '10' in content # New interval should be 5 + 5 = 10
|
||||
|
||||
# Should update poll time and increase interval
|
||||
mock_store.update_poll_time.assert_called_with(
|
||||
'test_device_code', increase_interval=True
|
||||
)
|
||||
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_interval_increases_with_repeated_fast_polling(self, mock_store):
|
||||
"""Test that interval increases with repeated fast polling."""
|
||||
# Create a device code with higher current interval from previous slow_down
|
||||
last_poll = datetime.now(UTC) - timedelta(seconds=5) # 5 seconds ago
|
||||
mock_device = DeviceCode(
|
||||
device_code='test_device_code',
|
||||
user_code='ABC123',
|
||||
status='pending',
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||
last_poll_time=last_poll,
|
||||
current_interval=15, # Already increased from previous slow_down
|
||||
)
|
||||
mock_store.get_by_device_code.return_value = mock_device
|
||||
mock_store.update_poll_time.return_value = True
|
||||
|
||||
device_code = 'test_device_code'
|
||||
result = await device_token(device_code=device_code)
|
||||
|
||||
# Should return slow_down error with increased interval
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == 400
|
||||
content = result.body.decode()
|
||||
assert 'slow_down' in content
|
||||
assert '20' in content # New interval should be 15 + 5 = 20
|
||||
|
||||
# Should update poll time and increase interval
|
||||
mock_store.update_poll_time.assert_called_with(
|
||||
'test_device_code', increase_interval=True
|
||||
)
|
||||
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_interval_caps_at_maximum(self, mock_store):
|
||||
"""Test that interval is capped at maximum value."""
|
||||
# Create a device code with interval near maximum
|
||||
last_poll = datetime.now(UTC) - timedelta(seconds=30)
|
||||
mock_device = DeviceCode(
|
||||
device_code='test_device_code',
|
||||
user_code='ABC123',
|
||||
status='pending',
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||
last_poll_time=last_poll,
|
||||
current_interval=58, # Near maximum of 60
|
||||
)
|
||||
mock_store.get_by_device_code.return_value = mock_device
|
||||
mock_store.update_poll_time.return_value = True
|
||||
|
||||
device_code = 'test_device_code'
|
||||
result = await device_token(device_code=device_code)
|
||||
|
||||
# Should return slow_down error with capped interval
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == 400
|
||||
content = result.body.decode()
|
||||
assert 'slow_down' in content
|
||||
assert '60' in content # Should be capped at 60, not 63
|
||||
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_rate_limiting_with_authorized_device(self, mock_store):
|
||||
"""Test that rate limiting still applies to authorized devices."""
|
||||
# Create an authorized device code with recent poll
|
||||
last_poll = datetime.now(UTC) - timedelta(seconds=2)
|
||||
mock_device = DeviceCode(
|
||||
device_code='test_device_code',
|
||||
user_code='ABC123',
|
||||
status='authorized', # Device is authorized
|
||||
keycloak_user_id='user123',
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||
last_poll_time=last_poll,
|
||||
current_interval=5,
|
||||
)
|
||||
mock_store.get_by_device_code.return_value = mock_device
|
||||
mock_store.update_poll_time.return_value = True
|
||||
|
||||
device_code = 'test_device_code'
|
||||
result = await device_token(device_code=device_code)
|
||||
|
||||
# Should still return slow_down error even for authorized device
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == 400
|
||||
content = result.body.decode()
|
||||
assert 'slow_down' in content
|
||||
|
||||
# Should update poll time and increase interval
|
||||
mock_store.update_poll_time.assert_called_with(
|
||||
'test_device_code', increase_interval=True
|
||||
)
|
||||
|
||||
|
||||
class TestDeviceVerificationTransactionIntegrity:
|
||||
"""Test transaction integrity for device verification to prevent orphaned API keys."""
|
||||
|
||||
@patch('server.routes.oauth_device.ApiKeyStore')
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_authorization_failure_prevents_api_key_creation(
|
||||
self, mock_store, mock_api_key_class
|
||||
):
|
||||
"""Test that if device authorization fails, no API key is created."""
|
||||
# Mock device code
|
||||
mock_device = MagicMock()
|
||||
mock_device.is_pending.return_value = True
|
||||
mock_store.get_by_user_code.return_value = mock_device
|
||||
mock_store.authorize_device_code.return_value = False # Authorization fails
|
||||
|
||||
# Mock API key store
|
||||
mock_api_key_store = MagicMock()
|
||||
mock_api_key_class.get_instance.return_value = mock_api_key_store
|
||||
|
||||
# Should raise HTTPException due to authorization failure
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await device_verification_authenticated(
|
||||
user_code='ABC12345', user_id='user-123'
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert 'Failed to authorize the device' in exc_info.value.detail
|
||||
|
||||
# API key should NOT be created since authorization failed
|
||||
mock_api_key_store.create_api_key.assert_not_called()
|
||||
mock_store.authorize_device_code.assert_called_once_with(
|
||||
user_code='ABC12345', user_id='user-123'
|
||||
)
|
||||
|
||||
@patch('server.routes.oauth_device.ApiKeyStore')
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_api_key_creation_failure_reverts_authorization(
|
||||
self, mock_store, mock_api_key_class
|
||||
):
|
||||
"""Test that if API key creation fails after authorization, the authorization is reverted."""
|
||||
# Mock device code
|
||||
mock_device = MagicMock()
|
||||
mock_device.is_pending.return_value = True
|
||||
mock_store.get_by_user_code.return_value = mock_device
|
||||
mock_store.authorize_device_code.return_value = True # Authorization succeeds
|
||||
mock_store.deny_device_code.return_value = True # Cleanup succeeds
|
||||
|
||||
# Mock API key store to fail on creation
|
||||
mock_api_key_store = MagicMock()
|
||||
mock_api_key_store.create_api_key.side_effect = Exception('Database error')
|
||||
mock_api_key_class.get_instance.return_value = mock_api_key_store
|
||||
|
||||
# Should raise HTTPException due to API key creation failure
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await device_verification_authenticated(
|
||||
user_code='ABC12345', user_id='user-123'
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert 'Failed to create API key for device access' in exc_info.value.detail
|
||||
|
||||
# Authorization should have been attempted first
|
||||
mock_store.authorize_device_code.assert_called_once_with(
|
||||
user_code='ABC12345', user_id='user-123'
|
||||
)
|
||||
|
||||
# API key creation should have been attempted after authorization
|
||||
mock_api_key_store.create_api_key.assert_called_once()
|
||||
|
||||
# Authorization should be reverted due to API key creation failure
|
||||
mock_store.deny_device_code.assert_called_once_with('ABC12345')
|
||||
|
||||
@patch('server.routes.oauth_device.ApiKeyStore')
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_api_key_creation_failure_cleanup_failure_logged(
|
||||
self, mock_store, mock_api_key_class
|
||||
):
|
||||
"""Test that cleanup failure is logged but doesn't prevent the main error from being raised."""
|
||||
# Mock device code
|
||||
mock_device = MagicMock()
|
||||
mock_device.is_pending.return_value = True
|
||||
mock_store.get_by_user_code.return_value = mock_device
|
||||
mock_store.authorize_device_code.return_value = True # Authorization succeeds
|
||||
mock_store.deny_device_code.side_effect = Exception(
|
||||
'Cleanup failed'
|
||||
) # Cleanup fails
|
||||
|
||||
# Mock API key store to fail on creation
|
||||
mock_api_key_store = MagicMock()
|
||||
mock_api_key_store.create_api_key.side_effect = Exception('Database error')
|
||||
mock_api_key_class.get_instance.return_value = mock_api_key_store
|
||||
|
||||
# Should still raise HTTPException for the original API key creation failure
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await device_verification_authenticated(
|
||||
user_code='ABC12345', user_id='user-123'
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert 'Failed to create API key for device access' in exc_info.value.detail
|
||||
|
||||
# Both operations should have been attempted
|
||||
mock_store.authorize_device_code.assert_called_once()
|
||||
mock_api_key_store.create_api_key.assert_called_once()
|
||||
mock_store.deny_device_code.assert_called_once_with('ABC12345')
|
||||
|
||||
@patch('server.routes.oauth_device.ApiKeyStore')
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_successful_flow_creates_api_key_after_authorization(
|
||||
self, mock_store, mock_api_key_class
|
||||
):
|
||||
"""Test that in the successful flow, API key is created only after authorization."""
|
||||
# Mock device code
|
||||
mock_device = MagicMock()
|
||||
mock_device.is_pending.return_value = True
|
||||
mock_store.get_by_user_code.return_value = mock_device
|
||||
mock_store.authorize_device_code.return_value = True # Authorization succeeds
|
||||
|
||||
# Mock API key store
|
||||
mock_api_key_store = MagicMock()
|
||||
mock_api_key_class.get_instance.return_value = mock_api_key_store
|
||||
|
||||
result = await device_verification_authenticated(
|
||||
user_code='ABC12345', user_id='user-123'
|
||||
)
|
||||
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == 200
|
||||
|
||||
# Verify the order: authorization first, then API key creation
|
||||
mock_store.authorize_device_code.assert_called_once_with(
|
||||
user_code='ABC12345', user_id='user-123'
|
||||
)
|
||||
mock_api_key_store.create_api_key.assert_called_once()
|
||||
|
||||
# No cleanup should be needed in successful case
|
||||
mock_store.deny_device_code.assert_not_called()
|
||||
83
enterprise/tests/unit/storage/test_device_code.py
Normal file
83
enterprise/tests/unit/storage/test_device_code.py
Normal file
@ -0,0 +1,83 @@
|
||||
"""Unit tests for DeviceCode model."""
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
from storage.device_code import DeviceCode, DeviceCodeStatus
|
||||
|
||||
|
||||
class TestDeviceCode:
|
||||
"""Test cases for DeviceCode model."""
|
||||
|
||||
@pytest.fixture
|
||||
def device_code(self):
|
||||
"""Create a test device code."""
|
||||
return DeviceCode(
|
||||
device_code='test-device-code-123',
|
||||
user_code='ABC12345',
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(minutes=10),
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'expires_delta,expected',
|
||||
[
|
||||
(timedelta(minutes=5), False), # Future expiry
|
||||
(timedelta(minutes=-5), True), # Past expiry
|
||||
(timedelta(seconds=1), False), # Just future (not expired)
|
||||
],
|
||||
)
|
||||
def test_is_expired(self, expires_delta, expected):
|
||||
"""Test expiration check with various time deltas."""
|
||||
device_code = DeviceCode(
|
||||
device_code='test-device-code',
|
||||
user_code='ABC12345',
|
||||
expires_at=datetime.now(timezone.utc) + expires_delta,
|
||||
)
|
||||
assert device_code.is_expired() == expected
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'status,expired,expected',
|
||||
[
|
||||
(DeviceCodeStatus.PENDING.value, False, True),
|
||||
(DeviceCodeStatus.PENDING.value, True, False),
|
||||
(DeviceCodeStatus.AUTHORIZED.value, False, False),
|
||||
(DeviceCodeStatus.DENIED.value, False, False),
|
||||
],
|
||||
)
|
||||
def test_is_pending(self, status, expired, expected):
|
||||
"""Test pending status check."""
|
||||
expires_at = (
|
||||
datetime.now(timezone.utc) - timedelta(minutes=1)
|
||||
if expired
|
||||
else datetime.now(timezone.utc) + timedelta(minutes=10)
|
||||
)
|
||||
device_code = DeviceCode(
|
||||
device_code='test-device-code',
|
||||
user_code='ABC12345',
|
||||
status=status,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
assert device_code.is_pending() == expected
|
||||
|
||||
def test_authorize(self, device_code):
|
||||
"""Test device authorization."""
|
||||
user_id = 'test-user-123'
|
||||
|
||||
device_code.authorize(user_id)
|
||||
|
||||
assert device_code.status == DeviceCodeStatus.AUTHORIZED.value
|
||||
assert device_code.keycloak_user_id == user_id
|
||||
assert device_code.authorized_at is not None
|
||||
assert isinstance(device_code.authorized_at, datetime)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'method,expected_status',
|
||||
[
|
||||
('deny', DeviceCodeStatus.DENIED.value),
|
||||
('expire', DeviceCodeStatus.EXPIRED.value),
|
||||
],
|
||||
)
|
||||
def test_status_changes(self, device_code, method, expected_status):
|
||||
"""Test status change methods."""
|
||||
getattr(device_code, method)()
|
||||
assert device_code.status == expected_status
|
||||
193
enterprise/tests/unit/storage/test_device_code_store.py
Normal file
193
enterprise/tests/unit/storage/test_device_code_store.py
Normal file
@ -0,0 +1,193 @@
|
||||
"""Unit tests for DeviceCodeStore."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from storage.device_code import DeviceCode
|
||||
from storage.device_code_store import DeviceCodeStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session():
|
||||
"""Mock database session."""
|
||||
session = MagicMock()
|
||||
return session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_maker(mock_session):
|
||||
"""Mock session maker."""
|
||||
session_maker = MagicMock()
|
||||
session_maker.return_value.__enter__.return_value = mock_session
|
||||
session_maker.return_value.__exit__.return_value = None
|
||||
return session_maker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def device_code_store(mock_session_maker):
|
||||
"""Create DeviceCodeStore instance."""
|
||||
return DeviceCodeStore(mock_session_maker)
|
||||
|
||||
|
||||
class TestDeviceCodeStore:
|
||||
"""Test cases for DeviceCodeStore."""
|
||||
|
||||
def test_generate_user_code(self, device_code_store):
|
||||
"""Test user code generation."""
|
||||
code = device_code_store.generate_user_code()
|
||||
|
||||
assert len(code) == 8
|
||||
assert code.isupper()
|
||||
# Should not contain confusing characters
|
||||
assert not any(char in code for char in 'IO01')
|
||||
|
||||
def test_generate_device_code(self, device_code_store):
|
||||
"""Test device code generation."""
|
||||
code = device_code_store.generate_device_code()
|
||||
|
||||
assert len(code) == 128
|
||||
assert code.isalnum()
|
||||
|
||||
def test_create_device_code_success(self, device_code_store, mock_session):
|
||||
"""Test successful device code creation."""
|
||||
# Mock successful creation (no IntegrityError)
|
||||
mock_device_code = MagicMock(spec=DeviceCode)
|
||||
mock_device_code.device_code = 'test-device-code-123'
|
||||
mock_device_code.user_code = 'TESTCODE'
|
||||
|
||||
# Mock the session to return our mock device code after refresh
|
||||
def mock_refresh(obj):
|
||||
obj.device_code = mock_device_code.device_code
|
||||
obj.user_code = mock_device_code.user_code
|
||||
|
||||
mock_session.refresh.side_effect = mock_refresh
|
||||
|
||||
result = device_code_store.create_device_code(expires_in=600)
|
||||
|
||||
assert isinstance(result, DeviceCode)
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
mock_session.refresh.assert_called_once()
|
||||
mock_session.expunge.assert_called_once()
|
||||
|
||||
def test_create_device_code_with_retries(
|
||||
self, device_code_store, mock_session_maker
|
||||
):
|
||||
"""Test device code creation with constraint violation retries."""
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session_maker.return_value.__exit__.return_value = None
|
||||
|
||||
# First attempt fails with IntegrityError, second succeeds
|
||||
mock_session.commit.side_effect = [IntegrityError('', '', ''), None]
|
||||
|
||||
mock_device_code = MagicMock(spec=DeviceCode)
|
||||
mock_device_code.device_code = 'test-device-code-456'
|
||||
mock_device_code.user_code = 'TESTCD2'
|
||||
|
||||
def mock_refresh(obj):
|
||||
obj.device_code = mock_device_code.device_code
|
||||
obj.user_code = mock_device_code.user_code
|
||||
|
||||
mock_session.refresh.side_effect = mock_refresh
|
||||
|
||||
store = DeviceCodeStore(mock_session_maker)
|
||||
result = store.create_device_code(expires_in=600)
|
||||
|
||||
assert isinstance(result, DeviceCode)
|
||||
assert mock_session.add.call_count == 2 # Two attempts
|
||||
assert mock_session.commit.call_count == 2 # Two attempts
|
||||
|
||||
def test_create_device_code_max_attempts_exceeded(
|
||||
self, device_code_store, mock_session_maker
|
||||
):
|
||||
"""Test device code creation failure after max attempts."""
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session_maker.return_value.__exit__.return_value = None
|
||||
|
||||
# All attempts fail with IntegrityError
|
||||
mock_session.commit.side_effect = IntegrityError('', '', '')
|
||||
|
||||
store = DeviceCodeStore(mock_session_maker)
|
||||
|
||||
with pytest.raises(
|
||||
RuntimeError,
|
||||
match='Failed to generate unique device codes after 3 attempts',
|
||||
):
|
||||
store.create_device_code(expires_in=600, max_attempts=3)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'lookup_method,lookup_field',
|
||||
[
|
||||
('get_by_device_code', 'device_code'),
|
||||
('get_by_user_code', 'user_code'),
|
||||
],
|
||||
)
|
||||
def test_lookup_methods(
|
||||
self, device_code_store, mock_session, lookup_method, lookup_field
|
||||
):
|
||||
"""Test device code lookup methods."""
|
||||
test_code = 'test-code-123'
|
||||
mock_device_code = MagicMock()
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = (
|
||||
mock_device_code
|
||||
)
|
||||
|
||||
result = getattr(device_code_store, lookup_method)(test_code)
|
||||
|
||||
assert result == mock_device_code
|
||||
mock_session.query.assert_called_once_with(DeviceCode)
|
||||
mock_session.query.return_value.filter_by.assert_called_once_with(
|
||||
**{lookup_field: test_code}
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'device_exists,is_pending,expected_result',
|
||||
[
|
||||
(True, True, True), # Success case
|
||||
(False, True, False), # Device not found
|
||||
(True, False, False), # Device not pending
|
||||
],
|
||||
)
|
||||
def test_authorize_device_code(
|
||||
self,
|
||||
device_code_store,
|
||||
mock_session,
|
||||
device_exists,
|
||||
is_pending,
|
||||
expected_result,
|
||||
):
|
||||
"""Test device code authorization."""
|
||||
user_code = 'ABC12345'
|
||||
user_id = 'test-user-123'
|
||||
|
||||
if device_exists:
|
||||
mock_device = MagicMock()
|
||||
mock_device.is_pending.return_value = is_pending
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_device
|
||||
else:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
result = device_code_store.authorize_device_code(user_code, user_id)
|
||||
|
||||
assert result == expected_result
|
||||
if expected_result:
|
||||
mock_device.authorize.assert_called_once_with(user_id)
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_deny_device_code(self, device_code_store, mock_session):
|
||||
"""Test device code denial."""
|
||||
user_code = 'ABC12345'
|
||||
mock_device = MagicMock()
|
||||
mock_device.is_pending.return_value = True
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = (
|
||||
mock_device
|
||||
)
|
||||
|
||||
result = device_code_store.deny_device_code(user_code)
|
||||
|
||||
assert result is True
|
||||
mock_device.deny.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
@ -25,10 +25,12 @@ def api_key_store(mock_session_maker):
|
||||
|
||||
|
||||
def test_generate_api_key(api_key_store):
|
||||
"""Test that generate_api_key returns a string of the expected length."""
|
||||
"""Test that generate_api_key returns a string with sk-oh- prefix and expected length."""
|
||||
key = api_key_store.generate_api_key(length=32)
|
||||
assert isinstance(key, str)
|
||||
assert len(key) == 32
|
||||
assert key.startswith('sk-oh-')
|
||||
# Total length should be prefix (6 chars) + random part (32 chars) = 38 chars
|
||||
assert len(key) == len('sk-oh-') + 32
|
||||
|
||||
|
||||
def test_create_api_key(api_key_store, mock_session):
|
||||
@ -90,6 +92,50 @@ def test_validate_api_key_expired(api_key_store, mock_session):
|
||||
mock_session.commit.assert_not_called()
|
||||
|
||||
|
||||
def test_validate_api_key_expired_timezone_naive(api_key_store, mock_session):
|
||||
"""Test validating an expired API key with timezone-naive datetime from database."""
|
||||
# Setup
|
||||
api_key = 'test-api-key'
|
||||
mock_key_record = MagicMock()
|
||||
# Simulate timezone-naive datetime as returned from database
|
||||
mock_key_record.expires_at = datetime.now() - timedelta(days=1) # No UTC timezone
|
||||
mock_key_record.id = 1
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = (
|
||||
mock_key_record
|
||||
)
|
||||
|
||||
# Execute
|
||||
result = api_key_store.validate_api_key(api_key)
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
mock_session.execute.assert_not_called()
|
||||
mock_session.commit.assert_not_called()
|
||||
|
||||
|
||||
def test_validate_api_key_valid_timezone_naive(api_key_store, mock_session):
|
||||
"""Test validating a valid API key with timezone-naive datetime from database."""
|
||||
# Setup
|
||||
api_key = 'test-api-key'
|
||||
user_id = 'test-user-123'
|
||||
mock_key_record = MagicMock()
|
||||
mock_key_record.user_id = user_id
|
||||
# Simulate timezone-naive datetime as returned from database (future date)
|
||||
mock_key_record.expires_at = datetime.now() + timedelta(days=1) # No UTC timezone
|
||||
mock_key_record.id = 1
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = (
|
||||
mock_key_record
|
||||
)
|
||||
|
||||
# Execute
|
||||
result = api_key_store.validate_api_key(api_key)
|
||||
|
||||
# Verify
|
||||
assert result == user_id
|
||||
mock_session.execute.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_validate_api_key_not_found(api_key_store, mock_session):
|
||||
"""Test validating a non-existent API key."""
|
||||
# Setup
|
||||
|
||||
132
enterprise/tests/unit/test_get_user_v1_enabled_setting.py
Normal file
132
enterprise/tests/unit/test_get_user_v1_enabled_setting.py
Normal file
@ -0,0 +1,132 @@
|
||||
"""Unit tests for get_user_v1_enabled_setting function."""
|
||||
|
||||
import os
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from integrations.github.github_view import get_user_v1_enabled_setting
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_settings():
|
||||
"""Create a mock user settings object."""
|
||||
settings = MagicMock()
|
||||
settings.v1_enabled = True # Default to True, can be overridden in tests
|
||||
return settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings_store(mock_user_settings):
|
||||
"""Create a mock settings store."""
|
||||
store = MagicMock()
|
||||
store.get_user_settings_by_keycloak_id = AsyncMock(return_value=mock_user_settings)
|
||||
return store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config():
|
||||
"""Create a mock config object."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_maker():
|
||||
"""Create a mock session maker."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dependencies(
|
||||
mock_settings_store, mock_config, mock_session_maker, mock_user_settings
|
||||
):
|
||||
"""Fixture that patches all the common dependencies."""
|
||||
with patch(
|
||||
'integrations.github.github_view.SaasSettingsStore',
|
||||
return_value=mock_settings_store,
|
||||
) as mock_store_class, patch(
|
||||
'integrations.github.github_view.get_config', return_value=mock_config
|
||||
) as mock_get_config, patch(
|
||||
'integrations.github.github_view.session_maker', mock_session_maker
|
||||
), patch(
|
||||
'integrations.github.github_view.call_sync_from_async',
|
||||
return_value=mock_user_settings,
|
||||
) as mock_call_sync:
|
||||
yield {
|
||||
'store_class': mock_store_class,
|
||||
'get_config': mock_get_config,
|
||||
'session_maker': mock_session_maker,
|
||||
'call_sync': mock_call_sync,
|
||||
'settings_store': mock_settings_store,
|
||||
'user_settings': mock_user_settings,
|
||||
}
|
||||
|
||||
|
||||
class TestGetUserV1EnabledSetting:
|
||||
"""Test cases for get_user_v1_enabled_setting function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'env_var_enabled,user_setting_enabled,expected_result',
|
||||
[
|
||||
(False, True, False), # Env var disabled, user enabled -> False
|
||||
(True, False, False), # Env var enabled, user disabled -> False
|
||||
(True, True, True), # Both enabled -> True
|
||||
(False, False, False), # Both disabled -> False
|
||||
],
|
||||
)
|
||||
async def test_v1_enabled_combinations(
|
||||
self, mock_dependencies, env_var_enabled, user_setting_enabled, expected_result
|
||||
):
|
||||
"""Test all combinations of environment variable and user setting values."""
|
||||
mock_dependencies['user_settings'].v1_enabled = user_setting_enabled
|
||||
|
||||
with patch(
|
||||
'integrations.github.github_view.ENABLE_V1_GITHUB_RESOLVER', env_var_enabled
|
||||
):
|
||||
result = await get_user_v1_enabled_setting('test_user_id')
|
||||
assert result is expected_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'env_var_value,env_var_bool,expected_result',
|
||||
[
|
||||
('false', False, False), # Environment variable 'false' -> False
|
||||
('true', True, True), # Environment variable 'true' -> True
|
||||
],
|
||||
)
|
||||
async def test_environment_variable_integration(
|
||||
self, mock_dependencies, env_var_value, env_var_bool, expected_result
|
||||
):
|
||||
"""Test that the function properly reads the ENABLE_V1_GITHUB_RESOLVER environment variable."""
|
||||
mock_dependencies['user_settings'].v1_enabled = True
|
||||
|
||||
with patch.dict(
|
||||
os.environ, {'ENABLE_V1_GITHUB_RESOLVER': env_var_value}
|
||||
), patch('integrations.utils.os.getenv', return_value=env_var_value), patch(
|
||||
'integrations.github.github_view.ENABLE_V1_GITHUB_RESOLVER', env_var_bool
|
||||
):
|
||||
result = await get_user_v1_enabled_setting('test_user_id')
|
||||
assert result is expected_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_function_calls_correct_methods(self, mock_dependencies):
|
||||
"""Test that the function calls the correct methods with correct parameters."""
|
||||
mock_dependencies['user_settings'].v1_enabled = True
|
||||
|
||||
with patch('integrations.github.github_view.ENABLE_V1_GITHUB_RESOLVER', True):
|
||||
result = await get_user_v1_enabled_setting('test_user_123')
|
||||
|
||||
# Verify the result
|
||||
assert result is True
|
||||
|
||||
# Verify correct methods were called with correct parameters
|
||||
mock_dependencies['get_config'].assert_called_once()
|
||||
mock_dependencies['store_class'].assert_called_once_with(
|
||||
user_id='test_user_123',
|
||||
session_maker=mock_dependencies['session_maker'],
|
||||
config=mock_dependencies['get_config'].return_value,
|
||||
)
|
||||
mock_dependencies['call_sync'].assert_called_once_with(
|
||||
mock_dependencies['settings_store'].get_user_settings_by_keycloak_id,
|
||||
'test_user_123',
|
||||
)
|
||||
@ -1,485 +0,0 @@
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from server.legacy_conversation_manager import (
|
||||
_LEGACY_ENTRY_TIMEOUT_SECONDS,
|
||||
LegacyCacheEntry,
|
||||
LegacyConversationManager,
|
||||
)
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.server.config.server_config import ServerConfig
|
||||
from openhands.server.monitoring import MonitoringListener
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_sio():
|
||||
"""Create a mock SocketIO server."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config():
|
||||
"""Create a mock OpenHands config."""
|
||||
return MagicMock(spec=OpenHandsConfig)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_server_config():
|
||||
"""Create a mock server config."""
|
||||
return MagicMock(spec=ServerConfig)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_file_store():
|
||||
"""Create a mock file store."""
|
||||
return MagicMock(spec=InMemoryFileStore)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_monitoring_listener():
|
||||
"""Create a mock monitoring listener."""
|
||||
return MagicMock(spec=MonitoringListener)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_conversation_manager():
|
||||
"""Create a mock SaasNestedConversationManager."""
|
||||
mock_cm = MagicMock()
|
||||
mock_cm._get_runtime = AsyncMock()
|
||||
return mock_cm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_legacy_conversation_manager():
|
||||
"""Create a mock ClusteredConversationManager."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def legacy_manager(
|
||||
mock_sio,
|
||||
mock_config,
|
||||
mock_server_config,
|
||||
mock_file_store,
|
||||
mock_conversation_manager,
|
||||
mock_legacy_conversation_manager,
|
||||
):
|
||||
"""Create a LegacyConversationManager instance for testing."""
|
||||
return LegacyConversationManager(
|
||||
sio=mock_sio,
|
||||
config=mock_config,
|
||||
server_config=mock_server_config,
|
||||
file_store=mock_file_store,
|
||||
conversation_manager=mock_conversation_manager,
|
||||
legacy_conversation_manager=mock_legacy_conversation_manager,
|
||||
)
|
||||
|
||||
|
||||
class TestLegacyCacheEntry:
|
||||
"""Test the LegacyCacheEntry dataclass."""
|
||||
|
||||
def test_cache_entry_creation(self):
|
||||
"""Test creating a cache entry."""
|
||||
timestamp = time.time()
|
||||
entry = LegacyCacheEntry(is_legacy=True, timestamp=timestamp)
|
||||
|
||||
assert entry.is_legacy is True
|
||||
assert entry.timestamp == timestamp
|
||||
|
||||
def test_cache_entry_false(self):
|
||||
"""Test creating a cache entry with False value."""
|
||||
timestamp = time.time()
|
||||
entry = LegacyCacheEntry(is_legacy=False, timestamp=timestamp)
|
||||
|
||||
assert entry.is_legacy is False
|
||||
assert entry.timestamp == timestamp
|
||||
|
||||
|
||||
class TestLegacyConversationManagerCacheCleanup:
|
||||
"""Test cache cleanup functionality."""
|
||||
|
||||
def test_cleanup_expired_cache_entries_removes_expired(self, legacy_manager):
|
||||
"""Test that expired entries are removed from cache."""
|
||||
current_time = time.time()
|
||||
expired_time = current_time - _LEGACY_ENTRY_TIMEOUT_SECONDS - 1
|
||||
valid_time = current_time - 100 # Well within timeout
|
||||
|
||||
# Add both expired and valid entries
|
||||
legacy_manager._legacy_cache = {
|
||||
'expired_conversation': LegacyCacheEntry(True, expired_time),
|
||||
'valid_conversation': LegacyCacheEntry(False, valid_time),
|
||||
'another_expired': LegacyCacheEntry(True, expired_time - 100),
|
||||
}
|
||||
|
||||
legacy_manager._cleanup_expired_cache_entries()
|
||||
|
||||
# Only valid entry should remain
|
||||
assert len(legacy_manager._legacy_cache) == 1
|
||||
assert 'valid_conversation' in legacy_manager._legacy_cache
|
||||
assert 'expired_conversation' not in legacy_manager._legacy_cache
|
||||
assert 'another_expired' not in legacy_manager._legacy_cache
|
||||
|
||||
def test_cleanup_expired_cache_entries_keeps_valid(self, legacy_manager):
|
||||
"""Test that valid entries are kept during cleanup."""
|
||||
current_time = time.time()
|
||||
valid_time = current_time - 100 # Well within timeout
|
||||
|
||||
legacy_manager._legacy_cache = {
|
||||
'valid_conversation_1': LegacyCacheEntry(True, valid_time),
|
||||
'valid_conversation_2': LegacyCacheEntry(False, valid_time - 50),
|
||||
}
|
||||
|
||||
legacy_manager._cleanup_expired_cache_entries()
|
||||
|
||||
# Both entries should remain
|
||||
assert len(legacy_manager._legacy_cache) == 2
|
||||
assert 'valid_conversation_1' in legacy_manager._legacy_cache
|
||||
assert 'valid_conversation_2' in legacy_manager._legacy_cache
|
||||
|
||||
def test_cleanup_expired_cache_entries_empty_cache(self, legacy_manager):
|
||||
"""Test cleanup with empty cache."""
|
||||
legacy_manager._legacy_cache = {}
|
||||
|
||||
legacy_manager._cleanup_expired_cache_entries()
|
||||
|
||||
assert len(legacy_manager._legacy_cache) == 0
|
||||
|
||||
|
||||
class TestIsLegacyRuntime:
|
||||
"""Test the is_legacy_runtime method."""
|
||||
|
||||
def test_is_legacy_runtime_none(self, legacy_manager):
|
||||
"""Test with None runtime."""
|
||||
result = legacy_manager.is_legacy_runtime(None)
|
||||
assert result is False
|
||||
|
||||
def test_is_legacy_runtime_legacy_command(self, legacy_manager):
|
||||
"""Test with legacy runtime command."""
|
||||
runtime = {'command': 'some_old_legacy_command'}
|
||||
result = legacy_manager.is_legacy_runtime(runtime)
|
||||
assert result is True
|
||||
|
||||
def test_is_legacy_runtime_new_command(self, legacy_manager):
|
||||
"""Test with new runtime command containing openhands.server."""
|
||||
runtime = {'command': 'python -m openhands.server.listen'}
|
||||
result = legacy_manager.is_legacy_runtime(runtime)
|
||||
assert result is False
|
||||
|
||||
def test_is_legacy_runtime_partial_match(self, legacy_manager):
|
||||
"""Test with command that partially matches but is still legacy."""
|
||||
runtime = {'command': 'openhands.client.start'}
|
||||
result = legacy_manager.is_legacy_runtime(runtime)
|
||||
assert result is True
|
||||
|
||||
def test_is_legacy_runtime_empty_command(self, legacy_manager):
|
||||
"""Test with empty command."""
|
||||
runtime = {'command': ''}
|
||||
result = legacy_manager.is_legacy_runtime(runtime)
|
||||
assert result is True
|
||||
|
||||
def test_is_legacy_runtime_missing_command_key(self, legacy_manager):
|
||||
"""Test with runtime missing command key."""
|
||||
runtime = {'other_key': 'value'}
|
||||
# This should raise a KeyError
|
||||
with pytest.raises(KeyError):
|
||||
legacy_manager.is_legacy_runtime(runtime)
|
||||
|
||||
|
||||
class TestShouldStartInLegacyMode:
|
||||
"""Test the should_start_in_legacy_mode method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_hit_valid_entry_legacy(self, legacy_manager):
|
||||
"""Test cache hit with valid legacy entry."""
|
||||
conversation_id = 'test_conversation'
|
||||
current_time = time.time()
|
||||
|
||||
# Add valid cache entry
|
||||
legacy_manager._legacy_cache[conversation_id] = LegacyCacheEntry(
|
||||
True, current_time - 100
|
||||
)
|
||||
|
||||
result = await legacy_manager.should_start_in_legacy_mode(conversation_id)
|
||||
|
||||
assert result is True
|
||||
# Should not call _get_runtime since we hit cache
|
||||
legacy_manager.conversation_manager._get_runtime.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_hit_valid_entry_non_legacy(self, legacy_manager):
|
||||
"""Test cache hit with valid non-legacy entry."""
|
||||
conversation_id = 'test_conversation'
|
||||
current_time = time.time()
|
||||
|
||||
# Add valid cache entry
|
||||
legacy_manager._legacy_cache[conversation_id] = LegacyCacheEntry(
|
||||
False, current_time - 100
|
||||
)
|
||||
|
||||
result = await legacy_manager.should_start_in_legacy_mode(conversation_id)
|
||||
|
||||
assert result is False
|
||||
# Should not call _get_runtime since we hit cache
|
||||
legacy_manager.conversation_manager._get_runtime.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_miss_legacy_runtime(self, legacy_manager):
|
||||
"""Test cache miss with legacy runtime."""
|
||||
conversation_id = 'test_conversation'
|
||||
runtime = {'command': 'old_command'}
|
||||
|
||||
legacy_manager.conversation_manager._get_runtime.return_value = runtime
|
||||
|
||||
result = await legacy_manager.should_start_in_legacy_mode(conversation_id)
|
||||
|
||||
assert result is True
|
||||
# Should call _get_runtime
|
||||
legacy_manager.conversation_manager._get_runtime.assert_called_once_with(
|
||||
conversation_id
|
||||
)
|
||||
# Should cache the result
|
||||
assert conversation_id in legacy_manager._legacy_cache
|
||||
assert legacy_manager._legacy_cache[conversation_id].is_legacy is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_miss_non_legacy_runtime(self, legacy_manager):
|
||||
"""Test cache miss with non-legacy runtime."""
|
||||
conversation_id = 'test_conversation'
|
||||
runtime = {'command': 'python -m openhands.server.listen'}
|
||||
|
||||
legacy_manager.conversation_manager._get_runtime.return_value = runtime
|
||||
|
||||
result = await legacy_manager.should_start_in_legacy_mode(conversation_id)
|
||||
|
||||
assert result is False
|
||||
# Should call _get_runtime
|
||||
legacy_manager.conversation_manager._get_runtime.assert_called_once_with(
|
||||
conversation_id
|
||||
)
|
||||
# Should cache the result
|
||||
assert conversation_id in legacy_manager._legacy_cache
|
||||
assert legacy_manager._legacy_cache[conversation_id].is_legacy is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_expired_entry(self, legacy_manager):
|
||||
"""Test with expired cache entry."""
|
||||
conversation_id = 'test_conversation'
|
||||
expired_time = time.time() - _LEGACY_ENTRY_TIMEOUT_SECONDS - 1
|
||||
runtime = {'command': 'python -m openhands.server.listen'}
|
||||
|
||||
# Add expired cache entry
|
||||
legacy_manager._legacy_cache[conversation_id] = LegacyCacheEntry(
|
||||
True,
|
||||
expired_time, # This should be considered expired
|
||||
)
|
||||
|
||||
legacy_manager.conversation_manager._get_runtime.return_value = runtime
|
||||
|
||||
result = await legacy_manager.should_start_in_legacy_mode(conversation_id)
|
||||
|
||||
assert result is False # Runtime indicates non-legacy
|
||||
# Should call _get_runtime since cache is expired
|
||||
legacy_manager.conversation_manager._get_runtime.assert_called_once_with(
|
||||
conversation_id
|
||||
)
|
||||
# Should update cache with new result
|
||||
assert legacy_manager._legacy_cache[conversation_id].is_legacy is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_exactly_at_timeout(self, legacy_manager):
|
||||
"""Test with cache entry exactly at timeout boundary."""
|
||||
conversation_id = 'test_conversation'
|
||||
timeout_time = time.time() - _LEGACY_ENTRY_TIMEOUT_SECONDS
|
||||
runtime = {'command': 'python -m openhands.server.listen'}
|
||||
|
||||
# Add cache entry exactly at timeout
|
||||
legacy_manager._legacy_cache[conversation_id] = LegacyCacheEntry(
|
||||
True, timeout_time
|
||||
)
|
||||
|
||||
legacy_manager.conversation_manager._get_runtime.return_value = runtime
|
||||
|
||||
result = await legacy_manager.should_start_in_legacy_mode(conversation_id)
|
||||
|
||||
# Should treat as expired and fetch from runtime
|
||||
assert result is False
|
||||
legacy_manager.conversation_manager._get_runtime.assert_called_once_with(
|
||||
conversation_id
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_returns_none(self, legacy_manager):
|
||||
"""Test when runtime returns None."""
|
||||
conversation_id = 'test_conversation'
|
||||
|
||||
legacy_manager.conversation_manager._get_runtime.return_value = None
|
||||
|
||||
result = await legacy_manager.should_start_in_legacy_mode(conversation_id)
|
||||
|
||||
assert result is False
|
||||
# Should cache the result
|
||||
assert conversation_id in legacy_manager._legacy_cache
|
||||
assert legacy_manager._legacy_cache[conversation_id].is_legacy is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_called_on_each_invocation(self, legacy_manager):
|
||||
"""Test that cleanup is called on each invocation."""
|
||||
conversation_id = 'test_conversation'
|
||||
runtime = {'command': 'test'}
|
||||
|
||||
legacy_manager.conversation_manager._get_runtime.return_value = runtime
|
||||
|
||||
# Mock the cleanup method to verify it's called
|
||||
with patch.object(
|
||||
legacy_manager, '_cleanup_expired_cache_entries'
|
||||
) as mock_cleanup:
|
||||
await legacy_manager.should_start_in_legacy_mode(conversation_id)
|
||||
mock_cleanup.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_conversations_cached_independently(self, legacy_manager):
|
||||
"""Test that multiple conversations are cached independently."""
|
||||
conv1 = 'conversation_1'
|
||||
conv2 = 'conversation_2'
|
||||
|
||||
runtime1 = {'command': 'old_command'} # Legacy
|
||||
runtime2 = {'command': 'python -m openhands.server.listen'} # Non-legacy
|
||||
|
||||
# Mock to return different runtimes based on conversation_id
|
||||
def mock_get_runtime(conversation_id):
|
||||
if conversation_id == conv1:
|
||||
return runtime1
|
||||
return runtime2
|
||||
|
||||
legacy_manager.conversation_manager._get_runtime.side_effect = mock_get_runtime
|
||||
|
||||
result1 = await legacy_manager.should_start_in_legacy_mode(conv1)
|
||||
result2 = await legacy_manager.should_start_in_legacy_mode(conv2)
|
||||
|
||||
assert result1 is True
|
||||
assert result2 is False
|
||||
|
||||
# Both should be cached
|
||||
assert conv1 in legacy_manager._legacy_cache
|
||||
assert conv2 in legacy_manager._legacy_cache
|
||||
assert legacy_manager._legacy_cache[conv1].is_legacy is True
|
||||
assert legacy_manager._legacy_cache[conv2].is_legacy is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_timestamp_updated_on_refresh(self, legacy_manager):
|
||||
"""Test that cache timestamp is updated when entry is refreshed."""
|
||||
conversation_id = 'test_conversation'
|
||||
old_time = time.time() - _LEGACY_ENTRY_TIMEOUT_SECONDS - 1
|
||||
runtime = {'command': 'test'}
|
||||
|
||||
# Add expired entry
|
||||
legacy_manager._legacy_cache[conversation_id] = LegacyCacheEntry(True, old_time)
|
||||
legacy_manager.conversation_manager._get_runtime.return_value = runtime
|
||||
|
||||
# Record time before call
|
||||
before_call = time.time()
|
||||
await legacy_manager.should_start_in_legacy_mode(conversation_id)
|
||||
after_call = time.time()
|
||||
|
||||
# Timestamp should be updated
|
||||
cached_entry = legacy_manager._legacy_cache[conversation_id]
|
||||
assert cached_entry.timestamp >= before_call
|
||||
assert cached_entry.timestamp <= after_call
|
||||
|
||||
|
||||
class TestLegacyConversationManagerIntegration:
|
||||
"""Integration tests for LegacyConversationManager."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_instance_creates_proper_manager(
|
||||
self,
|
||||
mock_sio,
|
||||
mock_config,
|
||||
mock_file_store,
|
||||
mock_server_config,
|
||||
mock_monitoring_listener,
|
||||
):
|
||||
"""Test that get_instance creates a properly configured manager."""
|
||||
with patch(
|
||||
'server.legacy_conversation_manager.SaasNestedConversationManager'
|
||||
) as mock_saas, patch(
|
||||
'server.legacy_conversation_manager.ClusteredConversationManager'
|
||||
) as mock_clustered:
|
||||
mock_saas.get_instance.return_value = MagicMock()
|
||||
mock_clustered.get_instance.return_value = MagicMock()
|
||||
|
||||
manager = LegacyConversationManager.get_instance(
|
||||
mock_sio,
|
||||
mock_config,
|
||||
mock_file_store,
|
||||
mock_server_config,
|
||||
mock_monitoring_listener,
|
||||
)
|
||||
|
||||
assert isinstance(manager, LegacyConversationManager)
|
||||
assert manager.sio == mock_sio
|
||||
assert manager.config == mock_config
|
||||
assert manager.file_store == mock_file_store
|
||||
assert manager.server_config == mock_server_config
|
||||
|
||||
# Verify that both nested managers are created
|
||||
mock_saas.get_instance.assert_called_once()
|
||||
mock_clustered.get_instance.assert_called_once()
|
||||
|
||||
def test_legacy_cache_initialized_empty(self, legacy_manager):
|
||||
"""Test that legacy cache is initialized as empty dict."""
|
||||
assert isinstance(legacy_manager._legacy_cache, dict)
|
||||
assert len(legacy_manager._legacy_cache) == 0
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Test edge cases and error scenarios."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_runtime_raises_exception(self, legacy_manager):
|
||||
"""Test behavior when _get_runtime raises an exception."""
|
||||
conversation_id = 'test_conversation'
|
||||
|
||||
legacy_manager.conversation_manager._get_runtime.side_effect = Exception(
|
||||
'Runtime error'
|
||||
)
|
||||
|
||||
# Should propagate the exception
|
||||
with pytest.raises(Exception, match='Runtime error'):
|
||||
await legacy_manager.should_start_in_legacy_mode(conversation_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_very_large_cache(self, legacy_manager):
|
||||
"""Test behavior with a large number of cache entries."""
|
||||
current_time = time.time()
|
||||
|
||||
# Add many cache entries
|
||||
for i in range(1000):
|
||||
legacy_manager._legacy_cache[f'conversation_{i}'] = LegacyCacheEntry(
|
||||
i % 2 == 0, current_time - i
|
||||
)
|
||||
|
||||
# This should work without issues
|
||||
await legacy_manager.should_start_in_legacy_mode('new_conversation')
|
||||
|
||||
# Should have added one more entry
|
||||
assert len(legacy_manager._legacy_cache) == 1001
|
||||
|
||||
def test_cleanup_with_concurrent_modifications(self, legacy_manager):
|
||||
"""Test cleanup behavior when cache is modified during cleanup."""
|
||||
current_time = time.time()
|
||||
expired_time = current_time - _LEGACY_ENTRY_TIMEOUT_SECONDS - 1
|
||||
|
||||
# Add expired entries
|
||||
legacy_manager._legacy_cache = {
|
||||
f'conversation_{i}': LegacyCacheEntry(True, expired_time) for i in range(10)
|
||||
}
|
||||
|
||||
# This should work without raising exceptions
|
||||
legacy_manager._cleanup_expired_cache_entries()
|
||||
|
||||
# All entries should be removed
|
||||
assert len(legacy_manager._legacy_cache) == 0
|
||||
@ -1,5 +1,10 @@
|
||||
# Evaluation
|
||||
|
||||
> [!WARNING]
|
||||
> **This directory is deprecated.** Our new benchmarks are located at [OpenHands/benchmarks](https://github.com/OpenHands/benchmarks).
|
||||
>
|
||||
> If you have already implemented a benchmark in this directory and would like to contribute it, we are happy to have the contribution. However, if you are starting anew, please use the new location.
|
||||
|
||||
This folder contains code and resources to run experiments and evaluations.
|
||||
|
||||
## For Benchmark Users
|
||||
|
||||
@ -18,6 +18,8 @@
|
||||
"i18next/no-literal-string": "error",
|
||||
"unused-imports/no-unused-imports": "error",
|
||||
"prettier/prettier": ["error"],
|
||||
// Enforce using optional chaining (?.) instead of && chains for null/undefined checks
|
||||
"@typescript-eslint/prefer-optional-chain": "error",
|
||||
// Resolves https://stackoverflow.com/questions/59265981/typescript-eslint-missing-file-extension-ts-import-extensions/59268871#59268871
|
||||
"import/extensions": [
|
||||
"error",
|
||||
|
||||
@ -1,2 +0,0 @@
|
||||
public-hoist-pattern[]=*@nextui-org/*
|
||||
enable-pre-post-scripts=true
|
||||
@ -30,61 +30,33 @@ vi.mock("react-i18next", async () => {
|
||||
};
|
||||
});
|
||||
|
||||
// Mock Zustand browser store
|
||||
let mockBrowserState = {
|
||||
url: "https://example.com",
|
||||
screenshotSrc: "",
|
||||
setUrl: vi.fn(),
|
||||
setScreenshotSrc: vi.fn(),
|
||||
reset: vi.fn(),
|
||||
};
|
||||
|
||||
vi.mock("#/stores/browser-store", () => ({
|
||||
useBrowserStore: () => mockBrowserState,
|
||||
}));
|
||||
|
||||
// Import the component after all mocks are set up
|
||||
import { BrowserPanel } from "#/components/features/browser/browser";
|
||||
import { useBrowserStore } from "#/stores/browser-store";
|
||||
|
||||
describe("Browser", () => {
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
// Reset the mock state
|
||||
mockBrowserState = {
|
||||
url: "https://example.com",
|
||||
screenshotSrc: "",
|
||||
setUrl: vi.fn(),
|
||||
setScreenshotSrc: vi.fn(),
|
||||
reset: vi.fn(),
|
||||
};
|
||||
});
|
||||
|
||||
it("renders a message if no screenshotSrc is provided", () => {
|
||||
// Set the mock state for this test
|
||||
mockBrowserState = {
|
||||
useBrowserStore.setState({
|
||||
url: "https://example.com",
|
||||
screenshotSrc: "",
|
||||
setUrl: vi.fn(),
|
||||
setScreenshotSrc: vi.fn(),
|
||||
reset: vi.fn(),
|
||||
};
|
||||
});
|
||||
|
||||
render(<BrowserPanel />);
|
||||
|
||||
// i18n empty message key
|
||||
expect(screen.getByText("BROWSER$NO_PAGE_LOADED")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("renders the url and a screenshot", () => {
|
||||
// Set the mock state for this test
|
||||
mockBrowserState = {
|
||||
useBrowserStore.setState({
|
||||
url: "https://example.com",
|
||||
screenshotSrc:
|
||||
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mN0uGvyHwAFCAJS091fQwAAAABJRU5ErkJggg==",
|
||||
setUrl: vi.fn(),
|
||||
setScreenshotSrc: vi.fn(),
|
||||
reset: vi.fn(),
|
||||
};
|
||||
});
|
||||
|
||||
render(<BrowserPanel />);
|
||||
|
||||
|
||||
@ -25,10 +25,7 @@ import { useUnifiedUploadFiles } from "#/hooks/mutation/use-unified-upload-files
|
||||
import { OpenHandsAction } from "#/types/core/actions";
|
||||
import { useEventStore } from "#/stores/use-event-store";
|
||||
|
||||
// Mock the hooks
|
||||
vi.mock("#/context/ws-client-provider");
|
||||
vi.mock("#/stores/error-message-store");
|
||||
vi.mock("#/stores/optimistic-user-message-store");
|
||||
vi.mock("#/hooks/query/use-config");
|
||||
vi.mock("#/hooks/mutation/use-get-trajectory");
|
||||
vi.mock("#/hooks/mutation/use-unified-upload-files");
|
||||
@ -102,24 +99,20 @@ describe("ChatInterface - Chat Suggestions", () => {
|
||||
},
|
||||
});
|
||||
|
||||
// Default mock implementations
|
||||
(useWsClient as unknown as ReturnType<typeof vi.fn>).mockReturnValue({
|
||||
send: vi.fn(),
|
||||
isLoadingMessages: false,
|
||||
parsedEvents: [],
|
||||
});
|
||||
(
|
||||
useOptimisticUserMessageStore as unknown as ReturnType<typeof vi.fn>
|
||||
).mockReturnValue({
|
||||
setOptimisticUserMessage: vi.fn(),
|
||||
getOptimisticUserMessage: vi.fn(() => null),
|
||||
|
||||
useOptimisticUserMessageStore.setState({
|
||||
optimisticUserMessage: null,
|
||||
});
|
||||
(
|
||||
useErrorMessageStore as unknown as ReturnType<typeof vi.fn>
|
||||
).mockReturnValue({
|
||||
setErrorMessage: vi.fn(),
|
||||
removeErrorMessage: vi.fn(),
|
||||
|
||||
useErrorMessageStore.setState({
|
||||
errorMessage: null,
|
||||
});
|
||||
|
||||
(useConfig as unknown as ReturnType<typeof vi.fn>).mockReturnValue({
|
||||
data: { APP_MODE: "local" },
|
||||
});
|
||||
@ -204,11 +197,8 @@ describe("ChatInterface - Chat Suggestions", () => {
|
||||
});
|
||||
|
||||
test("should hide chat suggestions when there is an optimistic user message", () => {
|
||||
(
|
||||
useOptimisticUserMessageStore as unknown as ReturnType<typeof vi.fn>
|
||||
).mockReturnValue({
|
||||
setOptimisticUserMessage: vi.fn(),
|
||||
getOptimisticUserMessage: vi.fn(() => "Optimistic message"),
|
||||
useOptimisticUserMessageStore.setState({
|
||||
optimisticUserMessage: "Optimistic message",
|
||||
});
|
||||
|
||||
renderWithQueryClient(<ChatInterface />, queryClient);
|
||||
@ -240,24 +230,19 @@ describe("ChatInterface - Empty state", () => {
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
// Reset mocks to ensure empty state
|
||||
(useWsClient as unknown as ReturnType<typeof vi.fn>).mockReturnValue({
|
||||
send: sendMock,
|
||||
status: "CONNECTED",
|
||||
isLoadingMessages: false,
|
||||
parsedEvents: [],
|
||||
});
|
||||
(
|
||||
useOptimisticUserMessageStore as unknown as ReturnType<typeof vi.fn>
|
||||
).mockReturnValue({
|
||||
setOptimisticUserMessage: vi.fn(),
|
||||
getOptimisticUserMessage: vi.fn(() => null),
|
||||
|
||||
useOptimisticUserMessageStore.setState({
|
||||
optimisticUserMessage: null,
|
||||
});
|
||||
(
|
||||
useErrorMessageStore as unknown as ReturnType<typeof vi.fn>
|
||||
).mockReturnValue({
|
||||
setErrorMessage: vi.fn(),
|
||||
removeErrorMessage: vi.fn(),
|
||||
|
||||
useErrorMessageStore.setState({
|
||||
errorMessage: null,
|
||||
});
|
||||
(useConfig as unknown as ReturnType<typeof vi.fn>).mockReturnValue({
|
||||
data: { APP_MODE: "local" },
|
||||
|
||||
@ -61,7 +61,7 @@ describe("ExpandableMessage", () => {
|
||||
expect(icon).toHaveClass("fill-success");
|
||||
});
|
||||
|
||||
it("should render with error icon for failed action messages", () => {
|
||||
it("should render with no icon for failed action messages", () => {
|
||||
renderWithProviders(
|
||||
<ExpandableMessage
|
||||
id="OBSERVATION_MESSAGE$RUN"
|
||||
@ -75,8 +75,7 @@ describe("ExpandableMessage", () => {
|
||||
"div.flex.gap-2.items-center.justify-start",
|
||||
);
|
||||
expect(container).toHaveClass("border-neutral-300");
|
||||
const icon = screen.getByTestId("status-icon");
|
||||
expect(icon).toHaveClass("fill-danger");
|
||||
expect(screen.queryByTestId("status-icon")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should render with neutral border and no icon for action messages without success prop", () => {
|
||||
|
||||
149
frontend/__tests__/components/conversation-tab-title.test.tsx
Normal file
149
frontend/__tests__/components/conversation-tab-title.test.tsx
Normal file
@ -0,0 +1,149 @@
|
||||
import { render, screen, waitFor } from "@testing-library/react";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { describe, expect, it, vi, beforeEach, afterEach } from "vitest";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import { ConversationTabTitle } from "#/components/features/conversation/conversation-tabs/conversation-tab-title";
|
||||
import GitService from "#/api/git-service/git-service.api";
|
||||
import V1GitService from "#/api/git-service/v1-git-service.api";
|
||||
|
||||
// Mock the services that the hook depends on
|
||||
vi.mock("#/api/git-service/git-service.api");
|
||||
vi.mock("#/api/git-service/v1-git-service.api");
|
||||
|
||||
// Mock the hooks that useUnifiedGetGitChanges depends on
|
||||
vi.mock("#/hooks/use-conversation-id", () => ({
|
||||
useConversationId: () => ({
|
||||
conversationId: "test-conversation-id",
|
||||
}),
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/query/use-active-conversation", () => ({
|
||||
useActiveConversation: () => ({
|
||||
data: {
|
||||
conversation_version: "V0",
|
||||
url: null,
|
||||
session_api_key: null,
|
||||
selected_repository: null,
|
||||
},
|
||||
}),
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/use-runtime-is-ready", () => ({
|
||||
useRuntimeIsReady: () => true,
|
||||
}));
|
||||
|
||||
vi.mock("#/utils/get-git-path", () => ({
|
||||
getGitPath: () => "/workspace",
|
||||
}));
|
||||
|
||||
describe("ConversationTabTitle", () => {
|
||||
let queryClient: QueryClient;
|
||||
|
||||
beforeEach(() => {
|
||||
queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: {
|
||||
retry: false,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Mock GitService methods
|
||||
vi.mocked(GitService.getGitChanges).mockResolvedValue([]);
|
||||
vi.mocked(V1GitService.getGitChanges).mockResolvedValue([]);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
queryClient.clear();
|
||||
});
|
||||
|
||||
const renderWithProviders = (ui: React.ReactElement) => {
|
||||
return render(
|
||||
<QueryClientProvider client={queryClient}>{ui}</QueryClientProvider>,
|
||||
);
|
||||
};
|
||||
|
||||
describe("Rendering", () => {
|
||||
it("should render the title", () => {
|
||||
// Arrange
|
||||
const title = "Test Title";
|
||||
|
||||
// Act
|
||||
renderWithProviders(
|
||||
<ConversationTabTitle title={title} conversationKey="browser" />,
|
||||
);
|
||||
|
||||
// Assert
|
||||
expect(screen.getByText(title)).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should show refresh button when conversationKey is 'editor'", () => {
|
||||
// Arrange
|
||||
const title = "Changes";
|
||||
|
||||
// Act
|
||||
renderWithProviders(
|
||||
<ConversationTabTitle title={title} conversationKey="editor" />,
|
||||
);
|
||||
|
||||
// Assert
|
||||
const refreshButton = screen.getByRole("button");
|
||||
expect(refreshButton).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not show refresh button when conversationKey is not 'editor'", () => {
|
||||
// Arrange
|
||||
const title = "Browser";
|
||||
|
||||
// Act
|
||||
renderWithProviders(
|
||||
<ConversationTabTitle title={title} conversationKey="browser" />,
|
||||
);
|
||||
|
||||
// Assert
|
||||
expect(screen.queryByRole("button")).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("User Interactions", () => {
|
||||
it("should call refetch and trigger GitService.getGitChanges when refresh button is clicked", async () => {
|
||||
// Arrange
|
||||
const user = userEvent.setup();
|
||||
const title = "Changes";
|
||||
const mockGitChanges: Array<{
|
||||
path: string;
|
||||
status: "M" | "A" | "D" | "R" | "U";
|
||||
}> = [
|
||||
{ path: "file1.ts", status: "M" },
|
||||
{ path: "file2.ts", status: "A" },
|
||||
];
|
||||
|
||||
vi.mocked(GitService.getGitChanges).mockResolvedValue(mockGitChanges);
|
||||
|
||||
renderWithProviders(
|
||||
<ConversationTabTitle title={title} conversationKey="editor" />,
|
||||
);
|
||||
|
||||
const refreshButton = screen.getByRole("button");
|
||||
|
||||
// Wait for initial query to complete
|
||||
await waitFor(() => {
|
||||
expect(GitService.getGitChanges).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
// Clear the mock to track refetch calls
|
||||
vi.mocked(GitService.getGitChanges).mockClear();
|
||||
|
||||
// Act
|
||||
await user.click(refreshButton);
|
||||
|
||||
// Assert - refetch should trigger another service call
|
||||
await waitFor(() => {
|
||||
expect(GitService.getGitChanges).toHaveBeenCalledWith(
|
||||
"test-conversation-id",
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
@ -0,0 +1,71 @@
|
||||
import { render, screen } from "@testing-library/react";
|
||||
import { describe, it, expect, vi } from "vitest";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import { MemoryRouter } from "react-router";
|
||||
import { AgentStatus } from "#/components/features/controls/agent-status";
|
||||
import { AgentState } from "#/types/agent-state";
|
||||
import { useAgentState } from "#/hooks/use-agent-state";
|
||||
import { useConversationStore } from "#/state/conversation-store";
|
||||
|
||||
vi.mock("#/hooks/use-agent-state");
|
||||
|
||||
vi.mock("#/hooks/use-conversation-id", () => ({
|
||||
useConversationId: () => ({ conversationId: "test-id" }),
|
||||
}));
|
||||
|
||||
const wrapper = ({ children }: { children: React.ReactNode }) => (
|
||||
<MemoryRouter>
|
||||
<QueryClientProvider client={new QueryClient()}>
|
||||
{children}
|
||||
</QueryClientProvider>
|
||||
</MemoryRouter>
|
||||
);
|
||||
|
||||
const renderAgentStatus = ({
|
||||
isPausing = false,
|
||||
}: { isPausing?: boolean } = {}) =>
|
||||
render(
|
||||
<AgentStatus
|
||||
handleStop={vi.fn()}
|
||||
handleResumeAgent={vi.fn()}
|
||||
isPausing={isPausing}
|
||||
/>,
|
||||
{ wrapper },
|
||||
);
|
||||
|
||||
describe("AgentStatus - isLoading logic", () => {
|
||||
it("should show loading when curAgentState is INIT", () => {
|
||||
vi.mocked(useAgentState).mockReturnValue({
|
||||
curAgentState: AgentState.INIT,
|
||||
});
|
||||
|
||||
renderAgentStatus();
|
||||
|
||||
expect(screen.getByTestId("agent-loading-spinner")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should show loading when isPausing is true, even if shouldShownAgentLoading is false", () => {
|
||||
vi.mocked(useAgentState).mockReturnValue({
|
||||
curAgentState: AgentState.AWAITING_USER_INPUT,
|
||||
});
|
||||
|
||||
renderAgentStatus({ isPausing: true });
|
||||
|
||||
expect(screen.getByTestId("agent-loading-spinner")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should NOT update global shouldShownAgentLoading when only isPausing is true", () => {
|
||||
vi.mocked(useAgentState).mockReturnValue({
|
||||
curAgentState: AgentState.AWAITING_USER_INPUT,
|
||||
});
|
||||
|
||||
renderAgentStatus({ isPausing: true });
|
||||
|
||||
// Loading spinner shows (because isPausing)
|
||||
expect(screen.getByTestId("agent-loading-spinner")).toBeInTheDocument();
|
||||
|
||||
// But global state should be false (because shouldShownAgentLoading is false)
|
||||
const { shouldShownAgentLoading } = useConversationStore.getState();
|
||||
expect(shouldShownAgentLoading).toBe(false);
|
||||
});
|
||||
});
|
||||
@ -42,7 +42,7 @@ vi.mock("react-i18next", async () => {
|
||||
BUTTON$EXPORT_CONVERSATION: "Export Conversation",
|
||||
BUTTON$DOWNLOAD_VIA_VSCODE: "Download via VS Code",
|
||||
BUTTON$SHOW_AGENT_TOOLS_AND_METADATA: "Show Agent Tools",
|
||||
CONVERSATION$SHOW_MICROAGENTS: "Show Microagents",
|
||||
CONVERSATION$SHOW_SKILLS: "Show Skills",
|
||||
BUTTON$DISPLAY_COST: "Display Cost",
|
||||
COMMON$CLOSE_CONVERSATION_STOP_RUNTIME:
|
||||
"Close Conversation (Stop Runtime)",
|
||||
@ -290,7 +290,7 @@ describe("ConversationNameContextMenu", () => {
|
||||
onStop: vi.fn(),
|
||||
onDisplayCost: vi.fn(),
|
||||
onShowAgentTools: vi.fn(),
|
||||
onShowMicroagents: vi.fn(),
|
||||
onShowSkills: vi.fn(),
|
||||
onExportConversation: vi.fn(),
|
||||
onDownloadViaVSCode: vi.fn(),
|
||||
};
|
||||
@ -304,7 +304,7 @@ describe("ConversationNameContextMenu", () => {
|
||||
expect(screen.getByTestId("stop-button")).toBeInTheDocument();
|
||||
expect(screen.getByTestId("display-cost-button")).toBeInTheDocument();
|
||||
expect(screen.getByTestId("show-agent-tools-button")).toBeInTheDocument();
|
||||
expect(screen.getByTestId("show-microagents-button")).toBeInTheDocument();
|
||||
expect(screen.getByTestId("show-skills-button")).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByTestId("export-conversation-button"),
|
||||
).toBeInTheDocument();
|
||||
@ -321,9 +321,7 @@ describe("ConversationNameContextMenu", () => {
|
||||
expect(
|
||||
screen.queryByTestId("show-agent-tools-button"),
|
||||
).not.toBeInTheDocument();
|
||||
expect(
|
||||
screen.queryByTestId("show-microagents-button"),
|
||||
).not.toBeInTheDocument();
|
||||
expect(screen.queryByTestId("show-skills-button")).not.toBeInTheDocument();
|
||||
expect(
|
||||
screen.queryByTestId("export-conversation-button"),
|
||||
).not.toBeInTheDocument();
|
||||
@ -410,19 +408,19 @@ describe("ConversationNameContextMenu", () => {
|
||||
|
||||
it("should call show microagents handler when show microagents button is clicked", async () => {
|
||||
const user = userEvent.setup();
|
||||
const onShowMicroagents = vi.fn();
|
||||
const onShowSkills = vi.fn();
|
||||
|
||||
renderWithProviders(
|
||||
<ConversationNameContextMenu
|
||||
{...defaultProps}
|
||||
onShowMicroagents={onShowMicroagents}
|
||||
onShowSkills={onShowSkills}
|
||||
/>,
|
||||
);
|
||||
|
||||
const showMicroagentsButton = screen.getByTestId("show-microagents-button");
|
||||
const showMicroagentsButton = screen.getByTestId("show-skills-button");
|
||||
await user.click(showMicroagentsButton);
|
||||
|
||||
expect(onShowMicroagents).toHaveBeenCalledTimes(1);
|
||||
expect(onShowSkills).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it("should call export conversation handler when export conversation button is clicked", async () => {
|
||||
@ -519,7 +517,7 @@ describe("ConversationNameContextMenu", () => {
|
||||
onStop: vi.fn(),
|
||||
onDisplayCost: vi.fn(),
|
||||
onShowAgentTools: vi.fn(),
|
||||
onShowMicroagents: vi.fn(),
|
||||
onShowSkills: vi.fn(),
|
||||
onExportConversation: vi.fn(),
|
||||
onDownloadViaVSCode: vi.fn(),
|
||||
};
|
||||
@ -541,8 +539,8 @@ describe("ConversationNameContextMenu", () => {
|
||||
expect(screen.getByTestId("show-agent-tools-button")).toHaveTextContent(
|
||||
"Show Agent Tools",
|
||||
);
|
||||
expect(screen.getByTestId("show-microagents-button")).toHaveTextContent(
|
||||
"Show Microagents",
|
||||
expect(screen.getByTestId("show-skills-button")).toHaveTextContent(
|
||||
"Show Skills",
|
||||
);
|
||||
expect(screen.getByTestId("export-conversation-button")).toHaveTextContent(
|
||||
"Export Conversation",
|
||||
|
||||
@ -0,0 +1,56 @@
|
||||
import { render, screen, waitFor } from "@testing-library/react";
|
||||
import { describe, it, expect, vi } from "vitest";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import { createRoutesStub } from "react-router";
|
||||
import { RecentConversations } from "#/components/features/home/recent-conversations/recent-conversations";
|
||||
import ConversationService from "#/api/conversation-service/conversation-service.api";
|
||||
|
||||
const renderRecentConversations = () => {
|
||||
const RouterStub = createRoutesStub([
|
||||
{
|
||||
Component: () => <RecentConversations />,
|
||||
path: "/",
|
||||
},
|
||||
]);
|
||||
|
||||
const queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: {
|
||||
retry: false,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
return render(<RouterStub />, {
|
||||
wrapper: ({ children }) => (
|
||||
<QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
|
||||
),
|
||||
});
|
||||
};
|
||||
|
||||
describe("RecentConversations", () => {
|
||||
const getUserConversationsSpy = vi.spyOn(
|
||||
ConversationService,
|
||||
"getUserConversations",
|
||||
);
|
||||
|
||||
it("should not show empty state when there is an error", async () => {
|
||||
getUserConversationsSpy.mockRejectedValue(
|
||||
new Error("Failed to fetch conversations"),
|
||||
);
|
||||
|
||||
renderRecentConversations();
|
||||
|
||||
// Wait for the error to be displayed
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("Failed to fetch conversations"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// The empty state should NOT be displayed when there's an error
|
||||
expect(
|
||||
screen.queryByText("HOME$NO_RECENT_CONVERSATIONS"),
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
@ -2,9 +2,9 @@ import { render, screen } from "@testing-library/react";
|
||||
import { describe, expect, vi, beforeEach, it } from "vitest";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import { RepositorySelectionForm } from "../../../../src/components/features/home/repo-selection-form";
|
||||
import UserService from "#/api/user-service/user-service.api";
|
||||
import GitService from "#/api/git-service/git-service.api";
|
||||
import { GitRepository } from "#/types/git";
|
||||
import { useHomeStore } from "#/stores/home-store";
|
||||
|
||||
// Create mock functions
|
||||
const mockUseUserRepositories = vi.fn();
|
||||
@ -97,7 +97,7 @@ vi.mock("#/context/auth-context", () => ({
|
||||
// Mock debounce to simulate proper debounced behavior
|
||||
let debouncedValue = "";
|
||||
vi.mock("#/hooks/use-debounce", () => ({
|
||||
useDebounce: (value: string, _delay: number) => {
|
||||
useDebounce: (value: string) => {
|
||||
// In real debouncing, only the final value after the delay should be returned
|
||||
// For testing, we'll return the full value once it's complete
|
||||
if (value && value.length > 20) {
|
||||
@ -124,28 +124,51 @@ vi.mock("#/hooks/query/use-search-repositories", () => ({
|
||||
}));
|
||||
|
||||
const mockOnRepoSelection = vi.fn();
|
||||
const renderForm = () =>
|
||||
render(<RepositorySelectionForm onRepoSelection={mockOnRepoSelection} />, {
|
||||
wrapper: ({ children }) => (
|
||||
<QueryClientProvider
|
||||
client={
|
||||
new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: {
|
||||
retry: false,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
>
|
||||
{children}
|
||||
</QueryClientProvider>
|
||||
),
|
||||
|
||||
// Helper function to render with custom store state
|
||||
const renderForm = (
|
||||
storeOverrides: Partial<{
|
||||
recentRepositories: GitRepository[];
|
||||
lastSelectedProvider: 'gitlab' | null;
|
||||
}> = {},
|
||||
) => {
|
||||
// Set up the store state before rendering
|
||||
useHomeStore.setState({
|
||||
recentRepositories: [],
|
||||
lastSelectedProvider: null,
|
||||
...storeOverrides,
|
||||
});
|
||||
|
||||
return render(
|
||||
<RepositorySelectionForm onRepoSelection={mockOnRepoSelection} />,
|
||||
{
|
||||
wrapper: ({ children }) => (
|
||||
<QueryClientProvider
|
||||
client={
|
||||
new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: {
|
||||
retry: false,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
>
|
||||
{children}
|
||||
</QueryClientProvider>
|
||||
),
|
||||
},
|
||||
);
|
||||
};
|
||||
|
||||
describe("RepositorySelectionForm", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
// Reset the store to initial state
|
||||
useHomeStore.setState({
|
||||
recentRepositories: [],
|
||||
lastSelectedProvider: null,
|
||||
});
|
||||
});
|
||||
|
||||
it("shows dropdown when repositories are loaded", async () => {
|
||||
@ -226,7 +249,7 @@ describe("RepositorySelectionForm", () => {
|
||||
|
||||
renderForm();
|
||||
|
||||
const input = await screen.findByTestId("git-repo-dropdown");
|
||||
await screen.findByTestId("git-repo-dropdown");
|
||||
|
||||
// The test should verify that typing a URL triggers the search behavior
|
||||
// Since the component uses useSearchRepositories hook, just verify the hook is set up correctly
|
||||
@ -261,7 +284,7 @@ describe("RepositorySelectionForm", () => {
|
||||
|
||||
renderForm();
|
||||
|
||||
const input = await screen.findByTestId("git-repo-dropdown");
|
||||
await screen.findByTestId("git-repo-dropdown");
|
||||
|
||||
// Verify that the onRepoSelection callback prop was provided
|
||||
expect(mockOnRepoSelection).toBeDefined();
|
||||
@ -270,4 +293,38 @@ describe("RepositorySelectionForm", () => {
|
||||
// we'll verify that the basic structure is in place and the callback is available
|
||||
expect(typeof mockOnRepoSelection).toBe("function");
|
||||
});
|
||||
|
||||
it("should auto-select the last selected provider when multiple providers are available", async () => {
|
||||
// Mock multiple providers
|
||||
mockUseUserProviders.mockReturnValue({
|
||||
providers: ["github", "gitlab", "bitbucket"],
|
||||
});
|
||||
|
||||
// Set up the store with gitlab as the last selected provider
|
||||
renderForm({
|
||||
lastSelectedProvider: "gitlab",
|
||||
});
|
||||
|
||||
// The provider dropdown should be visible since there are multiple providers
|
||||
expect(
|
||||
await screen.findByTestId("git-provider-dropdown"),
|
||||
).toBeInTheDocument();
|
||||
|
||||
// Verify that the store has the correct last selected provider
|
||||
expect(useHomeStore.getState().lastSelectedProvider).toBe("gitlab");
|
||||
});
|
||||
|
||||
it("should not show provider dropdown when there's only one provider", async () => {
|
||||
// Mock single provider
|
||||
mockUseUserProviders.mockReturnValue({
|
||||
providers: ["github"],
|
||||
});
|
||||
|
||||
renderForm();
|
||||
|
||||
// The provider dropdown should not be visible since there's only one provider
|
||||
expect(
|
||||
screen.queryByTestId("git-provider-dropdown"),
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import { render, screen, fireEvent } from "@testing-library/react";
|
||||
import { describe, it, expect, vi } from "vitest";
|
||||
import { MCPServerForm } from "../mcp-server-form";
|
||||
import { MCPServerForm } from "#/components/features/settings/mcp-settings/mcp-server-form";
|
||||
|
||||
// i18n mock
|
||||
vi.mock("react-i18next", () => ({
|
||||
@ -1,6 +1,6 @@
|
||||
import { render, screen } from "@testing-library/react";
|
||||
import { describe, it, expect, vi } from "vitest";
|
||||
import { MCPServerList } from "../mcp-server-list";
|
||||
import { MCPServerList } from "#/components/features/settings/mcp-settings/mcp-server-list";
|
||||
|
||||
// Mock react-i18next
|
||||
vi.mock("react-i18next", () => ({
|
||||
@ -8,16 +8,10 @@ import { AgentState } from "#/types/agent-state";
|
||||
import { useAgentState } from "#/hooks/use-agent-state";
|
||||
import { useConversationStore } from "#/state/conversation-store";
|
||||
|
||||
// Mock the agent state hook
|
||||
vi.mock("#/hooks/use-agent-state", () => ({
|
||||
useAgentState: vi.fn(),
|
||||
}));
|
||||
|
||||
// Mock the conversation store
|
||||
vi.mock("#/state/conversation-store", () => ({
|
||||
useConversationStore: vi.fn(),
|
||||
}));
|
||||
|
||||
// Mock React Router hooks
|
||||
vi.mock("react-router", async () => {
|
||||
const actual = await vi.importActual("react-router");
|
||||
@ -58,44 +52,23 @@ vi.mock("#/hooks/use-conversation-name-context-menu", () => ({
|
||||
describe("InteractiveChatBox", () => {
|
||||
const onSubmitMock = vi.fn();
|
||||
|
||||
// Helper function to mock stores
|
||||
const mockStores = (agentState: AgentState = AgentState.INIT) => {
|
||||
vi.mocked(useAgentState).mockReturnValue({
|
||||
curAgentState: agentState,
|
||||
});
|
||||
|
||||
vi.mocked(useConversationStore).mockReturnValue({
|
||||
useConversationStore.setState({
|
||||
images: [],
|
||||
files: [],
|
||||
addImages: vi.fn(),
|
||||
addFiles: vi.fn(),
|
||||
clearAllFiles: vi.fn(),
|
||||
addFileLoading: vi.fn(),
|
||||
removeFileLoading: vi.fn(),
|
||||
addImageLoading: vi.fn(),
|
||||
removeImageLoading: vi.fn(),
|
||||
submittedMessage: null,
|
||||
setShouldHideSuggestions: vi.fn(),
|
||||
setSubmittedMessage: vi.fn(),
|
||||
isRightPanelShown: true,
|
||||
selectedTab: "editor" as const,
|
||||
loadingFiles: [],
|
||||
loadingImages: [],
|
||||
submittedMessage: null,
|
||||
messageToSend: null,
|
||||
shouldShownAgentLoading: false,
|
||||
shouldHideSuggestions: false,
|
||||
isRightPanelShown: true,
|
||||
selectedTab: "editor" as const,
|
||||
hasRightPanelToggled: true,
|
||||
setIsRightPanelShown: vi.fn(),
|
||||
setSelectedTab: vi.fn(),
|
||||
setShouldShownAgentLoading: vi.fn(),
|
||||
removeImage: vi.fn(),
|
||||
removeFile: vi.fn(),
|
||||
clearImages: vi.fn(),
|
||||
clearFiles: vi.fn(),
|
||||
clearAllLoading: vi.fn(),
|
||||
setMessageToSend: vi.fn(),
|
||||
resetConversationState: vi.fn(),
|
||||
setHasRightPanelToggled: vi.fn(),
|
||||
});
|
||||
};
|
||||
|
||||
|
||||
@ -1,91 +0,0 @@
|
||||
import { screen } from "@testing-library/react";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
|
||||
import { renderWithProviders } from "test-utils";
|
||||
import { MicroagentsModal } from "#/components/features/conversation-panel/microagents-modal";
|
||||
import ConversationService from "#/api/conversation-service/conversation-service.api";
|
||||
import { AgentState } from "#/types/agent-state";
|
||||
import { useAgentState } from "#/hooks/use-agent-state";
|
||||
|
||||
// Mock the agent state hook
|
||||
vi.mock("#/hooks/use-agent-state", () => ({
|
||||
useAgentState: vi.fn(),
|
||||
}));
|
||||
|
||||
// Mock the conversation ID hook
|
||||
vi.mock("#/hooks/use-conversation-id", () => ({
|
||||
useConversationId: () => ({ conversationId: "test-conversation-id" }),
|
||||
}));
|
||||
|
||||
describe("MicroagentsModal - Refresh Button", () => {
|
||||
const mockOnClose = vi.fn();
|
||||
const conversationId = "test-conversation-id";
|
||||
|
||||
const defaultProps = {
|
||||
onClose: mockOnClose,
|
||||
conversationId,
|
||||
};
|
||||
|
||||
const mockMicroagents = [
|
||||
{
|
||||
name: "Test Agent 1",
|
||||
type: "repo" as const,
|
||||
triggers: ["test", "example"],
|
||||
content: "This is test content for agent 1",
|
||||
},
|
||||
{
|
||||
name: "Test Agent 2",
|
||||
type: "knowledge" as const,
|
||||
triggers: ["help", "support"],
|
||||
content: "This is test content for agent 2",
|
||||
},
|
||||
];
|
||||
|
||||
beforeEach(() => {
|
||||
// Reset all mocks before each test
|
||||
vi.clearAllMocks();
|
||||
|
||||
// Setup default mock for getMicroagents
|
||||
vi.spyOn(ConversationService, "getMicroagents").mockResolvedValue({
|
||||
microagents: mockMicroagents,
|
||||
});
|
||||
|
||||
// Mock the agent state to return a ready state
|
||||
vi.mocked(useAgentState).mockReturnValue({
|
||||
curAgentState: AgentState.AWAITING_USER_INPUT,
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe("Refresh Button Rendering", () => {
|
||||
it("should render the refresh button with correct text and test ID", async () => {
|
||||
renderWithProviders(<MicroagentsModal {...defaultProps} />);
|
||||
|
||||
// Wait for the component to load and render the refresh button
|
||||
const refreshButton = await screen.findByTestId("refresh-microagents");
|
||||
expect(refreshButton).toBeInTheDocument();
|
||||
expect(refreshButton).toHaveTextContent("BUTTON$REFRESH");
|
||||
});
|
||||
});
|
||||
|
||||
describe("Refresh Button Functionality", () => {
|
||||
it("should call refetch when refresh button is clicked", async () => {
|
||||
const user = userEvent.setup();
|
||||
const refreshSpy = vi.spyOn(ConversationService, "getMicroagents");
|
||||
|
||||
renderWithProviders(<MicroagentsModal {...defaultProps} />);
|
||||
|
||||
// Wait for the component to load and render the refresh button
|
||||
const refreshButton = await screen.findByTestId("refresh-microagents");
|
||||
|
||||
refreshSpy.mockClear();
|
||||
|
||||
await user.click(refreshButton);
|
||||
|
||||
expect(refreshSpy).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
});
|
||||
394
frontend/__tests__/components/modals/skills/skill-modal.test.tsx
Normal file
394
frontend/__tests__/components/modals/skills/skill-modal.test.tsx
Normal file
@ -0,0 +1,394 @@
|
||||
import { screen } from "@testing-library/react";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
|
||||
import { renderWithProviders } from "test-utils";
|
||||
import { SkillsModal } from "#/components/features/conversation-panel/skills-modal";
|
||||
import ConversationService from "#/api/conversation-service/conversation-service.api";
|
||||
import V1ConversationService from "#/api/conversation-service/v1-conversation-service.api";
|
||||
import { AgentState } from "#/types/agent-state";
|
||||
import { useAgentState } from "#/hooks/use-agent-state";
|
||||
import SettingsService from "#/api/settings-service/settings-service.api";
|
||||
|
||||
// Mock the agent state hook
|
||||
vi.mock("#/hooks/use-agent-state", () => ({
|
||||
useAgentState: vi.fn(),
|
||||
}));
|
||||
|
||||
// Mock the conversation ID hook
|
||||
vi.mock("#/hooks/use-conversation-id", () => ({
|
||||
useConversationId: () => ({ conversationId: "test-conversation-id" }),
|
||||
}));
|
||||
|
||||
describe("SkillsModal - Refresh Button", () => {
|
||||
const mockOnClose = vi.fn();
|
||||
const conversationId = "test-conversation-id";
|
||||
|
||||
const defaultProps = {
|
||||
onClose: mockOnClose,
|
||||
conversationId,
|
||||
};
|
||||
|
||||
const mockSkills = [
|
||||
{
|
||||
name: "Test Agent 1",
|
||||
type: "repo" as const,
|
||||
triggers: ["test", "example"],
|
||||
content: "This is test content for agent 1",
|
||||
},
|
||||
{
|
||||
name: "Test Agent 2",
|
||||
type: "knowledge" as const,
|
||||
triggers: ["help", "support"],
|
||||
content: "This is test content for agent 2",
|
||||
},
|
||||
];
|
||||
|
||||
beforeEach(() => {
|
||||
// Reset all mocks before each test
|
||||
vi.clearAllMocks();
|
||||
|
||||
// Setup default mock for getMicroagents (V0)
|
||||
vi.spyOn(ConversationService, "getMicroagents").mockResolvedValue({
|
||||
microagents: mockSkills,
|
||||
});
|
||||
|
||||
// Mock the agent state to return a ready state
|
||||
vi.mocked(useAgentState).mockReturnValue({
|
||||
curAgentState: AgentState.AWAITING_USER_INPUT,
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe("Refresh Button Rendering", () => {
|
||||
it("should render the refresh button with correct text and test ID", async () => {
|
||||
renderWithProviders(<SkillsModal {...defaultProps} />);
|
||||
|
||||
// Wait for the component to load and render the refresh button
|
||||
const refreshButton = await screen.findByTestId("refresh-skills");
|
||||
expect(refreshButton).toBeInTheDocument();
|
||||
expect(refreshButton).toHaveTextContent("BUTTON$REFRESH");
|
||||
});
|
||||
});
|
||||
|
||||
describe("Refresh Button Functionality", () => {
|
||||
it("should call refetch when refresh button is clicked", async () => {
|
||||
const user = userEvent.setup();
|
||||
const refreshSpy = vi.spyOn(ConversationService, "getMicroagents");
|
||||
|
||||
renderWithProviders(<SkillsModal {...defaultProps} />);
|
||||
|
||||
// Wait for the component to load and render the refresh button
|
||||
const refreshButton = await screen.findByTestId("refresh-skills");
|
||||
|
||||
// Clear previous calls to only track the click
|
||||
refreshSpy.mockClear();
|
||||
|
||||
await user.click(refreshButton);
|
||||
|
||||
// Verify the refresh triggered a new API call
|
||||
expect(refreshSpy).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("useConversationSkills - V1 API Integration", () => {
|
||||
const conversationId = "test-conversation-id";
|
||||
|
||||
const mockMicroagents = [
|
||||
{
|
||||
name: "V0 Test Agent",
|
||||
type: "repo" as const,
|
||||
triggers: ["v0"],
|
||||
content: "V0 skill content",
|
||||
},
|
||||
];
|
||||
|
||||
const mockSkills = [
|
||||
{
|
||||
name: "V1 Test Skill",
|
||||
type: "knowledge" as const,
|
||||
triggers: ["v1", "skill"],
|
||||
content: "V1 skill content",
|
||||
},
|
||||
];
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
|
||||
// Mock agent state
|
||||
vi.mocked(useAgentState).mockReturnValue({
|
||||
curAgentState: AgentState.AWAITING_USER_INPUT,
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe("V0 API Usage (v1_enabled: false)", () => {
|
||||
it("should call v0 ConversationService.getMicroagents when v1_enabled is false", async () => {
|
||||
// Arrange
|
||||
const getMicroagentsSpy = vi
|
||||
.spyOn(ConversationService, "getMicroagents")
|
||||
.mockResolvedValue({ microagents: mockMicroagents });
|
||||
|
||||
vi.spyOn(SettingsService, "getSettings").mockResolvedValue({
|
||||
v1_enabled: false,
|
||||
llm_model: "test-model",
|
||||
llm_base_url: "",
|
||||
agent: "test-agent",
|
||||
language: "en",
|
||||
llm_api_key: null,
|
||||
llm_api_key_set: false,
|
||||
search_api_key_set: false,
|
||||
confirmation_mode: false,
|
||||
security_analyzer: null,
|
||||
remote_runtime_resource_factor: null,
|
||||
provider_tokens_set: {},
|
||||
enable_default_condenser: false,
|
||||
condenser_max_size: null,
|
||||
enable_sound_notifications: false,
|
||||
enable_proactive_conversation_starters: false,
|
||||
enable_solvability_analysis: false,
|
||||
user_consents_to_analytics: null,
|
||||
max_budget_per_task: null,
|
||||
});
|
||||
|
||||
// Act
|
||||
renderWithProviders(<SkillsModal onClose={vi.fn()} />);
|
||||
|
||||
// Assert
|
||||
await screen.findByText("V0 Test Agent");
|
||||
expect(getMicroagentsSpy).toHaveBeenCalledWith(conversationId);
|
||||
expect(getMicroagentsSpy).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it("should display v0 skills correctly", async () => {
|
||||
// Arrange
|
||||
vi.spyOn(ConversationService, "getMicroagents").mockResolvedValue({
|
||||
microagents: mockMicroagents,
|
||||
});
|
||||
|
||||
vi.spyOn(SettingsService, "getSettings").mockResolvedValue({
|
||||
v1_enabled: false,
|
||||
llm_model: "test-model",
|
||||
llm_base_url: "",
|
||||
agent: "test-agent",
|
||||
language: "en",
|
||||
llm_api_key: null,
|
||||
llm_api_key_set: false,
|
||||
search_api_key_set: false,
|
||||
confirmation_mode: false,
|
||||
security_analyzer: null,
|
||||
remote_runtime_resource_factor: null,
|
||||
provider_tokens_set: {},
|
||||
enable_default_condenser: false,
|
||||
condenser_max_size: null,
|
||||
enable_sound_notifications: false,
|
||||
enable_proactive_conversation_starters: false,
|
||||
enable_solvability_analysis: false,
|
||||
user_consents_to_analytics: null,
|
||||
max_budget_per_task: null,
|
||||
});
|
||||
|
||||
// Act
|
||||
renderWithProviders(<SkillsModal onClose={vi.fn()} />);
|
||||
|
||||
// Assert
|
||||
const agentName = await screen.findByText("V0 Test Agent");
|
||||
expect(agentName).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("V1 API Usage (v1_enabled: true)", () => {
|
||||
it("should call v1 V1ConversationService.getSkills when v1_enabled is true", async () => {
|
||||
// Arrange
|
||||
const getSkillsSpy = vi
|
||||
.spyOn(V1ConversationService, "getSkills")
|
||||
.mockResolvedValue({ skills: mockSkills });
|
||||
|
||||
vi.spyOn(SettingsService, "getSettings").mockResolvedValue({
|
||||
v1_enabled: true,
|
||||
llm_model: "test-model",
|
||||
llm_base_url: "",
|
||||
agent: "test-agent",
|
||||
language: "en",
|
||||
llm_api_key: null,
|
||||
llm_api_key_set: false,
|
||||
search_api_key_set: false,
|
||||
confirmation_mode: false,
|
||||
security_analyzer: null,
|
||||
remote_runtime_resource_factor: null,
|
||||
provider_tokens_set: {},
|
||||
enable_default_condenser: false,
|
||||
condenser_max_size: null,
|
||||
enable_sound_notifications: false,
|
||||
enable_proactive_conversation_starters: false,
|
||||
enable_solvability_analysis: false,
|
||||
user_consents_to_analytics: null,
|
||||
max_budget_per_task: null,
|
||||
});
|
||||
|
||||
// Act
|
||||
renderWithProviders(<SkillsModal onClose={vi.fn()} />);
|
||||
|
||||
// Assert
|
||||
await screen.findByText("V1 Test Skill");
|
||||
expect(getSkillsSpy).toHaveBeenCalledWith(conversationId);
|
||||
expect(getSkillsSpy).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it("should display v1 skills correctly", async () => {
|
||||
// Arrange
|
||||
vi.spyOn(V1ConversationService, "getSkills").mockResolvedValue({
|
||||
skills: mockSkills,
|
||||
});
|
||||
|
||||
vi.spyOn(SettingsService, "getSettings").mockResolvedValue({
|
||||
v1_enabled: true,
|
||||
llm_model: "test-model",
|
||||
llm_base_url: "",
|
||||
agent: "test-agent",
|
||||
language: "en",
|
||||
llm_api_key: null,
|
||||
llm_api_key_set: false,
|
||||
search_api_key_set: false,
|
||||
confirmation_mode: false,
|
||||
security_analyzer: null,
|
||||
remote_runtime_resource_factor: null,
|
||||
provider_tokens_set: {},
|
||||
enable_default_condenser: false,
|
||||
condenser_max_size: null,
|
||||
enable_sound_notifications: false,
|
||||
enable_proactive_conversation_starters: false,
|
||||
enable_solvability_analysis: false,
|
||||
user_consents_to_analytics: null,
|
||||
max_budget_per_task: null,
|
||||
});
|
||||
|
||||
// Act
|
||||
renderWithProviders(<SkillsModal onClose={vi.fn()} />);
|
||||
|
||||
// Assert
|
||||
const skillName = await screen.findByText("V1 Test Skill");
|
||||
expect(skillName).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should use v1 API when v1_enabled is true", async () => {
|
||||
// Arrange
|
||||
vi.spyOn(SettingsService, "getSettings").mockResolvedValue({
|
||||
v1_enabled: true,
|
||||
llm_model: "test-model",
|
||||
llm_base_url: "",
|
||||
agent: "test-agent",
|
||||
language: "en",
|
||||
llm_api_key: null,
|
||||
llm_api_key_set: false,
|
||||
search_api_key_set: false,
|
||||
confirmation_mode: false,
|
||||
security_analyzer: null,
|
||||
remote_runtime_resource_factor: null,
|
||||
provider_tokens_set: {},
|
||||
enable_default_condenser: false,
|
||||
condenser_max_size: null,
|
||||
enable_sound_notifications: false,
|
||||
enable_proactive_conversation_starters: false,
|
||||
enable_solvability_analysis: false,
|
||||
user_consents_to_analytics: null,
|
||||
max_budget_per_task: null,
|
||||
});
|
||||
|
||||
const getSkillsSpy = vi
|
||||
.spyOn(V1ConversationService, "getSkills")
|
||||
.mockResolvedValue({
|
||||
skills: mockSkills,
|
||||
});
|
||||
|
||||
// Act
|
||||
renderWithProviders(<SkillsModal onClose={vi.fn()} />);
|
||||
|
||||
// Assert
|
||||
await screen.findByText("V1 Test Skill");
|
||||
// Verify v1 API was called
|
||||
expect(getSkillsSpy).toHaveBeenCalledWith(conversationId);
|
||||
});
|
||||
});
|
||||
|
||||
describe("API Switching on Settings Change", () => {
|
||||
it("should refetch using different API when v1_enabled setting changes", async () => {
|
||||
// Arrange
|
||||
const getMicroagentsSpy = vi
|
||||
.spyOn(ConversationService, "getMicroagents")
|
||||
.mockResolvedValue({ microagents: mockMicroagents });
|
||||
const getSkillsSpy = vi
|
||||
.spyOn(V1ConversationService, "getSkills")
|
||||
.mockResolvedValue({ skills: mockSkills });
|
||||
|
||||
const settingsSpy = vi
|
||||
.spyOn(SettingsService, "getSettings")
|
||||
.mockResolvedValue({
|
||||
v1_enabled: false,
|
||||
llm_model: "test-model",
|
||||
llm_base_url: "",
|
||||
agent: "test-agent",
|
||||
language: "en",
|
||||
llm_api_key: null,
|
||||
llm_api_key_set: false,
|
||||
search_api_key_set: false,
|
||||
confirmation_mode: false,
|
||||
security_analyzer: null,
|
||||
remote_runtime_resource_factor: null,
|
||||
provider_tokens_set: {},
|
||||
enable_default_condenser: false,
|
||||
condenser_max_size: null,
|
||||
enable_sound_notifications: false,
|
||||
enable_proactive_conversation_starters: false,
|
||||
enable_solvability_analysis: false,
|
||||
user_consents_to_analytics: null,
|
||||
max_budget_per_task: null,
|
||||
});
|
||||
|
||||
// Act - Initial render with v1_enabled: false
|
||||
const { rerender } = renderWithProviders(
|
||||
<SkillsModal onClose={vi.fn()} />,
|
||||
);
|
||||
|
||||
// Assert - v0 API called initially
|
||||
await screen.findByText("V0 Test Agent");
|
||||
expect(getMicroagentsSpy).toHaveBeenCalledWith(conversationId);
|
||||
|
||||
// Arrange - Change settings to v1_enabled: true
|
||||
settingsSpy.mockResolvedValue({
|
||||
v1_enabled: true,
|
||||
llm_model: "test-model",
|
||||
llm_base_url: "",
|
||||
agent: "test-agent",
|
||||
language: "en",
|
||||
llm_api_key: null,
|
||||
llm_api_key_set: false,
|
||||
search_api_key_set: false,
|
||||
confirmation_mode: false,
|
||||
security_analyzer: null,
|
||||
remote_runtime_resource_factor: null,
|
||||
provider_tokens_set: {},
|
||||
enable_default_condenser: false,
|
||||
condenser_max_size: null,
|
||||
enable_sound_notifications: false,
|
||||
enable_proactive_conversation_starters: false,
|
||||
enable_solvability_analysis: false,
|
||||
user_consents_to_analytics: null,
|
||||
max_budget_per_task: null,
|
||||
});
|
||||
|
||||
// Act - Force re-render
|
||||
rerender(<SkillsModal onClose={vi.fn()} />);
|
||||
|
||||
// Assert - v1 API should be called after settings change
|
||||
await screen.findByText("V1 Test Skill");
|
||||
expect(getSkillsSpy).toHaveBeenCalledWith(conversationId);
|
||||
});
|
||||
});
|
||||
});
|
||||
@ -16,7 +16,7 @@ describe("SettingsForm", () => {
|
||||
Component: () => (
|
||||
<SettingsForm
|
||||
settings={DEFAULT_SETTINGS}
|
||||
models={[DEFAULT_SETTINGS.LLM_MODEL]}
|
||||
models={[DEFAULT_SETTINGS.llm_model]}
|
||||
onClose={onCloseMock}
|
||||
/>
|
||||
),
|
||||
@ -33,7 +33,7 @@ describe("SettingsForm", () => {
|
||||
|
||||
expect(saveSettingsSpy).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
llm_model: DEFAULT_SETTINGS.LLM_MODEL,
|
||||
llm_model: DEFAULT_SETTINGS.llm_model,
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import { describe, it, expect } from "vitest";
|
||||
import { getObservationContent } from "../get-observation-content";
|
||||
import { getObservationContent } from "#/components/v1/chat/event-content-helpers/get-observation-content";
|
||||
import { ObservationEvent } from "#/types/v1/core";
|
||||
import { BrowserObservation } from "#/types/v1/core/base/observation";
|
||||
|
||||
53
frontend/__tests__/hooks/use-settings-nav-items.test.tsx
Normal file
53
frontend/__tests__/hooks/use-settings-nav-items.test.tsx
Normal file
@ -0,0 +1,53 @@
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import { renderHook, waitFor } from "@testing-library/react";
|
||||
import { describe, it, expect, vi, beforeEach } from "vitest";
|
||||
import { SAAS_NAV_ITEMS, OSS_NAV_ITEMS } from "#/constants/settings-nav";
|
||||
import OptionService from "#/api/option-service/option-service.api";
|
||||
import { useSettingsNavItems } from "#/hooks/use-settings-nav-items";
|
||||
|
||||
const queryClient = new QueryClient();
|
||||
const wrapper = ({ children }: { children: React.ReactNode }) => (
|
||||
<QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
|
||||
);
|
||||
|
||||
const mockConfig = (appMode: "saas" | "oss", hideLlmSettings = false) => {
|
||||
vi.spyOn(OptionService, "getConfig").mockResolvedValue({
|
||||
APP_MODE: appMode,
|
||||
FEATURE_FLAGS: { HIDE_LLM_SETTINGS: hideLlmSettings },
|
||||
} as Awaited<ReturnType<typeof OptionService.getConfig>>);
|
||||
};
|
||||
|
||||
describe("useSettingsNavItems", () => {
|
||||
beforeEach(() => {
|
||||
queryClient.clear();
|
||||
});
|
||||
|
||||
it("should return SAAS_NAV_ITEMS when APP_MODE is 'saas'", async () => {
|
||||
mockConfig("saas");
|
||||
const { result } = renderHook(() => useSettingsNavItems(), { wrapper });
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current).toEqual(SAAS_NAV_ITEMS);
|
||||
});
|
||||
});
|
||||
|
||||
it("should return OSS_NAV_ITEMS when APP_MODE is 'oss'", async () => {
|
||||
mockConfig("oss");
|
||||
const { result } = renderHook(() => useSettingsNavItems(), { wrapper });
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current).toEqual(OSS_NAV_ITEMS);
|
||||
});
|
||||
});
|
||||
|
||||
it("should filter out '/settings' item when HIDE_LLM_SETTINGS feature flag is enabled", async () => {
|
||||
mockConfig("saas", true);
|
||||
const { result } = renderHook(() => useSettingsNavItems(), { wrapper });
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
result.current.find((item) => item.to === "/settings"),
|
||||
).toBeUndefined();
|
||||
});
|
||||
});
|
||||
});
|
||||
@ -1,3 +1,11 @@
|
||||
/**
|
||||
* TODO: Fix flaky WebSocket tests (https://github.com/OpenHands/OpenHands/issues/11944)
|
||||
*
|
||||
* Several tests in this file are skipped because they fail intermittently in CI
|
||||
* but pass locally. The SUSPECTED root cause is that `wsLink.broadcast()` sends messages
|
||||
* to ALL connected clients across all tests, causing cross-test contamination
|
||||
* when tests run in parallel with Vitest v4.
|
||||
*/
|
||||
import { renderHook, waitFor } from "@testing-library/react";
|
||||
import {
|
||||
describe,
|
||||
@ -52,7 +60,7 @@ describe("useWebSocket", () => {
|
||||
expect(result.current.socket).toBeTruthy();
|
||||
});
|
||||
|
||||
it("should handle incoming messages correctly", async () => {
|
||||
it.skip("should handle incoming messages correctly", async () => {
|
||||
const { result } = renderHook(() => useWebSocket("ws://acme.com/ws"));
|
||||
|
||||
// Wait for connection to be established
|
||||
@ -115,7 +123,7 @@ describe("useWebSocket", () => {
|
||||
expect(result.current.socket).toBeTruthy();
|
||||
});
|
||||
|
||||
it("should close the WebSocket connection on unmount", async () => {
|
||||
it.skip("should close the WebSocket connection on unmount", async () => {
|
||||
const { result, unmount } = renderHook(() =>
|
||||
useWebSocket("ws://acme.com/ws"),
|
||||
);
|
||||
@ -205,7 +213,7 @@ describe("useWebSocket", () => {
|
||||
});
|
||||
});
|
||||
|
||||
it("should call onMessage handler when WebSocket receives a message", async () => {
|
||||
it.skip("should call onMessage handler when WebSocket receives a message", async () => {
|
||||
const onMessageSpy = vi.fn();
|
||||
const options = { onMessage: onMessageSpy };
|
||||
|
||||
@ -279,7 +287,7 @@ describe("useWebSocket", () => {
|
||||
expect(onErrorSpy).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("should provide sendMessage function to send messages to WebSocket", async () => {
|
||||
it.skip("should provide sendMessage function to send messages to WebSocket", async () => {
|
||||
const { result } = renderHook(() => useWebSocket("ws://acme.com/ws"));
|
||||
|
||||
// Wait for connection to be established
|
||||
|
||||
@ -253,6 +253,83 @@ describe("Content", () => {
|
||||
expect(securityAnalyzer).toHaveValue("SETTINGS$SECURITY_ANALYZER_NONE");
|
||||
});
|
||||
});
|
||||
|
||||
it("should omit invariant and custom analyzers when V1 is enabled", async () => {
|
||||
const getSettingsSpy = vi.spyOn(SettingsService, "getSettings");
|
||||
getSettingsSpy.mockResolvedValue({
|
||||
...MOCK_DEFAULT_USER_SETTINGS,
|
||||
confirmation_mode: true,
|
||||
security_analyzer: "llm",
|
||||
v1_enabled: true,
|
||||
});
|
||||
|
||||
const getSecurityAnalyzersSpy = vi.spyOn(
|
||||
OptionService,
|
||||
"getSecurityAnalyzers",
|
||||
);
|
||||
getSecurityAnalyzersSpy.mockResolvedValue([
|
||||
"llm",
|
||||
"none",
|
||||
"invariant",
|
||||
"custom",
|
||||
]);
|
||||
|
||||
renderLlmSettingsScreen();
|
||||
await screen.findByTestId("llm-settings-screen");
|
||||
|
||||
const advancedSwitch = screen.getByTestId("advanced-settings-switch");
|
||||
await userEvent.click(advancedSwitch);
|
||||
|
||||
const securityAnalyzer = await screen.findByTestId(
|
||||
"security-analyzer-input",
|
||||
);
|
||||
await userEvent.click(securityAnalyzer);
|
||||
|
||||
// Only llm + none should be available when V1 is enabled
|
||||
screen.getByText("SETTINGS$SECURITY_ANALYZER_LLM_DEFAULT");
|
||||
screen.getByText("SETTINGS$SECURITY_ANALYZER_NONE");
|
||||
expect(
|
||||
screen.queryByText("SETTINGS$SECURITY_ANALYZER_INVARIANT"),
|
||||
).not.toBeInTheDocument();
|
||||
expect(screen.queryByText("custom")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should include invariant analyzer option when V1 is disabled", async () => {
|
||||
const getSettingsSpy = vi.spyOn(SettingsService, "getSettings");
|
||||
getSettingsSpy.mockResolvedValue({
|
||||
...MOCK_DEFAULT_USER_SETTINGS,
|
||||
confirmation_mode: true,
|
||||
security_analyzer: "llm",
|
||||
v1_enabled: false,
|
||||
});
|
||||
|
||||
const getSecurityAnalyzersSpy = vi.spyOn(
|
||||
OptionService,
|
||||
"getSecurityAnalyzers",
|
||||
);
|
||||
getSecurityAnalyzersSpy.mockResolvedValue(["llm", "none", "invariant"]);
|
||||
|
||||
renderLlmSettingsScreen();
|
||||
await screen.findByTestId("llm-settings-screen");
|
||||
|
||||
const advancedSwitch = screen.getByTestId("advanced-settings-switch");
|
||||
await userEvent.click(advancedSwitch);
|
||||
|
||||
const securityAnalyzer = await screen.findByTestId(
|
||||
"security-analyzer-input",
|
||||
);
|
||||
await userEvent.click(securityAnalyzer);
|
||||
|
||||
expect(
|
||||
screen.getByText("SETTINGS$SECURITY_ANALYZER_LLM_DEFAULT"),
|
||||
).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText("SETTINGS$SECURITY_ANALYZER_NONE"),
|
||||
).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText("SETTINGS$SECURITY_ANALYZER_INVARIANT"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it.todo("should render an indicator if the llm api key is set");
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
|
||||
import { handleStatusMessage } from "../actions";
|
||||
import { handleStatusMessage } from "#/services/actions";
|
||||
import { StatusMessage } from "#/types/message";
|
||||
import { queryClient } from "#/query-client-config";
|
||||
import { useStatusStore } from "#/state/status-store";
|
||||
@ -1,8 +1,8 @@
|
||||
import { describe, it, expect, vi, beforeEach } from "vitest";
|
||||
import ActionType from "#/types/action-type";
|
||||
import { ActionMessage } from "#/types/message";
|
||||
import { useCommandStore } from "#/state/command-store";
|
||||
|
||||
// Mock the store and actions
|
||||
const mockDispatch = vi.fn();
|
||||
const mockAppendInput = vi.fn();
|
||||
|
||||
@ -12,26 +12,12 @@ vi.mock("#/store", () => ({
|
||||
},
|
||||
}));
|
||||
|
||||
vi.mock("#/state/command-store", () => ({
|
||||
useCommandStore: {
|
||||
getState: () => ({
|
||||
appendInput: mockAppendInput,
|
||||
}),
|
||||
},
|
||||
}));
|
||||
|
||||
vi.mock("#/state/metrics-slice", () => ({
|
||||
setMetrics: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("#/state/security-analyzer-slice", () => ({
|
||||
appendSecurityAnalyzerInput: vi.fn(),
|
||||
}));
|
||||
|
||||
describe("handleActionMessage", () => {
|
||||
beforeEach(() => {
|
||||
// Clear all mocks before each test
|
||||
vi.clearAllMocks();
|
||||
useCommandStore.setState({
|
||||
appendInput: mockAppendInput,
|
||||
});
|
||||
});
|
||||
|
||||
it("should handle RUN actions by adding input to terminal", async () => {
|
||||
|
||||
@ -3,7 +3,7 @@ import toast from "react-hot-toast";
|
||||
import {
|
||||
displaySuccessToast,
|
||||
displayErrorToast,
|
||||
} from "../custom-toast-handlers";
|
||||
} from "#/utils/custom-toast-handlers";
|
||||
|
||||
// Mock react-hot-toast
|
||||
vi.mock("react-hot-toast", () => ({
|
||||
@ -12,20 +12,20 @@ describe("hasAdvancedSettingsSet", () => {
|
||||
});
|
||||
|
||||
describe("should be true if", () => {
|
||||
test("LLM_BASE_URL is set", () => {
|
||||
test("llm_base_url is set", () => {
|
||||
expect(
|
||||
hasAdvancedSettingsSet({
|
||||
...DEFAULT_SETTINGS,
|
||||
LLM_BASE_URL: "test",
|
||||
llm_base_url: "test",
|
||||
}),
|
||||
).toBe(true);
|
||||
});
|
||||
|
||||
test("AGENT is not default value", () => {
|
||||
test("agent is not default value", () => {
|
||||
expect(
|
||||
hasAdvancedSettingsSet({
|
||||
...DEFAULT_SETTINGS,
|
||||
AGENT: "test",
|
||||
agent: "test",
|
||||
}),
|
||||
).toBe(true);
|
||||
});
|
||||
|
||||
@ -13,7 +13,7 @@ describe("Model name case preservation", () => {
|
||||
const settings = extractSettings(formData);
|
||||
|
||||
// Test that model names maintain their original casing
|
||||
expect(settings.LLM_MODEL).toBe("SambaNova/Meta-Llama-3.1-8B-Instruct");
|
||||
expect(settings.llm_model).toBe("SambaNova/Meta-Llama-3.1-8B-Instruct");
|
||||
});
|
||||
|
||||
it("should preserve openai model case", () => {
|
||||
@ -24,7 +24,7 @@ describe("Model name case preservation", () => {
|
||||
formData.set("language", "en");
|
||||
|
||||
const settings = extractSettings(formData);
|
||||
expect(settings.LLM_MODEL).toBe("openai/gpt-4o");
|
||||
expect(settings.llm_model).toBe("openai/gpt-4o");
|
||||
});
|
||||
|
||||
it("should preserve anthropic model case", () => {
|
||||
@ -35,7 +35,7 @@ describe("Model name case preservation", () => {
|
||||
formData.set("language", "en");
|
||||
|
||||
const settings = extractSettings(formData);
|
||||
expect(settings.LLM_MODEL).toBe("anthropic/claude-sonnet-4-20250514");
|
||||
expect(settings.llm_model).toBe("anthropic/claude-sonnet-4-20250514");
|
||||
});
|
||||
|
||||
it("should not automatically lowercase model names", () => {
|
||||
@ -48,7 +48,7 @@ describe("Model name case preservation", () => {
|
||||
const settings = extractSettings(formData);
|
||||
|
||||
// Test that camelCase and PascalCase are preserved
|
||||
expect(settings.LLM_MODEL).not.toBe("sambanova/meta-llama-3.1-8b-instruct");
|
||||
expect(settings.LLM_MODEL).toBe("SambaNova/Meta-Llama-3.1-8B-Instruct");
|
||||
expect(settings.llm_model).not.toBe("sambanova/meta-llama-3.1-8b-instruct");
|
||||
expect(settings.llm_model).toBe("SambaNova/Meta-Llama-3.1-8B-Instruct");
|
||||
});
|
||||
});
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import { describe, it, expect } from "vitest";
|
||||
import { parseMaxBudgetPerTask, extractSettings } from "../settings-utils";
|
||||
import { parseMaxBudgetPerTask, extractSettings } from "#/utils/settings-utils";
|
||||
|
||||
describe("parseMaxBudgetPerTask", () => {
|
||||
it("should return null for empty string", () => {
|
||||
@ -67,10 +67,10 @@ describe("extractSettings", () => {
|
||||
|
||||
// Verify that the model name case is preserved
|
||||
const expectedModel = `${provider}/${model}`;
|
||||
expect(settings.LLM_MODEL).toBe(expectedModel);
|
||||
expect(settings.llm_model).toBe(expectedModel);
|
||||
// Only test that it's not lowercased if the original has uppercase letters
|
||||
if (expectedModel !== expectedModel.toLowerCase()) {
|
||||
expect(settings.LLM_MODEL).not.toBe(expectedModel.toLowerCase());
|
||||
expect(settings.llm_model).not.toBe(expectedModel.toLowerCase());
|
||||
}
|
||||
});
|
||||
});
|
||||
@ -85,7 +85,7 @@ describe("extractSettings", () => {
|
||||
const settings = extractSettings(formData);
|
||||
|
||||
// Custom model should take precedence and preserve case
|
||||
expect(settings.LLM_MODEL).toBe("Custom-Model-Name");
|
||||
expect(settings.LLM_MODEL).not.toBe("custom-model-name");
|
||||
expect(settings.llm_model).toBe("Custom-Model-Name");
|
||||
expect(settings.llm_model).not.toBe("custom-model-name");
|
||||
});
|
||||
});
|
||||
@ -1,5 +1,5 @@
|
||||
import { describe, it, expect } from "vitest";
|
||||
import { calculateToastDuration } from "../toast-duration";
|
||||
import { calculateToastDuration } from "#/utils/toast-duration";
|
||||
|
||||
describe("calculateToastDuration", () => {
|
||||
it("should return minimum duration for short messages", () => {
|
||||
@ -1,5 +1,5 @@
|
||||
import { describe, it, expect, beforeEach, afterEach } from "vitest";
|
||||
import { transformVSCodeUrl } from "../vscode-url-helper";
|
||||
import { transformVSCodeUrl } from "#/utils/vscode-url-helper";
|
||||
|
||||
describe("transformVSCodeUrl", () => {
|
||||
const originalWindowLocation = window.location;
|
||||
2049
frontend/package-lock.json
generated
2049
frontend/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@ -1,49 +1,39 @@
|
||||
{
|
||||
"name": "openhands-frontend",
|
||||
"version": "0.62.0",
|
||||
"version": "1.0.0",
|
||||
"private": true,
|
||||
"type": "module",
|
||||
"engines": {
|
||||
"node": ">=22.0.0"
|
||||
},
|
||||
"dependencies": {
|
||||
"@heroui/react": "2.8.5",
|
||||
"@heroui/use-infinite-scroll": "^2.2.12",
|
||||
"@heroui/react": "2.8.6",
|
||||
"@microlink/react-json-view": "^1.26.2",
|
||||
"@monaco-editor/react": "^4.7.0-rc.0",
|
||||
"@posthog/react": "^1.5.2",
|
||||
"@react-router/node": "^7.10.1",
|
||||
"@react-router/serve": "^7.10.1",
|
||||
"@react-types/shared": "^3.32.0",
|
||||
"@stripe/react-stripe-js": "^5.4.1",
|
||||
"@stripe/stripe-js": "^8.5.3",
|
||||
"@tailwindcss/postcss": "^4.1.17",
|
||||
"@tailwindcss/vite": "^4.1.17",
|
||||
"@tailwindcss/vite": "^4.1.18",
|
||||
"@tanstack/react-query": "^5.90.12",
|
||||
"@uidotdev/usehooks": "^2.4.1",
|
||||
"@vitejs/plugin-react": "^5.1.2",
|
||||
"@xterm/addon-fit": "^0.10.0",
|
||||
"@xterm/xterm": "^5.4.0",
|
||||
"axios": "^1.13.2",
|
||||
"class-variance-authority": "^0.7.1",
|
||||
"clsx": "^2.1.1",
|
||||
"date-fns": "^4.1.0",
|
||||
"downshift": "^9.0.12",
|
||||
"downshift": "^9.0.13",
|
||||
"eslint-config-airbnb-typescript": "^18.0.0",
|
||||
"framer-motion": "^12.23.25",
|
||||
"i18next": "^25.7.2",
|
||||
"i18next": "^25.7.3",
|
||||
"i18next-browser-languagedetector": "^8.2.0",
|
||||
"i18next-http-backend": "^3.0.2",
|
||||
"isbot": "^5.1.32",
|
||||
"jose": "^6.1.3",
|
||||
"lucide-react": "^0.556.0",
|
||||
"lucide-react": "^0.561.0",
|
||||
"monaco-editor": "^0.55.1",
|
||||
"posthog-js": "^1.302.2",
|
||||
"react": "^19.2.0",
|
||||
"react-dom": "^19.2.0",
|
||||
"react-highlight": "^0.15.0",
|
||||
"posthog-js": "^1.309.0",
|
||||
"react": "^19.2.3",
|
||||
"react-dom": "^19.2.3",
|
||||
"react-hot-toast": "^2.6.0",
|
||||
"react-i18next": "^16.4.0",
|
||||
"react-i18next": "^16.5.0",
|
||||
"react-icons": "^5.5.0",
|
||||
"react-markdown": "^10.1.0",
|
||||
"react-router": "^7.10.1",
|
||||
@ -54,9 +44,7 @@
|
||||
"socket.io-client": "^4.8.1",
|
||||
"tailwind-merge": "^3.4.0",
|
||||
"tailwind-scrollbar": "^4.0.2",
|
||||
"vite": "^7.2.7",
|
||||
"web-vitals": "^5.1.0",
|
||||
"ws": "^8.18.2",
|
||||
"vite": "^7.3.0",
|
||||
"zustand": "^5.0.9"
|
||||
},
|
||||
"scripts": {
|
||||
@ -92,9 +80,6 @@
|
||||
]
|
||||
},
|
||||
"devDependencies": {
|
||||
"@babel/parser": "^7.28.3",
|
||||
"@babel/traverse": "^7.28.3",
|
||||
"@babel/types": "^7.28.2",
|
||||
"@mswjs/socket.io-binding": "^0.2.0",
|
||||
"@playwright/test": "^1.57.0",
|
||||
"@react-router/dev": "^7.10.1",
|
||||
@ -102,18 +87,15 @@
|
||||
"@tanstack/eslint-plugin-query": "^5.91.0",
|
||||
"@testing-library/dom": "^10.4.1",
|
||||
"@testing-library/jest-dom": "^6.9.1",
|
||||
"@testing-library/react": "^16.3.0",
|
||||
"@testing-library/react": "^16.3.1",
|
||||
"@testing-library/user-event": "^14.6.1",
|
||||
"@types/node": "^24.10.1",
|
||||
"@types/node": "^25.0.3",
|
||||
"@types/react": "^19.2.7",
|
||||
"@types/react-dom": "^19.2.3",
|
||||
"@types/react-highlight": "^0.12.8",
|
||||
"@types/react-syntax-highlighter": "^15.5.13",
|
||||
"@types/ws": "^8.18.1",
|
||||
"@typescript-eslint/eslint-plugin": "^7.18.0",
|
||||
"@typescript-eslint/parser": "^7.18.0",
|
||||
"@vitest/coverage-v8": "^4.0.14",
|
||||
"autoprefixer": "^10.4.22",
|
||||
"@vitest/coverage-v8": "^4.0.16",
|
||||
"cross-env": "^10.1.0",
|
||||
"eslint": "^8.57.0",
|
||||
"eslint-config-airbnb": "^19.0.4",
|
||||
@ -131,11 +113,10 @@
|
||||
"lint-staged": "^16.2.7",
|
||||
"msw": "^2.6.6",
|
||||
"prettier": "^3.7.3",
|
||||
"stripe": "^20.0.0",
|
||||
"tailwindcss": "^4.1.8",
|
||||
"typescript": "^5.9.3",
|
||||
"vite-plugin-svgr": "^4.5.0",
|
||||
"vite-tsconfig-paths": "^5.1.4",
|
||||
"vite-tsconfig-paths": "^6.0.2",
|
||||
"vitest": "^4.0.14"
|
||||
},
|
||||
"packageManager": "npm@10.5.0",
|
||||
|
||||
@ -11,6 +11,7 @@ import type {
|
||||
V1AppConversationStartTask,
|
||||
V1AppConversationStartTaskPage,
|
||||
V1AppConversation,
|
||||
GetSkillsResponse,
|
||||
} from "./v1-conversation-service.types";
|
||||
|
||||
class V1ConversationService {
|
||||
@ -315,6 +316,18 @@ class V1ConversationService {
|
||||
);
|
||||
return data;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all skills associated with a V1 conversation
|
||||
* @param conversationId The conversation ID
|
||||
* @returns The available skills associated with the conversation
|
||||
*/
|
||||
static async getSkills(conversationId: string): Promise<GetSkillsResponse> {
|
||||
const { data } = await openHands.get<GetSkillsResponse>(
|
||||
`/api/v1/app-conversations/${conversationId}/skills`,
|
||||
);
|
||||
return data;
|
||||
}
|
||||
}
|
||||
|
||||
export default V1ConversationService;
|
||||
|
||||
@ -99,3 +99,14 @@ export interface V1AppConversation {
|
||||
conversation_url: string | null;
|
||||
session_api_key: string | null;
|
||||
}
|
||||
|
||||
export interface Skill {
|
||||
name: string;
|
||||
type: "repo" | "knowledge";
|
||||
content: string;
|
||||
triggers: string[];
|
||||
}
|
||||
|
||||
export interface GetSkillsResponse {
|
||||
skills: Skill[];
|
||||
}
|
||||
|
||||
@ -5,7 +5,6 @@ import type {
|
||||
ConfirmationResponseRequest,
|
||||
ConfirmationResponseResponse,
|
||||
} from "./event-service.types";
|
||||
import { openHands } from "../open-hands-axios";
|
||||
|
||||
class EventService {
|
||||
/**
|
||||
@ -38,11 +37,27 @@ class EventService {
|
||||
return data;
|
||||
}
|
||||
|
||||
static async getEventCount(conversationId: string): Promise<number> {
|
||||
const params = new URLSearchParams();
|
||||
params.append("conversation_id__eq", conversationId);
|
||||
const { data } = await openHands.get<number>(
|
||||
`/api/v1/events/count?${params.toString()}`,
|
||||
/**
|
||||
* Get event count for a V1 conversation
|
||||
* @param conversationId The conversation ID
|
||||
* @param conversationUrl The conversation URL (e.g., "http://localhost:54928/api/conversations/...")
|
||||
* @param sessionApiKey Session API key for authentication (required for V1)
|
||||
* @returns The event count
|
||||
*/
|
||||
static async getEventCount(
|
||||
conversationId: string,
|
||||
conversationUrl: string,
|
||||
sessionApiKey?: string | null,
|
||||
): Promise<number> {
|
||||
// Build the runtime URL using the conversation URL
|
||||
const runtimeUrl = buildHttpBaseUrl(conversationUrl);
|
||||
|
||||
// Build session headers for authentication
|
||||
const headers = buildSessionHeaders(sessionApiKey);
|
||||
|
||||
const { data } = await axios.get<number>(
|
||||
`${runtimeUrl}/api/conversations/${conversationId}/events/count`,
|
||||
{ headers },
|
||||
);
|
||||
return data;
|
||||
}
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import { openHands } from "../open-hands-axios";
|
||||
import { ApiSettings, PostApiSettings } from "./settings.types";
|
||||
import { Settings } from "#/types/settings";
|
||||
|
||||
/**
|
||||
* Settings service for managing application settings
|
||||
@ -8,8 +8,8 @@ class SettingsService {
|
||||
/**
|
||||
* Get the settings from the server or use the default settings if not found
|
||||
*/
|
||||
static async getSettings(): Promise<ApiSettings> {
|
||||
const { data } = await openHands.get<ApiSettings>("/api/settings");
|
||||
static async getSettings(): Promise<Settings> {
|
||||
const { data } = await openHands.get<Settings>("/api/settings");
|
||||
return data;
|
||||
}
|
||||
|
||||
@ -17,9 +17,7 @@ class SettingsService {
|
||||
* Save the settings to the server. Only valid settings are saved.
|
||||
* @param settings - the settings to save
|
||||
*/
|
||||
static async saveSettings(
|
||||
settings: Partial<PostApiSettings>,
|
||||
): Promise<boolean> {
|
||||
static async saveSettings(settings: Partial<Settings>): Promise<boolean> {
|
||||
const data = await openHands.post("/api/settings", settings);
|
||||
return data.status === 200;
|
||||
}
|
||||
|
||||
@ -1,54 +0,0 @@
|
||||
import { Provider } from "#/types/settings";
|
||||
|
||||
export type ApiSettings = {
|
||||
llm_model: string;
|
||||
llm_base_url: string;
|
||||
agent: string;
|
||||
language: string;
|
||||
llm_api_key: string | null;
|
||||
llm_api_key_set: boolean;
|
||||
search_api_key_set: boolean;
|
||||
confirmation_mode: boolean;
|
||||
security_analyzer: string | null;
|
||||
remote_runtime_resource_factor: number | null;
|
||||
enable_default_condenser: boolean;
|
||||
// Max size for condenser in backend settings
|
||||
condenser_max_size: number | null;
|
||||
enable_sound_notifications: boolean;
|
||||
enable_proactive_conversation_starters: boolean;
|
||||
enable_solvability_analysis: boolean;
|
||||
user_consents_to_analytics: boolean | null;
|
||||
search_api_key?: string;
|
||||
provider_tokens_set: Partial<Record<Provider, string | null>>;
|
||||
max_budget_per_task: number | null;
|
||||
mcp_config?: {
|
||||
sse_servers: (string | { url: string; api_key?: string })[];
|
||||
stdio_servers: {
|
||||
name: string;
|
||||
command: string;
|
||||
args?: string[];
|
||||
env?: Record<string, string>;
|
||||
}[];
|
||||
shttp_servers: (string | { url: string; api_key?: string })[];
|
||||
};
|
||||
email?: string;
|
||||
email_verified?: boolean;
|
||||
git_user_name?: string;
|
||||
git_user_email?: string;
|
||||
v1_enabled?: boolean;
|
||||
};
|
||||
|
||||
export type PostApiSettings = ApiSettings & {
|
||||
user_consents_to_analytics: boolean | null;
|
||||
search_api_key?: string;
|
||||
mcp_config?: {
|
||||
sse_servers: (string | { url: string; api_key?: string })[];
|
||||
stdio_servers: {
|
||||
name: string;
|
||||
command: string;
|
||||
args?: string[];
|
||||
env?: Record<string, string>;
|
||||
}[];
|
||||
shttp_servers: (string | { url: string; api_key?: string })[];
|
||||
};
|
||||
};
|
||||
@ -12,10 +12,9 @@ export function BrowserPanel() {
|
||||
reset();
|
||||
}, [conversationId, reset]);
|
||||
|
||||
const imgSrc =
|
||||
screenshotSrc && screenshotSrc.startsWith("data:image/png;base64,")
|
||||
? screenshotSrc
|
||||
: `data:image/png;base64,${screenshotSrc || ""}`;
|
||||
const imgSrc = screenshotSrc?.startsWith("data:image/png;base64,")
|
||||
? screenshotSrc
|
||||
: `data:image/png;base64,${screenshotSrc ?? ""}`;
|
||||
|
||||
return (
|
||||
<div className="h-full w-full flex flex-col text-neutral-400">
|
||||
|
||||
@ -9,7 +9,7 @@ function ConfirmationModeEnabled() {
|
||||
|
||||
const { data: settings } = useSettings();
|
||||
|
||||
if (!settings?.CONFIRMATION_MODE) {
|
||||
if (!settings?.confirmation_mode) {
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
@ -140,7 +140,7 @@ const getTaskTrackingObservationContent = (
|
||||
content += "\n\n**Task List:** Empty";
|
||||
}
|
||||
|
||||
if (event.content && event.content.trim()) {
|
||||
if (event.content?.trim()) {
|
||||
content += `\n\n**Result:** ${event.content.trim()}`;
|
||||
}
|
||||
|
||||
|
||||
@ -6,7 +6,6 @@ import { I18nKey } from "#/i18n/declaration";
|
||||
import ArrowDown from "#/icons/angle-down-solid.svg?react";
|
||||
import ArrowUp from "#/icons/angle-up-solid.svg?react";
|
||||
import CheckCircle from "#/icons/check-circle-solid.svg?react";
|
||||
import XCircle from "#/icons/x-circle-solid.svg?react";
|
||||
import { OpenHandsAction } from "#/types/core/actions";
|
||||
import { OpenHandsObservation } from "#/types/core/observations";
|
||||
import { cn } from "#/utils/utils";
|
||||
@ -169,19 +168,12 @@ export function ExpandableMessage({
|
||||
)}
|
||||
</button>
|
||||
</span>
|
||||
{type === "action" && success !== undefined && (
|
||||
{type === "action" && success && (
|
||||
<span className="flex-shrink-0">
|
||||
{success ? (
|
||||
<CheckCircle
|
||||
data-testid="status-icon"
|
||||
className={cn(statusIconClasses, "fill-success")}
|
||||
/>
|
||||
) : (
|
||||
<XCircle
|
||||
data-testid="status-icon"
|
||||
className={cn(statusIconClasses, "fill-danger")}
|
||||
/>
|
||||
)}
|
||||
<CheckCircle
|
||||
data-testid="status-icon"
|
||||
className={cn(statusIconClasses, "fill-success")}
|
||||
/>
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@ -192,8 +192,7 @@ export const Messages: React.FC<MessagesProps> = React.memo(
|
||||
) => {
|
||||
const conversationInstructions = `Target file: ${target}\n\nDescription: ${query}\n\nTriggers: ${triggers.join(", ")}`;
|
||||
if (
|
||||
!conversation ||
|
||||
!conversation.selected_repository ||
|
||||
!conversation?.selected_repository ||
|
||||
!conversation.selected_branch ||
|
||||
!conversation.git_provider ||
|
||||
!selectedEventId
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import { FaClock } from "react-icons/fa";
|
||||
import CheckCircle from "#/icons/check-circle-solid.svg?react";
|
||||
import XCircle from "#/icons/x-circle-solid.svg?react";
|
||||
import { ObservationResultStatus } from "./event-content-helpers/get-observation-result";
|
||||
|
||||
interface SuccessIndicatorProps {
|
||||
@ -17,13 +16,6 @@ export function SuccessIndicator({ status }: SuccessIndicatorProps) {
|
||||
/>
|
||||
)}
|
||||
|
||||
{status === "error" && (
|
||||
<XCircle
|
||||
data-testid="status-icon"
|
||||
className="h-4 w-4 ml-2 inline fill-danger"
|
||||
/>
|
||||
)}
|
||||
|
||||
{status === "timeout" && (
|
||||
<FaClock
|
||||
data-testid="status-icon"
|
||||
|
||||
@ -24,7 +24,7 @@ export function TaskItem({ task }: TaskItemProps) {
|
||||
case "todo":
|
||||
return <CircleIcon className="w-4 h-4 text-[#ffffff]" />;
|
||||
case "in_progress":
|
||||
return <LoadingIcon className="w-4 h-4 text-[#ffffff]" />;
|
||||
return <LoadingIcon className="w-4 h-4 text-[#ffffff] animate-spin" />;
|
||||
case "done":
|
||||
return <CheckCircleIcon className="w-4 h-4 text-[#A3A3A3]" />;
|
||||
default:
|
||||
|
||||
@ -5,11 +5,10 @@ import { ContextMenu } from "#/ui/context-menu";
|
||||
import { ContextMenuListItem } from "./context-menu-list-item";
|
||||
import { Divider } from "#/ui/divider";
|
||||
import { useClickOutsideElement } from "#/hooks/use-click-outside-element";
|
||||
import { useConfig } from "#/hooks/query/use-config";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import LogOutIcon from "#/icons/log-out.svg?react";
|
||||
import DocumentIcon from "#/icons/document.svg?react";
|
||||
import { SAAS_NAV_ITEMS, OSS_NAV_ITEMS } from "#/constants/settings-nav";
|
||||
import { useSettingsNavItems } from "#/hooks/use-settings-nav-items";
|
||||
|
||||
interface AccountSettingsContextMenuProps {
|
||||
onLogout: () => void;
|
||||
@ -22,15 +21,8 @@ export function AccountSettingsContextMenu({
|
||||
}: AccountSettingsContextMenuProps) {
|
||||
const ref = useClickOutsideElement<HTMLUListElement>(onClose);
|
||||
const { t } = useTranslation();
|
||||
const { data: config } = useConfig();
|
||||
|
||||
const isSaas = config?.APP_MODE === "saas";
|
||||
|
||||
// Get navigation items and filter out LLM settings if the feature flag is enabled
|
||||
let items = isSaas ? SAAS_NAV_ITEMS : OSS_NAV_ITEMS;
|
||||
if (config?.FEATURE_FLAGS?.HIDE_LLM_SETTINGS) {
|
||||
items = items.filter((item) => item.to !== "/settings");
|
||||
}
|
||||
const items = useSettingsNavItems();
|
||||
|
||||
const navItems = items.map((item) => ({
|
||||
...item,
|
||||
@ -39,11 +31,7 @@ export function AccountSettingsContextMenu({
|
||||
height: 16,
|
||||
} as React.SVGProps<SVGSVGElement>),
|
||||
}));
|
||||
|
||||
const handleNavigationClick = () => {
|
||||
onClose();
|
||||
// The Link component will handle the actual navigation
|
||||
};
|
||||
const handleNavigationClick = () => onClose();
|
||||
|
||||
return (
|
||||
<ContextMenu
|
||||
@ -55,7 +43,7 @@ export function AccountSettingsContextMenu({
|
||||
{navItems.map(({ to, text, icon }) => (
|
||||
<Link key={to} to={to} className="text-decoration-none">
|
||||
<ContextMenuListItem
|
||||
onClick={() => handleNavigationClick()}
|
||||
onClick={handleNavigationClick}
|
||||
className="flex items-center gap-2 p-2 hover:bg-[#5C5D62] rounded h-[30px]"
|
||||
>
|
||||
{icon}
|
||||
|
||||
@ -59,13 +59,15 @@ export function AgentStatus({
|
||||
);
|
||||
|
||||
const shouldShownAgentLoading =
|
||||
isPausing ||
|
||||
curAgentState === AgentState.INIT ||
|
||||
curAgentState === AgentState.LOADING ||
|
||||
(webSocketStatus === "CONNECTING" && taskStatus !== "ERROR") ||
|
||||
isTaskPolling(taskStatus) ||
|
||||
isTaskPolling(subConversationTaskStatus);
|
||||
|
||||
// For UI rendering - includes pause state
|
||||
const isLoading = shouldShownAgentLoading || isPausing;
|
||||
|
||||
const shouldShownAgentError =
|
||||
curAgentState === AgentState.ERROR ||
|
||||
curAgentState === AgentState.RATE_LIMITED ||
|
||||
@ -93,25 +95,28 @@ export function AgentStatus({
|
||||
<div
|
||||
className={cn(
|
||||
"bg-[#525252] box-border content-stretch flex flex-row gap-[3px] items-center justify-center overflow-clip px-0.5 py-1 relative rounded-[100px] shrink-0 size-6 transition-all duration-200 active:scale-95",
|
||||
!shouldShownAgentLoading &&
|
||||
!isLoading &&
|
||||
(shouldShownAgentStop || shouldShownAgentResume) &&
|
||||
"hover:bg-[#737373] cursor-pointer",
|
||||
)}
|
||||
>
|
||||
{shouldShownAgentLoading && <AgentLoading />}
|
||||
{!shouldShownAgentLoading && shouldShownAgentStop && (
|
||||
{isLoading && <AgentLoading />}
|
||||
{!isLoading && shouldShownAgentStop && (
|
||||
<ChatStopButton handleStop={handleStop} />
|
||||
)}
|
||||
{!shouldShownAgentLoading && shouldShownAgentResume && (
|
||||
{!isLoading && shouldShownAgentResume && (
|
||||
<ChatResumeAgentButton
|
||||
onAgentResumed={handleResumeAgent}
|
||||
disabled={disabled}
|
||||
/>
|
||||
)}
|
||||
{!shouldShownAgentLoading && shouldShownAgentError && (
|
||||
<CircleErrorIcon className="w-4 h-4" />
|
||||
{!isLoading && shouldShownAgentError && (
|
||||
<CircleErrorIcon
|
||||
className="w-4 h-4"
|
||||
data-testid="circle-error-icon"
|
||||
/>
|
||||
)}
|
||||
{!shouldShownAgentLoading &&
|
||||
{!isLoading &&
|
||||
!shouldShownAgentStop &&
|
||||
!shouldShownAgentResume &&
|
||||
!shouldShownAgentError && <ClockIcon className="w-4 h-4" />}
|
||||
|
||||
@ -26,14 +26,14 @@ const contextMenuListItemClassName = cn(
|
||||
|
||||
interface ToolsContextMenuProps {
|
||||
onClose: () => void;
|
||||
onShowMicroagents: (event: React.MouseEvent<HTMLButtonElement>) => void;
|
||||
onShowSkills: (event: React.MouseEvent<HTMLButtonElement>) => void;
|
||||
onShowAgentTools: (event: React.MouseEvent<HTMLButtonElement>) => void;
|
||||
shouldShowAgentTools?: boolean;
|
||||
}
|
||||
|
||||
export function ToolsContextMenu({
|
||||
onClose,
|
||||
onShowMicroagents,
|
||||
onShowSkills,
|
||||
onShowAgentTools,
|
||||
shouldShowAgentTools = true,
|
||||
}: ToolsContextMenuProps) {
|
||||
@ -41,7 +41,6 @@ export function ToolsContextMenu({
|
||||
const { data: conversation } = useActiveConversation();
|
||||
const { providers } = useUserProviders();
|
||||
|
||||
// TODO: Hide microagent menu items for V1 conversations
|
||||
// This is a temporary measure and may be re-enabled in the future
|
||||
const isV1Conversation = conversation?.conversation_version === "V1";
|
||||
|
||||
@ -130,20 +129,17 @@ export function ToolsContextMenu({
|
||||
|
||||
{(!isV1Conversation || shouldShowAgentTools) && <Divider />}
|
||||
|
||||
{/* Show Available Microagents - Hidden for V1 conversations */}
|
||||
{!isV1Conversation && (
|
||||
<ContextMenuListItem
|
||||
testId="show-microagents-button"
|
||||
onClick={onShowMicroagents}
|
||||
className={contextMenuListItemClassName}
|
||||
>
|
||||
<ToolsContextMenuIconText
|
||||
icon={<RobotIcon width={16} height={16} />}
|
||||
text={t(I18nKey.CONVERSATION$SHOW_MICROAGENTS)}
|
||||
className={CONTEXT_MENU_ICON_TEXT_CLASSNAME}
|
||||
/>
|
||||
</ContextMenuListItem>
|
||||
)}
|
||||
<ContextMenuListItem
|
||||
testId="show-skills-button"
|
||||
onClick={onShowSkills}
|
||||
className={contextMenuListItemClassName}
|
||||
>
|
||||
<ToolsContextMenuIconText
|
||||
icon={<RobotIcon width={16} height={16} />}
|
||||
text={t(I18nKey.CONVERSATION$SHOW_SKILLS)}
|
||||
className={CONTEXT_MENU_ICON_TEXT_CLASSNAME}
|
||||
/>
|
||||
</ContextMenuListItem>
|
||||
|
||||
{/* Show Agent Tools and Metadata - Only show if system message is available */}
|
||||
{shouldShowAgentTools && (
|
||||
|
||||
@ -7,7 +7,7 @@ import { ToolsContextMenu } from "./tools-context-menu";
|
||||
import { useConversationNameContextMenu } from "#/hooks/use-conversation-name-context-menu";
|
||||
import { useActiveConversation } from "#/hooks/query/use-active-conversation";
|
||||
import { SystemMessageModal } from "../conversation-panel/system-message-modal";
|
||||
import { MicroagentsModal } from "../conversation-panel/microagents-modal";
|
||||
import { SkillsModal } from "../conversation-panel/skills-modal";
|
||||
|
||||
export function Tools() {
|
||||
const { t } = useTranslation();
|
||||
@ -17,11 +17,11 @@ export function Tools() {
|
||||
|
||||
const {
|
||||
handleShowAgentTools,
|
||||
handleShowMicroagents,
|
||||
handleShowSkills,
|
||||
systemModalVisible,
|
||||
setSystemModalVisible,
|
||||
microagentsModalVisible,
|
||||
setMicroagentsModalVisible,
|
||||
skillsModalVisible,
|
||||
setSkillsModalVisible,
|
||||
systemMessage,
|
||||
shouldShowAgentTools,
|
||||
} = useConversationNameContextMenu({
|
||||
@ -51,7 +51,7 @@ export function Tools() {
|
||||
{contextMenuOpen && (
|
||||
<ToolsContextMenu
|
||||
onClose={() => setContextMenuOpen(false)}
|
||||
onShowMicroagents={handleShowMicroagents}
|
||||
onShowSkills={handleShowSkills}
|
||||
onShowAgentTools={handleShowAgentTools}
|
||||
shouldShowAgentTools={shouldShowAgentTools}
|
||||
/>
|
||||
@ -64,9 +64,9 @@ export function Tools() {
|
||||
systemMessage={systemMessage ? systemMessage.args : null}
|
||||
/>
|
||||
|
||||
{/* Microagents Modal */}
|
||||
{microagentsModalVisible && (
|
||||
<MicroagentsModal onClose={() => setMicroagentsModalVisible(false)} />
|
||||
{/* Skills Modal */}
|
||||
{skillsModalVisible && (
|
||||
<SkillsModal onClose={() => setSkillsModalVisible(false)} />
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
|
||||
@ -1,147 +0,0 @@
|
||||
import {
|
||||
Trash,
|
||||
Power,
|
||||
Pencil,
|
||||
Download,
|
||||
Wallet,
|
||||
Wrench,
|
||||
Bot,
|
||||
} from "lucide-react";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { useClickOutsideElement } from "#/hooks/use-click-outside-element";
|
||||
import { cn } from "#/utils/utils";
|
||||
import { ContextMenu } from "#/ui/context-menu";
|
||||
import { ContextMenuListItem } from "../context-menu/context-menu-list-item";
|
||||
import { Divider } from "#/ui/divider";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { ContextMenuIconText } from "../context-menu/context-menu-icon-text";
|
||||
import { useActiveConversation } from "#/hooks/query/use-active-conversation";
|
||||
|
||||
interface ConversationCardContextMenuProps {
|
||||
onClose: () => void;
|
||||
onDelete?: (event: React.MouseEvent<HTMLButtonElement>) => void;
|
||||
onStop?: (event: React.MouseEvent<HTMLButtonElement>) => void;
|
||||
onEdit?: (event: React.MouseEvent<HTMLButtonElement>) => void;
|
||||
onDisplayCost?: (event: React.MouseEvent<HTMLButtonElement>) => void;
|
||||
onShowAgentTools?: (event: React.MouseEvent<HTMLButtonElement>) => void;
|
||||
onShowMicroagents?: (event: React.MouseEvent<HTMLButtonElement>) => void;
|
||||
onDownloadViaVSCode?: (event: React.MouseEvent<HTMLButtonElement>) => void;
|
||||
position?: "top" | "bottom";
|
||||
}
|
||||
|
||||
export function ConversationCardContextMenu({
|
||||
onClose,
|
||||
onDelete,
|
||||
onStop,
|
||||
onEdit,
|
||||
onDisplayCost,
|
||||
onShowAgentTools,
|
||||
onShowMicroagents,
|
||||
onDownloadViaVSCode,
|
||||
position = "bottom",
|
||||
}: ConversationCardContextMenuProps) {
|
||||
const { t } = useTranslation();
|
||||
const ref = useClickOutsideElement<HTMLUListElement>(onClose);
|
||||
const { data: conversation } = useActiveConversation();
|
||||
|
||||
// TODO: Hide microagent menu items for V1 conversations
|
||||
// This is a temporary measure and may be re-enabled in the future
|
||||
const isV1Conversation = conversation?.conversation_version === "V1";
|
||||
|
||||
const hasEdit = Boolean(onEdit);
|
||||
const hasDownload = Boolean(onDownloadViaVSCode);
|
||||
const hasTools = Boolean(onShowAgentTools || onShowMicroagents);
|
||||
const hasInfo = Boolean(onDisplayCost);
|
||||
const hasControl = Boolean(onStop || onDelete);
|
||||
|
||||
return (
|
||||
<ContextMenu
|
||||
ref={ref}
|
||||
testId="context-menu"
|
||||
className={cn(
|
||||
"right-0 absolute mt-3",
|
||||
position === "top" && "bottom-full",
|
||||
position === "bottom" && "top-full",
|
||||
)}
|
||||
>
|
||||
{onEdit && (
|
||||
<ContextMenuListItem testId="edit-button" onClick={onEdit}>
|
||||
<ContextMenuIconText
|
||||
icon={Pencil}
|
||||
text={t(I18nKey.BUTTON$EDIT_TITLE)}
|
||||
/>
|
||||
</ContextMenuListItem>
|
||||
)}
|
||||
|
||||
{hasEdit && (hasDownload || hasTools || hasInfo || hasControl) && (
|
||||
<Divider />
|
||||
)}
|
||||
|
||||
{onDownloadViaVSCode && (
|
||||
<ContextMenuListItem
|
||||
testId="download-vscode-button"
|
||||
onClick={onDownloadViaVSCode}
|
||||
>
|
||||
<ContextMenuIconText
|
||||
icon={Download}
|
||||
text={t(I18nKey.BUTTON$DOWNLOAD_VIA_VSCODE)}
|
||||
/>
|
||||
</ContextMenuListItem>
|
||||
)}
|
||||
|
||||
{hasDownload && (hasTools || hasInfo || hasControl) && <Divider />}
|
||||
|
||||
{onShowAgentTools && (
|
||||
<ContextMenuListItem
|
||||
testId="show-agent-tools-button"
|
||||
onClick={onShowAgentTools}
|
||||
>
|
||||
<ContextMenuIconText
|
||||
icon={Wrench}
|
||||
text={t(I18nKey.BUTTON$SHOW_AGENT_TOOLS_AND_METADATA)}
|
||||
/>
|
||||
</ContextMenuListItem>
|
||||
)}
|
||||
|
||||
{onShowMicroagents && !isV1Conversation && (
|
||||
<ContextMenuListItem
|
||||
testId="show-microagents-button"
|
||||
onClick={onShowMicroagents}
|
||||
>
|
||||
<ContextMenuIconText
|
||||
icon={Bot}
|
||||
text={t(I18nKey.CONVERSATION$SHOW_MICROAGENTS)}
|
||||
/>
|
||||
</ContextMenuListItem>
|
||||
)}
|
||||
|
||||
{hasTools && (hasInfo || hasControl) && <Divider />}
|
||||
|
||||
{onDisplayCost && (
|
||||
<ContextMenuListItem
|
||||
testId="display-cost-button"
|
||||
onClick={onDisplayCost}
|
||||
>
|
||||
<ContextMenuIconText
|
||||
icon={Wallet}
|
||||
text={t(I18nKey.BUTTON$DISPLAY_COST)}
|
||||
/>
|
||||
</ContextMenuListItem>
|
||||
)}
|
||||
|
||||
{hasInfo && hasControl && <Divider />}
|
||||
|
||||
{onStop && (
|
||||
<ContextMenuListItem testId="stop-button" onClick={onStop}>
|
||||
<ContextMenuIconText icon={Power} text={t(I18nKey.BUTTON$PAUSE)} />
|
||||
</ContextMenuListItem>
|
||||
)}
|
||||
|
||||
{onDelete && (
|
||||
<ContextMenuListItem testId="delete-button" onClick={onDelete}>
|
||||
<ContextMenuIconText icon={Trash} text={t(I18nKey.BUTTON$DELETE)} />
|
||||
</ContextMenuListItem>
|
||||
)}
|
||||
</ContextMenu>
|
||||
);
|
||||
}
|
||||
@ -22,7 +22,7 @@ interface ConversationCardContextMenuProps {
|
||||
onEdit?: (event: React.MouseEvent<HTMLButtonElement>) => void;
|
||||
onDisplayCost?: (event: React.MouseEvent<HTMLButtonElement>) => void;
|
||||
onShowAgentTools?: (event: React.MouseEvent<HTMLButtonElement>) => void;
|
||||
onShowMicroagents?: (event: React.MouseEvent<HTMLButtonElement>) => void;
|
||||
onShowSkills?: (event: React.MouseEvent<HTMLButtonElement>) => void;
|
||||
onDownloadViaVSCode?: (event: React.MouseEvent<HTMLButtonElement>) => void;
|
||||
position?: "top" | "bottom";
|
||||
}
|
||||
@ -37,7 +37,7 @@ export function ConversationCardContextMenu({
|
||||
onEdit,
|
||||
onDisplayCost,
|
||||
onShowAgentTools,
|
||||
onShowMicroagents,
|
||||
onShowSkills,
|
||||
onDownloadViaVSCode,
|
||||
position = "bottom",
|
||||
}: ConversationCardContextMenuProps) {
|
||||
@ -96,15 +96,15 @@ export function ConversationCardContextMenu({
|
||||
/>
|
||||
</ContextMenuListItem>
|
||||
),
|
||||
onShowMicroagents && (
|
||||
onShowSkills && (
|
||||
<ContextMenuListItem
|
||||
testId="show-microagents-button"
|
||||
onClick={onShowMicroagents}
|
||||
testId="show-skills-button"
|
||||
onClick={onShowSkills}
|
||||
className={contextMenuListItemClassName}
|
||||
>
|
||||
<ConversationNameContextMenuIconText
|
||||
icon={<RobotIcon width={16} height={16} />}
|
||||
text={t(I18nKey.CONVERSATION$SHOW_MICROAGENTS)}
|
||||
text={t(I18nKey.CONVERSATION$SHOW_SKILLS)}
|
||||
/>
|
||||
</ContextMenuListItem>
|
||||
),
|
||||
|
||||
@ -20,7 +20,7 @@ export function ConversationPanelWrapper({
|
||||
return ReactDOM.createPortal(
|
||||
<div
|
||||
className={cn(
|
||||
"absolute h-full w-full left-0 top-0 z-[9999] bg-black/80 rounded-xl",
|
||||
"absolute h-full w-full left-0 top-0 z-[100] bg-black/80 rounded-xl",
|
||||
pathname === "/" && "bottom-0 top-0 md:top-3 md:bottom-3 h-auto",
|
||||
)}
|
||||
>
|
||||
|
||||
@ -3,17 +3,17 @@ import { I18nKey } from "#/i18n/declaration";
|
||||
import { Typography } from "#/ui/typography";
|
||||
import { Pre } from "#/ui/pre";
|
||||
|
||||
interface MicroagentContentProps {
|
||||
interface SkillContentProps {
|
||||
content: string;
|
||||
}
|
||||
|
||||
export function MicroagentContent({ content }: MicroagentContentProps) {
|
||||
export function SkillContent({ content }: SkillContentProps) {
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<div className="mt-2">
|
||||
<Typography.Text className="text-sm font-semibold text-gray-300 mb-2">
|
||||
{t(I18nKey.MICROAGENTS_MODAL$CONTENT)}
|
||||
{t(I18nKey.COMMON$CONTENT)}
|
||||
</Typography.Text>
|
||||
<Pre
|
||||
size="default"
|
||||
@ -28,7 +28,7 @@ export function MicroagentContent({ content }: MicroagentContentProps) {
|
||||
overflow="auto"
|
||||
className="mt-2"
|
||||
>
|
||||
{content || t(I18nKey.MICROAGENTS_MODAL$NO_CONTENT)}
|
||||
{content || t(I18nKey.SKILLS_MODAL$NO_CONTENT)}
|
||||
</Pre>
|
||||
</div>
|
||||
);
|
||||
@ -1,35 +1,31 @@
|
||||
import { ChevronDown, ChevronRight } from "lucide-react";
|
||||
import { Microagent } from "#/api/open-hands.types";
|
||||
import { Typography } from "#/ui/typography";
|
||||
import { MicroagentTriggers } from "./microagent-triggers";
|
||||
import { MicroagentContent } from "./microagent-content";
|
||||
import { SkillTriggers } from "./skill-triggers";
|
||||
import { SkillContent } from "./skill-content";
|
||||
import { Skill } from "#/api/conversation-service/v1-conversation-service.types";
|
||||
|
||||
interface MicroagentItemProps {
|
||||
agent: Microagent;
|
||||
interface SkillItemProps {
|
||||
skill: Skill;
|
||||
isExpanded: boolean;
|
||||
onToggle: (agentName: string) => void;
|
||||
}
|
||||
|
||||
export function MicroagentItem({
|
||||
agent,
|
||||
isExpanded,
|
||||
onToggle,
|
||||
}: MicroagentItemProps) {
|
||||
export function SkillItem({ skill, isExpanded, onToggle }: SkillItemProps) {
|
||||
return (
|
||||
<div className="rounded-md overflow-hidden">
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => onToggle(agent.name)}
|
||||
onClick={() => onToggle(skill.name)}
|
||||
className="w-full py-3 px-2 text-left flex items-center justify-between hover:bg-gray-700 transition-colors"
|
||||
>
|
||||
<div className="flex items-center">
|
||||
<Typography.Text className="font-bold text-gray-100">
|
||||
{agent.name}
|
||||
{skill.name}
|
||||
</Typography.Text>
|
||||
</div>
|
||||
<div className="flex items-center">
|
||||
<Typography.Text className="px-2 py-1 text-xs rounded-full bg-gray-800 mr-2">
|
||||
{agent.type === "repo" ? "Repository" : "Knowledge"}
|
||||
{skill.type === "repo" ? "Repository" : "Knowledge"}
|
||||
</Typography.Text>
|
||||
<Typography.Text className="text-gray-300">
|
||||
{isExpanded ? (
|
||||
@ -43,8 +39,8 @@ export function MicroagentItem({
|
||||
|
||||
{isExpanded && (
|
||||
<div className="px-2 pb-3 pt-1">
|
||||
<MicroagentTriggers triggers={agent.triggers} />
|
||||
<MicroagentContent content={agent.content} />
|
||||
<SkillTriggers triggers={skill.triggers} />
|
||||
<SkillContent content={skill.content} />
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
@ -2,11 +2,11 @@ import { useTranslation } from "react-i18next";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { Typography } from "#/ui/typography";
|
||||
|
||||
interface MicroagentTriggersProps {
|
||||
interface SkillTriggersProps {
|
||||
triggers: string[];
|
||||
}
|
||||
|
||||
export function MicroagentTriggers({ triggers }: MicroagentTriggersProps) {
|
||||
export function SkillTriggers({ triggers }: SkillTriggersProps) {
|
||||
const { t } = useTranslation();
|
||||
|
||||
if (!triggers || triggers.length === 0) {
|
||||
@ -16,7 +16,7 @@ export function MicroagentTriggers({ triggers }: MicroagentTriggersProps) {
|
||||
return (
|
||||
<div className="mt-2 mb-3">
|
||||
<Typography.Text className="text-sm font-semibold text-gray-300 mb-2">
|
||||
{t(I18nKey.MICROAGENTS_MODAL$TRIGGERS)}
|
||||
{t(I18nKey.COMMON$TRIGGERS)}
|
||||
</Typography.Text>
|
||||
<div className="flex flex-wrap gap-1">
|
||||
{triggers.map((trigger) => (
|
||||
@ -2,19 +2,19 @@ import { useTranslation } from "react-i18next";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { Typography } from "#/ui/typography";
|
||||
|
||||
interface MicroagentsEmptyStateProps {
|
||||
interface SkillsEmptyStateProps {
|
||||
isError: boolean;
|
||||
}
|
||||
|
||||
export function MicroagentsEmptyState({ isError }: MicroagentsEmptyStateProps) {
|
||||
export function SkillsEmptyState({ isError }: SkillsEmptyStateProps) {
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<div className="flex items-center justify-center h-full p-4">
|
||||
<Typography.Text className="text-gray-400">
|
||||
{isError
|
||||
? t(I18nKey.MICROAGENTS_MODAL$FETCH_ERROR)
|
||||
: t(I18nKey.CONVERSATION$NO_MICROAGENTS)}
|
||||
? t(I18nKey.COMMON$FETCH_ERROR)
|
||||
: t(I18nKey.CONVERSATION$NO_SKILLS)}
|
||||
</Typography.Text>
|
||||
</div>
|
||||
);
|
||||
@ -1,4 +1,4 @@
|
||||
export function MicroagentsLoadingState() {
|
||||
export function SkillsLoadingState() {
|
||||
return (
|
||||
<div className="flex justify-center items-center py-8">
|
||||
<div className="animate-spin rounded-full h-8 w-8 border-t-2 border-b-2 border-primary" />
|
||||
@ -4,28 +4,28 @@ import { BaseModalTitle } from "#/components/shared/modals/confirmation-modals/b
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { BrandButton } from "../settings/brand-button";
|
||||
|
||||
interface MicroagentsModalHeaderProps {
|
||||
interface SkillsModalHeaderProps {
|
||||
isAgentReady: boolean;
|
||||
isLoading: boolean;
|
||||
isRefetching: boolean;
|
||||
onRefresh: () => void;
|
||||
}
|
||||
|
||||
export function MicroagentsModalHeader({
|
||||
export function SkillsModalHeader({
|
||||
isAgentReady,
|
||||
isLoading,
|
||||
isRefetching,
|
||||
onRefresh,
|
||||
}: MicroagentsModalHeaderProps) {
|
||||
}: SkillsModalHeaderProps) {
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-6 w-full">
|
||||
<div className="flex items-center justify-between w-full">
|
||||
<BaseModalTitle title={t(I18nKey.MICROAGENTS_MODAL$TITLE)} />
|
||||
<BaseModalTitle title={t(I18nKey.SKILLS_MODAL$TITLE)} />
|
||||
{isAgentReady && (
|
||||
<BrandButton
|
||||
testId="refresh-microagents"
|
||||
testId="refresh-skills"
|
||||
type="button"
|
||||
variant="primary"
|
||||
className="flex items-center gap-2"
|
||||
@ -3,43 +3,32 @@ import { useTranslation } from "react-i18next";
|
||||
import { ModalBackdrop } from "#/components/shared/modals/modal-backdrop";
|
||||
import { ModalBody } from "#/components/shared/modals/modal-body";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { useConversationMicroagents } from "#/hooks/query/use-conversation-microagents";
|
||||
import { useConversationSkills } from "#/hooks/query/use-conversation-skills";
|
||||
import { AgentState } from "#/types/agent-state";
|
||||
import { Typography } from "#/ui/typography";
|
||||
import { MicroagentsModalHeader } from "./microagents-modal-header";
|
||||
import { MicroagentsLoadingState } from "./microagents-loading-state";
|
||||
import { MicroagentsEmptyState } from "./microagents-empty-state";
|
||||
import { MicroagentItem } from "./microagent-item";
|
||||
import { SkillsModalHeader } from "./skills-modal-header";
|
||||
import { SkillsLoadingState } from "./skills-loading-state";
|
||||
import { SkillsEmptyState } from "./skills-empty-state";
|
||||
import { SkillItem } from "./skill-item";
|
||||
import { useAgentState } from "#/hooks/use-agent-state";
|
||||
import { useActiveConversation } from "#/hooks/query/use-active-conversation";
|
||||
|
||||
interface MicroagentsModalProps {
|
||||
interface SkillsModalProps {
|
||||
onClose: () => void;
|
||||
}
|
||||
|
||||
export function MicroagentsModal({ onClose }: MicroagentsModalProps) {
|
||||
export function SkillsModal({ onClose }: SkillsModalProps) {
|
||||
const { t } = useTranslation();
|
||||
const { curAgentState } = useAgentState();
|
||||
const { data: conversation } = useActiveConversation();
|
||||
const [expandedAgents, setExpandedAgents] = useState<Record<string, boolean>>(
|
||||
{},
|
||||
);
|
||||
const {
|
||||
data: microagents,
|
||||
data: skills,
|
||||
isLoading,
|
||||
isError,
|
||||
refetch,
|
||||
isRefetching,
|
||||
} = useConversationMicroagents();
|
||||
|
||||
// TODO: Hide MicroagentsModal for V1 conversations
|
||||
// This is a temporary measure and may be re-enabled in the future
|
||||
const isV1Conversation = conversation?.conversation_version === "V1";
|
||||
|
||||
// Don't render anything for V1 conversations
|
||||
if (isV1Conversation) {
|
||||
return null;
|
||||
}
|
||||
} = useConversationSkills();
|
||||
|
||||
const toggleAgent = (agentName: string) => {
|
||||
setExpandedAgents((prev) => ({
|
||||
@ -57,9 +46,9 @@ export function MicroagentsModal({ onClose }: MicroagentsModalProps) {
|
||||
<ModalBody
|
||||
width="medium"
|
||||
className="max-h-[80vh] flex flex-col items-start"
|
||||
testID="microagents-modal"
|
||||
testID="skills-modal"
|
||||
>
|
||||
<MicroagentsModalHeader
|
||||
<SkillsModalHeader
|
||||
isAgentReady={isAgentReady}
|
||||
isLoading={isLoading}
|
||||
isRefetching={isRefetching}
|
||||
@ -68,7 +57,7 @@ export function MicroagentsModal({ onClose }: MicroagentsModalProps) {
|
||||
|
||||
{isAgentReady && (
|
||||
<Typography.Text className="text-sm text-gray-400">
|
||||
{t(I18nKey.MICROAGENTS_MODAL$WARNING)}
|
||||
{t(I18nKey.SKILLS_MODAL$WARNING)}
|
||||
</Typography.Text>
|
||||
)}
|
||||
|
||||
@ -81,33 +70,30 @@ export function MicroagentsModal({ onClose }: MicroagentsModalProps) {
|
||||
</div>
|
||||
)}
|
||||
|
||||
{isLoading && <MicroagentsLoadingState />}
|
||||
{isLoading && <SkillsLoadingState />}
|
||||
|
||||
{!isLoading &&
|
||||
isAgentReady &&
|
||||
(isError || !microagents || microagents.length === 0) && (
|
||||
<MicroagentsEmptyState isError={isError} />
|
||||
(isError || !skills || skills.length === 0) && (
|
||||
<SkillsEmptyState isError={isError} />
|
||||
)}
|
||||
|
||||
{!isLoading &&
|
||||
isAgentReady &&
|
||||
microagents &&
|
||||
microagents.length > 0 && (
|
||||
<div className="p-2 space-y-3">
|
||||
{microagents.map((agent) => {
|
||||
const isExpanded = expandedAgents[agent.name] || false;
|
||||
{!isLoading && isAgentReady && skills && skills.length > 0 && (
|
||||
<div className="p-2 space-y-3">
|
||||
{skills.map((skill) => {
|
||||
const isExpanded = expandedAgents[skill.name] || false;
|
||||
|
||||
return (
|
||||
<MicroagentItem
|
||||
key={agent.name}
|
||||
agent={agent}
|
||||
isExpanded={isExpanded}
|
||||
onToggle={toggleAgent}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
)}
|
||||
return (
|
||||
<SkillItem
|
||||
key={skill.name}
|
||||
skill={skill}
|
||||
isExpanded={isExpanded}
|
||||
onToggle={toggleAgent}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</ModalBody>
|
||||
</ModalBackdrop>
|
||||
@ -31,7 +31,7 @@ interface ConversationNameContextMenuProps {
|
||||
onStop?: (event: React.MouseEvent<HTMLButtonElement>) => void;
|
||||
onDisplayCost?: (event: React.MouseEvent<HTMLButtonElement>) => void;
|
||||
onShowAgentTools?: (event: React.MouseEvent<HTMLButtonElement>) => void;
|
||||
onShowMicroagents?: (event: React.MouseEvent<HTMLButtonElement>) => void;
|
||||
onShowSkills?: (event: React.MouseEvent<HTMLButtonElement>) => void;
|
||||
onExportConversation?: (event: React.MouseEvent<HTMLButtonElement>) => void;
|
||||
onDownloadViaVSCode?: (event: React.MouseEvent<HTMLButtonElement>) => void;
|
||||
position?: "top" | "bottom";
|
||||
@ -44,7 +44,7 @@ export function ConversationNameContextMenu({
|
||||
onStop,
|
||||
onDisplayCost,
|
||||
onShowAgentTools,
|
||||
onShowMicroagents,
|
||||
onShowSkills,
|
||||
onExportConversation,
|
||||
onDownloadViaVSCode,
|
||||
position = "bottom",
|
||||
@ -55,13 +55,12 @@ export function ConversationNameContextMenu({
|
||||
const ref = useClickOutsideElement<HTMLUListElement>(onClose);
|
||||
const { data: conversation } = useActiveConversation();
|
||||
|
||||
// TODO: Hide microagent menu items for V1 conversations
|
||||
// This is a temporary measure and may be re-enabled in the future
|
||||
const isV1Conversation = conversation?.conversation_version === "V1";
|
||||
|
||||
const hasDownload = Boolean(onDownloadViaVSCode);
|
||||
const hasExport = Boolean(onExportConversation);
|
||||
const hasTools = Boolean(onShowAgentTools || onShowMicroagents);
|
||||
const hasTools = Boolean(onShowAgentTools || onShowSkills);
|
||||
const hasInfo = Boolean(onDisplayCost);
|
||||
const hasControl = Boolean(onStop || onDelete);
|
||||
|
||||
@ -91,15 +90,15 @@ export function ConversationNameContextMenu({
|
||||
|
||||
{hasTools && <Divider testId="separator-tools" />}
|
||||
|
||||
{onShowMicroagents && !isV1Conversation && (
|
||||
{onShowSkills && (
|
||||
<ContextMenuListItem
|
||||
testId="show-microagents-button"
|
||||
onClick={onShowMicroagents}
|
||||
testId="show-skills-button"
|
||||
onClick={onShowSkills}
|
||||
className={contextMenuListItemClassName}
|
||||
>
|
||||
<ConversationNameContextMenuIconText
|
||||
icon={<RobotIcon width={16} height={16} />}
|
||||
text={t(I18nKey.CONVERSATION$SHOW_MICROAGENTS)}
|
||||
text={t(I18nKey.CONVERSATION$SHOW_SKILLS)}
|
||||
className={CONTEXT_MENU_ICON_TEXT_CLASSNAME}
|
||||
/>
|
||||
</ContextMenuListItem>
|
||||
|
||||
@ -9,7 +9,7 @@ import { I18nKey } from "#/i18n/declaration";
|
||||
import { EllipsisButton } from "../conversation-panel/ellipsis-button";
|
||||
import { ConversationNameContextMenu } from "./conversation-name-context-menu";
|
||||
import { SystemMessageModal } from "../conversation-panel/system-message-modal";
|
||||
import { MicroagentsModal } from "../conversation-panel/microagents-modal";
|
||||
import { SkillsModal } from "../conversation-panel/skills-modal";
|
||||
import { ConfirmDeleteModal } from "../conversation-panel/confirm-delete-modal";
|
||||
import { ConfirmStopModal } from "../conversation-panel/confirm-stop-modal";
|
||||
import { MetricsModal } from "./metrics-modal/metrics-modal";
|
||||
@ -32,7 +32,7 @@ export function ConversationName() {
|
||||
handleDownloadViaVSCode,
|
||||
handleDisplayCost,
|
||||
handleShowAgentTools,
|
||||
handleShowMicroagents,
|
||||
handleShowSkills,
|
||||
handleExportConversation,
|
||||
handleConfirmDelete,
|
||||
handleConfirmStop,
|
||||
@ -40,8 +40,8 @@ export function ConversationName() {
|
||||
setMetricsModalVisible,
|
||||
systemModalVisible,
|
||||
setSystemModalVisible,
|
||||
microagentsModalVisible,
|
||||
setMicroagentsModalVisible,
|
||||
skillsModalVisible,
|
||||
setSkillsModalVisible,
|
||||
confirmDeleteModalVisible,
|
||||
setConfirmDeleteModalVisible,
|
||||
confirmStopModalVisible,
|
||||
@ -52,7 +52,7 @@ export function ConversationName() {
|
||||
shouldShowExport,
|
||||
shouldShowDisplayCost,
|
||||
shouldShowAgentTools,
|
||||
shouldShowMicroagents,
|
||||
shouldShowSkills,
|
||||
} = useConversationNameContextMenu({
|
||||
conversationId,
|
||||
conversationStatus: conversation?.status,
|
||||
@ -170,9 +170,7 @@ export function ConversationName() {
|
||||
onShowAgentTools={
|
||||
shouldShowAgentTools ? handleShowAgentTools : undefined
|
||||
}
|
||||
onShowMicroagents={
|
||||
shouldShowMicroagents ? handleShowMicroagents : undefined
|
||||
}
|
||||
onShowSkills={shouldShowSkills ? handleShowSkills : undefined}
|
||||
onExportConversation={
|
||||
shouldShowExport ? handleExportConversation : undefined
|
||||
}
|
||||
@ -199,9 +197,9 @@ export function ConversationName() {
|
||||
systemMessage={systemMessage ? systemMessage.args : null}
|
||||
/>
|
||||
|
||||
{/* Microagents Modal */}
|
||||
{microagentsModalVisible && (
|
||||
<MicroagentsModal onClose={() => setMicroagentsModalVisible(false)} />
|
||||
{/* Skills Modal */}
|
||||
{skillsModalVisible && (
|
||||
<SkillsModal onClose={() => setSkillsModalVisible(false)} />
|
||||
)}
|
||||
|
||||
{/* Confirm Delete Modal */}
|
||||
|
||||
@ -82,13 +82,45 @@ export function ConversationTabContent() {
|
||||
isPlannerActive,
|
||||
]);
|
||||
|
||||
const conversationKey = useMemo(() => {
|
||||
if (isEditorActive) {
|
||||
return "editor";
|
||||
}
|
||||
if (isBrowserActive) {
|
||||
return "browser";
|
||||
}
|
||||
if (isServedActive) {
|
||||
return "served";
|
||||
}
|
||||
if (isVSCodeActive) {
|
||||
return "vscode";
|
||||
}
|
||||
if (isTerminalActive) {
|
||||
return "terminal";
|
||||
}
|
||||
if (isPlannerActive) {
|
||||
return "planner";
|
||||
}
|
||||
return "";
|
||||
}, [
|
||||
isEditorActive,
|
||||
isBrowserActive,
|
||||
isServedActive,
|
||||
isVSCodeActive,
|
||||
isTerminalActive,
|
||||
isPlannerActive,
|
||||
]);
|
||||
|
||||
if (shouldShownAgentLoading) {
|
||||
return <ConversationLoading />;
|
||||
}
|
||||
|
||||
return (
|
||||
<TabContainer>
|
||||
<ConversationTabTitle title={conversationTabTitle} />
|
||||
<ConversationTabTitle
|
||||
title={conversationTabTitle}
|
||||
conversationKey={conversationKey}
|
||||
/>
|
||||
<TabContentArea>
|
||||
{tabs.map(({ key, component: Component, isActive }) => (
|
||||
<TabWrapper
|
||||
|
||||
@ -1,11 +1,33 @@
|
||||
import RefreshIcon from "#/icons/u-refresh.svg?react";
|
||||
import { useUnifiedGetGitChanges } from "#/hooks/query/use-unified-get-git-changes";
|
||||
|
||||
type ConversationTabTitleProps = {
|
||||
title: string;
|
||||
conversationKey: string;
|
||||
};
|
||||
|
||||
export function ConversationTabTitle({ title }: ConversationTabTitleProps) {
|
||||
export function ConversationTabTitle({
|
||||
title,
|
||||
conversationKey,
|
||||
}: ConversationTabTitleProps) {
|
||||
const { refetch } = useUnifiedGetGitChanges();
|
||||
|
||||
const handleRefresh = () => {
|
||||
refetch();
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex flex-row items-center justify-between border-b border-[#474A54] py-2 px-3">
|
||||
<span className="text-xs font-medium text-white">{title}</span>
|
||||
{conversationKey === "editor" && (
|
||||
<button
|
||||
type="button"
|
||||
className="flex w-[26px] py-1 justify-center items-center gap-[10px] rounded-[7px] hover:bg-[#474A54] cursor-pointer"
|
||||
onClick={handleRefresh}
|
||||
>
|
||||
<RefreshIcon width={12.75} height={15} color="#ffffff" />
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@ -20,13 +20,13 @@ export function EmailVerificationGuard({
|
||||
if (isLoading) return;
|
||||
|
||||
// If EMAIL_VERIFIED is explicitly false (not undefined or null)
|
||||
if (settings?.EMAIL_VERIFIED === false) {
|
||||
if (settings?.email_verified === false) {
|
||||
// Allow access to /settings/user but redirect from any other page
|
||||
if (pathname !== "/settings/user") {
|
||||
navigate("/settings/user", { replace: true });
|
||||
}
|
||||
}
|
||||
}, [settings?.EMAIL_VERIFIED, pathname, navigate, isLoading]);
|
||||
}, [settings?.email_verified, pathname, navigate, isLoading]);
|
||||
|
||||
return children;
|
||||
}
|
||||
|
||||
@ -75,7 +75,7 @@ export function GitProviderDropdown({
|
||||
}
|
||||
|
||||
// If no input value, show all providers
|
||||
if (!inputValue || !inputValue.trim()) {
|
||||
if (!inputValue?.trim()) {
|
||||
return providers;
|
||||
}
|
||||
|
||||
|
||||
@ -99,7 +99,7 @@ export function GitRepoDropdown({
|
||||
);
|
||||
|
||||
// If no input value, return all recent repos for this provider
|
||||
if (!inputValue || !inputValue.trim()) {
|
||||
if (!inputValue?.trim()) {
|
||||
return providerFilteredRepos;
|
||||
}
|
||||
|
||||
@ -139,7 +139,7 @@ export function GitRepoDropdown({
|
||||
baseRepositories = repositories;
|
||||
}
|
||||
// If no input value, show all repositories
|
||||
else if (!inputValue || !inputValue.trim()) {
|
||||
else if (!inputValue?.trim()) {
|
||||
baseRepositories = repositories;
|
||||
}
|
||||
// For URL inputs, use the processed search input for filtering
|
||||
@ -246,8 +246,7 @@ export function GitRepoDropdown({
|
||||
// Create sticky footer item for GitHub provider
|
||||
const stickyFooterItem = useMemo(() => {
|
||||
if (
|
||||
!config ||
|
||||
!config.APP_SLUG ||
|
||||
!config?.APP_SLUG ||
|
||||
provider !== ProviderOptions.github ||
|
||||
config.APP_MODE !== "saas"
|
||||
)
|
||||
|
||||
@ -78,7 +78,7 @@ export function RecentConversations() {
|
||||
)}
|
||||
</div>
|
||||
|
||||
{!isInitialLoading && displayedConversations?.length === 0 && (
|
||||
{!isInitialLoading && !error && displayedConversations?.length === 0 && (
|
||||
<span className="text-xs leading-4 text-white font-medium pl-4">
|
||||
{t(I18nKey.HOME$NO_RECENT_CONVERSATIONS)}
|
||||
</span>
|
||||
|
||||
@ -35,7 +35,11 @@ export function RepositorySelectionForm({
|
||||
React.useState<Provider | null>(null);
|
||||
|
||||
const { providers } = useUserProviders();
|
||||
const { addRecentRepository } = useHomeStore();
|
||||
const {
|
||||
addRecentRepository,
|
||||
setLastSelectedProvider,
|
||||
getLastSelectedProvider,
|
||||
} = useHomeStore();
|
||||
const {
|
||||
mutate: createConversation,
|
||||
isPending,
|
||||
@ -46,12 +50,24 @@ export function RepositorySelectionForm({
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
// Auto-select provider if there's only one
|
||||
// Auto-select provider logic
|
||||
React.useEffect(() => {
|
||||
if (providers.length === 0) return;
|
||||
|
||||
// If there's only one provider, auto-select it
|
||||
if (providers.length === 1 && !selectedProvider) {
|
||||
setSelectedProvider(providers[0]);
|
||||
return;
|
||||
}
|
||||
}, [providers, selectedProvider]);
|
||||
|
||||
// If there are multiple providers and none is selected, try to use the last selected one
|
||||
if (providers.length > 1 && !selectedProvider) {
|
||||
const lastSelected = getLastSelectedProvider();
|
||||
if (lastSelected && providers.includes(lastSelected)) {
|
||||
setSelectedProvider(lastSelected);
|
||||
}
|
||||
}
|
||||
}, [providers, selectedProvider, getLastSelectedProvider]);
|
||||
|
||||
// We check for isSuccess because the app might require time to render
|
||||
// into the new conversation screen after the conversation is created.
|
||||
@ -66,6 +82,7 @@ export function RepositorySelectionForm({
|
||||
}
|
||||
|
||||
setSelectedProvider(provider);
|
||||
setLastSelectedProvider(provider); // Store the selected provider
|
||||
setSelectedRepository(null); // Reset repository selection when provider changes
|
||||
setSelectedBranch(null); // Reset branch selection when provider changes
|
||||
onRepoSelection(null); // Reset parent component's selected repo
|
||||
|
||||
@ -45,7 +45,7 @@ export function DropdownItem<T>({
|
||||
// eslint-disable-next-line react/jsx-props-no-spreading
|
||||
<li key={getItemKey(item)} {...itemProps}>
|
||||
<div className="flex items-center gap-2">
|
||||
{renderIcon && renderIcon(item)}
|
||||
{renderIcon?.(item)}
|
||||
<span className="font-medium">{getDisplayText(item)}</span>
|
||||
</div>
|
||||
</li>
|
||||
|
||||
@ -1,24 +1,14 @@
|
||||
import { useMutation } from "@tanstack/react-query";
|
||||
import { Trans, useTranslation } from "react-i18next";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import OpenHandsLogo from "#/assets/branding/openhands-logo.svg?react";
|
||||
import { ModalBackdrop } from "#/components/shared/modals/modal-backdrop";
|
||||
import { ModalBody } from "#/components/shared/modals/modal-body";
|
||||
import BillingService from "#/api/billing-service/billing-service.api";
|
||||
import { BrandButton } from "../settings/brand-button";
|
||||
import { displayErrorToast } from "#/utils/custom-toast-handlers";
|
||||
import { useCreateBillingSession } from "#/hooks/mutation/use-create-billing-session";
|
||||
|
||||
export function SetupPaymentModal() {
|
||||
const { t } = useTranslation();
|
||||
const { mutate, isPending } = useMutation({
|
||||
mutationFn: BillingService.createBillingSessionResponse,
|
||||
onSuccess: (data) => {
|
||||
window.location.href = data;
|
||||
},
|
||||
onError: () => {
|
||||
displayErrorToast(t(I18nKey.BILLING$ERROR_WHILE_CREATING_SESSION));
|
||||
},
|
||||
});
|
||||
const { mutate, isPending } = useCreateBillingSession();
|
||||
|
||||
return (
|
||||
<ModalBackdrop>
|
||||
|
||||
@ -13,10 +13,8 @@ import { CreateApiKeyModal } from "./create-api-key-modal";
|
||||
import { DeleteApiKeyModal } from "./delete-api-key-modal";
|
||||
import { NewApiKeyModal } from "./new-api-key-modal";
|
||||
import { useApiKeys } from "#/hooks/query/use-api-keys";
|
||||
import {
|
||||
useLlmApiKey,
|
||||
useRefreshLlmApiKey,
|
||||
} from "#/hooks/query/use-llm-api-key";
|
||||
import { useLlmApiKey } from "#/hooks/query/use-llm-api-key";
|
||||
import { useRefreshLlmApiKey } from "#/hooks/mutation/use-refresh-llm-api-key";
|
||||
|
||||
interface LlmApiKeyManagerProps {
|
||||
llmApiKey: { key: string | null } | undefined;
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user