From 7853b41adda0a6069274d236902b1e57a8c213c4 Mon Sep 17 00:00:00 2001 From: Rohit Malhotra Date: Tue, 16 Dec 2025 11:54:01 -0500 Subject: [PATCH 01/14] Add OAuth 2.0 Device Flow backend for CLI authentication (#11984) Co-authored-by: openhands Co-authored-by: Xingyao Wang --- .../allhands-realm-github-provider.json.tmpl | 1 + .../versions/084_create_device_codes_table.py | 49 ++ enterprise/saas_server.py | 2 + enterprise/server/middleware.py | 9 +- enterprise/server/routes/oauth_device.py | 324 ++++++++++ enterprise/storage/api_key_store.py | 39 +- enterprise/storage/device_code.py | 109 ++++ enterprise/storage/device_code_store.py | 167 +++++ enterprise/tests/unit/conftest.py | 1 + .../unit/server/routes/test_oauth_device.py | 610 ++++++++++++++++++ .../tests/unit/storage/test_device_code.py | 83 +++ .../unit/storage/test_device_code_store.py | 193 ++++++ enterprise/tests/unit/test_api_key_store.py | 44 ++ frontend/src/routes.ts | 1 + frontend/src/routes/device-verify.tsx | 274 ++++++++ 15 files changed, 1901 insertions(+), 5 deletions(-) create mode 100644 enterprise/migrations/versions/084_create_device_codes_table.py create mode 100644 enterprise/server/routes/oauth_device.py create mode 100644 enterprise/storage/device_code.py create mode 100644 enterprise/storage/device_code_store.py create mode 100644 enterprise/tests/unit/server/routes/test_oauth_device.py create mode 100644 enterprise/tests/unit/storage/test_device_code.py create mode 100644 enterprise/tests/unit/storage/test_device_code_store.py create mode 100644 frontend/src/routes/device-verify.tsx diff --git a/enterprise/allhands-realm-github-provider.json.tmpl b/enterprise/allhands-realm-github-provider.json.tmpl index 6cdaa34383..35ff5f0afc 100644 --- a/enterprise/allhands-realm-github-provider.json.tmpl +++ b/enterprise/allhands-realm-github-provider.json.tmpl @@ -721,6 +721,7 @@ "https://$WEB_HOST/oauth/keycloak/callback", "https://$WEB_HOST/oauth/keycloak/offline/callback", "https://$WEB_HOST/slack/keycloak-callback", + "https://$WEB_HOST/oauth/device/keycloak-callback", "https://$WEB_HOST/api/email/verified", "/realms/$KEYCLOAK_REALM_NAME/$KEYCLOAK_CLIENT_ID/*" ], diff --git a/enterprise/migrations/versions/084_create_device_codes_table.py b/enterprise/migrations/versions/084_create_device_codes_table.py new file mode 100644 index 0000000000..0898e09ef5 --- /dev/null +++ b/enterprise/migrations/versions/084_create_device_codes_table.py @@ -0,0 +1,49 @@ +"""Create device_codes table for OAuth 2.0 Device Flow + +Revision ID: 084 +Revises: 083 +Create Date: 2024-12-10 12:00:00.000000 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = '084' +down_revision = '083' +branch_labels = None +depends_on = None + + +def upgrade(): + """Create device_codes table for OAuth 2.0 Device Flow.""" + op.create_table( + 'device_codes', + sa.Column('id', sa.Integer(), autoincrement=True, nullable=False), + sa.Column('device_code', sa.String(length=128), nullable=False), + sa.Column('user_code', sa.String(length=16), nullable=False), + sa.Column('status', sa.String(length=32), nullable=False), + sa.Column('keycloak_user_id', sa.String(length=255), nullable=True), + sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False), + sa.Column('authorized_at', sa.DateTime(timezone=True), nullable=True), + # Rate limiting fields for RFC 8628 section 3.5 compliance + sa.Column('last_poll_time', sa.DateTime(timezone=True), nullable=True), + sa.Column('current_interval', sa.Integer(), nullable=False, default=5), + sa.PrimaryKeyConstraint('id'), + ) + + # Create indexes for efficient lookups + op.create_index( + 'ix_device_codes_device_code', 'device_codes', ['device_code'], unique=True + ) + op.create_index( + 'ix_device_codes_user_code', 'device_codes', ['user_code'], unique=True + ) + + +def downgrade(): + """Drop device_codes table.""" + op.drop_index('ix_device_codes_user_code', table_name='device_codes') + op.drop_index('ix_device_codes_device_code', table_name='device_codes') + op.drop_table('device_codes') diff --git a/enterprise/saas_server.py b/enterprise/saas_server.py index 4c3c7c49ba..96e19a9815 100644 --- a/enterprise/saas_server.py +++ b/enterprise/saas_server.py @@ -34,6 +34,7 @@ from server.routes.integration.jira_dc import jira_dc_integration_router # noqa from server.routes.integration.linear import linear_integration_router # noqa: E402 from server.routes.integration.slack import slack_router # noqa: E402 from server.routes.mcp_patch import patch_mcp_server # noqa: E402 +from server.routes.oauth_device import oauth_device_router # noqa: E402 from server.routes.readiness import readiness_router # noqa: E402 from server.routes.user import saas_user_router # noqa: E402 @@ -60,6 +61,7 @@ base_app.mount('/internal/metrics', metrics_app()) base_app.include_router(readiness_router) # Add routes for readiness checks base_app.include_router(api_router) # Add additional route for github auth base_app.include_router(oauth_router) # Add additional route for oauth callback +base_app.include_router(oauth_device_router) # Add OAuth 2.0 Device Flow routes base_app.include_router(saas_user_router) # Add additional route SAAS user calls base_app.include_router( billing_router diff --git a/enterprise/server/middleware.py b/enterprise/server/middleware.py index 2972c1ec38..54e3319595 100644 --- a/enterprise/server/middleware.py +++ b/enterprise/server/middleware.py @@ -152,17 +152,22 @@ class SetAuthCookieMiddleware: return False path = request.url.path - is_api_that_should_attach = path.startswith('/api') and path not in ( + ignore_paths = ( '/api/options/config', '/api/keycloak/callback', '/api/billing/success', '/api/billing/cancel', '/api/billing/customer-setup-success', '/api/billing/stripe-webhook', + '/oauth/device/authorize', + '/oauth/device/token', ) + if path in ignore_paths: + return False is_mcp = path.startswith('/mcp') - return is_api_that_should_attach or is_mcp + is_api_route = path.startswith('/api') + return is_api_route or is_mcp async def _logout(self, request: Request): # Log out of keycloak - this prevents issues where you did not log in with the idp you believe you used diff --git a/enterprise/server/routes/oauth_device.py b/enterprise/server/routes/oauth_device.py new file mode 100644 index 0000000000..39ff9a4081 --- /dev/null +++ b/enterprise/server/routes/oauth_device.py @@ -0,0 +1,324 @@ +"""OAuth 2.0 Device Flow endpoints for CLI authentication.""" + +from datetime import UTC, datetime, timedelta +from typing import Optional + +from fastapi import APIRouter, Depends, Form, HTTPException, Request, status +from fastapi.responses import JSONResponse +from pydantic import BaseModel +from storage.api_key_store import ApiKeyStore +from storage.database import session_maker +from storage.device_code_store import DeviceCodeStore + +from openhands.core.logger import openhands_logger as logger +from openhands.server.user_auth import get_user_id + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +DEVICE_CODE_EXPIRES_IN = 600 # 10 minutes +DEVICE_TOKEN_POLL_INTERVAL = 5 # seconds + +API_KEY_NAME = 'Device Link Access Key' +KEY_EXPIRATION_TIME = timedelta(days=1) # Key expires in 24 hours + +# --------------------------------------------------------------------------- +# Models +# --------------------------------------------------------------------------- + + +class DeviceAuthorizationResponse(BaseModel): + device_code: str + user_code: str + verification_uri: str + verification_uri_complete: str + expires_in: int + interval: int + + +class DeviceTokenResponse(BaseModel): + access_token: str # This will be the user's API key + token_type: str = 'Bearer' + expires_in: Optional[int] = None # API keys may not have expiration + + +class DeviceTokenErrorResponse(BaseModel): + error: str + error_description: Optional[str] = None + interval: Optional[int] = None # Required for slow_down error + + +# --------------------------------------------------------------------------- +# Router + stores +# --------------------------------------------------------------------------- + +oauth_device_router = APIRouter(prefix='/oauth/device') +device_code_store = DeviceCodeStore(session_maker) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _oauth_error( + status_code: int, + error: str, + description: str, + interval: Optional[int] = None, +) -> JSONResponse: + """Return a JSON OAuth-style error response.""" + return JSONResponse( + status_code=status_code, + content=DeviceTokenErrorResponse( + error=error, + error_description=description, + interval=interval, + ).model_dump(), + ) + + +# --------------------------------------------------------------------------- +# Endpoints +# --------------------------------------------------------------------------- + + +@oauth_device_router.post('/authorize', response_model=DeviceAuthorizationResponse) +async def device_authorization( + http_request: Request, +) -> DeviceAuthorizationResponse: + """Start device flow by generating device and user codes.""" + try: + device_code_entry = device_code_store.create_device_code( + expires_in=DEVICE_CODE_EXPIRES_IN, + ) + + base_url = str(http_request.base_url).rstrip('/') + verification_uri = f'{base_url}/oauth/device/verify' + verification_uri_complete = ( + f'{verification_uri}?user_code={device_code_entry.user_code}' + ) + + logger.info( + 'Device authorization initiated', + extra={'user_code': device_code_entry.user_code}, + ) + + return DeviceAuthorizationResponse( + device_code=device_code_entry.device_code, + user_code=device_code_entry.user_code, + verification_uri=verification_uri, + verification_uri_complete=verification_uri_complete, + expires_in=DEVICE_CODE_EXPIRES_IN, + interval=device_code_entry.current_interval, + ) + except Exception as e: + logger.exception('Error in device authorization: %s', str(e)) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail='Internal server error', + ) from e + + +@oauth_device_router.post('/token') +async def device_token(device_code: str = Form(...)): + """Poll for a token until the user authorizes or the code expires.""" + try: + device_code_entry = device_code_store.get_by_device_code(device_code) + + if not device_code_entry: + return _oauth_error( + status.HTTP_400_BAD_REQUEST, + 'invalid_grant', + 'Invalid device code', + ) + + # Check rate limiting (RFC 8628 section 3.5) + is_too_fast, current_interval = device_code_entry.check_rate_limit() + if is_too_fast: + # Update poll time and increase interval + device_code_store.update_poll_time(device_code, increase_interval=True) + logger.warning( + 'Client polling too fast, returning slow_down error', + extra={ + 'device_code': device_code[:8] + '...', # Log partial for privacy + 'new_interval': current_interval, + }, + ) + return _oauth_error( + status.HTTP_400_BAD_REQUEST, + 'slow_down', + f'Polling too frequently. Wait at least {current_interval} seconds between requests.', + interval=current_interval, + ) + + # Update poll time for successful rate limit check + device_code_store.update_poll_time(device_code, increase_interval=False) + + if device_code_entry.is_expired(): + return _oauth_error( + status.HTTP_400_BAD_REQUEST, + 'expired_token', + 'Device code has expired', + ) + + if device_code_entry.status == 'denied': + return _oauth_error( + status.HTTP_400_BAD_REQUEST, + 'access_denied', + 'User denied the authorization request', + ) + + if device_code_entry.status == 'pending': + return _oauth_error( + status.HTTP_400_BAD_REQUEST, + 'authorization_pending', + 'User has not yet completed authorization', + ) + + if device_code_entry.status == 'authorized': + # Retrieve the specific API key for this device using the user_code + api_key_store = ApiKeyStore.get_instance() + device_key_name = f'{API_KEY_NAME} ({device_code_entry.user_code})' + device_api_key = api_key_store.retrieve_api_key_by_name( + device_code_entry.keycloak_user_id, device_key_name + ) + + if not device_api_key: + logger.error( + 'No device API key found for authorized device', + extra={ + 'user_id': device_code_entry.keycloak_user_id, + 'user_code': device_code_entry.user_code, + }, + ) + return _oauth_error( + status.HTTP_500_INTERNAL_SERVER_ERROR, + 'server_error', + 'API key not found', + ) + + # Return the API key as access_token + return DeviceTokenResponse( + access_token=device_api_key, + ) + + # Fallback for unexpected status values + logger.error( + 'Unknown device code status', + extra={'status': device_code_entry.status}, + ) + return _oauth_error( + status.HTTP_500_INTERNAL_SERVER_ERROR, + 'server_error', + 'Unknown device code status', + ) + + except Exception as e: + logger.exception('Error in device token: %s', str(e)) + return _oauth_error( + status.HTTP_500_INTERNAL_SERVER_ERROR, + 'server_error', + 'Internal server error', + ) + + +@oauth_device_router.post('/verify-authenticated') +async def device_verification_authenticated( + user_code: str = Form(...), + user_id: str = Depends(get_user_id), +): + """Process device verification for authenticated users (called by frontend).""" + try: + if not user_id: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail='Authentication required', + ) + + # Validate device code + device_code_entry = device_code_store.get_by_user_code(user_code) + if not device_code_entry: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail='The device code is invalid or has expired.', + ) + + if not device_code_entry.is_pending(): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail='This device code has already been processed.', + ) + + # First, authorize the device code + success = device_code_store.authorize_device_code( + user_code=user_code, + user_id=user_id, + ) + + if not success: + logger.error( + 'Failed to authorize device code', + extra={'user_code': user_code, 'user_id': user_id}, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail='Failed to authorize the device. Please try again.', + ) + + # Only create API key AFTER successful authorization + api_key_store = ApiKeyStore.get_instance() + try: + # Create a unique API key for this device using user_code in the name + device_key_name = f'{API_KEY_NAME} ({user_code})' + api_key_store.create_api_key( + user_id, + name=device_key_name, + expires_at=datetime.now(UTC) + KEY_EXPIRATION_TIME, + ) + logger.info( + 'Created new device API key for user after successful authorization', + extra={'user_id': user_id, 'user_code': user_code}, + ) + except Exception as e: + logger.exception( + 'Failed to create device API key after authorization: %s', str(e) + ) + + # Clean up: revert the device authorization since API key creation failed + # This prevents the device from being in an authorized state without an API key + try: + device_code_store.deny_device_code(user_code) + logger.info( + 'Reverted device authorization due to API key creation failure', + extra={'user_code': user_code, 'user_id': user_id}, + ) + except Exception as cleanup_error: + logger.exception( + 'Failed to revert device authorization during cleanup: %s', + str(cleanup_error), + ) + + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail='Failed to create API key for device access.', + ) + + logger.info( + 'Device code authorized with API key successfully', + extra={'user_code': user_code, 'user_id': user_id}, + ) + return JSONResponse( + status_code=status.HTTP_200_OK, + content={'message': 'Device authorized successfully!'}, + ) + + except HTTPException: + raise + except Exception as e: + logger.exception('Error in device verification: %s', str(e)) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail='An unexpected error occurred. Please try again.', + ) diff --git a/enterprise/storage/api_key_store.py b/enterprise/storage/api_key_store.py index 162ed415c1..693bfdb321 100644 --- a/enterprise/storage/api_key_store.py +++ b/enterprise/storage/api_key_store.py @@ -57,9 +57,15 @@ class ApiKeyStore: return None # Check if the key has expired - if key_record.expires_at and key_record.expires_at < now: - logger.info(f'API key has expired: {key_record.id}') - return None + if key_record.expires_at: + # Handle timezone-naive datetime from database by assuming it's UTC + expires_at = key_record.expires_at + if expires_at.tzinfo is None: + expires_at = expires_at.replace(tzinfo=UTC) + + if expires_at < now: + logger.info(f'API key has expired: {key_record.id}') + return None # Update last_used_at timestamp session.execute( @@ -125,6 +131,33 @@ class ApiKeyStore: return None + def retrieve_api_key_by_name(self, user_id: str, name: str) -> str | None: + """Retrieve an API key by name for a specific user.""" + with self.session_maker() as session: + key_record = ( + session.query(ApiKey) + .filter(ApiKey.user_id == user_id, ApiKey.name == name) + .first() + ) + return key_record.key if key_record else None + + def delete_api_key_by_name(self, user_id: str, name: str) -> bool: + """Delete an API key by name for a specific user.""" + with self.session_maker() as session: + key_record = ( + session.query(ApiKey) + .filter(ApiKey.user_id == user_id, ApiKey.name == name) + .first() + ) + + if not key_record: + return False + + session.delete(key_record) + session.commit() + + return True + @classmethod def get_instance(cls) -> ApiKeyStore: """Get an instance of the ApiKeyStore.""" diff --git a/enterprise/storage/device_code.py b/enterprise/storage/device_code.py new file mode 100644 index 0000000000..47e18b51bc --- /dev/null +++ b/enterprise/storage/device_code.py @@ -0,0 +1,109 @@ +"""Device code storage model for OAuth 2.0 Device Flow.""" + +from datetime import datetime, timezone +from enum import Enum + +from sqlalchemy import Column, DateTime, Integer, String +from storage.base import Base + + +class DeviceCodeStatus(Enum): + """Status of a device code authorization request.""" + + PENDING = 'pending' + AUTHORIZED = 'authorized' + EXPIRED = 'expired' + DENIED = 'denied' + + +class DeviceCode(Base): + """Device code for OAuth 2.0 Device Flow. + + This stores the device codes issued during the device authorization flow, + along with their status and associated user information once authorized. + """ + + __tablename__ = 'device_codes' + + id = Column(Integer, primary_key=True, autoincrement=True) + device_code = Column(String(128), unique=True, nullable=False, index=True) + user_code = Column(String(16), unique=True, nullable=False, index=True) + status = Column(String(32), nullable=False, default=DeviceCodeStatus.PENDING.value) + + # Keycloak user ID who authorized the device (set during verification) + keycloak_user_id = Column(String(255), nullable=True) + + # Timestamps + expires_at = Column(DateTime(timezone=True), nullable=False) + authorized_at = Column(DateTime(timezone=True), nullable=True) + + # Rate limiting fields for RFC 8628 section 3.5 compliance + last_poll_time = Column(DateTime(timezone=True), nullable=True) + current_interval = Column(Integer, nullable=False, default=5) + + def __repr__(self) -> str: + return f"" + + def is_expired(self) -> bool: + """Check if the device code has expired.""" + now = datetime.now(timezone.utc) + return now > self.expires_at + + def is_pending(self) -> bool: + """Check if the device code is still pending authorization.""" + return self.status == DeviceCodeStatus.PENDING.value and not self.is_expired() + + def is_authorized(self) -> bool: + """Check if the device code has been authorized.""" + return self.status == DeviceCodeStatus.AUTHORIZED.value + + def authorize(self, user_id: str) -> None: + """Mark the device code as authorized.""" + self.status = DeviceCodeStatus.AUTHORIZED.value + self.keycloak_user_id = user_id # Set the Keycloak user ID during authorization + self.authorized_at = datetime.now(timezone.utc) + + def deny(self) -> None: + """Mark the device code as denied.""" + self.status = DeviceCodeStatus.DENIED.value + + def expire(self) -> None: + """Mark the device code as expired.""" + self.status = DeviceCodeStatus.EXPIRED.value + + def check_rate_limit(self) -> tuple[bool, int]: + """Check if the client is polling too fast. + + Returns: + tuple: (is_too_fast, current_interval) + - is_too_fast: True if client should receive slow_down error + - current_interval: Current polling interval to use + """ + now = datetime.now(timezone.utc) + + # If this is the first poll, allow it + if self.last_poll_time is None: + return False, self.current_interval + + # Calculate time since last poll + time_since_last_poll = (now - self.last_poll_time).total_seconds() + + # Check if polling too fast + if time_since_last_poll < self.current_interval: + # Increase interval for slow_down (RFC 8628 section 3.5) + new_interval = min(self.current_interval + 5, 60) # Cap at 60 seconds + return True, new_interval + + return False, self.current_interval + + def update_poll_time(self, increase_interval: bool = False) -> None: + """Update the last poll time and optionally increase the interval. + + Args: + increase_interval: If True, increase the current interval for slow_down + """ + self.last_poll_time = datetime.now(timezone.utc) + + if increase_interval: + # Increase interval by 5 seconds, cap at 60 seconds (RFC 8628) + self.current_interval = min(self.current_interval + 5, 60) diff --git a/enterprise/storage/device_code_store.py b/enterprise/storage/device_code_store.py new file mode 100644 index 0000000000..de2fe29cc4 --- /dev/null +++ b/enterprise/storage/device_code_store.py @@ -0,0 +1,167 @@ +"""Device code store for OAuth 2.0 Device Flow.""" + +import secrets +import string +from datetime import datetime, timedelta, timezone + +from sqlalchemy.exc import IntegrityError +from storage.device_code import DeviceCode + + +class DeviceCodeStore: + """Store for managing OAuth 2.0 device codes.""" + + def __init__(self, session_maker): + self.session_maker = session_maker + + def generate_user_code(self) -> str: + """Generate a human-readable user code (8 characters, uppercase letters and digits).""" + # Use a mix of uppercase letters and digits, avoiding confusing characters + alphabet = 'ABCDEFGHJKLMNPQRSTUVWXYZ23456789' # No I, O, 0, 1 + return ''.join(secrets.choice(alphabet) for _ in range(8)) + + def generate_device_code(self) -> str: + """Generate a secure device code (128 characters).""" + alphabet = string.ascii_letters + string.digits + return ''.join(secrets.choice(alphabet) for _ in range(128)) + + def create_device_code( + self, + expires_in: int = 600, # 10 minutes default + max_attempts: int = 10, + ) -> DeviceCode: + """Create a new device code entry. + + Uses database constraints to ensure uniqueness, avoiding TOCTOU race conditions. + Retries on constraint violations until unique codes are generated. + + Args: + expires_in: Expiration time in seconds + max_attempts: Maximum number of attempts to generate unique codes + + Returns: + The created DeviceCode instance + + Raises: + RuntimeError: If unable to generate unique codes after max_attempts + """ + for attempt in range(max_attempts): + user_code = self.generate_user_code() + device_code = self.generate_device_code() + expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in) + + device_code_entry = DeviceCode( + device_code=device_code, + user_code=user_code, + keycloak_user_id=None, # Will be set during authorization + expires_at=expires_at, + ) + + try: + with self.session_maker() as session: + session.add(device_code_entry) + session.commit() + session.refresh(device_code_entry) + session.expunge(device_code_entry) # Detach from session cleanly + return device_code_entry + except IntegrityError: + # Constraint violation - codes already exist, retry with new codes + continue + + raise RuntimeError( + f'Failed to generate unique device codes after {max_attempts} attempts' + ) + + def get_by_device_code(self, device_code: str) -> DeviceCode | None: + """Get device code entry by device code.""" + with self.session_maker() as session: + result = ( + session.query(DeviceCode).filter_by(device_code=device_code).first() + ) + if result: + session.expunge(result) # Detach from session cleanly + return result + + def get_by_user_code(self, user_code: str) -> DeviceCode | None: + """Get device code entry by user code.""" + with self.session_maker() as session: + result = session.query(DeviceCode).filter_by(user_code=user_code).first() + if result: + session.expunge(result) # Detach from session cleanly + return result + + def authorize_device_code(self, user_code: str, user_id: str) -> bool: + """Authorize a device code. + + Args: + user_code: The user code to authorize + user_id: The user ID from Keycloak + + Returns: + True if authorization was successful, False otherwise + """ + with self.session_maker() as session: + device_code_entry = ( + session.query(DeviceCode).filter_by(user_code=user_code).first() + ) + + if not device_code_entry: + return False + + if not device_code_entry.is_pending(): + return False + + device_code_entry.authorize(user_id) + session.commit() + + return True + + def deny_device_code(self, user_code: str) -> bool: + """Deny a device code authorization. + + Args: + user_code: The user code to deny + + Returns: + True if denial was successful, False otherwise + """ + with self.session_maker() as session: + device_code_entry = ( + session.query(DeviceCode).filter_by(user_code=user_code).first() + ) + + if not device_code_entry: + return False + + if not device_code_entry.is_pending(): + return False + + device_code_entry.deny() + session.commit() + + return True + + def update_poll_time( + self, device_code: str, increase_interval: bool = False + ) -> bool: + """Update the poll time for a device code and optionally increase interval. + + Args: + device_code: The device code to update + increase_interval: If True, increase the polling interval for slow_down + + Returns: + True if update was successful, False otherwise + """ + with self.session_maker() as session: + device_code_entry = ( + session.query(DeviceCode).filter_by(device_code=device_code).first() + ) + + if not device_code_entry: + return False + + device_code_entry.update_poll_time(increase_interval) + session.commit() + + return True diff --git a/enterprise/tests/unit/conftest.py b/enterprise/tests/unit/conftest.py index 08516fd813..873f7b775f 100644 --- a/enterprise/tests/unit/conftest.py +++ b/enterprise/tests/unit/conftest.py @@ -12,6 +12,7 @@ from storage.base import Base # Anything not loaded here may not have a table created for it. from storage.billing_session import BillingSession from storage.conversation_work import ConversationWork +from storage.device_code import DeviceCode # noqa: F401 from storage.feedback import Feedback from storage.github_app_installation import GithubAppInstallation from storage.maintenance_task import MaintenanceTask, MaintenanceTaskStatus diff --git a/enterprise/tests/unit/server/routes/test_oauth_device.py b/enterprise/tests/unit/server/routes/test_oauth_device.py new file mode 100644 index 0000000000..53682e65f0 --- /dev/null +++ b/enterprise/tests/unit/server/routes/test_oauth_device.py @@ -0,0 +1,610 @@ +"""Unit tests for OAuth2 Device Flow endpoints.""" + +from datetime import UTC, datetime, timedelta +from unittest.mock import MagicMock, patch + +import pytest +from fastapi import HTTPException, Request +from fastapi.responses import JSONResponse +from server.routes.oauth_device import ( + device_authorization, + device_token, + device_verification_authenticated, +) +from storage.device_code import DeviceCode + + +@pytest.fixture +def mock_device_code_store(): + """Mock device code store.""" + return MagicMock() + + +@pytest.fixture +def mock_api_key_store(): + """Mock API key store.""" + return MagicMock() + + +@pytest.fixture +def mock_token_manager(): + """Mock token manager.""" + return MagicMock() + + +@pytest.fixture +def mock_request(): + """Mock FastAPI request.""" + request = MagicMock(spec=Request) + request.base_url = 'https://test.example.com/' + return request + + +class TestDeviceAuthorization: + """Test device authorization endpoint.""" + + @patch('server.routes.oauth_device.device_code_store') + async def test_device_authorization_success(self, mock_store, mock_request): + """Test successful device authorization.""" + mock_device = DeviceCode( + device_code='test-device-code-123', + user_code='ABC12345', + expires_at=datetime.now(UTC) + timedelta(minutes=10), + current_interval=5, # Default interval + ) + mock_store.create_device_code.return_value = mock_device + + result = await device_authorization(mock_request) + + assert result.device_code == 'test-device-code-123' + assert result.user_code == 'ABC12345' + assert result.expires_in == 600 + assert result.interval == 5 # Should match device's current_interval + assert 'verify' in result.verification_uri + assert 'ABC12345' in result.verification_uri_complete + + @patch('server.routes.oauth_device.device_code_store') + async def test_device_authorization_with_increased_interval( + self, mock_store, mock_request + ): + """Test device authorization returns increased interval from rate limiting.""" + mock_device = DeviceCode( + device_code='test-device-code-456', + user_code='XYZ98765', + expires_at=datetime.now(UTC) + timedelta(minutes=10), + current_interval=15, # Increased interval from previous rate limiting + ) + mock_store.create_device_code.return_value = mock_device + + result = await device_authorization(mock_request) + + assert result.device_code == 'test-device-code-456' + assert result.user_code == 'XYZ98765' + assert result.expires_in == 600 + assert result.interval == 15 # Should match device's increased current_interval + assert 'verify' in result.verification_uri + assert 'XYZ98765' in result.verification_uri_complete + + +class TestDeviceToken: + """Test device token endpoint.""" + + @pytest.mark.parametrize( + 'device_exists,status,expected_error', + [ + (False, None, 'invalid_grant'), + (True, 'expired', 'expired_token'), + (True, 'denied', 'access_denied'), + (True, 'pending', 'authorization_pending'), + ], + ) + @patch('server.routes.oauth_device.device_code_store') + async def test_device_token_error_cases( + self, mock_store, device_exists, status, expected_error + ): + """Test various error cases for device token endpoint.""" + device_code = 'test-device-code' + + if device_exists: + mock_device = MagicMock() + mock_device.is_expired.return_value = status == 'expired' + mock_device.status = status + # Mock rate limiting - return False (not too fast) and default interval + mock_device.check_rate_limit.return_value = (False, 5) + mock_store.get_by_device_code.return_value = mock_device + mock_store.update_poll_time.return_value = True + else: + mock_store.get_by_device_code.return_value = None + + result = await device_token(device_code=device_code) + + assert isinstance(result, JSONResponse) + assert result.status_code == 400 + # Check error in response content + content = result.body.decode() + assert expected_error in content + + @patch('server.routes.oauth_device.ApiKeyStore') + @patch('server.routes.oauth_device.device_code_store') + async def test_device_token_success(self, mock_store, mock_api_key_class): + """Test successful device token retrieval.""" + device_code = 'test-device-code' + + # Mock authorized device + mock_device = MagicMock() + mock_device.is_expired.return_value = False + mock_device.status = 'authorized' + mock_device.keycloak_user_id = 'user-123' + mock_device.user_code = ( + 'ABC12345' # Add user_code for device-specific API key lookup + ) + # Mock rate limiting - return False (not too fast) and default interval + mock_device.check_rate_limit.return_value = (False, 5) + mock_store.get_by_device_code.return_value = mock_device + mock_store.update_poll_time.return_value = True + + # Mock API key retrieval + mock_api_key_store = MagicMock() + mock_api_key_store.retrieve_api_key_by_name.return_value = 'test-api-key' + mock_api_key_class.get_instance.return_value = mock_api_key_store + + result = await device_token(device_code=device_code) + + # Check that result is a DeviceTokenResponse + assert result.access_token == 'test-api-key' + assert result.token_type == 'Bearer' + + # Verify that the correct device-specific API key name was used + mock_api_key_store.retrieve_api_key_by_name.assert_called_once_with( + 'user-123', 'Device Link Access Key (ABC12345)' + ) + + +class TestDeviceVerificationAuthenticated: + """Test device verification authenticated endpoint.""" + + async def test_verification_unauthenticated_user(self): + """Test verification with unauthenticated user.""" + with pytest.raises(HTTPException): + await device_verification_authenticated(user_code='ABC12345', user_id=None) + + @patch('server.routes.oauth_device.ApiKeyStore') + @patch('server.routes.oauth_device.device_code_store') + async def test_verification_invalid_device_code( + self, mock_store, mock_api_key_class + ): + """Test verification with invalid device code.""" + mock_store.get_by_user_code.return_value = None + + with pytest.raises(HTTPException): + await device_verification_authenticated( + user_code='INVALID', user_id='user-123' + ) + + @patch('server.routes.oauth_device.ApiKeyStore') + @patch('server.routes.oauth_device.device_code_store') + async def test_verification_already_processed(self, mock_store, mock_api_key_class): + """Test verification with already processed device code.""" + mock_device = MagicMock() + mock_device.is_pending.return_value = False + mock_store.get_by_user_code.return_value = mock_device + + with pytest.raises(HTTPException): + await device_verification_authenticated( + user_code='ABC12345', user_id='user-123' + ) + + @patch('server.routes.oauth_device.ApiKeyStore') + @patch('server.routes.oauth_device.device_code_store') + async def test_verification_success(self, mock_store, mock_api_key_class): + """Test successful device verification.""" + # Mock device code + mock_device = MagicMock() + mock_device.is_pending.return_value = True + mock_store.get_by_user_code.return_value = mock_device + mock_store.authorize_device_code.return_value = True + + # Mock API key store + mock_api_key_store = MagicMock() + mock_api_key_class.get_instance.return_value = mock_api_key_store + + result = await device_verification_authenticated( + user_code='ABC12345', user_id='user-123' + ) + + assert isinstance(result, JSONResponse) + assert result.status_code == 200 + # Should NOT delete existing API keys (multiple devices allowed) + mock_api_key_store.delete_api_key_by_name.assert_not_called() + # Should create a new API key with device-specific name + mock_api_key_store.create_api_key.assert_called_once() + call_args = mock_api_key_store.create_api_key.call_args + assert call_args[1]['name'] == 'Device Link Access Key (ABC12345)' + mock_store.authorize_device_code.assert_called_once_with( + user_code='ABC12345', user_id='user-123' + ) + + @patch('server.routes.oauth_device.ApiKeyStore') + @patch('server.routes.oauth_device.device_code_store') + async def test_multiple_device_authentication(self, mock_store, mock_api_key_class): + """Test that multiple devices can authenticate simultaneously.""" + # Mock API key store + mock_api_key_store = MagicMock() + mock_api_key_class.get_instance.return_value = mock_api_key_store + + # Simulate two different devices + device1_code = 'ABC12345' + device2_code = 'XYZ67890' + user_id = 'user-123' + + # Mock device codes + mock_device1 = MagicMock() + mock_device1.is_pending.return_value = True + mock_device2 = MagicMock() + mock_device2.is_pending.return_value = True + + # Configure mock store to return appropriate device for each user_code + def get_by_user_code_side_effect(user_code): + if user_code == device1_code: + return mock_device1 + elif user_code == device2_code: + return mock_device2 + return None + + mock_store.get_by_user_code.side_effect = get_by_user_code_side_effect + mock_store.authorize_device_code.return_value = True + + # Authenticate first device + result1 = await device_verification_authenticated( + user_code=device1_code, user_id=user_id + ) + + # Authenticate second device + result2 = await device_verification_authenticated( + user_code=device2_code, user_id=user_id + ) + + # Both should succeed + assert isinstance(result1, JSONResponse) + assert result1.status_code == 200 + assert isinstance(result2, JSONResponse) + assert result2.status_code == 200 + + # Should create two separate API keys with different names + assert mock_api_key_store.create_api_key.call_count == 2 + + # Check that each device got a unique API key name + call_args_list = mock_api_key_store.create_api_key.call_args_list + device1_name = call_args_list[0][1]['name'] + device2_name = call_args_list[1][1]['name'] + + assert device1_name == f'Device Link Access Key ({device1_code})' + assert device2_name == f'Device Link Access Key ({device2_code})' + assert device1_name != device2_name # Ensure they're different + + # Should NOT delete any existing API keys + mock_api_key_store.delete_api_key_by_name.assert_not_called() + + +class TestDeviceTokenRateLimiting: + """Test rate limiting for device token polling (RFC 8628 section 3.5).""" + + @patch('server.routes.oauth_device.device_code_store') + async def test_first_poll_allowed(self, mock_store): + """Test that the first poll is always allowed.""" + # Create a device code with no previous poll time + mock_device = DeviceCode( + device_code='test_device_code', + user_code='ABC123', + status='pending', + expires_at=datetime.now(UTC) + timedelta(minutes=10), + last_poll_time=None, # First poll + current_interval=5, + ) + mock_store.get_by_device_code.return_value = mock_device + mock_store.update_poll_time.return_value = True + + device_code = 'test_device_code' + result = await device_token(device_code=device_code) + + # Should return authorization_pending, not slow_down + assert isinstance(result, JSONResponse) + assert result.status_code == 400 + content = result.body.decode() + assert 'authorization_pending' in content + assert 'slow_down' not in content + + # Should update poll time without increasing interval + mock_store.update_poll_time.assert_called_with( + 'test_device_code', increase_interval=False + ) + + @patch('server.routes.oauth_device.device_code_store') + async def test_normal_polling_allowed(self, mock_store): + """Test that normal polling (respecting interval) is allowed.""" + # Create a device code with last poll time 6 seconds ago (interval is 5) + last_poll = datetime.now(UTC) - timedelta(seconds=6) + mock_device = DeviceCode( + device_code='test_device_code', + user_code='ABC123', + status='pending', + expires_at=datetime.now(UTC) + timedelta(minutes=10), + last_poll_time=last_poll, + current_interval=5, + ) + mock_store.get_by_device_code.return_value = mock_device + mock_store.update_poll_time.return_value = True + + device_code = 'test_device_code' + result = await device_token(device_code=device_code) + + # Should return authorization_pending, not slow_down + assert isinstance(result, JSONResponse) + assert result.status_code == 400 + content = result.body.decode() + assert 'authorization_pending' in content + assert 'slow_down' not in content + + # Should update poll time without increasing interval + mock_store.update_poll_time.assert_called_with( + 'test_device_code', increase_interval=False + ) + + @patch('server.routes.oauth_device.device_code_store') + async def test_fast_polling_returns_slow_down(self, mock_store): + """Test that polling too fast returns slow_down error.""" + # Create a device code with last poll time 2 seconds ago (interval is 5) + last_poll = datetime.now(UTC) - timedelta(seconds=2) + mock_device = DeviceCode( + device_code='test_device_code', + user_code='ABC123', + status='pending', + expires_at=datetime.now(UTC) + timedelta(minutes=10), + last_poll_time=last_poll, + current_interval=5, + ) + mock_store.get_by_device_code.return_value = mock_device + mock_store.update_poll_time.return_value = True + + device_code = 'test_device_code' + result = await device_token(device_code=device_code) + + # Should return slow_down error + assert isinstance(result, JSONResponse) + assert result.status_code == 400 + content = result.body.decode() + assert 'slow_down' in content + assert 'interval' in content + assert '10' in content # New interval should be 5 + 5 = 10 + + # Should update poll time and increase interval + mock_store.update_poll_time.assert_called_with( + 'test_device_code', increase_interval=True + ) + + @patch('server.routes.oauth_device.device_code_store') + async def test_interval_increases_with_repeated_fast_polling(self, mock_store): + """Test that interval increases with repeated fast polling.""" + # Create a device code with higher current interval from previous slow_down + last_poll = datetime.now(UTC) - timedelta(seconds=5) # 5 seconds ago + mock_device = DeviceCode( + device_code='test_device_code', + user_code='ABC123', + status='pending', + expires_at=datetime.now(UTC) + timedelta(minutes=10), + last_poll_time=last_poll, + current_interval=15, # Already increased from previous slow_down + ) + mock_store.get_by_device_code.return_value = mock_device + mock_store.update_poll_time.return_value = True + + device_code = 'test_device_code' + result = await device_token(device_code=device_code) + + # Should return slow_down error with increased interval + assert isinstance(result, JSONResponse) + assert result.status_code == 400 + content = result.body.decode() + assert 'slow_down' in content + assert '20' in content # New interval should be 15 + 5 = 20 + + # Should update poll time and increase interval + mock_store.update_poll_time.assert_called_with( + 'test_device_code', increase_interval=True + ) + + @patch('server.routes.oauth_device.device_code_store') + async def test_interval_caps_at_maximum(self, mock_store): + """Test that interval is capped at maximum value.""" + # Create a device code with interval near maximum + last_poll = datetime.now(UTC) - timedelta(seconds=30) + mock_device = DeviceCode( + device_code='test_device_code', + user_code='ABC123', + status='pending', + expires_at=datetime.now(UTC) + timedelta(minutes=10), + last_poll_time=last_poll, + current_interval=58, # Near maximum of 60 + ) + mock_store.get_by_device_code.return_value = mock_device + mock_store.update_poll_time.return_value = True + + device_code = 'test_device_code' + result = await device_token(device_code=device_code) + + # Should return slow_down error with capped interval + assert isinstance(result, JSONResponse) + assert result.status_code == 400 + content = result.body.decode() + assert 'slow_down' in content + assert '60' in content # Should be capped at 60, not 63 + + @patch('server.routes.oauth_device.device_code_store') + async def test_rate_limiting_with_authorized_device(self, mock_store): + """Test that rate limiting still applies to authorized devices.""" + # Create an authorized device code with recent poll + last_poll = datetime.now(UTC) - timedelta(seconds=2) + mock_device = DeviceCode( + device_code='test_device_code', + user_code='ABC123', + status='authorized', # Device is authorized + keycloak_user_id='user123', + expires_at=datetime.now(UTC) + timedelta(minutes=10), + last_poll_time=last_poll, + current_interval=5, + ) + mock_store.get_by_device_code.return_value = mock_device + mock_store.update_poll_time.return_value = True + + device_code = 'test_device_code' + result = await device_token(device_code=device_code) + + # Should still return slow_down error even for authorized device + assert isinstance(result, JSONResponse) + assert result.status_code == 400 + content = result.body.decode() + assert 'slow_down' in content + + # Should update poll time and increase interval + mock_store.update_poll_time.assert_called_with( + 'test_device_code', increase_interval=True + ) + + +class TestDeviceVerificationTransactionIntegrity: + """Test transaction integrity for device verification to prevent orphaned API keys.""" + + @patch('server.routes.oauth_device.ApiKeyStore') + @patch('server.routes.oauth_device.device_code_store') + async def test_authorization_failure_prevents_api_key_creation( + self, mock_store, mock_api_key_class + ): + """Test that if device authorization fails, no API key is created.""" + # Mock device code + mock_device = MagicMock() + mock_device.is_pending.return_value = True + mock_store.get_by_user_code.return_value = mock_device + mock_store.authorize_device_code.return_value = False # Authorization fails + + # Mock API key store + mock_api_key_store = MagicMock() + mock_api_key_class.get_instance.return_value = mock_api_key_store + + # Should raise HTTPException due to authorization failure + with pytest.raises(HTTPException) as exc_info: + await device_verification_authenticated( + user_code='ABC12345', user_id='user-123' + ) + + assert exc_info.value.status_code == 500 + assert 'Failed to authorize the device' in exc_info.value.detail + + # API key should NOT be created since authorization failed + mock_api_key_store.create_api_key.assert_not_called() + mock_store.authorize_device_code.assert_called_once_with( + user_code='ABC12345', user_id='user-123' + ) + + @patch('server.routes.oauth_device.ApiKeyStore') + @patch('server.routes.oauth_device.device_code_store') + async def test_api_key_creation_failure_reverts_authorization( + self, mock_store, mock_api_key_class + ): + """Test that if API key creation fails after authorization, the authorization is reverted.""" + # Mock device code + mock_device = MagicMock() + mock_device.is_pending.return_value = True + mock_store.get_by_user_code.return_value = mock_device + mock_store.authorize_device_code.return_value = True # Authorization succeeds + mock_store.deny_device_code.return_value = True # Cleanup succeeds + + # Mock API key store to fail on creation + mock_api_key_store = MagicMock() + mock_api_key_store.create_api_key.side_effect = Exception('Database error') + mock_api_key_class.get_instance.return_value = mock_api_key_store + + # Should raise HTTPException due to API key creation failure + with pytest.raises(HTTPException) as exc_info: + await device_verification_authenticated( + user_code='ABC12345', user_id='user-123' + ) + + assert exc_info.value.status_code == 500 + assert 'Failed to create API key for device access' in exc_info.value.detail + + # Authorization should have been attempted first + mock_store.authorize_device_code.assert_called_once_with( + user_code='ABC12345', user_id='user-123' + ) + + # API key creation should have been attempted after authorization + mock_api_key_store.create_api_key.assert_called_once() + + # Authorization should be reverted due to API key creation failure + mock_store.deny_device_code.assert_called_once_with('ABC12345') + + @patch('server.routes.oauth_device.ApiKeyStore') + @patch('server.routes.oauth_device.device_code_store') + async def test_api_key_creation_failure_cleanup_failure_logged( + self, mock_store, mock_api_key_class + ): + """Test that cleanup failure is logged but doesn't prevent the main error from being raised.""" + # Mock device code + mock_device = MagicMock() + mock_device.is_pending.return_value = True + mock_store.get_by_user_code.return_value = mock_device + mock_store.authorize_device_code.return_value = True # Authorization succeeds + mock_store.deny_device_code.side_effect = Exception( + 'Cleanup failed' + ) # Cleanup fails + + # Mock API key store to fail on creation + mock_api_key_store = MagicMock() + mock_api_key_store.create_api_key.side_effect = Exception('Database error') + mock_api_key_class.get_instance.return_value = mock_api_key_store + + # Should still raise HTTPException for the original API key creation failure + with pytest.raises(HTTPException) as exc_info: + await device_verification_authenticated( + user_code='ABC12345', user_id='user-123' + ) + + assert exc_info.value.status_code == 500 + assert 'Failed to create API key for device access' in exc_info.value.detail + + # Both operations should have been attempted + mock_store.authorize_device_code.assert_called_once() + mock_api_key_store.create_api_key.assert_called_once() + mock_store.deny_device_code.assert_called_once_with('ABC12345') + + @patch('server.routes.oauth_device.ApiKeyStore') + @patch('server.routes.oauth_device.device_code_store') + async def test_successful_flow_creates_api_key_after_authorization( + self, mock_store, mock_api_key_class + ): + """Test that in the successful flow, API key is created only after authorization.""" + # Mock device code + mock_device = MagicMock() + mock_device.is_pending.return_value = True + mock_store.get_by_user_code.return_value = mock_device + mock_store.authorize_device_code.return_value = True # Authorization succeeds + + # Mock API key store + mock_api_key_store = MagicMock() + mock_api_key_class.get_instance.return_value = mock_api_key_store + + result = await device_verification_authenticated( + user_code='ABC12345', user_id='user-123' + ) + + assert isinstance(result, JSONResponse) + assert result.status_code == 200 + + # Verify the order: authorization first, then API key creation + mock_store.authorize_device_code.assert_called_once_with( + user_code='ABC12345', user_id='user-123' + ) + mock_api_key_store.create_api_key.assert_called_once() + + # No cleanup should be needed in successful case + mock_store.deny_device_code.assert_not_called() diff --git a/enterprise/tests/unit/storage/test_device_code.py b/enterprise/tests/unit/storage/test_device_code.py new file mode 100644 index 0000000000..0d2193075b --- /dev/null +++ b/enterprise/tests/unit/storage/test_device_code.py @@ -0,0 +1,83 @@ +"""Unit tests for DeviceCode model.""" + +from datetime import datetime, timedelta, timezone + +import pytest +from storage.device_code import DeviceCode, DeviceCodeStatus + + +class TestDeviceCode: + """Test cases for DeviceCode model.""" + + @pytest.fixture + def device_code(self): + """Create a test device code.""" + return DeviceCode( + device_code='test-device-code-123', + user_code='ABC12345', + expires_at=datetime.now(timezone.utc) + timedelta(minutes=10), + ) + + @pytest.mark.parametrize( + 'expires_delta,expected', + [ + (timedelta(minutes=5), False), # Future expiry + (timedelta(minutes=-5), True), # Past expiry + (timedelta(seconds=1), False), # Just future (not expired) + ], + ) + def test_is_expired(self, expires_delta, expected): + """Test expiration check with various time deltas.""" + device_code = DeviceCode( + device_code='test-device-code', + user_code='ABC12345', + expires_at=datetime.now(timezone.utc) + expires_delta, + ) + assert device_code.is_expired() == expected + + @pytest.mark.parametrize( + 'status,expired,expected', + [ + (DeviceCodeStatus.PENDING.value, False, True), + (DeviceCodeStatus.PENDING.value, True, False), + (DeviceCodeStatus.AUTHORIZED.value, False, False), + (DeviceCodeStatus.DENIED.value, False, False), + ], + ) + def test_is_pending(self, status, expired, expected): + """Test pending status check.""" + expires_at = ( + datetime.now(timezone.utc) - timedelta(minutes=1) + if expired + else datetime.now(timezone.utc) + timedelta(minutes=10) + ) + device_code = DeviceCode( + device_code='test-device-code', + user_code='ABC12345', + status=status, + expires_at=expires_at, + ) + assert device_code.is_pending() == expected + + def test_authorize(self, device_code): + """Test device authorization.""" + user_id = 'test-user-123' + + device_code.authorize(user_id) + + assert device_code.status == DeviceCodeStatus.AUTHORIZED.value + assert device_code.keycloak_user_id == user_id + assert device_code.authorized_at is not None + assert isinstance(device_code.authorized_at, datetime) + + @pytest.mark.parametrize( + 'method,expected_status', + [ + ('deny', DeviceCodeStatus.DENIED.value), + ('expire', DeviceCodeStatus.EXPIRED.value), + ], + ) + def test_status_changes(self, device_code, method, expected_status): + """Test status change methods.""" + getattr(device_code, method)() + assert device_code.status == expected_status diff --git a/enterprise/tests/unit/storage/test_device_code_store.py b/enterprise/tests/unit/storage/test_device_code_store.py new file mode 100644 index 0000000000..65a58cda8a --- /dev/null +++ b/enterprise/tests/unit/storage/test_device_code_store.py @@ -0,0 +1,193 @@ +"""Unit tests for DeviceCodeStore.""" + +from unittest.mock import MagicMock + +import pytest +from sqlalchemy.exc import IntegrityError +from storage.device_code import DeviceCode +from storage.device_code_store import DeviceCodeStore + + +@pytest.fixture +def mock_session(): + """Mock database session.""" + session = MagicMock() + return session + + +@pytest.fixture +def mock_session_maker(mock_session): + """Mock session maker.""" + session_maker = MagicMock() + session_maker.return_value.__enter__.return_value = mock_session + session_maker.return_value.__exit__.return_value = None + return session_maker + + +@pytest.fixture +def device_code_store(mock_session_maker): + """Create DeviceCodeStore instance.""" + return DeviceCodeStore(mock_session_maker) + + +class TestDeviceCodeStore: + """Test cases for DeviceCodeStore.""" + + def test_generate_user_code(self, device_code_store): + """Test user code generation.""" + code = device_code_store.generate_user_code() + + assert len(code) == 8 + assert code.isupper() + # Should not contain confusing characters + assert not any(char in code for char in 'IO01') + + def test_generate_device_code(self, device_code_store): + """Test device code generation.""" + code = device_code_store.generate_device_code() + + assert len(code) == 128 + assert code.isalnum() + + def test_create_device_code_success(self, device_code_store, mock_session): + """Test successful device code creation.""" + # Mock successful creation (no IntegrityError) + mock_device_code = MagicMock(spec=DeviceCode) + mock_device_code.device_code = 'test-device-code-123' + mock_device_code.user_code = 'TESTCODE' + + # Mock the session to return our mock device code after refresh + def mock_refresh(obj): + obj.device_code = mock_device_code.device_code + obj.user_code = mock_device_code.user_code + + mock_session.refresh.side_effect = mock_refresh + + result = device_code_store.create_device_code(expires_in=600) + + assert isinstance(result, DeviceCode) + mock_session.add.assert_called_once() + mock_session.commit.assert_called_once() + mock_session.refresh.assert_called_once() + mock_session.expunge.assert_called_once() + + def test_create_device_code_with_retries( + self, device_code_store, mock_session_maker + ): + """Test device code creation with constraint violation retries.""" + mock_session = MagicMock() + mock_session_maker.return_value.__enter__.return_value = mock_session + mock_session_maker.return_value.__exit__.return_value = None + + # First attempt fails with IntegrityError, second succeeds + mock_session.commit.side_effect = [IntegrityError('', '', ''), None] + + mock_device_code = MagicMock(spec=DeviceCode) + mock_device_code.device_code = 'test-device-code-456' + mock_device_code.user_code = 'TESTCD2' + + def mock_refresh(obj): + obj.device_code = mock_device_code.device_code + obj.user_code = mock_device_code.user_code + + mock_session.refresh.side_effect = mock_refresh + + store = DeviceCodeStore(mock_session_maker) + result = store.create_device_code(expires_in=600) + + assert isinstance(result, DeviceCode) + assert mock_session.add.call_count == 2 # Two attempts + assert mock_session.commit.call_count == 2 # Two attempts + + def test_create_device_code_max_attempts_exceeded( + self, device_code_store, mock_session_maker + ): + """Test device code creation failure after max attempts.""" + mock_session = MagicMock() + mock_session_maker.return_value.__enter__.return_value = mock_session + mock_session_maker.return_value.__exit__.return_value = None + + # All attempts fail with IntegrityError + mock_session.commit.side_effect = IntegrityError('', '', '') + + store = DeviceCodeStore(mock_session_maker) + + with pytest.raises( + RuntimeError, + match='Failed to generate unique device codes after 3 attempts', + ): + store.create_device_code(expires_in=600, max_attempts=3) + + @pytest.mark.parametrize( + 'lookup_method,lookup_field', + [ + ('get_by_device_code', 'device_code'), + ('get_by_user_code', 'user_code'), + ], + ) + def test_lookup_methods( + self, device_code_store, mock_session, lookup_method, lookup_field + ): + """Test device code lookup methods.""" + test_code = 'test-code-123' + mock_device_code = MagicMock() + mock_session.query.return_value.filter_by.return_value.first.return_value = ( + mock_device_code + ) + + result = getattr(device_code_store, lookup_method)(test_code) + + assert result == mock_device_code + mock_session.query.assert_called_once_with(DeviceCode) + mock_session.query.return_value.filter_by.assert_called_once_with( + **{lookup_field: test_code} + ) + + @pytest.mark.parametrize( + 'device_exists,is_pending,expected_result', + [ + (True, True, True), # Success case + (False, True, False), # Device not found + (True, False, False), # Device not pending + ], + ) + def test_authorize_device_code( + self, + device_code_store, + mock_session, + device_exists, + is_pending, + expected_result, + ): + """Test device code authorization.""" + user_code = 'ABC12345' + user_id = 'test-user-123' + + if device_exists: + mock_device = MagicMock() + mock_device.is_pending.return_value = is_pending + mock_session.query.return_value.filter_by.return_value.first.return_value = mock_device + else: + mock_session.query.return_value.filter_by.return_value.first.return_value = None + + result = device_code_store.authorize_device_code(user_code, user_id) + + assert result == expected_result + if expected_result: + mock_device.authorize.assert_called_once_with(user_id) + mock_session.commit.assert_called_once() + + def test_deny_device_code(self, device_code_store, mock_session): + """Test device code denial.""" + user_code = 'ABC12345' + mock_device = MagicMock() + mock_device.is_pending.return_value = True + mock_session.query.return_value.filter_by.return_value.first.return_value = ( + mock_device + ) + + result = device_code_store.deny_device_code(user_code) + + assert result is True + mock_device.deny.assert_called_once() + mock_session.commit.assert_called_once() diff --git a/enterprise/tests/unit/test_api_key_store.py b/enterprise/tests/unit/test_api_key_store.py index ea386cb69c..c1c6a98f3d 100644 --- a/enterprise/tests/unit/test_api_key_store.py +++ b/enterprise/tests/unit/test_api_key_store.py @@ -90,6 +90,50 @@ def test_validate_api_key_expired(api_key_store, mock_session): mock_session.commit.assert_not_called() +def test_validate_api_key_expired_timezone_naive(api_key_store, mock_session): + """Test validating an expired API key with timezone-naive datetime from database.""" + # Setup + api_key = 'test-api-key' + mock_key_record = MagicMock() + # Simulate timezone-naive datetime as returned from database + mock_key_record.expires_at = datetime.now() - timedelta(days=1) # No UTC timezone + mock_key_record.id = 1 + mock_session.query.return_value.filter.return_value.first.return_value = ( + mock_key_record + ) + + # Execute + result = api_key_store.validate_api_key(api_key) + + # Verify + assert result is None + mock_session.execute.assert_not_called() + mock_session.commit.assert_not_called() + + +def test_validate_api_key_valid_timezone_naive(api_key_store, mock_session): + """Test validating a valid API key with timezone-naive datetime from database.""" + # Setup + api_key = 'test-api-key' + user_id = 'test-user-123' + mock_key_record = MagicMock() + mock_key_record.user_id = user_id + # Simulate timezone-naive datetime as returned from database (future date) + mock_key_record.expires_at = datetime.now() + timedelta(days=1) # No UTC timezone + mock_key_record.id = 1 + mock_session.query.return_value.filter.return_value.first.return_value = ( + mock_key_record + ) + + # Execute + result = api_key_store.validate_api_key(api_key) + + # Verify + assert result == user_id + mock_session.execute.assert_called_once() + mock_session.commit.assert_called_once() + + def test_validate_api_key_not_found(api_key_store, mock_session): """Test validating a non-existent API key.""" # Setup diff --git a/frontend/src/routes.ts b/frontend/src/routes.ts index 4c3c48adc5..ecee511688 100644 --- a/frontend/src/routes.ts +++ b/frontend/src/routes.ts @@ -21,5 +21,6 @@ export default [ ]), route("conversations/:conversationId", "routes/conversation.tsx"), route("microagent-management", "routes/microagent-management.tsx"), + route("oauth/device/verify", "routes/device-verify.tsx"), ]), ] satisfies RouteConfig; diff --git a/frontend/src/routes/device-verify.tsx b/frontend/src/routes/device-verify.tsx new file mode 100644 index 0000000000..f306d660a5 --- /dev/null +++ b/frontend/src/routes/device-verify.tsx @@ -0,0 +1,274 @@ +/* eslint-disable i18next/no-literal-string */ +import React, { useState } from "react"; +import { useSearchParams } from "react-router"; +import { useIsAuthed } from "#/hooks/query/use-is-authed"; + +export default function DeviceVerify() { + const [searchParams] = useSearchParams(); + const { data: isAuthed, isLoading: isAuthLoading } = useIsAuthed(); + const [verificationResult, setVerificationResult] = useState<{ + success: boolean; + message: string; + } | null>(null); + const [isProcessing, setIsProcessing] = useState(false); + + // Get user_code from URL parameters + const userCode = searchParams.get("user_code"); + + const processDeviceVerification = async (code: string) => { + try { + setIsProcessing(true); + + // Call the backend API endpoint to process device verification + const response = await fetch("/oauth/device/verify-authenticated", { + method: "POST", + headers: { + "Content-Type": "application/x-www-form-urlencoded", + }, + body: `user_code=${encodeURIComponent(code)}`, + credentials: "include", // Include cookies for authentication + }); + + if (response.ok) { + // Show success message + setVerificationResult({ + success: true, + message: + "Device authorized successfully! You can now return to your CLI and close this window.", + }); + } else { + const errorText = await response.text(); + setVerificationResult({ + success: false, + message: errorText || "Failed to authorize device. Please try again.", + }); + } + } catch (error) { + setVerificationResult({ + success: false, + message: + "An error occurred while authorizing the device. Please try again.", + }); + } finally { + setIsProcessing(false); + } + }; + + // Remove automatic verification - require explicit user consent + + const handleManualSubmit = (event: React.FormEvent) => { + event.preventDefault(); + const formData = new FormData(event.currentTarget); + const code = formData.get("user_code") as string; + if (code && isAuthed) { + processDeviceVerification(code); + } + }; + + // Show verification result if we have one + if (verificationResult) { + return ( +
+
+
+
+ {verificationResult.success ? ( + + + + ) : ( + + + + )} +
+

+ {verificationResult.success ? "Success!" : "Error"} +

+

+ {verificationResult.message} +

+ {!verificationResult.success && ( + + )} +
+
+
+ ); + } + + // Show processing state + if (isProcessing) { + return ( +
+
+
+
+

+ Processing device verification... +

+
+
+
+ ); + } + + // Show device authorization confirmation if user is authenticated and code is provided + if (isAuthed && userCode) { + return ( +
+
+

+ Device Authorization Request +

+
+

Device Code:

+

+ {userCode} +

+
+
+
+ + + +
+

+ Security Notice +

+

+ Only authorize this device if you initiated this request from + your CLI or application. +

+
+
+
+

+ Do you want to authorize this device to access your OpenHands + account? +

+
+ + +
+
+
+ ); + } + + // Show manual code entry form if no code in URL but user is authenticated + if (isAuthed && !userCode) { + return ( +
+
+

+ Device Authorization +

+

+ Enter the code displayed on your device: +

+
+
+ + +
+ +
+
+
+ ); + } + + // Show loading state while checking authentication + if (isAuthLoading) { + return ( +
+
+
+

+ Processing device verification... +

+
+
+ ); + } + + // Show authentication required message (this will trigger the auth modal via root layout) + return ( +
+
+

Authentication Required

+

+ Please sign in to authorize your device. +

+
+
+ ); +} From 281ac91540ab5e93b9b2075baf865d6912f8a244 Mon Sep 17 00:00:00 2001 From: Tim O'Farrell Date: Tue, 16 Dec 2025 14:53:15 -0700 Subject: [PATCH 02/14] Bump sdk 1.6.0 (#12067) --- enterprise/poetry.lock | 41 ++++++++++--------- .../sandbox/sandbox_spec_service.py | 2 +- poetry.lock | 35 ++++++++-------- pyproject.toml | 6 +-- 4 files changed, 43 insertions(+), 41 deletions(-) diff --git a/enterprise/poetry.lock b/enterprise/poetry.lock index 13645253c5..bd2c55c317 100644 --- a/enterprise/poetry.lock +++ b/enterprise/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. [[package]] name = "aiofiles" @@ -4624,14 +4624,14 @@ files = [ [[package]] name = "lmnr" -version = "0.7.20" +version = "0.7.24" description = "Python SDK for Laminar" optional = false python-versions = "<4,>=3.10" groups = ["main"] files = [ - {file = "lmnr-0.7.20-py3-none-any.whl", hash = "sha256:5f9fa7444e6f96c25e097f66484ff29e632bdd1de0e9346948bf5595f4a8af38"}, - {file = "lmnr-0.7.20.tar.gz", hash = "sha256:1f484cd618db2d71af65f90a0b8b36d20d80dc91a5138b811575c8677bf7c4fd"}, + {file = "lmnr-0.7.24-py3-none-any.whl", hash = "sha256:ad780d4a62ece897048811f3368639c240a9329ab31027da8c96545137a3a08a"}, + {file = "lmnr-0.7.24.tar.gz", hash = "sha256:aa6973f46fc4ba95c9061c1feceb58afc02eb43c9376c21e32545371ff6123d7"}, ] [package.dependencies] @@ -4654,14 +4654,15 @@ tqdm = ">=4.0" [package.extras] alephalpha = ["opentelemetry-instrumentation-alephalpha (>=0.47.1)"] -all = ["opentelemetry-instrumentation-alephalpha (>=0.47.1)", "opentelemetry-instrumentation-bedrock (>=0.47.1)", "opentelemetry-instrumentation-chromadb (>=0.47.1)", "opentelemetry-instrumentation-cohere (>=0.47.1)", "opentelemetry-instrumentation-crewai (>=0.47.1)", "opentelemetry-instrumentation-haystack (>=0.47.1)", "opentelemetry-instrumentation-lancedb (>=0.47.1)", "opentelemetry-instrumentation-langchain (>=0.47.1)", "opentelemetry-instrumentation-llamaindex (>=0.47.1)", "opentelemetry-instrumentation-marqo (>=0.47.1)", "opentelemetry-instrumentation-mcp (>=0.47.1)", "opentelemetry-instrumentation-milvus (>=0.47.1)", "opentelemetry-instrumentation-mistralai (>=0.47.1)", "opentelemetry-instrumentation-ollama (>=0.47.1)", "opentelemetry-instrumentation-pinecone (>=0.47.1)", "opentelemetry-instrumentation-qdrant (>=0.47.1)", "opentelemetry-instrumentation-replicate (>=0.47.1)", "opentelemetry-instrumentation-sagemaker (>=0.47.1)", "opentelemetry-instrumentation-together (>=0.47.1)", "opentelemetry-instrumentation-transformers (>=0.47.1)", "opentelemetry-instrumentation-vertexai (>=0.47.1)", "opentelemetry-instrumentation-watsonx (>=0.47.1)", "opentelemetry-instrumentation-weaviate (>=0.47.1)"] +all = ["opentelemetry-instrumentation-alephalpha (>=0.47.1)", "opentelemetry-instrumentation-bedrock (>=0.47.1)", "opentelemetry-instrumentation-chromadb (>=0.47.1)", "opentelemetry-instrumentation-cohere (>=0.47.1)", "opentelemetry-instrumentation-crewai (>=0.47.1)", "opentelemetry-instrumentation-haystack (>=0.47.1)", "opentelemetry-instrumentation-lancedb (>=0.47.1)", "opentelemetry-instrumentation-langchain (>=0.47.1,<0.48.0)", "opentelemetry-instrumentation-llamaindex (>=0.47.1)", "opentelemetry-instrumentation-marqo (>=0.47.1)", "opentelemetry-instrumentation-mcp (>=0.47.1)", "opentelemetry-instrumentation-milvus (>=0.47.1)", "opentelemetry-instrumentation-mistralai (>=0.47.1)", "opentelemetry-instrumentation-ollama (>=0.47.1)", "opentelemetry-instrumentation-pinecone (>=0.47.1)", "opentelemetry-instrumentation-qdrant (>=0.47.1)", "opentelemetry-instrumentation-replicate (>=0.47.1)", "opentelemetry-instrumentation-sagemaker (>=0.47.1)", "opentelemetry-instrumentation-together (>=0.47.1)", "opentelemetry-instrumentation-transformers (>=0.47.1)", "opentelemetry-instrumentation-vertexai (>=0.47.1)", "opentelemetry-instrumentation-watsonx (>=0.47.1)", "opentelemetry-instrumentation-weaviate (>=0.47.1)"] bedrock = ["opentelemetry-instrumentation-bedrock (>=0.47.1)"] chromadb = ["opentelemetry-instrumentation-chromadb (>=0.47.1)"] +claude-agent-sdk = ["lmnr-claude-code-proxy (>=0.1.0a5)"] cohere = ["opentelemetry-instrumentation-cohere (>=0.47.1)"] crewai = ["opentelemetry-instrumentation-crewai (>=0.47.1)"] haystack = ["opentelemetry-instrumentation-haystack (>=0.47.1)"] lancedb = ["opentelemetry-instrumentation-lancedb (>=0.47.1)"] -langchain = ["opentelemetry-instrumentation-langchain (>=0.47.1)"] +langchain = ["opentelemetry-instrumentation-langchain (>=0.47.1,<0.48.0)"] llamaindex = ["opentelemetry-instrumentation-llamaindex (>=0.47.1)"] marqo = ["opentelemetry-instrumentation-marqo (>=0.47.1)"] mcp = ["opentelemetry-instrumentation-mcp (>=0.47.1)"] @@ -5835,14 +5836,14 @@ llama = ["llama-index (>=0.12.29,<0.13.0)", "llama-index-core (>=0.12.29,<0.13.0 [[package]] name = "openhands-agent-server" -version = "1.5.2" +version = "1.6.0" description = "OpenHands Agent Server - REST/WebSocket interface for OpenHands AI Agent" optional = false python-versions = ">=3.12" groups = ["main"] files = [ - {file = "openhands_agent_server-1.5.2-py3-none-any.whl", hash = "sha256:7a368f61036f85446f566b9f6f9d6c7318684776cf2293daa5bce3ee19ac077d"}, - {file = "openhands_agent_server-1.5.2.tar.gz", hash = "sha256:dfaf5583dd71dae933643a8f8160156ce6fa7ed20db5cc3c45465b079bc576cd"}, + {file = "openhands_agent_server-1.6.0-py3-none-any.whl", hash = "sha256:e6ae865ac3e7a96b234e10a0faad23f6210e025bbf7721cb66bc7a71d160848c"}, + {file = "openhands_agent_server-1.6.0.tar.gz", hash = "sha256:44ce7694ae2d4bb0666d318ef13e6618bd4dc73022c60354839fe6130e67d02a"}, ] [package.dependencies] @@ -5859,7 +5860,7 @@ wsproto = ">=1.2.0" [[package]] name = "openhands-ai" -version = "0.62.0" +version = "0.0.0-post.5687+7853b41ad" description = "OpenHands: Code Less, Make More" optional = false python-versions = "^3.12,<3.14" @@ -5901,9 +5902,9 @@ memory-profiler = "^0.61.0" numpy = "*" openai = "2.8.0" openhands-aci = "0.3.2" -openhands-agent-server = "1.5.2" -openhands-sdk = "1.5.2" -openhands-tools = "1.5.2" +openhands-agent-server = "1.6.0" +openhands-sdk = "1.6.0" +openhands-tools = "1.6.0" opentelemetry-api = "^1.33.1" opentelemetry-exporter-otlp-proto-grpc = "^1.33.1" pathspec = "^0.12.1" @@ -5959,14 +5960,14 @@ url = ".." [[package]] name = "openhands-sdk" -version = "1.5.2" +version = "1.6.0" description = "OpenHands SDK - Core functionality for building AI agents" optional = false python-versions = ">=3.12" groups = ["main"] files = [ - {file = "openhands_sdk-1.5.2-py3-none-any.whl", hash = "sha256:593430e9c8729e345fce3fca7e9a9a7ef084a08222d6ba42113e6ba5f6e9f15d"}, - {file = "openhands_sdk-1.5.2.tar.gz", hash = "sha256:798aa8f8ccd84b15deb418c4301d00f33da288bc1a8d41efa5cc47c10aaf3fd6"}, + {file = "openhands_sdk-1.6.0-py3-none-any.whl", hash = "sha256:94d2f87fb35406373da6728ae2d88584137f9e9b67fa0e940444c72f2e44e7d3"}, + {file = "openhands_sdk-1.6.0.tar.gz", hash = "sha256:f45742350e3874a7f5b08befc4a9d5adc7e4454f7ab5f8391c519eee3116090f"}, ] [package.dependencies] @@ -5974,7 +5975,7 @@ deprecation = ">=2.1.0" fastmcp = ">=2.11.3" httpx = ">=0.27.0" litellm = ">=1.80.7" -lmnr = ">=0.7.20" +lmnr = ">=0.7.24" pydantic = ">=2.11.7" python-frontmatter = ">=1.1.0" python-json-logger = ">=3.3.0" @@ -5986,14 +5987,14 @@ boto3 = ["boto3 (>=1.35.0)"] [[package]] name = "openhands-tools" -version = "1.5.2" +version = "1.6.0" description = "OpenHands Tools - Runtime tools for AI agents" optional = false python-versions = ">=3.12" groups = ["main"] files = [ - {file = "openhands_tools-1.5.2-py3-none-any.whl", hash = "sha256:33e9c2af65aaa7b6b9a10b42d2fb11137e6b35e7ac02a4b9269ef37b5c79cc01"}, - {file = "openhands_tools-1.5.2.tar.gz", hash = "sha256:4644a24144fbdf630fb0edc303526b4add61b3fbe7a7434da73f231312c34846"}, + {file = "openhands_tools-1.6.0-py3-none-any.whl", hash = "sha256:176556d44186536751b23fe052d3505492cc2afb8d52db20fb7a2cc0169cd57a"}, + {file = "openhands_tools-1.6.0.tar.gz", hash = "sha256:d07ba31050fd4a7891a4c48388aa53ce9f703e17064ddbd59146d6c77e5980b3"}, ] [package.dependencies] diff --git a/openhands/app_server/sandbox/sandbox_spec_service.py b/openhands/app_server/sandbox/sandbox_spec_service.py index edaecc1b76..fe9d1653a9 100644 --- a/openhands/app_server/sandbox/sandbox_spec_service.py +++ b/openhands/app_server/sandbox/sandbox_spec_service.py @@ -12,7 +12,7 @@ from openhands.sdk.utils.models import DiscriminatedUnionMixin # The version of the agent server to use for deployments. # Typically this will be the same as the values from the pyproject.toml -AGENT_SERVER_IMAGE = 'ghcr.io/openhands/agent-server:8f90b92-python' +AGENT_SERVER_IMAGE = 'ghcr.io/openhands/agent-server:97652be-python' class SandboxSpecService(ABC): diff --git a/poetry.lock b/poetry.lock index 04831cc890..23789d3285 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. [[package]] name = "aiofiles" @@ -5675,14 +5675,14 @@ utils = ["numpydoc"] [[package]] name = "lmnr" -version = "0.7.20" +version = "0.7.24" description = "Python SDK for Laminar" optional = false python-versions = "<4,>=3.10" groups = ["main"] files = [ - {file = "lmnr-0.7.20-py3-none-any.whl", hash = "sha256:5f9fa7444e6f96c25e097f66484ff29e632bdd1de0e9346948bf5595f4a8af38"}, - {file = "lmnr-0.7.20.tar.gz", hash = "sha256:1f484cd618db2d71af65f90a0b8b36d20d80dc91a5138b811575c8677bf7c4fd"}, + {file = "lmnr-0.7.24-py3-none-any.whl", hash = "sha256:ad780d4a62ece897048811f3368639c240a9329ab31027da8c96545137a3a08a"}, + {file = "lmnr-0.7.24.tar.gz", hash = "sha256:aa6973f46fc4ba95c9061c1feceb58afc02eb43c9376c21e32545371ff6123d7"}, ] [package.dependencies] @@ -5705,14 +5705,15 @@ tqdm = ">=4.0" [package.extras] alephalpha = ["opentelemetry-instrumentation-alephalpha (>=0.47.1)"] -all = ["opentelemetry-instrumentation-alephalpha (>=0.47.1)", "opentelemetry-instrumentation-bedrock (>=0.47.1)", "opentelemetry-instrumentation-chromadb (>=0.47.1)", "opentelemetry-instrumentation-cohere (>=0.47.1)", "opentelemetry-instrumentation-crewai (>=0.47.1)", "opentelemetry-instrumentation-haystack (>=0.47.1)", "opentelemetry-instrumentation-lancedb (>=0.47.1)", "opentelemetry-instrumentation-langchain (>=0.47.1)", "opentelemetry-instrumentation-llamaindex (>=0.47.1)", "opentelemetry-instrumentation-marqo (>=0.47.1)", "opentelemetry-instrumentation-mcp (>=0.47.1)", "opentelemetry-instrumentation-milvus (>=0.47.1)", "opentelemetry-instrumentation-mistralai (>=0.47.1)", "opentelemetry-instrumentation-ollama (>=0.47.1)", "opentelemetry-instrumentation-pinecone (>=0.47.1)", "opentelemetry-instrumentation-qdrant (>=0.47.1)", "opentelemetry-instrumentation-replicate (>=0.47.1)", "opentelemetry-instrumentation-sagemaker (>=0.47.1)", "opentelemetry-instrumentation-together (>=0.47.1)", "opentelemetry-instrumentation-transformers (>=0.47.1)", "opentelemetry-instrumentation-vertexai (>=0.47.1)", "opentelemetry-instrumentation-watsonx (>=0.47.1)", "opentelemetry-instrumentation-weaviate (>=0.47.1)"] +all = ["opentelemetry-instrumentation-alephalpha (>=0.47.1)", "opentelemetry-instrumentation-bedrock (>=0.47.1)", "opentelemetry-instrumentation-chromadb (>=0.47.1)", "opentelemetry-instrumentation-cohere (>=0.47.1)", "opentelemetry-instrumentation-crewai (>=0.47.1)", "opentelemetry-instrumentation-haystack (>=0.47.1)", "opentelemetry-instrumentation-lancedb (>=0.47.1)", "opentelemetry-instrumentation-langchain (>=0.47.1,<0.48.0)", "opentelemetry-instrumentation-llamaindex (>=0.47.1)", "opentelemetry-instrumentation-marqo (>=0.47.1)", "opentelemetry-instrumentation-mcp (>=0.47.1)", "opentelemetry-instrumentation-milvus (>=0.47.1)", "opentelemetry-instrumentation-mistralai (>=0.47.1)", "opentelemetry-instrumentation-ollama (>=0.47.1)", "opentelemetry-instrumentation-pinecone (>=0.47.1)", "opentelemetry-instrumentation-qdrant (>=0.47.1)", "opentelemetry-instrumentation-replicate (>=0.47.1)", "opentelemetry-instrumentation-sagemaker (>=0.47.1)", "opentelemetry-instrumentation-together (>=0.47.1)", "opentelemetry-instrumentation-transformers (>=0.47.1)", "opentelemetry-instrumentation-vertexai (>=0.47.1)", "opentelemetry-instrumentation-watsonx (>=0.47.1)", "opentelemetry-instrumentation-weaviate (>=0.47.1)"] bedrock = ["opentelemetry-instrumentation-bedrock (>=0.47.1)"] chromadb = ["opentelemetry-instrumentation-chromadb (>=0.47.1)"] +claude-agent-sdk = ["lmnr-claude-code-proxy (>=0.1.0a5)"] cohere = ["opentelemetry-instrumentation-cohere (>=0.47.1)"] crewai = ["opentelemetry-instrumentation-crewai (>=0.47.1)"] haystack = ["opentelemetry-instrumentation-haystack (>=0.47.1)"] lancedb = ["opentelemetry-instrumentation-lancedb (>=0.47.1)"] -langchain = ["opentelemetry-instrumentation-langchain (>=0.47.1)"] +langchain = ["opentelemetry-instrumentation-langchain (>=0.47.1,<0.48.0)"] llamaindex = ["opentelemetry-instrumentation-llamaindex (>=0.47.1)"] marqo = ["opentelemetry-instrumentation-marqo (>=0.47.1)"] mcp = ["opentelemetry-instrumentation-mcp (>=0.47.1)"] @@ -7379,14 +7380,14 @@ llama = ["llama-index (>=0.12.29,<0.13.0)", "llama-index-core (>=0.12.29,<0.13.0 [[package]] name = "openhands-agent-server" -version = "1.5.2" +version = "1.6.0" description = "OpenHands Agent Server - REST/WebSocket interface for OpenHands AI Agent" optional = false python-versions = ">=3.12" groups = ["main"] files = [ - {file = "openhands_agent_server-1.5.2-py3-none-any.whl", hash = "sha256:7a368f61036f85446f566b9f6f9d6c7318684776cf2293daa5bce3ee19ac077d"}, - {file = "openhands_agent_server-1.5.2.tar.gz", hash = "sha256:dfaf5583dd71dae933643a8f8160156ce6fa7ed20db5cc3c45465b079bc576cd"}, + {file = "openhands_agent_server-1.6.0-py3-none-any.whl", hash = "sha256:e6ae865ac3e7a96b234e10a0faad23f6210e025bbf7721cb66bc7a71d160848c"}, + {file = "openhands_agent_server-1.6.0.tar.gz", hash = "sha256:44ce7694ae2d4bb0666d318ef13e6618bd4dc73022c60354839fe6130e67d02a"}, ] [package.dependencies] @@ -7403,14 +7404,14 @@ wsproto = ">=1.2.0" [[package]] name = "openhands-sdk" -version = "1.5.2" +version = "1.6.0" description = "OpenHands SDK - Core functionality for building AI agents" optional = false python-versions = ">=3.12" groups = ["main"] files = [ - {file = "openhands_sdk-1.5.2-py3-none-any.whl", hash = "sha256:593430e9c8729e345fce3fca7e9a9a7ef084a08222d6ba42113e6ba5f6e9f15d"}, - {file = "openhands_sdk-1.5.2.tar.gz", hash = "sha256:798aa8f8ccd84b15deb418c4301d00f33da288bc1a8d41efa5cc47c10aaf3fd6"}, + {file = "openhands_sdk-1.6.0-py3-none-any.whl", hash = "sha256:94d2f87fb35406373da6728ae2d88584137f9e9b67fa0e940444c72f2e44e7d3"}, + {file = "openhands_sdk-1.6.0.tar.gz", hash = "sha256:f45742350e3874a7f5b08befc4a9d5adc7e4454f7ab5f8391c519eee3116090f"}, ] [package.dependencies] @@ -7418,7 +7419,7 @@ deprecation = ">=2.1.0" fastmcp = ">=2.11.3" httpx = ">=0.27.0" litellm = ">=1.80.7" -lmnr = ">=0.7.20" +lmnr = ">=0.7.24" pydantic = ">=2.11.7" python-frontmatter = ">=1.1.0" python-json-logger = ">=3.3.0" @@ -7430,14 +7431,14 @@ boto3 = ["boto3 (>=1.35.0)"] [[package]] name = "openhands-tools" -version = "1.5.2" +version = "1.6.0" description = "OpenHands Tools - Runtime tools for AI agents" optional = false python-versions = ">=3.12" groups = ["main"] files = [ - {file = "openhands_tools-1.5.2-py3-none-any.whl", hash = "sha256:33e9c2af65aaa7b6b9a10b42d2fb11137e6b35e7ac02a4b9269ef37b5c79cc01"}, - {file = "openhands_tools-1.5.2.tar.gz", hash = "sha256:4644a24144fbdf630fb0edc303526b4add61b3fbe7a7434da73f231312c34846"}, + {file = "openhands_tools-1.6.0-py3-none-any.whl", hash = "sha256:176556d44186536751b23fe052d3505492cc2afb8d52db20fb7a2cc0169cd57a"}, + {file = "openhands_tools-1.6.0.tar.gz", hash = "sha256:d07ba31050fd4a7891a4c48388aa53ce9f703e17064ddbd59146d6c77e5980b3"}, ] [package.dependencies] @@ -16822,4 +16823,4 @@ third-party-runtimes = ["daytona", "e2b-code-interpreter", "modal", "runloop-api [metadata] lock-version = "2.1" python-versions = "^3.12,<3.14" -content-hash = "9ec48649a3b54d1c19d2aae9af77c640e9eadbc6a368ef437a5655f14fc2a37a" +content-hash = "9764f3b69ec8ed35feebd78a826bbc6bfa4ac6d5b56bc999be8bc738b644e538" diff --git a/pyproject.toml b/pyproject.toml index dc2c52a112..c70c110dcc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -116,9 +116,9 @@ pybase62 = "^1.0.0" #openhands-agent-server = { git = "https://github.com/OpenHands/agent-sdk.git", subdirectory = "openhands-agent-server", rev = "15f565b8ac38876e40dc05c08e2b04ccaae4a66d" } #openhands-sdk = { git = "https://github.com/OpenHands/agent-sdk.git", subdirectory = "openhands-sdk", rev = "15f565b8ac38876e40dc05c08e2b04ccaae4a66d" } #openhands-tools = { git = "https://github.com/OpenHands/agent-sdk.git", subdirectory = "openhands-tools", rev = "15f565b8ac38876e40dc05c08e2b04ccaae4a66d" } -openhands-sdk = "1.5.2" -openhands-agent-server = "1.5.2" -openhands-tools = "1.5.2" +openhands-sdk = "1.6.0" +openhands-agent-server = "1.6.0" +openhands-tools = "1.6.0" python-jose = { version = ">=3.3", extras = [ "cryptography" ] } sqlalchemy = { extras = [ "asyncio" ], version = "^2.0.40" } pg8000 = "^1.31.5" From dc14624480db3f427f9844cb3e26b6ca696e9f7e Mon Sep 17 00:00:00 2001 From: Tim O'Farrell Date: Tue, 16 Dec 2025 20:35:46 -0700 Subject: [PATCH 03/14] Fix for frontend stall (#12069) --- frontend/src/utils/parse-terminal-output.ts | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/frontend/src/utils/parse-terminal-output.ts b/frontend/src/utils/parse-terminal-output.ts index a6ccc73cfc..1cd54eb858 100644 --- a/frontend/src/utils/parse-terminal-output.ts +++ b/frontend/src/utils/parse-terminal-output.ts @@ -1,3 +1,5 @@ +const START = "[Python Interpreter: "; + /** * Parses the raw output from the terminal into the command and symbol * @param raw The raw output to be displayed in the terminal @@ -13,9 +15,14 @@ * console.log(parsed.symbol); // openhands@659478cb008c:/workspace $ */ export const parseTerminalOutput = (raw: string) => { - const envRegex = /(.*)\[Python Interpreter: (.*)\]/s; - const match = raw.match(envRegex); - - if (!match) return raw; - return match[1]?.trim() || ""; + const start = raw.indexOf(START); + if (start < 0) { + return raw; + } + const offset = start + START.length; + const end = raw.indexOf("]", offset); + if (end <= offset) { + return raw; + } + return raw.substring(0, start).trim(); }; From 435e53769329578e3e14d6fcbfd6fc660e63340b Mon Sep 17 00:00:00 2001 From: Nhan Nguyen Date: Wed, 17 Dec 2025 07:05:10 -0500 Subject: [PATCH 04/14] fix: Prevent old instructions from being re-executed after conversation condensation (#11982) --- .../agenthub/codeact_agent/codeact_agent.py | 15 +- openhands/memory/conversation_memory.py | 41 ++++- openhands/memory/view.py | 3 + tests/unit/agenthub/test_agents.py | 8 +- tests/unit/agenthub/test_prompt_caching.py | 4 +- tests/unit/memory/test_conversation_memory.py | 144 +++++++++++++++++- 6 files changed, 197 insertions(+), 18 deletions(-) diff --git a/openhands/agenthub/codeact_agent/codeact_agent.py b/openhands/agenthub/codeact_agent/codeact_agent.py index 85e5f88cbc..9dd814e9cf 100644 --- a/openhands/agenthub/codeact_agent/codeact_agent.py +++ b/openhands/agenthub/codeact_agent/codeact_agent.py @@ -194,9 +194,12 @@ class CodeActAgent(Agent): # event we'll just return that instead of an action. The controller will # immediately ask the agent to step again with the new view. condensed_history: list[Event] = [] + # Track which event IDs have been forgotten/condensed + forgotten_event_ids: set[int] = set() match self.condenser.condensed_history(state): - case View(events=events): + case View(events=events, forgotten_event_ids=forgotten_ids): condensed_history = events + forgotten_event_ids = forgotten_ids case Condensation(action=condensation_action): return condensation_action @@ -206,7 +209,9 @@ class CodeActAgent(Agent): ) initial_user_message = self._get_initial_user_message(state.history) - messages = self._get_messages(condensed_history, initial_user_message) + messages = self._get_messages( + condensed_history, initial_user_message, forgotten_event_ids + ) params: dict = { 'messages': messages, } @@ -245,7 +250,10 @@ class CodeActAgent(Agent): return initial_user_message def _get_messages( - self, events: list[Event], initial_user_message: MessageAction + self, + events: list[Event], + initial_user_message: MessageAction, + forgotten_event_ids: set[int], ) -> list[Message]: """Constructs the message history for the LLM conversation. @@ -284,6 +292,7 @@ class CodeActAgent(Agent): messages = self.conversation_memory.process_events( condensed_history=events, initial_user_action=initial_user_message, + forgotten_event_ids=forgotten_event_ids, max_message_chars=self.llm.config.max_message_chars, vision_is_active=self.llm.vision_is_active(), ) diff --git a/openhands/memory/conversation_memory.py b/openhands/memory/conversation_memory.py index 5ff6ec7e58..5ae1a2cd71 100644 --- a/openhands/memory/conversation_memory.py +++ b/openhands/memory/conversation_memory.py @@ -76,6 +76,7 @@ class ConversationMemory: self, condensed_history: list[Event], initial_user_action: MessageAction, + forgotten_event_ids: set[int] | None = None, max_message_chars: int | None = None, vision_is_active: bool = False, ) -> list[Message]: @@ -85,16 +86,23 @@ class ConversationMemory: Args: condensed_history: The condensed history of events to convert + initial_user_action: The initial user message action, if available. Used to ensure the conversation starts correctly. + forgotten_event_ids: Set of event IDs that have been forgotten/condensed. If the initial user action's ID + is in this set, it will not be re-inserted to prevent re-execution of old instructions. max_message_chars: The maximum number of characters in the content of an event included in the prompt to the LLM. Larger observations are truncated. vision_is_active: Whether vision is active in the LLM. If True, image URLs will be included. - initial_user_action: The initial user message action, if available. Used to ensure the conversation starts correctly. """ events = condensed_history + # Default to empty set if not provided + if forgotten_event_ids is None: + forgotten_event_ids = set() # Ensure the event list starts with SystemMessageAction, then MessageAction(source='user') self._ensure_system_message(events) - self._ensure_initial_user_message(events, initial_user_action) + self._ensure_initial_user_message( + events, initial_user_action, forgotten_event_ids + ) # log visual browsing status logger.debug(f'Visual browsing: {self.agent_config.enable_som_visual_browsing}') @@ -827,9 +835,23 @@ class ConversationMemory: ) def _ensure_initial_user_message( - self, events: list[Event], initial_user_action: MessageAction + self, + events: list[Event], + initial_user_action: MessageAction, + forgotten_event_ids: set[int], ) -> None: - """Checks if the second event is a user MessageAction and inserts the provided one if needed.""" + """Checks if the second event is a user MessageAction and inserts the provided one if needed. + + IMPORTANT: If the initial user action has been condensed (its ID is in forgotten_event_ids), + we do NOT re-insert it. This prevents old instructions from being re-executed after + conversation condensation. The condensation summary already contains the context of + what was requested and completed. + + Args: + events: The list of events to modify in-place + initial_user_action: The initial user message action from the full history + forgotten_event_ids: Set of event IDs that have been forgotten/condensed + """ if ( not events ): # Should have system message from previous step, but safety check @@ -837,6 +859,17 @@ class ConversationMemory: # Or raise? Let's log for now, _ensure_system_message should handle this. return + # Check if the initial user action has been condensed/forgotten. + # If so, we should NOT re-insert it to prevent re-execution of old instructions. + # The condensation summary already contains the context of what was requested. + initial_user_action_id = initial_user_action.id + if initial_user_action_id in forgotten_event_ids: + logger.info( + f'Initial user action (id={initial_user_action_id}) has been condensed. ' + 'Not re-inserting to prevent re-execution of old instructions.' + ) + return + # We expect events[0] to be SystemMessageAction after _ensure_system_message if len(events) == 1: # Only system message exists diff --git a/openhands/memory/view.py b/openhands/memory/view.py index 87a20b6340..81dd8bab5d 100644 --- a/openhands/memory/view.py +++ b/openhands/memory/view.py @@ -18,6 +18,8 @@ class View(BaseModel): events: list[Event] unhandled_condensation_request: bool = False + # Set of event IDs that have been forgotten/condensed + forgotten_event_ids: set[int] = set() def __len__(self) -> int: return len(self.events) @@ -90,4 +92,5 @@ class View(BaseModel): return View( events=kept_events, unhandled_condensation_request=unhandled_condensation_request, + forgotten_event_ids=forgotten_event_ids, ) diff --git a/tests/unit/agenthub/test_agents.py b/tests/unit/agenthub/test_agents.py index 2a90dcb668..09f28e991c 100644 --- a/tests/unit/agenthub/test_agents.py +++ b/tests/unit/agenthub/test_agents.py @@ -393,7 +393,7 @@ def test_mismatched_tool_call_events_and_auto_add_system_message( # 2. The action message # 3. The observation message mock_state.history = [initial_user_message, action, observation] - messages = agent._get_messages(mock_state.history, initial_user_message) + messages = agent._get_messages(mock_state.history, initial_user_message, set()) assert len(messages) == 4 # System + initial user + action + observation assert messages[0].role == 'system' # First message should be the system message assert ( @@ -404,7 +404,7 @@ def test_mismatched_tool_call_events_and_auto_add_system_message( # The same should hold if the events are presented out-of-order mock_state.history = [initial_user_message, observation, action] - messages = agent._get_messages(mock_state.history, initial_user_message) + messages = agent._get_messages(mock_state.history, initial_user_message, set()) assert len(messages) == 4 assert messages[0].role == 'system' # First message should be the system message assert ( @@ -414,7 +414,7 @@ def test_mismatched_tool_call_events_and_auto_add_system_message( # If only one of the two events is present, then we should just get the system message # plus any valid message from the event mock_state.history = [initial_user_message, action] - messages = agent._get_messages(mock_state.history, initial_user_message) + messages = agent._get_messages(mock_state.history, initial_user_message, set()) assert ( len(messages) == 2 ) # System + initial user message, action is waiting for its observation @@ -422,7 +422,7 @@ def test_mismatched_tool_call_events_and_auto_add_system_message( assert messages[1].role == 'user' mock_state.history = [initial_user_message, observation] - messages = agent._get_messages(mock_state.history, initial_user_message) + messages = agent._get_messages(mock_state.history, initial_user_message, set()) assert ( len(messages) == 2 ) # System + initial user message, observation has no matching action diff --git a/tests/unit/agenthub/test_prompt_caching.py b/tests/unit/agenthub/test_prompt_caching.py index 60cc0bb16f..2435b1320a 100644 --- a/tests/unit/agenthub/test_prompt_caching.py +++ b/tests/unit/agenthub/test_prompt_caching.py @@ -80,7 +80,7 @@ def test_get_messages(codeact_agent: CodeActAgent): history.append(message_action_5) codeact_agent.reset() - messages = codeact_agent._get_messages(history, message_action_1) + messages = codeact_agent._get_messages(history, message_action_1, set()) assert ( len(messages) == 6 @@ -122,7 +122,7 @@ def test_get_messages_prompt_caching(codeact_agent: CodeActAgent): history.append(message_action_agent) codeact_agent.reset() - messages = codeact_agent._get_messages(history, initial_user_message) + messages = codeact_agent._get_messages(history, initial_user_message, set()) # Check that only the last two user messages have cache_prompt=True cached_user_messages = [ diff --git a/tests/unit/memory/test_conversation_memory.py b/tests/unit/memory/test_conversation_memory.py index abaa8d9a3d..50fd48f49a 100644 --- a/tests/unit/memory/test_conversation_memory.py +++ b/tests/unit/memory/test_conversation_memory.py @@ -158,7 +158,8 @@ def test_ensure_initial_user_message_adds_if_only_system( system_message = SystemMessageAction(content='System') system_message._source = EventSource.AGENT events = [system_message] - conversation_memory._ensure_initial_user_message(events, initial_user_action) + # Pass empty set for forgotten_event_ids (no events have been condensed) + conversation_memory._ensure_initial_user_message(events, initial_user_action, set()) assert len(events) == 2 assert events[0] == system_message assert events[1] == initial_user_action @@ -177,7 +178,8 @@ def test_ensure_initial_user_message_correct_already_present( agent_message, ] original_events = list(events) - conversation_memory._ensure_initial_user_message(events, initial_user_action) + # Pass empty set for forgotten_event_ids (no events have been condensed) + conversation_memory._ensure_initial_user_message(events, initial_user_action, set()) assert events == original_events @@ -189,7 +191,8 @@ def test_ensure_initial_user_message_incorrect_at_index_1( incorrect_second_message = MessageAction(content='Assistant') incorrect_second_message._source = EventSource.AGENT events = [system_message, incorrect_second_message] - conversation_memory._ensure_initial_user_message(events, initial_user_action) + # Pass empty set for forgotten_event_ids (no events have been condensed) + conversation_memory._ensure_initial_user_message(events, initial_user_action, set()) assert len(events) == 3 assert events[0] == system_message assert events[1] == initial_user_action # Correct one inserted @@ -206,7 +209,8 @@ def test_ensure_initial_user_message_correct_present_later( # Correct initial message is present, but later in the list events = [system_message, incorrect_second_message] conversation_memory._ensure_system_message(events) - conversation_memory._ensure_initial_user_message(events, initial_user_action) + # Pass empty set for forgotten_event_ids (no events have been condensed) + conversation_memory._ensure_initial_user_message(events, initial_user_action, set()) assert len(events) == 3 # Should still insert at index 1, not remove the later one assert events[0] == system_message assert events[1] == initial_user_action # Correct one inserted at index 1 @@ -222,7 +226,8 @@ def test_ensure_initial_user_message_different_user_msg_at_index_1( different_user_message = MessageAction(content='Different User Message') different_user_message._source = EventSource.USER events = [system_message, different_user_message] - conversation_memory._ensure_initial_user_message(events, initial_user_action) + # Pass empty set for forgotten_event_ids (no events have been condensed) + conversation_memory._ensure_initial_user_message(events, initial_user_action, set()) assert len(events) == 2 assert events[0] == system_message assert events[1] == different_user_message # Original second message remains @@ -1583,3 +1588,132 @@ def test_process_ipython_observation_with_vision_disabled( assert isinstance(message.content[1], ImageContent) # Check that NO explanatory text about filtered images was added when vision is disabled assert 'invalid or empty image(s) were filtered' not in message.content[0].text + + +def test_ensure_initial_user_message_not_reinserted_when_condensed( + conversation_memory, initial_user_action +): + """Test that initial user message is NOT re-inserted when it has been condensed. + + This is a critical test for bug #11910: Old instructions should not be re-executed + after conversation condensation. If the initial user message has been condensed + (its ID is in the forgotten_event_ids set), we should NOT re-insert it to prevent + the LLM from seeing old instructions as fresh commands. + """ + system_message = SystemMessageAction(content='System') + system_message._source = EventSource.AGENT + + # Simulate that the initial_user_action has been condensed by adding its ID + # to the forgotten_event_ids set + initial_user_action._id = 1 # Assign an ID to the initial user action + forgotten_event_ids = {1} # The initial user action's ID is in the forgotten set + + events = [system_message] # Only system message, no user message + + # Call _ensure_initial_user_message with the condensed event ID + conversation_memory._ensure_initial_user_message( + events, initial_user_action, forgotten_event_ids + ) + + # The initial user action should NOT be inserted because it was condensed + assert len(events) == 1 + assert events[0] == system_message + # Verify the initial user action was NOT added + assert initial_user_action not in events + + +def test_ensure_initial_user_message_reinserted_when_not_condensed( + conversation_memory, initial_user_action +): + """Test that initial user message IS re-inserted when it has NOT been condensed. + + This ensures backward compatibility: when no condensation has happened, + the initial user message should still be inserted as before. + """ + system_message = SystemMessageAction(content='System') + system_message._source = EventSource.AGENT + + # The initial user action has NOT been condensed + initial_user_action._id = 1 + forgotten_event_ids = {5, 10, 15} # Different IDs, not including the initial action + + events = [system_message] + + # Call _ensure_initial_user_message with non-matching forgotten IDs + conversation_memory._ensure_initial_user_message( + events, initial_user_action, forgotten_event_ids + ) + + # The initial user action SHOULD be inserted because it was NOT condensed + assert len(events) == 2 + assert events[0] == system_message + assert events[1] == initial_user_action + + +def test_process_events_does_not_reinsert_condensed_initial_message( + conversation_memory, +): + """Test that process_events does not re-insert initial user message when condensed. + + This is an integration test for the full process_events flow, verifying that + when the initial user message has been condensed, it is not re-inserted into + the conversation sent to the LLM. + """ + # Create a system message + system_message = SystemMessageAction(content='System message') + system_message._source = EventSource.AGENT + system_message._id = 0 + + # Create the initial user message (will be marked as condensed) + initial_user_message = MessageAction(content='Do task A, B, and C') + initial_user_message._source = EventSource.USER + initial_user_message._id = 1 + + # Create a condensation summary observation + from openhands.events.observation.agent import AgentCondensationObservation + + condensation_summary = AgentCondensationObservation( + content='Summary: User requested tasks A, B, C. Task A was completed successfully.' + ) + condensation_summary._id = 2 + + # Create a recent user message (not condensed) + recent_user_message = MessageAction(content='Now continue with task D') + recent_user_message._source = EventSource.USER + recent_user_message._id = 3 + + # Simulate condensed history: system + summary + recent message + # The initial user message (id=1) has been condensed/forgotten + condensed_history = [system_message, condensation_summary, recent_user_message] + + # The initial user message's ID is in the forgotten set + forgotten_event_ids = {1} + + messages = conversation_memory.process_events( + condensed_history=condensed_history, + initial_user_action=initial_user_message, + forgotten_event_ids=forgotten_event_ids, + max_message_chars=None, + vision_is_active=False, + ) + + # Verify the structure of messages + # Should have: system, condensation summary, recent user message + # Should NOT have the initial user message "Do task A, B, and C" + assert len(messages) == 3 + assert messages[0].role == 'system' + assert messages[0].content[0].text == 'System message' + + # The second message should be the condensation summary, NOT the initial user message + assert messages[1].role == 'user' + assert 'Summary: User requested tasks A, B, C' in messages[1].content[0].text + + # The third message should be the recent user message + assert messages[2].role == 'user' + assert 'Now continue with task D' in messages[2].content[0].text + + # Critically, the old instruction should NOT appear + for msg in messages: + for content in msg.content: + if hasattr(content, 'text'): + assert 'Do task A, B, and C' not in content.text From 2c83e419dc1c3a77a1da7a328cd850e9f7ee0e0c Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 17 Dec 2025 19:16:54 +0400 Subject: [PATCH 05/14] chore(deps): bump the version-all group across 1 directory with 5 updates (#12071) --- frontend/package-lock.json | 159 +++++++++++++++++-------------------- frontend/package.json | 8 +- 2 files changed, 75 insertions(+), 92 deletions(-) diff --git a/frontend/package-lock.json b/frontend/package-lock.json index 08011449b8..e130cad40f 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -30,7 +30,7 @@ "isbot": "^5.1.32", "lucide-react": "^0.561.0", "monaco-editor": "^0.55.1", - "posthog-js": "^1.306.1", + "posthog-js": "^1.309.0", "react": "^19.2.3", "react-dom": "^19.2.3", "react-hot-toast": "^2.6.0", @@ -58,13 +58,13 @@ "@testing-library/jest-dom": "^6.9.1", "@testing-library/react": "^16.3.1", "@testing-library/user-event": "^14.6.1", - "@types/node": "^25.0.2", + "@types/node": "^25.0.3", "@types/react": "^19.2.7", "@types/react-dom": "^19.2.3", "@types/react-syntax-highlighter": "^15.5.13", "@typescript-eslint/eslint-plugin": "^7.18.0", "@typescript-eslint/parser": "^7.18.0", - "@vitest/coverage-v8": "^4.0.14", + "@vitest/coverage-v8": "^4.0.16", "cross-env": "^10.1.0", "eslint": "^8.57.0", "eslint-config-airbnb": "^19.0.4", @@ -85,7 +85,7 @@ "tailwindcss": "^4.1.8", "typescript": "^5.9.3", "vite-plugin-svgr": "^4.5.0", - "vite-tsconfig-paths": "^6.0.1", + "vite-tsconfig-paths": "^6.0.2", "vitest": "^4.0.14" }, "engines": { @@ -3192,10 +3192,9 @@ "license": "MIT" }, "node_modules/@posthog/core": { - "version": "1.7.1", - "resolved": "https://registry.npmjs.org/@posthog/core/-/core-1.7.1.tgz", - "integrity": "sha512-kjK0eFMIpKo9GXIbts8VtAknsoZ18oZorANdtuTj1CbgS28t4ZVq//HAWhnxEuXRTrtkd+SUJ6Ux3j2Af8NCuA==", - "license": "MIT", + "version": "1.8.0", + "resolved": "https://registry.npmjs.org/@posthog/core/-/core-1.8.0.tgz", + "integrity": "sha512-SfmG1EdbR+2zpQccgBUxM/snCROB9WGkY7VH1r9iaoTNqoaN9IkmIEA/07cZLY4DxVP8jt6Vdfe3s84xksac1g==", "dependencies": { "cross-spawn": "^7.0.6" } @@ -4949,11 +4948,10 @@ "license": "MIT" }, "node_modules/@standard-schema/spec": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/@standard-schema/spec/-/spec-1.0.0.tgz", - "integrity": "sha512-m2bOd0f2RT9k8QJx1JN85cZYyH1RqFBdlwtkSlf4tBDYLCiiZnv1fIIwacK6cqwXavOydf0NPToMQgpKq+dVlA==", - "dev": true, - "license": "MIT" + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/@standard-schema/spec/-/spec-1.1.0.tgz", + "integrity": "sha512-l2aFy5jALhniG5HgqrD6jXLi/rUWrKvqN/qJx6yoJsgKhblVd+iqqU4RCXavm/jPityDo5TCvKMnpjKnOriy0w==", + "dev": true }, "node_modules/@svgr/babel-plugin-add-jsx-attribute": { "version": "8.0.0", @@ -5684,7 +5682,6 @@ "resolved": "https://registry.npmjs.org/@types/chai/-/chai-5.2.3.tgz", "integrity": "sha512-Mw558oeA9fFbv65/y4mHtXDs9bPnFMZAL/jxdPFUpOHHIXX91mcgEHbS5Lahr+pwZFR8A7GQleRWeI6cGFC2UA==", "dev": true, - "license": "MIT", "dependencies": { "@types/deep-eql": "*", "assertion-error": "^2.0.1" @@ -5703,8 +5700,7 @@ "version": "4.0.2", "resolved": "https://registry.npmjs.org/@types/deep-eql/-/deep-eql-4.0.2.tgz", "integrity": "sha512-c9h9dVVMigMPc4bwTvC5dxqtqJZwQPePsWjPlpSOnojbor6pGqdk541lfA7AqFQr5pB1BRdq0juY9db81BwyFw==", - "dev": true, - "license": "MIT" + "dev": true }, "node_modules/@types/estree": { "version": "1.0.8", @@ -5759,9 +5755,9 @@ "license": "MIT" }, "node_modules/@types/node": { - "version": "25.0.2", - "resolved": "https://registry.npmjs.org/@types/node/-/node-25.0.2.tgz", - "integrity": "sha512-gWEkeiyYE4vqjON/+Obqcoeffmk0NF15WSBwSs7zwVA2bAbTaE0SJ7P0WNGoJn8uE7fiaV5a7dKYIJriEqOrmA==", + "version": "25.0.3", + "resolved": "https://registry.npmjs.org/@types/node/-/node-25.0.3.tgz", + "integrity": "sha512-W609buLVRVmeW693xKfzHeIV6nJGGz98uCPfeXI1ELMLXVeKYZ9m15fAMSaUPBHYLGFsVRcMmSCksQOrZV9BYA==", "devOptional": true, "dependencies": { "undici-types": "~7.16.0" @@ -6239,14 +6235,13 @@ "license": "ISC" }, "node_modules/@vitest/coverage-v8": { - "version": "4.0.15", - "resolved": "https://registry.npmjs.org/@vitest/coverage-v8/-/coverage-v8-4.0.15.tgz", - "integrity": "sha512-FUJ+1RkpTFW7rQITdgTi93qOCWJobWhBirEPCeXh2SW2wsTlFxy51apDz5gzG+ZEYt/THvWeNmhdAoS9DTwpCw==", + "version": "4.0.16", + "resolved": "https://registry.npmjs.org/@vitest/coverage-v8/-/coverage-v8-4.0.16.tgz", + "integrity": "sha512-2rNdjEIsPRzsdu6/9Eq0AYAzYdpP6Bx9cje9tL3FE5XzXRQF1fNU9pe/1yE8fCrS0HD+fBtt6gLPh6LI57tX7A==", "dev": true, - "license": "MIT", "dependencies": { "@bcoe/v8-coverage": "^1.0.2", - "@vitest/utils": "4.0.15", + "@vitest/utils": "4.0.16", "ast-v8-to-istanbul": "^0.3.8", "istanbul-lib-coverage": "^3.2.2", "istanbul-lib-report": "^3.0.1", @@ -6261,8 +6256,8 @@ "url": "https://opencollective.com/vitest" }, "peerDependencies": { - "@vitest/browser": "4.0.15", - "vitest": "4.0.15" + "@vitest/browser": "4.0.16", + "vitest": "4.0.16" }, "peerDependenciesMeta": { "@vitest/browser": { @@ -6271,16 +6266,15 @@ } }, "node_modules/@vitest/expect": { - "version": "4.0.15", - "resolved": "https://registry.npmjs.org/@vitest/expect/-/expect-4.0.15.tgz", - "integrity": "sha512-Gfyva9/GxPAWXIWjyGDli9O+waHDC0Q0jaLdFP1qPAUUfo1FEXPXUfUkp3eZA0sSq340vPycSyOlYUeM15Ft1w==", + "version": "4.0.16", + "resolved": "https://registry.npmjs.org/@vitest/expect/-/expect-4.0.16.tgz", + "integrity": "sha512-eshqULT2It7McaJkQGLkPjPjNph+uevROGuIMJdG3V+0BSR2w9u6J9Lwu+E8cK5TETlfou8GRijhafIMhXsimA==", "dev": true, - "license": "MIT", "dependencies": { "@standard-schema/spec": "^1.0.0", "@types/chai": "^5.2.2", - "@vitest/spy": "4.0.15", - "@vitest/utils": "4.0.15", + "@vitest/spy": "4.0.16", + "@vitest/utils": "4.0.16", "chai": "^6.2.1", "tinyrainbow": "^3.0.3" }, @@ -6289,13 +6283,12 @@ } }, "node_modules/@vitest/mocker": { - "version": "4.0.15", - "resolved": "https://registry.npmjs.org/@vitest/mocker/-/mocker-4.0.15.tgz", - "integrity": "sha512-CZ28GLfOEIFkvCFngN8Sfx5h+Se0zN+h4B7yOsPVCcgtiO7t5jt9xQh2E1UkFep+eb9fjyMfuC5gBypwb07fvQ==", + "version": "4.0.16", + "resolved": "https://registry.npmjs.org/@vitest/mocker/-/mocker-4.0.16.tgz", + "integrity": "sha512-yb6k4AZxJTB+q9ycAvsoxGn+j/po0UaPgajllBgt1PzoMAAmJGYFdDk0uCcRcxb3BrME34I6u8gHZTQlkqSZpg==", "dev": true, - "license": "MIT", "dependencies": { - "@vitest/spy": "4.0.15", + "@vitest/spy": "4.0.16", "estree-walker": "^3.0.3", "magic-string": "^0.30.21" }, @@ -6316,11 +6309,10 @@ } }, "node_modules/@vitest/pretty-format": { - "version": "4.0.15", - "resolved": "https://registry.npmjs.org/@vitest/pretty-format/-/pretty-format-4.0.15.tgz", - "integrity": "sha512-SWdqR8vEv83WtZcrfLNqlqeQXlQLh2iilO1Wk1gv4eiHKjEzvgHb2OVc3mIPyhZE6F+CtfYjNlDJwP5MN6Km7A==", + "version": "4.0.16", + "resolved": "https://registry.npmjs.org/@vitest/pretty-format/-/pretty-format-4.0.16.tgz", + "integrity": "sha512-eNCYNsSty9xJKi/UdVD8Ou16alu7AYiS2fCPRs0b1OdhJiV89buAXQLpTbe+X8V9L6qrs9CqyvU7OaAopJYPsA==", "dev": true, - "license": "MIT", "dependencies": { "tinyrainbow": "^3.0.3" }, @@ -6329,13 +6321,12 @@ } }, "node_modules/@vitest/runner": { - "version": "4.0.15", - "resolved": "https://registry.npmjs.org/@vitest/runner/-/runner-4.0.15.tgz", - "integrity": "sha512-+A+yMY8dGixUhHmNdPUxOh0la6uVzun86vAbuMT3hIDxMrAOmn5ILBHm8ajrqHE0t8R9T1dGnde1A5DTnmi3qw==", + "version": "4.0.16", + "resolved": "https://registry.npmjs.org/@vitest/runner/-/runner-4.0.16.tgz", + "integrity": "sha512-VWEDm5Wv9xEo80ctjORcTQRJ539EGPB3Pb9ApvVRAY1U/WkHXmmYISqU5E79uCwcW7xYUV38gwZD+RV755fu3Q==", "dev": true, - "license": "MIT", "dependencies": { - "@vitest/utils": "4.0.15", + "@vitest/utils": "4.0.16", "pathe": "^2.0.3" }, "funding": { @@ -6346,17 +6337,15 @@ "version": "2.0.3", "resolved": "https://registry.npmjs.org/pathe/-/pathe-2.0.3.tgz", "integrity": "sha512-WUjGcAqP1gQacoQe+OBJsFA7Ld4DyXuUIjZ5cc75cLHvJ7dtNsTugphxIADwspS+AraAUePCKrSVtPLFj/F88w==", - "dev": true, - "license": "MIT" + "dev": true }, "node_modules/@vitest/snapshot": { - "version": "4.0.15", - "resolved": "https://registry.npmjs.org/@vitest/snapshot/-/snapshot-4.0.15.tgz", - "integrity": "sha512-A7Ob8EdFZJIBjLjeO0DZF4lqR6U7Ydi5/5LIZ0xcI+23lYlsYJAfGn8PrIWTYdZQRNnSRlzhg0zyGu37mVdy5g==", + "version": "4.0.16", + "resolved": "https://registry.npmjs.org/@vitest/snapshot/-/snapshot-4.0.16.tgz", + "integrity": "sha512-sf6NcrYhYBsSYefxnry+DR8n3UV4xWZwWxYbCJUt2YdvtqzSPR7VfGrY0zsv090DAbjFZsi7ZaMi1KnSRyK1XA==", "dev": true, - "license": "MIT", "dependencies": { - "@vitest/pretty-format": "4.0.15", + "@vitest/pretty-format": "4.0.16", "magic-string": "^0.30.21", "pathe": "^2.0.3" }, @@ -6368,27 +6357,24 @@ "version": "2.0.3", "resolved": "https://registry.npmjs.org/pathe/-/pathe-2.0.3.tgz", "integrity": "sha512-WUjGcAqP1gQacoQe+OBJsFA7Ld4DyXuUIjZ5cc75cLHvJ7dtNsTugphxIADwspS+AraAUePCKrSVtPLFj/F88w==", - "dev": true, - "license": "MIT" + "dev": true }, "node_modules/@vitest/spy": { - "version": "4.0.15", - "resolved": "https://registry.npmjs.org/@vitest/spy/-/spy-4.0.15.tgz", - "integrity": "sha512-+EIjOJmnY6mIfdXtE/bnozKEvTC4Uczg19yeZ2vtCz5Yyb0QQ31QWVQ8hswJ3Ysx/K2EqaNsVanjr//2+P3FHw==", + "version": "4.0.16", + "resolved": "https://registry.npmjs.org/@vitest/spy/-/spy-4.0.16.tgz", + "integrity": "sha512-4jIOWjKP0ZUaEmJm00E0cOBLU+5WE0BpeNr3XN6TEF05ltro6NJqHWxXD0kA8/Zc8Nh23AT8WQxwNG+WeROupw==", "dev": true, - "license": "MIT", "funding": { "url": "https://opencollective.com/vitest" } }, "node_modules/@vitest/utils": { - "version": "4.0.15", - "resolved": "https://registry.npmjs.org/@vitest/utils/-/utils-4.0.15.tgz", - "integrity": "sha512-HXjPW2w5dxhTD0dLwtYHDnelK3j8sR8cWIaLxr22evTyY6q8pRCjZSmhRWVjBaOVXChQd6AwMzi9pucorXCPZA==", + "version": "4.0.16", + "resolved": "https://registry.npmjs.org/@vitest/utils/-/utils-4.0.16.tgz", + "integrity": "sha512-h8z9yYhV3e1LEfaQ3zdypIrnAg/9hguReGZoS7Gl0aBG5xgA410zBqECqmaF/+RkTggRsfnzc1XaAHA6bmUufA==", "dev": true, - "license": "MIT", "dependencies": { - "@vitest/pretty-format": "4.0.15", + "@vitest/pretty-format": "4.0.16", "tinyrainbow": "^3.0.3" }, "funding": { @@ -6729,7 +6715,6 @@ "resolved": "https://registry.npmjs.org/assertion-error/-/assertion-error-2.0.1.tgz", "integrity": "sha512-Izi8RQcffqCeNVgFigKli1ssklIbpHnCYc6AknXGYoB6grJqyeby7jv12JUQgmTAnIDnbck1uxksT4dzN3PWBA==", "dev": true, - "license": "MIT", "engines": { "node": ">=12" } @@ -7125,7 +7110,6 @@ "resolved": "https://registry.npmjs.org/chai/-/chai-6.2.1.tgz", "integrity": "sha512-p4Z49OGG5W/WBCPSS/dH3jQ73kD6tiMmUM+bckNK6Jr5JHMG3k9bg/BvKR8lKmtVBKmOiuVaV2ws8s9oSbwysg==", "dev": true, - "license": "MIT", "engines": { "node": ">=18" } @@ -13419,11 +13403,11 @@ } }, "node_modules/posthog-js": { - "version": "1.306.1", - "resolved": "https://registry.npmjs.org/posthog-js/-/posthog-js-1.306.1.tgz", - "integrity": "sha512-wO7bliv/5tlAlfoKCUzwkGXZVNexk0dHigMf9tNp0q1rzs62wThogREY7Tz7h/iWKYiuXy1RumtVlTmHuBXa1w==", + "version": "1.309.0", + "resolved": "https://registry.npmjs.org/posthog-js/-/posthog-js-1.309.0.tgz", + "integrity": "sha512-SmFF0uKX3tNTgQOW4mR4shGLQ0YYG0FXyKTz13SbIH83/FtAJedppOIL7s0y9e7rjogBh6LsPekphhchs9Kh1Q==", "dependencies": { - "@posthog/core": "1.7.1", + "@posthog/core": "1.8.0", "core-js": "^3.38.1", "fflate": "^0.4.8", "preact": "^10.19.3", @@ -16127,9 +16111,9 @@ } }, "node_modules/vite-tsconfig-paths": { - "version": "6.0.1", - "resolved": "https://registry.npmjs.org/vite-tsconfig-paths/-/vite-tsconfig-paths-6.0.1.tgz", - "integrity": "sha512-OQuYkfCQhc2T+n//0N7/oogTosgiSyZQ7dydrpUlH5SbnFTtplpekdY4GMi6jDwEpiwWlqeUJMyPfC2ePM1+2A==", + "version": "6.0.2", + "resolved": "https://registry.npmjs.org/vite-tsconfig-paths/-/vite-tsconfig-paths-6.0.2.tgz", + "integrity": "sha512-c06LOO8fWB5RuEPpEIHXU9t7Dt4DoiPIljnKws9UygIaQo6PoFKawTftz5/QVcO+6pOs/HHWycnq4UrZkWVYnQ==", "dev": true, "dependencies": { "debug": "^4.1.1", @@ -16189,19 +16173,18 @@ } }, "node_modules/vitest": { - "version": "4.0.15", - "resolved": "https://registry.npmjs.org/vitest/-/vitest-4.0.15.tgz", - "integrity": "sha512-n1RxDp8UJm6N0IbJLQo+yzLZ2sQCDyl1o0LeugbPWf8+8Fttp29GghsQBjYJVmWq3gBFfe9Hs1spR44vovn2wA==", + "version": "4.0.16", + "resolved": "https://registry.npmjs.org/vitest/-/vitest-4.0.16.tgz", + "integrity": "sha512-E4t7DJ9pESL6E3I8nFjPa4xGUd3PmiWDLsDztS2qXSJWfHtbQnwAWylaBvSNY48I3vr8PTqIZlyK8TE3V3CA4Q==", "dev": true, - "license": "MIT", "dependencies": { - "@vitest/expect": "4.0.15", - "@vitest/mocker": "4.0.15", - "@vitest/pretty-format": "4.0.15", - "@vitest/runner": "4.0.15", - "@vitest/snapshot": "4.0.15", - "@vitest/spy": "4.0.15", - "@vitest/utils": "4.0.15", + "@vitest/expect": "4.0.16", + "@vitest/mocker": "4.0.16", + "@vitest/pretty-format": "4.0.16", + "@vitest/runner": "4.0.16", + "@vitest/snapshot": "4.0.16", + "@vitest/spy": "4.0.16", + "@vitest/utils": "4.0.16", "es-module-lexer": "^1.7.0", "expect-type": "^1.2.2", "magic-string": "^0.30.21", @@ -16229,10 +16212,10 @@ "@edge-runtime/vm": "*", "@opentelemetry/api": "^1.9.0", "@types/node": "^20.0.0 || ^22.0.0 || >=24.0.0", - "@vitest/browser-playwright": "4.0.15", - "@vitest/browser-preview": "4.0.15", - "@vitest/browser-webdriverio": "4.0.15", - "@vitest/ui": "4.0.15", + "@vitest/browser-playwright": "4.0.16", + "@vitest/browser-preview": "4.0.16", + "@vitest/browser-webdriverio": "4.0.16", + "@vitest/ui": "4.0.16", "happy-dom": "*", "jsdom": "*" }, diff --git a/frontend/package.json b/frontend/package.json index 7dc0c5bcfb..90636fed77 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -29,7 +29,7 @@ "isbot": "^5.1.32", "lucide-react": "^0.561.0", "monaco-editor": "^0.55.1", - "posthog-js": "^1.306.1", + "posthog-js": "^1.309.0", "react": "^19.2.3", "react-dom": "^19.2.3", "react-hot-toast": "^2.6.0", @@ -89,13 +89,13 @@ "@testing-library/jest-dom": "^6.9.1", "@testing-library/react": "^16.3.1", "@testing-library/user-event": "^14.6.1", - "@types/node": "^25.0.2", + "@types/node": "^25.0.3", "@types/react": "^19.2.7", "@types/react-dom": "^19.2.3", "@types/react-syntax-highlighter": "^15.5.13", "@typescript-eslint/eslint-plugin": "^7.18.0", "@typescript-eslint/parser": "^7.18.0", - "@vitest/coverage-v8": "^4.0.14", + "@vitest/coverage-v8": "^4.0.16", "cross-env": "^10.1.0", "eslint": "^8.57.0", "eslint-config-airbnb": "^19.0.4", @@ -116,7 +116,7 @@ "tailwindcss": "^4.1.8", "typescript": "^5.9.3", "vite-plugin-svgr": "^4.5.0", - "vite-tsconfig-paths": "^6.0.1", + "vite-tsconfig-paths": "^6.0.2", "vitest": "^4.0.14" }, "packageManager": "npm@10.5.0", From 060761437299315f67e03bdea2315a34f04047bf Mon Sep 17 00:00:00 2001 From: Hiep Le <69354317+hieptl@users.noreply.github.com> Date: Wed, 17 Dec 2025 22:29:18 +0700 Subject: [PATCH 06/14] feat(frontend): add refresh button to changes tab (#12036) Co-authored-by: Tim O'Farrell --- .../conversation-tab-title.test.tsx | 149 ++++++++++++++++++ .../conversation-tab-content.tsx | 34 +++- .../conversation-tab-title.tsx | 24 ++- .../query/use-unified-get-git-changes.ts | 1 + frontend/src/icons/u-refresh.svg | 3 + 5 files changed, 209 insertions(+), 2 deletions(-) create mode 100644 frontend/__tests__/components/conversation-tab-title.test.tsx create mode 100644 frontend/src/icons/u-refresh.svg diff --git a/frontend/__tests__/components/conversation-tab-title.test.tsx b/frontend/__tests__/components/conversation-tab-title.test.tsx new file mode 100644 index 0000000000..4e3a0aa0fe --- /dev/null +++ b/frontend/__tests__/components/conversation-tab-title.test.tsx @@ -0,0 +1,149 @@ +import { render, screen, waitFor } from "@testing-library/react"; +import userEvent from "@testing-library/user-event"; +import { describe, expect, it, vi, beforeEach, afterEach } from "vitest"; +import { QueryClient, QueryClientProvider } from "@tanstack/react-query"; +import { ConversationTabTitle } from "#/components/features/conversation/conversation-tabs/conversation-tab-title"; +import GitService from "#/api/git-service/git-service.api"; +import V1GitService from "#/api/git-service/v1-git-service.api"; + +// Mock the services that the hook depends on +vi.mock("#/api/git-service/git-service.api"); +vi.mock("#/api/git-service/v1-git-service.api"); + +// Mock the hooks that useUnifiedGetGitChanges depends on +vi.mock("#/hooks/use-conversation-id", () => ({ + useConversationId: () => ({ + conversationId: "test-conversation-id", + }), +})); + +vi.mock("#/hooks/query/use-active-conversation", () => ({ + useActiveConversation: () => ({ + data: { + conversation_version: "V0", + url: null, + session_api_key: null, + selected_repository: null, + }, + }), +})); + +vi.mock("#/hooks/use-runtime-is-ready", () => ({ + useRuntimeIsReady: () => true, +})); + +vi.mock("#/utils/get-git-path", () => ({ + getGitPath: () => "/workspace", +})); + +describe("ConversationTabTitle", () => { + let queryClient: QueryClient; + + beforeEach(() => { + queryClient = new QueryClient({ + defaultOptions: { + queries: { + retry: false, + }, + }, + }); + + // Mock GitService methods + vi.mocked(GitService.getGitChanges).mockResolvedValue([]); + vi.mocked(V1GitService.getGitChanges).mockResolvedValue([]); + }); + + afterEach(() => { + vi.clearAllMocks(); + queryClient.clear(); + }); + + const renderWithProviders = (ui: React.ReactElement) => { + return render( + {ui}, + ); + }; + + describe("Rendering", () => { + it("should render the title", () => { + // Arrange + const title = "Test Title"; + + // Act + renderWithProviders( + , + ); + + // Assert + expect(screen.getByText(title)).toBeInTheDocument(); + }); + + it("should show refresh button when conversationKey is 'editor'", () => { + // Arrange + const title = "Changes"; + + // Act + renderWithProviders( + , + ); + + // Assert + const refreshButton = screen.getByRole("button"); + expect(refreshButton).toBeInTheDocument(); + }); + + it("should not show refresh button when conversationKey is not 'editor'", () => { + // Arrange + const title = "Browser"; + + // Act + renderWithProviders( + , + ); + + // Assert + expect(screen.queryByRole("button")).not.toBeInTheDocument(); + }); + }); + + describe("User Interactions", () => { + it("should call refetch and trigger GitService.getGitChanges when refresh button is clicked", async () => { + // Arrange + const user = userEvent.setup(); + const title = "Changes"; + const mockGitChanges: Array<{ + path: string; + status: "M" | "A" | "D" | "R" | "U"; + }> = [ + { path: "file1.ts", status: "M" }, + { path: "file2.ts", status: "A" }, + ]; + + vi.mocked(GitService.getGitChanges).mockResolvedValue(mockGitChanges); + + renderWithProviders( + , + ); + + const refreshButton = screen.getByRole("button"); + + // Wait for initial query to complete + await waitFor(() => { + expect(GitService.getGitChanges).toHaveBeenCalled(); + }); + + // Clear the mock to track refetch calls + vi.mocked(GitService.getGitChanges).mockClear(); + + // Act + await user.click(refreshButton); + + // Assert - refetch should trigger another service call + await waitFor(() => { + expect(GitService.getGitChanges).toHaveBeenCalledWith( + "test-conversation-id", + ); + }); + }); + }); +}); diff --git a/frontend/src/components/features/conversation/conversation-tabs/conversation-tab-content/conversation-tab-content.tsx b/frontend/src/components/features/conversation/conversation-tabs/conversation-tab-content/conversation-tab-content.tsx index 70b45ea73a..39b68c9033 100644 --- a/frontend/src/components/features/conversation/conversation-tabs/conversation-tab-content/conversation-tab-content.tsx +++ b/frontend/src/components/features/conversation/conversation-tabs/conversation-tab-content/conversation-tab-content.tsx @@ -82,13 +82,45 @@ export function ConversationTabContent() { isPlannerActive, ]); + const conversationKey = useMemo(() => { + if (isEditorActive) { + return "editor"; + } + if (isBrowserActive) { + return "browser"; + } + if (isServedActive) { + return "served"; + } + if (isVSCodeActive) { + return "vscode"; + } + if (isTerminalActive) { + return "terminal"; + } + if (isPlannerActive) { + return "planner"; + } + return ""; + }, [ + isEditorActive, + isBrowserActive, + isServedActive, + isVSCodeActive, + isTerminalActive, + isPlannerActive, + ]); + if (shouldShownAgentLoading) { return ; } return ( - + {tabs.map(({ key, component: Component, isActive }) => ( { + refetch(); + }; + return (
{title} + {conversationKey === "editor" && ( + + )}
); } diff --git a/frontend/src/hooks/query/use-unified-get-git-changes.ts b/frontend/src/hooks/query/use-unified-get-git-changes.ts index ae5600469a..6b0856031c 100644 --- a/frontend/src/hooks/query/use-unified-get-git-changes.ts +++ b/frontend/src/hooks/query/use-unified-get-git-changes.ts @@ -103,5 +103,6 @@ export const useUnifiedGetGitChanges = () => { isSuccess: result.isSuccess, isError: result.isError, error: result.error, + refetch: result.refetch, }; }; diff --git a/frontend/src/icons/u-refresh.svg b/frontend/src/icons/u-refresh.svg new file mode 100644 index 0000000000..9e3a2051d2 --- /dev/null +++ b/frontend/src/icons/u-refresh.svg @@ -0,0 +1,3 @@ + + + From f98e7fbc49e2698f78010d04d1e66ed638aafdc2 Mon Sep 17 00:00:00 2001 From: Hiep Le <69354317+hieptl@users.noreply.github.com> Date: Wed, 17 Dec 2025 22:34:28 +0700 Subject: [PATCH 07/14] fix(frontend): observation events and action events (v1 conversations) (#12066) Co-authored-by: openhands --- enterprise/storage/saas_settings_store.py | 3 +++ .../v1/chat/event-content-helpers/should-render-event.ts | 4 ++++ .../observation-pair-event-message.tsx | 7 ++++++- frontend/src/types/v1/type-guards.ts | 8 +++++++- 4 files changed, 20 insertions(+), 2 deletions(-) diff --git a/enterprise/storage/saas_settings_store.py b/enterprise/storage/saas_settings_store.py index 6cbcb50802..cfcbec7583 100644 --- a/enterprise/storage/saas_settings_store.py +++ b/enterprise/storage/saas_settings_store.py @@ -94,6 +94,9 @@ class SaasSettingsStore(SettingsStore): } self._decrypt_kwargs(kwargs) settings = Settings(**kwargs) + + settings.v1_enabled = True + return settings async def store(self, item: Settings): diff --git a/frontend/src/components/v1/chat/event-content-helpers/should-render-event.ts b/frontend/src/components/v1/chat/event-content-helpers/should-render-event.ts index a5fdc62252..1171c21c92 100644 --- a/frontend/src/components/v1/chat/event-content-helpers/should-render-event.ts +++ b/frontend/src/components/v1/chat/event-content-helpers/should-render-event.ts @@ -18,6 +18,10 @@ export const shouldRenderEvent = (event: OpenHandsEvent) => { // For V1, action is an object with kind property const actionType = event.action.kind; + if (!actionType) { + return false; + } + // Hide user commands from the chat interface if (actionType === "ExecuteBashAction" && event.source === "user") { return false; diff --git a/frontend/src/components/v1/chat/event-message-components/observation-pair-event-message.tsx b/frontend/src/components/v1/chat/event-message-components/observation-pair-event-message.tsx index aa0bbc09b4..221d758dd6 100644 --- a/frontend/src/components/v1/chat/event-message-components/observation-pair-event-message.tsx +++ b/frontend/src/components/v1/chat/event-message-components/observation-pair-event-message.tsx @@ -34,7 +34,12 @@ export function ObservationPairEventMessage({ .map((t) => t.text) .join("\n"); - if (thoughtContent && event.action.kind !== "ThinkAction") { + // Defensive check: ensure action exists and has kind property + if ( + thoughtContent && + event.action?.kind && + event.action.kind !== "ThinkAction" + ) { return (
diff --git a/frontend/src/types/v1/type-guards.ts b/frontend/src/types/v1/type-guards.ts index ee831ea489..dec1816209 100644 --- a/frontend/src/types/v1/type-guards.ts +++ b/frontend/src/types/v1/type-guards.ts @@ -54,7 +54,10 @@ export const isObservationEvent = ( ): event is ObservationEvent => event.source === "environment" && "action_id" in event && - "observation" in event; + "observation" in event && + event.observation !== null && + typeof event.observation === "object" && + "kind" in event.observation; /** * Type guard function to check if an event is an agent error event @@ -94,6 +97,9 @@ export const isUserMessageEvent = ( export const isActionEvent = (event: OpenHandsEvent): event is ActionEvent => event.source === "agent" && "action" in event && + event.action !== null && + typeof event.action === "object" && + "kind" in event.action && "tool_name" in event && "tool_call_id" in event && typeof event.tool_name === "string" && From 9ef11bf9302c4b3edb257e6ecfb0f41f70ad5bec Mon Sep 17 00:00:00 2001 From: Hiep Le <69354317+hieptl@users.noreply.github.com> Date: Wed, 17 Dec 2025 23:25:10 +0700 Subject: [PATCH 08/14] feat: show available skills for v1 conversations (#12039) --- .openhands/microagents/repo.md | 2 +- .../conversation/conversation-name.test.tsx | 24 +- .../microagents/microagent-modal.test.tsx | 91 ---- .../modals/skills/skill-modal.test.tsx | 394 ++++++++++++++ .../v1-conversation-service.api.ts | 13 + .../v1-conversation-service.types.ts | 11 + .../features/controls/tools-context-menu.tsx | 30 +- .../components/features/controls/tools.tsx | 16 +- .../conversation-card-context-menu.tsx | 147 ----- .../conversation-card-context-menu.tsx | 12 +- ...croagent-content.tsx => skill-content.tsx} | 8 +- .../{microagent-item.tsx => skill-item.tsx} | 26 +- ...oagent-triggers.tsx => skill-triggers.tsx} | 6 +- ...empty-state.tsx => skills-empty-state.tsx} | 8 +- ...ing-state.tsx => skills-loading-state.tsx} | 2 +- ...dal-header.tsx => skills-modal-header.tsx} | 10 +- ...microagents-modal.tsx => skills-modal.tsx} | 74 ++- .../conversation-name-context-menu.tsx | 15 +- .../conversation/conversation-name.tsx | 20 +- ...roagents.ts => use-conversation-skills.ts} | 14 +- .../use-conversation-name-context-menu.ts | 17 +- frontend/src/i18n/declaration.ts | 15 +- frontend/src/i18n/translation.json | 172 +++--- .../app_conversation_models.py | 10 + .../app_conversation_router.py | 152 +++++- .../app_conversation_service_base.py | 6 +- .../test_app_conversation_service_base.py | 16 +- .../test_app_conversation_skills_endpoint.py | 503 ++++++++++++++++++ 28 files changed, 1325 insertions(+), 489 deletions(-) delete mode 100644 frontend/__tests__/components/modals/microagents/microagent-modal.test.tsx create mode 100644 frontend/__tests__/components/modals/skills/skill-modal.test.tsx delete mode 100644 frontend/src/components/features/conversation-panel/conversation-card-context-menu.tsx rename frontend/src/components/features/conversation-panel/{microagent-content.tsx => skill-content.tsx} (76%) rename frontend/src/components/features/conversation-panel/{microagent-item.tsx => skill-item.tsx} (65%) rename frontend/src/components/features/conversation-panel/{microagent-triggers.tsx => skill-triggers.tsx} (81%) rename frontend/src/components/features/conversation-panel/{microagents-empty-state.tsx => skills-empty-state.tsx} (63%) rename frontend/src/components/features/conversation-panel/{microagents-loading-state.tsx => skills-loading-state.tsx} (80%) rename frontend/src/components/features/conversation-panel/{microagents-modal-header.tsx => skills-modal-header.tsx} (82%) rename frontend/src/components/features/conversation-panel/{microagents-modal.tsx => skills-modal.tsx} (50%) rename frontend/src/hooks/query/{use-conversation-microagents.ts => use-conversation-skills.ts} (62%) create mode 100644 tests/unit/app_server/test_app_conversation_skills_endpoint.py diff --git a/.openhands/microagents/repo.md b/.openhands/microagents/repo.md index ceb87bc2f7..cd3ef33074 100644 --- a/.openhands/microagents/repo.md +++ b/.openhands/microagents/repo.md @@ -63,7 +63,7 @@ Frontend: - We use TanStack Query (fka React Query) for data fetching and cache management - Data Access Layer: API client methods are located in `frontend/src/api` and should never be called directly from UI components - they must always be wrapped with TanStack Query - Custom hooks are located in `frontend/src/hooks/query/` and `frontend/src/hooks/mutation/` - - Query hooks should follow the pattern use[Resource] (e.g., `useConversationMicroagents`) + - Query hooks should follow the pattern use[Resource] (e.g., `useConversationSkills`) - Mutation hooks should follow the pattern use[Action] (e.g., `useDeleteConversation`) - Architecture rule: UI components → TanStack Query hooks → Data Access Layer (`frontend/src/api`) → API endpoints diff --git a/frontend/__tests__/components/features/conversation/conversation-name.test.tsx b/frontend/__tests__/components/features/conversation/conversation-name.test.tsx index 572ca590b1..41078b69cb 100644 --- a/frontend/__tests__/components/features/conversation/conversation-name.test.tsx +++ b/frontend/__tests__/components/features/conversation/conversation-name.test.tsx @@ -42,7 +42,7 @@ vi.mock("react-i18next", async () => { BUTTON$EXPORT_CONVERSATION: "Export Conversation", BUTTON$DOWNLOAD_VIA_VSCODE: "Download via VS Code", BUTTON$SHOW_AGENT_TOOLS_AND_METADATA: "Show Agent Tools", - CONVERSATION$SHOW_MICROAGENTS: "Show Microagents", + CONVERSATION$SHOW_SKILLS: "Show Skills", BUTTON$DISPLAY_COST: "Display Cost", COMMON$CLOSE_CONVERSATION_STOP_RUNTIME: "Close Conversation (Stop Runtime)", @@ -290,7 +290,7 @@ describe("ConversationNameContextMenu", () => { onStop: vi.fn(), onDisplayCost: vi.fn(), onShowAgentTools: vi.fn(), - onShowMicroagents: vi.fn(), + onShowSkills: vi.fn(), onExportConversation: vi.fn(), onDownloadViaVSCode: vi.fn(), }; @@ -304,7 +304,7 @@ describe("ConversationNameContextMenu", () => { expect(screen.getByTestId("stop-button")).toBeInTheDocument(); expect(screen.getByTestId("display-cost-button")).toBeInTheDocument(); expect(screen.getByTestId("show-agent-tools-button")).toBeInTheDocument(); - expect(screen.getByTestId("show-microagents-button")).toBeInTheDocument(); + expect(screen.getByTestId("show-skills-button")).toBeInTheDocument(); expect( screen.getByTestId("export-conversation-button"), ).toBeInTheDocument(); @@ -321,9 +321,7 @@ describe("ConversationNameContextMenu", () => { expect( screen.queryByTestId("show-agent-tools-button"), ).not.toBeInTheDocument(); - expect( - screen.queryByTestId("show-microagents-button"), - ).not.toBeInTheDocument(); + expect(screen.queryByTestId("show-skills-button")).not.toBeInTheDocument(); expect( screen.queryByTestId("export-conversation-button"), ).not.toBeInTheDocument(); @@ -410,19 +408,19 @@ describe("ConversationNameContextMenu", () => { it("should call show microagents handler when show microagents button is clicked", async () => { const user = userEvent.setup(); - const onShowMicroagents = vi.fn(); + const onShowSkills = vi.fn(); renderWithProviders( , ); - const showMicroagentsButton = screen.getByTestId("show-microagents-button"); + const showMicroagentsButton = screen.getByTestId("show-skills-button"); await user.click(showMicroagentsButton); - expect(onShowMicroagents).toHaveBeenCalledTimes(1); + expect(onShowSkills).toHaveBeenCalledTimes(1); }); it("should call export conversation handler when export conversation button is clicked", async () => { @@ -519,7 +517,7 @@ describe("ConversationNameContextMenu", () => { onStop: vi.fn(), onDisplayCost: vi.fn(), onShowAgentTools: vi.fn(), - onShowMicroagents: vi.fn(), + onShowSkills: vi.fn(), onExportConversation: vi.fn(), onDownloadViaVSCode: vi.fn(), }; @@ -541,8 +539,8 @@ describe("ConversationNameContextMenu", () => { expect(screen.getByTestId("show-agent-tools-button")).toHaveTextContent( "Show Agent Tools", ); - expect(screen.getByTestId("show-microagents-button")).toHaveTextContent( - "Show Microagents", + expect(screen.getByTestId("show-skills-button")).toHaveTextContent( + "Show Skills", ); expect(screen.getByTestId("export-conversation-button")).toHaveTextContent( "Export Conversation", diff --git a/frontend/__tests__/components/modals/microagents/microagent-modal.test.tsx b/frontend/__tests__/components/modals/microagents/microagent-modal.test.tsx deleted file mode 100644 index 858c07207d..0000000000 --- a/frontend/__tests__/components/modals/microagents/microagent-modal.test.tsx +++ /dev/null @@ -1,91 +0,0 @@ -import { screen } from "@testing-library/react"; -import userEvent from "@testing-library/user-event"; -import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; -import { renderWithProviders } from "test-utils"; -import { MicroagentsModal } from "#/components/features/conversation-panel/microagents-modal"; -import ConversationService from "#/api/conversation-service/conversation-service.api"; -import { AgentState } from "#/types/agent-state"; -import { useAgentState } from "#/hooks/use-agent-state"; - -// Mock the agent state hook -vi.mock("#/hooks/use-agent-state", () => ({ - useAgentState: vi.fn(), -})); - -// Mock the conversation ID hook -vi.mock("#/hooks/use-conversation-id", () => ({ - useConversationId: () => ({ conversationId: "test-conversation-id" }), -})); - -describe("MicroagentsModal - Refresh Button", () => { - const mockOnClose = vi.fn(); - const conversationId = "test-conversation-id"; - - const defaultProps = { - onClose: mockOnClose, - conversationId, - }; - - const mockMicroagents = [ - { - name: "Test Agent 1", - type: "repo" as const, - triggers: ["test", "example"], - content: "This is test content for agent 1", - }, - { - name: "Test Agent 2", - type: "knowledge" as const, - triggers: ["help", "support"], - content: "This is test content for agent 2", - }, - ]; - - beforeEach(() => { - // Reset all mocks before each test - vi.clearAllMocks(); - - // Setup default mock for getMicroagents - vi.spyOn(ConversationService, "getMicroagents").mockResolvedValue({ - microagents: mockMicroagents, - }); - - // Mock the agent state to return a ready state - vi.mocked(useAgentState).mockReturnValue({ - curAgentState: AgentState.AWAITING_USER_INPUT, - }); - }); - - afterEach(() => { - vi.restoreAllMocks(); - }); - - describe("Refresh Button Rendering", () => { - it("should render the refresh button with correct text and test ID", async () => { - renderWithProviders(); - - // Wait for the component to load and render the refresh button - const refreshButton = await screen.findByTestId("refresh-microagents"); - expect(refreshButton).toBeInTheDocument(); - expect(refreshButton).toHaveTextContent("BUTTON$REFRESH"); - }); - }); - - describe("Refresh Button Functionality", () => { - it("should call refetch when refresh button is clicked", async () => { - const user = userEvent.setup(); - const refreshSpy = vi.spyOn(ConversationService, "getMicroagents"); - - renderWithProviders(); - - // Wait for the component to load and render the refresh button - const refreshButton = await screen.findByTestId("refresh-microagents"); - - refreshSpy.mockClear(); - - await user.click(refreshButton); - - expect(refreshSpy).toHaveBeenCalledTimes(1); - }); - }); -}); diff --git a/frontend/__tests__/components/modals/skills/skill-modal.test.tsx b/frontend/__tests__/components/modals/skills/skill-modal.test.tsx new file mode 100644 index 0000000000..33ab5098c8 --- /dev/null +++ b/frontend/__tests__/components/modals/skills/skill-modal.test.tsx @@ -0,0 +1,394 @@ +import { screen } from "@testing-library/react"; +import userEvent from "@testing-library/user-event"; +import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; +import { renderWithProviders } from "test-utils"; +import { SkillsModal } from "#/components/features/conversation-panel/skills-modal"; +import ConversationService from "#/api/conversation-service/conversation-service.api"; +import V1ConversationService from "#/api/conversation-service/v1-conversation-service.api"; +import { AgentState } from "#/types/agent-state"; +import { useAgentState } from "#/hooks/use-agent-state"; +import SettingsService from "#/api/settings-service/settings-service.api"; + +// Mock the agent state hook +vi.mock("#/hooks/use-agent-state", () => ({ + useAgentState: vi.fn(), +})); + +// Mock the conversation ID hook +vi.mock("#/hooks/use-conversation-id", () => ({ + useConversationId: () => ({ conversationId: "test-conversation-id" }), +})); + +describe("SkillsModal - Refresh Button", () => { + const mockOnClose = vi.fn(); + const conversationId = "test-conversation-id"; + + const defaultProps = { + onClose: mockOnClose, + conversationId, + }; + + const mockSkills = [ + { + name: "Test Agent 1", + type: "repo" as const, + triggers: ["test", "example"], + content: "This is test content for agent 1", + }, + { + name: "Test Agent 2", + type: "knowledge" as const, + triggers: ["help", "support"], + content: "This is test content for agent 2", + }, + ]; + + beforeEach(() => { + // Reset all mocks before each test + vi.clearAllMocks(); + + // Setup default mock for getMicroagents (V0) + vi.spyOn(ConversationService, "getMicroagents").mockResolvedValue({ + microagents: mockSkills, + }); + + // Mock the agent state to return a ready state + vi.mocked(useAgentState).mockReturnValue({ + curAgentState: AgentState.AWAITING_USER_INPUT, + }); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + describe("Refresh Button Rendering", () => { + it("should render the refresh button with correct text and test ID", async () => { + renderWithProviders(); + + // Wait for the component to load and render the refresh button + const refreshButton = await screen.findByTestId("refresh-skills"); + expect(refreshButton).toBeInTheDocument(); + expect(refreshButton).toHaveTextContent("BUTTON$REFRESH"); + }); + }); + + describe("Refresh Button Functionality", () => { + it("should call refetch when refresh button is clicked", async () => { + const user = userEvent.setup(); + const refreshSpy = vi.spyOn(ConversationService, "getMicroagents"); + + renderWithProviders(); + + // Wait for the component to load and render the refresh button + const refreshButton = await screen.findByTestId("refresh-skills"); + + // Clear previous calls to only track the click + refreshSpy.mockClear(); + + await user.click(refreshButton); + + // Verify the refresh triggered a new API call + expect(refreshSpy).toHaveBeenCalled(); + }); + }); +}); + +describe("useConversationSkills - V1 API Integration", () => { + const conversationId = "test-conversation-id"; + + const mockMicroagents = [ + { + name: "V0 Test Agent", + type: "repo" as const, + triggers: ["v0"], + content: "V0 skill content", + }, + ]; + + const mockSkills = [ + { + name: "V1 Test Skill", + type: "knowledge" as const, + triggers: ["v1", "skill"], + content: "V1 skill content", + }, + ]; + + beforeEach(() => { + vi.clearAllMocks(); + + // Mock agent state + vi.mocked(useAgentState).mockReturnValue({ + curAgentState: AgentState.AWAITING_USER_INPUT, + }); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + describe("V0 API Usage (v1_enabled: false)", () => { + it("should call v0 ConversationService.getMicroagents when v1_enabled is false", async () => { + // Arrange + const getMicroagentsSpy = vi + .spyOn(ConversationService, "getMicroagents") + .mockResolvedValue({ microagents: mockMicroagents }); + + vi.spyOn(SettingsService, "getSettings").mockResolvedValue({ + v1_enabled: false, + llm_model: "test-model", + llm_base_url: "", + agent: "test-agent", + language: "en", + llm_api_key: null, + llm_api_key_set: false, + search_api_key_set: false, + confirmation_mode: false, + security_analyzer: null, + remote_runtime_resource_factor: null, + provider_tokens_set: {}, + enable_default_condenser: false, + condenser_max_size: null, + enable_sound_notifications: false, + enable_proactive_conversation_starters: false, + enable_solvability_analysis: false, + user_consents_to_analytics: null, + max_budget_per_task: null, + }); + + // Act + renderWithProviders(); + + // Assert + await screen.findByText("V0 Test Agent"); + expect(getMicroagentsSpy).toHaveBeenCalledWith(conversationId); + expect(getMicroagentsSpy).toHaveBeenCalledTimes(1); + }); + + it("should display v0 skills correctly", async () => { + // Arrange + vi.spyOn(ConversationService, "getMicroagents").mockResolvedValue({ + microagents: mockMicroagents, + }); + + vi.spyOn(SettingsService, "getSettings").mockResolvedValue({ + v1_enabled: false, + llm_model: "test-model", + llm_base_url: "", + agent: "test-agent", + language: "en", + llm_api_key: null, + llm_api_key_set: false, + search_api_key_set: false, + confirmation_mode: false, + security_analyzer: null, + remote_runtime_resource_factor: null, + provider_tokens_set: {}, + enable_default_condenser: false, + condenser_max_size: null, + enable_sound_notifications: false, + enable_proactive_conversation_starters: false, + enable_solvability_analysis: false, + user_consents_to_analytics: null, + max_budget_per_task: null, + }); + + // Act + renderWithProviders(); + + // Assert + const agentName = await screen.findByText("V0 Test Agent"); + expect(agentName).toBeInTheDocument(); + }); + }); + + describe("V1 API Usage (v1_enabled: true)", () => { + it("should call v1 V1ConversationService.getSkills when v1_enabled is true", async () => { + // Arrange + const getSkillsSpy = vi + .spyOn(V1ConversationService, "getSkills") + .mockResolvedValue({ skills: mockSkills }); + + vi.spyOn(SettingsService, "getSettings").mockResolvedValue({ + v1_enabled: true, + llm_model: "test-model", + llm_base_url: "", + agent: "test-agent", + language: "en", + llm_api_key: null, + llm_api_key_set: false, + search_api_key_set: false, + confirmation_mode: false, + security_analyzer: null, + remote_runtime_resource_factor: null, + provider_tokens_set: {}, + enable_default_condenser: false, + condenser_max_size: null, + enable_sound_notifications: false, + enable_proactive_conversation_starters: false, + enable_solvability_analysis: false, + user_consents_to_analytics: null, + max_budget_per_task: null, + }); + + // Act + renderWithProviders(); + + // Assert + await screen.findByText("V1 Test Skill"); + expect(getSkillsSpy).toHaveBeenCalledWith(conversationId); + expect(getSkillsSpy).toHaveBeenCalledTimes(1); + }); + + it("should display v1 skills correctly", async () => { + // Arrange + vi.spyOn(V1ConversationService, "getSkills").mockResolvedValue({ + skills: mockSkills, + }); + + vi.spyOn(SettingsService, "getSettings").mockResolvedValue({ + v1_enabled: true, + llm_model: "test-model", + llm_base_url: "", + agent: "test-agent", + language: "en", + llm_api_key: null, + llm_api_key_set: false, + search_api_key_set: false, + confirmation_mode: false, + security_analyzer: null, + remote_runtime_resource_factor: null, + provider_tokens_set: {}, + enable_default_condenser: false, + condenser_max_size: null, + enable_sound_notifications: false, + enable_proactive_conversation_starters: false, + enable_solvability_analysis: false, + user_consents_to_analytics: null, + max_budget_per_task: null, + }); + + // Act + renderWithProviders(); + + // Assert + const skillName = await screen.findByText("V1 Test Skill"); + expect(skillName).toBeInTheDocument(); + }); + + it("should use v1 API when v1_enabled is true", async () => { + // Arrange + vi.spyOn(SettingsService, "getSettings").mockResolvedValue({ + v1_enabled: true, + llm_model: "test-model", + llm_base_url: "", + agent: "test-agent", + language: "en", + llm_api_key: null, + llm_api_key_set: false, + search_api_key_set: false, + confirmation_mode: false, + security_analyzer: null, + remote_runtime_resource_factor: null, + provider_tokens_set: {}, + enable_default_condenser: false, + condenser_max_size: null, + enable_sound_notifications: false, + enable_proactive_conversation_starters: false, + enable_solvability_analysis: false, + user_consents_to_analytics: null, + max_budget_per_task: null, + }); + + const getSkillsSpy = vi + .spyOn(V1ConversationService, "getSkills") + .mockResolvedValue({ + skills: mockSkills, + }); + + // Act + renderWithProviders(); + + // Assert + await screen.findByText("V1 Test Skill"); + // Verify v1 API was called + expect(getSkillsSpy).toHaveBeenCalledWith(conversationId); + }); + }); + + describe("API Switching on Settings Change", () => { + it("should refetch using different API when v1_enabled setting changes", async () => { + // Arrange + const getMicroagentsSpy = vi + .spyOn(ConversationService, "getMicroagents") + .mockResolvedValue({ microagents: mockMicroagents }); + const getSkillsSpy = vi + .spyOn(V1ConversationService, "getSkills") + .mockResolvedValue({ skills: mockSkills }); + + const settingsSpy = vi + .spyOn(SettingsService, "getSettings") + .mockResolvedValue({ + v1_enabled: false, + llm_model: "test-model", + llm_base_url: "", + agent: "test-agent", + language: "en", + llm_api_key: null, + llm_api_key_set: false, + search_api_key_set: false, + confirmation_mode: false, + security_analyzer: null, + remote_runtime_resource_factor: null, + provider_tokens_set: {}, + enable_default_condenser: false, + condenser_max_size: null, + enable_sound_notifications: false, + enable_proactive_conversation_starters: false, + enable_solvability_analysis: false, + user_consents_to_analytics: null, + max_budget_per_task: null, + }); + + // Act - Initial render with v1_enabled: false + const { rerender } = renderWithProviders( + , + ); + + // Assert - v0 API called initially + await screen.findByText("V0 Test Agent"); + expect(getMicroagentsSpy).toHaveBeenCalledWith(conversationId); + + // Arrange - Change settings to v1_enabled: true + settingsSpy.mockResolvedValue({ + v1_enabled: true, + llm_model: "test-model", + llm_base_url: "", + agent: "test-agent", + language: "en", + llm_api_key: null, + llm_api_key_set: false, + search_api_key_set: false, + confirmation_mode: false, + security_analyzer: null, + remote_runtime_resource_factor: null, + provider_tokens_set: {}, + enable_default_condenser: false, + condenser_max_size: null, + enable_sound_notifications: false, + enable_proactive_conversation_starters: false, + enable_solvability_analysis: false, + user_consents_to_analytics: null, + max_budget_per_task: null, + }); + + // Act - Force re-render + rerender(); + + // Assert - v1 API should be called after settings change + await screen.findByText("V1 Test Skill"); + expect(getSkillsSpy).toHaveBeenCalledWith(conversationId); + }); + }); +}); diff --git a/frontend/src/api/conversation-service/v1-conversation-service.api.ts b/frontend/src/api/conversation-service/v1-conversation-service.api.ts index bd37fa8180..d2f8f51ff5 100644 --- a/frontend/src/api/conversation-service/v1-conversation-service.api.ts +++ b/frontend/src/api/conversation-service/v1-conversation-service.api.ts @@ -11,6 +11,7 @@ import type { V1AppConversationStartTask, V1AppConversationStartTaskPage, V1AppConversation, + GetSkillsResponse, } from "./v1-conversation-service.types"; class V1ConversationService { @@ -315,6 +316,18 @@ class V1ConversationService { ); return data; } + + /** + * Get all skills associated with a V1 conversation + * @param conversationId The conversation ID + * @returns The available skills associated with the conversation + */ + static async getSkills(conversationId: string): Promise { + const { data } = await openHands.get( + `/api/v1/app-conversations/${conversationId}/skills`, + ); + return data; + } } export default V1ConversationService; diff --git a/frontend/src/api/conversation-service/v1-conversation-service.types.ts b/frontend/src/api/conversation-service/v1-conversation-service.types.ts index 621283c274..7c8b04ccbf 100644 --- a/frontend/src/api/conversation-service/v1-conversation-service.types.ts +++ b/frontend/src/api/conversation-service/v1-conversation-service.types.ts @@ -99,3 +99,14 @@ export interface V1AppConversation { conversation_url: string | null; session_api_key: string | null; } + +export interface Skill { + name: string; + type: "repo" | "knowledge"; + content: string; + triggers: string[]; +} + +export interface GetSkillsResponse { + skills: Skill[]; +} diff --git a/frontend/src/components/features/controls/tools-context-menu.tsx b/frontend/src/components/features/controls/tools-context-menu.tsx index 39330e25e4..2089f95111 100644 --- a/frontend/src/components/features/controls/tools-context-menu.tsx +++ b/frontend/src/components/features/controls/tools-context-menu.tsx @@ -26,14 +26,14 @@ const contextMenuListItemClassName = cn( interface ToolsContextMenuProps { onClose: () => void; - onShowMicroagents: (event: React.MouseEvent) => void; + onShowSkills: (event: React.MouseEvent) => void; onShowAgentTools: (event: React.MouseEvent) => void; shouldShowAgentTools?: boolean; } export function ToolsContextMenu({ onClose, - onShowMicroagents, + onShowSkills, onShowAgentTools, shouldShowAgentTools = true, }: ToolsContextMenuProps) { @@ -41,7 +41,6 @@ export function ToolsContextMenu({ const { data: conversation } = useActiveConversation(); const { providers } = useUserProviders(); - // TODO: Hide microagent menu items for V1 conversations // This is a temporary measure and may be re-enabled in the future const isV1Conversation = conversation?.conversation_version === "V1"; @@ -130,20 +129,17 @@ export function ToolsContextMenu({ {(!isV1Conversation || shouldShowAgentTools) && } - {/* Show Available Microagents - Hidden for V1 conversations */} - {!isV1Conversation && ( - - } - text={t(I18nKey.CONVERSATION$SHOW_MICROAGENTS)} - className={CONTEXT_MENU_ICON_TEXT_CLASSNAME} - /> - - )} + + } + text={t(I18nKey.CONVERSATION$SHOW_SKILLS)} + className={CONTEXT_MENU_ICON_TEXT_CLASSNAME} + /> + {/* Show Agent Tools and Metadata - Only show if system message is available */} {shouldShowAgentTools && ( diff --git a/frontend/src/components/features/controls/tools.tsx b/frontend/src/components/features/controls/tools.tsx index 56ef58bc8e..80994cbe65 100644 --- a/frontend/src/components/features/controls/tools.tsx +++ b/frontend/src/components/features/controls/tools.tsx @@ -7,7 +7,7 @@ import { ToolsContextMenu } from "./tools-context-menu"; import { useConversationNameContextMenu } from "#/hooks/use-conversation-name-context-menu"; import { useActiveConversation } from "#/hooks/query/use-active-conversation"; import { SystemMessageModal } from "../conversation-panel/system-message-modal"; -import { MicroagentsModal } from "../conversation-panel/microagents-modal"; +import { SkillsModal } from "../conversation-panel/skills-modal"; export function Tools() { const { t } = useTranslation(); @@ -17,11 +17,11 @@ export function Tools() { const { handleShowAgentTools, - handleShowMicroagents, + handleShowSkills, systemModalVisible, setSystemModalVisible, - microagentsModalVisible, - setMicroagentsModalVisible, + skillsModalVisible, + setSkillsModalVisible, systemMessage, shouldShowAgentTools, } = useConversationNameContextMenu({ @@ -51,7 +51,7 @@ export function Tools() { {contextMenuOpen && ( setContextMenuOpen(false)} - onShowMicroagents={handleShowMicroagents} + onShowSkills={handleShowSkills} onShowAgentTools={handleShowAgentTools} shouldShowAgentTools={shouldShowAgentTools} /> @@ -64,9 +64,9 @@ export function Tools() { systemMessage={systemMessage ? systemMessage.args : null} /> - {/* Microagents Modal */} - {microagentsModalVisible && ( - setMicroagentsModalVisible(false)} /> + {/* Skills Modal */} + {skillsModalVisible && ( + setSkillsModalVisible(false)} /> )}
); diff --git a/frontend/src/components/features/conversation-panel/conversation-card-context-menu.tsx b/frontend/src/components/features/conversation-panel/conversation-card-context-menu.tsx deleted file mode 100644 index 63ea33152b..0000000000 --- a/frontend/src/components/features/conversation-panel/conversation-card-context-menu.tsx +++ /dev/null @@ -1,147 +0,0 @@ -import { - Trash, - Power, - Pencil, - Download, - Wallet, - Wrench, - Bot, -} from "lucide-react"; -import { useTranslation } from "react-i18next"; -import { useClickOutsideElement } from "#/hooks/use-click-outside-element"; -import { cn } from "#/utils/utils"; -import { ContextMenu } from "#/ui/context-menu"; -import { ContextMenuListItem } from "../context-menu/context-menu-list-item"; -import { Divider } from "#/ui/divider"; -import { I18nKey } from "#/i18n/declaration"; -import { ContextMenuIconText } from "../context-menu/context-menu-icon-text"; -import { useActiveConversation } from "#/hooks/query/use-active-conversation"; - -interface ConversationCardContextMenuProps { - onClose: () => void; - onDelete?: (event: React.MouseEvent) => void; - onStop?: (event: React.MouseEvent) => void; - onEdit?: (event: React.MouseEvent) => void; - onDisplayCost?: (event: React.MouseEvent) => void; - onShowAgentTools?: (event: React.MouseEvent) => void; - onShowMicroagents?: (event: React.MouseEvent) => void; - onDownloadViaVSCode?: (event: React.MouseEvent) => void; - position?: "top" | "bottom"; -} - -export function ConversationCardContextMenu({ - onClose, - onDelete, - onStop, - onEdit, - onDisplayCost, - onShowAgentTools, - onShowMicroagents, - onDownloadViaVSCode, - position = "bottom", -}: ConversationCardContextMenuProps) { - const { t } = useTranslation(); - const ref = useClickOutsideElement(onClose); - const { data: conversation } = useActiveConversation(); - - // TODO: Hide microagent menu items for V1 conversations - // This is a temporary measure and may be re-enabled in the future - const isV1Conversation = conversation?.conversation_version === "V1"; - - const hasEdit = Boolean(onEdit); - const hasDownload = Boolean(onDownloadViaVSCode); - const hasTools = Boolean(onShowAgentTools || onShowMicroagents); - const hasInfo = Boolean(onDisplayCost); - const hasControl = Boolean(onStop || onDelete); - - return ( - - {onEdit && ( - - - - )} - - {hasEdit && (hasDownload || hasTools || hasInfo || hasControl) && ( - - )} - - {onDownloadViaVSCode && ( - - - - )} - - {hasDownload && (hasTools || hasInfo || hasControl) && } - - {onShowAgentTools && ( - - - - )} - - {onShowMicroagents && !isV1Conversation && ( - - - - )} - - {hasTools && (hasInfo || hasControl) && } - - {onDisplayCost && ( - - - - )} - - {hasInfo && hasControl && } - - {onStop && ( - - - - )} - - {onDelete && ( - - - - )} - - ); -} diff --git a/frontend/src/components/features/conversation-panel/conversation-card/conversation-card-context-menu.tsx b/frontend/src/components/features/conversation-panel/conversation-card/conversation-card-context-menu.tsx index 6565a83a10..30a7ec42cb 100644 --- a/frontend/src/components/features/conversation-panel/conversation-card/conversation-card-context-menu.tsx +++ b/frontend/src/components/features/conversation-panel/conversation-card/conversation-card-context-menu.tsx @@ -22,7 +22,7 @@ interface ConversationCardContextMenuProps { onEdit?: (event: React.MouseEvent) => void; onDisplayCost?: (event: React.MouseEvent) => void; onShowAgentTools?: (event: React.MouseEvent) => void; - onShowMicroagents?: (event: React.MouseEvent) => void; + onShowSkills?: (event: React.MouseEvent) => void; onDownloadViaVSCode?: (event: React.MouseEvent) => void; position?: "top" | "bottom"; } @@ -37,7 +37,7 @@ export function ConversationCardContextMenu({ onEdit, onDisplayCost, onShowAgentTools, - onShowMicroagents, + onShowSkills, onDownloadViaVSCode, position = "bottom", }: ConversationCardContextMenuProps) { @@ -96,15 +96,15 @@ export function ConversationCardContextMenu({ /> ), - onShowMicroagents && ( + onShowSkills && ( } - text={t(I18nKey.CONVERSATION$SHOW_MICROAGENTS)} + text={t(I18nKey.CONVERSATION$SHOW_SKILLS)} /> ), diff --git a/frontend/src/components/features/conversation-panel/microagent-content.tsx b/frontend/src/components/features/conversation-panel/skill-content.tsx similarity index 76% rename from frontend/src/components/features/conversation-panel/microagent-content.tsx rename to frontend/src/components/features/conversation-panel/skill-content.tsx index fad0485607..9303047e3a 100644 --- a/frontend/src/components/features/conversation-panel/microagent-content.tsx +++ b/frontend/src/components/features/conversation-panel/skill-content.tsx @@ -3,17 +3,17 @@ import { I18nKey } from "#/i18n/declaration"; import { Typography } from "#/ui/typography"; import { Pre } from "#/ui/pre"; -interface MicroagentContentProps { +interface SkillContentProps { content: string; } -export function MicroagentContent({ content }: MicroagentContentProps) { +export function SkillContent({ content }: SkillContentProps) { const { t } = useTranslation(); return (
- {t(I18nKey.MICROAGENTS_MODAL$CONTENT)} + {t(I18nKey.COMMON$CONTENT)}
-        {content || t(I18nKey.MICROAGENTS_MODAL$NO_CONTENT)}
+        {content || t(I18nKey.SKILLS_MODAL$NO_CONTENT)}
       
); diff --git a/frontend/src/components/features/conversation-panel/microagent-item.tsx b/frontend/src/components/features/conversation-panel/skill-item.tsx similarity index 65% rename from frontend/src/components/features/conversation-panel/microagent-item.tsx rename to frontend/src/components/features/conversation-panel/skill-item.tsx index d23febb099..c76bf10be9 100644 --- a/frontend/src/components/features/conversation-panel/microagent-item.tsx +++ b/frontend/src/components/features/conversation-panel/skill-item.tsx @@ -1,35 +1,31 @@ import { ChevronDown, ChevronRight } from "lucide-react"; -import { Microagent } from "#/api/open-hands.types"; import { Typography } from "#/ui/typography"; -import { MicroagentTriggers } from "./microagent-triggers"; -import { MicroagentContent } from "./microagent-content"; +import { SkillTriggers } from "./skill-triggers"; +import { SkillContent } from "./skill-content"; +import { Skill } from "#/api/conversation-service/v1-conversation-service.types"; -interface MicroagentItemProps { - agent: Microagent; +interface SkillItemProps { + skill: Skill; isExpanded: boolean; onToggle: (agentName: string) => void; } -export function MicroagentItem({ - agent, - isExpanded, - onToggle, -}: MicroagentItemProps) { +export function SkillItem({ skill, isExpanded, onToggle }: SkillItemProps) { return (