diff --git a/enterprise/allhands-realm-github-provider.json.tmpl b/enterprise/allhands-realm-github-provider.json.tmpl index 6cdaa34383..35ff5f0afc 100644 --- a/enterprise/allhands-realm-github-provider.json.tmpl +++ b/enterprise/allhands-realm-github-provider.json.tmpl @@ -721,6 +721,7 @@ "https://$WEB_HOST/oauth/keycloak/callback", "https://$WEB_HOST/oauth/keycloak/offline/callback", "https://$WEB_HOST/slack/keycloak-callback", + "https://$WEB_HOST/oauth/device/keycloak-callback", "https://$WEB_HOST/api/email/verified", "/realms/$KEYCLOAK_REALM_NAME/$KEYCLOAK_CLIENT_ID/*" ], diff --git a/enterprise/migrations/versions/084_create_device_codes_table.py b/enterprise/migrations/versions/084_create_device_codes_table.py new file mode 100644 index 0000000000..0898e09ef5 --- /dev/null +++ b/enterprise/migrations/versions/084_create_device_codes_table.py @@ -0,0 +1,49 @@ +"""Create device_codes table for OAuth 2.0 Device Flow + +Revision ID: 084 +Revises: 083 +Create Date: 2024-12-10 12:00:00.000000 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = '084' +down_revision = '083' +branch_labels = None +depends_on = None + + +def upgrade(): + """Create device_codes table for OAuth 2.0 Device Flow.""" + op.create_table( + 'device_codes', + sa.Column('id', sa.Integer(), autoincrement=True, nullable=False), + sa.Column('device_code', sa.String(length=128), nullable=False), + sa.Column('user_code', sa.String(length=16), nullable=False), + sa.Column('status', sa.String(length=32), nullable=False), + sa.Column('keycloak_user_id', sa.String(length=255), nullable=True), + sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False), + sa.Column('authorized_at', sa.DateTime(timezone=True), nullable=True), + # Rate limiting fields for RFC 8628 section 3.5 compliance + sa.Column('last_poll_time', sa.DateTime(timezone=True), nullable=True), + sa.Column('current_interval', sa.Integer(), nullable=False, default=5), + sa.PrimaryKeyConstraint('id'), + ) + + # Create indexes for efficient lookups + op.create_index( + 'ix_device_codes_device_code', 'device_codes', ['device_code'], unique=True + ) + op.create_index( + 'ix_device_codes_user_code', 'device_codes', ['user_code'], unique=True + ) + + +def downgrade(): + """Drop device_codes table.""" + op.drop_index('ix_device_codes_user_code', table_name='device_codes') + op.drop_index('ix_device_codes_device_code', table_name='device_codes') + op.drop_table('device_codes') diff --git a/enterprise/saas_server.py b/enterprise/saas_server.py index 4c3c7c49ba..96e19a9815 100644 --- a/enterprise/saas_server.py +++ b/enterprise/saas_server.py @@ -34,6 +34,7 @@ from server.routes.integration.jira_dc import jira_dc_integration_router # noqa from server.routes.integration.linear import linear_integration_router # noqa: E402 from server.routes.integration.slack import slack_router # noqa: E402 from server.routes.mcp_patch import patch_mcp_server # noqa: E402 +from server.routes.oauth_device import oauth_device_router # noqa: E402 from server.routes.readiness import readiness_router # noqa: E402 from server.routes.user import saas_user_router # noqa: E402 @@ -60,6 +61,7 @@ base_app.mount('/internal/metrics', metrics_app()) base_app.include_router(readiness_router) # Add routes for readiness checks base_app.include_router(api_router) # Add additional route for github auth base_app.include_router(oauth_router) # Add additional route for oauth callback +base_app.include_router(oauth_device_router) # Add OAuth 2.0 Device Flow routes base_app.include_router(saas_user_router) # Add additional route SAAS user calls base_app.include_router( billing_router diff --git a/enterprise/server/middleware.py b/enterprise/server/middleware.py index 2972c1ec38..54e3319595 100644 --- a/enterprise/server/middleware.py +++ b/enterprise/server/middleware.py @@ -152,17 +152,22 @@ class SetAuthCookieMiddleware: return False path = request.url.path - is_api_that_should_attach = path.startswith('/api') and path not in ( + ignore_paths = ( '/api/options/config', '/api/keycloak/callback', '/api/billing/success', '/api/billing/cancel', '/api/billing/customer-setup-success', '/api/billing/stripe-webhook', + '/oauth/device/authorize', + '/oauth/device/token', ) + if path in ignore_paths: + return False is_mcp = path.startswith('/mcp') - return is_api_that_should_attach or is_mcp + is_api_route = path.startswith('/api') + return is_api_route or is_mcp async def _logout(self, request: Request): # Log out of keycloak - this prevents issues where you did not log in with the idp you believe you used diff --git a/enterprise/server/routes/oauth_device.py b/enterprise/server/routes/oauth_device.py new file mode 100644 index 0000000000..39ff9a4081 --- /dev/null +++ b/enterprise/server/routes/oauth_device.py @@ -0,0 +1,324 @@ +"""OAuth 2.0 Device Flow endpoints for CLI authentication.""" + +from datetime import UTC, datetime, timedelta +from typing import Optional + +from fastapi import APIRouter, Depends, Form, HTTPException, Request, status +from fastapi.responses import JSONResponse +from pydantic import BaseModel +from storage.api_key_store import ApiKeyStore +from storage.database import session_maker +from storage.device_code_store import DeviceCodeStore + +from openhands.core.logger import openhands_logger as logger +from openhands.server.user_auth import get_user_id + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +DEVICE_CODE_EXPIRES_IN = 600 # 10 minutes +DEVICE_TOKEN_POLL_INTERVAL = 5 # seconds + +API_KEY_NAME = 'Device Link Access Key' +KEY_EXPIRATION_TIME = timedelta(days=1) # Key expires in 24 hours + +# --------------------------------------------------------------------------- +# Models +# --------------------------------------------------------------------------- + + +class DeviceAuthorizationResponse(BaseModel): + device_code: str + user_code: str + verification_uri: str + verification_uri_complete: str + expires_in: int + interval: int + + +class DeviceTokenResponse(BaseModel): + access_token: str # This will be the user's API key + token_type: str = 'Bearer' + expires_in: Optional[int] = None # API keys may not have expiration + + +class DeviceTokenErrorResponse(BaseModel): + error: str + error_description: Optional[str] = None + interval: Optional[int] = None # Required for slow_down error + + +# --------------------------------------------------------------------------- +# Router + stores +# --------------------------------------------------------------------------- + +oauth_device_router = APIRouter(prefix='/oauth/device') +device_code_store = DeviceCodeStore(session_maker) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _oauth_error( + status_code: int, + error: str, + description: str, + interval: Optional[int] = None, +) -> JSONResponse: + """Return a JSON OAuth-style error response.""" + return JSONResponse( + status_code=status_code, + content=DeviceTokenErrorResponse( + error=error, + error_description=description, + interval=interval, + ).model_dump(), + ) + + +# --------------------------------------------------------------------------- +# Endpoints +# --------------------------------------------------------------------------- + + +@oauth_device_router.post('/authorize', response_model=DeviceAuthorizationResponse) +async def device_authorization( + http_request: Request, +) -> DeviceAuthorizationResponse: + """Start device flow by generating device and user codes.""" + try: + device_code_entry = device_code_store.create_device_code( + expires_in=DEVICE_CODE_EXPIRES_IN, + ) + + base_url = str(http_request.base_url).rstrip('/') + verification_uri = f'{base_url}/oauth/device/verify' + verification_uri_complete = ( + f'{verification_uri}?user_code={device_code_entry.user_code}' + ) + + logger.info( + 'Device authorization initiated', + extra={'user_code': device_code_entry.user_code}, + ) + + return DeviceAuthorizationResponse( + device_code=device_code_entry.device_code, + user_code=device_code_entry.user_code, + verification_uri=verification_uri, + verification_uri_complete=verification_uri_complete, + expires_in=DEVICE_CODE_EXPIRES_IN, + interval=device_code_entry.current_interval, + ) + except Exception as e: + logger.exception('Error in device authorization: %s', str(e)) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail='Internal server error', + ) from e + + +@oauth_device_router.post('/token') +async def device_token(device_code: str = Form(...)): + """Poll for a token until the user authorizes or the code expires.""" + try: + device_code_entry = device_code_store.get_by_device_code(device_code) + + if not device_code_entry: + return _oauth_error( + status.HTTP_400_BAD_REQUEST, + 'invalid_grant', + 'Invalid device code', + ) + + # Check rate limiting (RFC 8628 section 3.5) + is_too_fast, current_interval = device_code_entry.check_rate_limit() + if is_too_fast: + # Update poll time and increase interval + device_code_store.update_poll_time(device_code, increase_interval=True) + logger.warning( + 'Client polling too fast, returning slow_down error', + extra={ + 'device_code': device_code[:8] + '...', # Log partial for privacy + 'new_interval': current_interval, + }, + ) + return _oauth_error( + status.HTTP_400_BAD_REQUEST, + 'slow_down', + f'Polling too frequently. Wait at least {current_interval} seconds between requests.', + interval=current_interval, + ) + + # Update poll time for successful rate limit check + device_code_store.update_poll_time(device_code, increase_interval=False) + + if device_code_entry.is_expired(): + return _oauth_error( + status.HTTP_400_BAD_REQUEST, + 'expired_token', + 'Device code has expired', + ) + + if device_code_entry.status == 'denied': + return _oauth_error( + status.HTTP_400_BAD_REQUEST, + 'access_denied', + 'User denied the authorization request', + ) + + if device_code_entry.status == 'pending': + return _oauth_error( + status.HTTP_400_BAD_REQUEST, + 'authorization_pending', + 'User has not yet completed authorization', + ) + + if device_code_entry.status == 'authorized': + # Retrieve the specific API key for this device using the user_code + api_key_store = ApiKeyStore.get_instance() + device_key_name = f'{API_KEY_NAME} ({device_code_entry.user_code})' + device_api_key = api_key_store.retrieve_api_key_by_name( + device_code_entry.keycloak_user_id, device_key_name + ) + + if not device_api_key: + logger.error( + 'No device API key found for authorized device', + extra={ + 'user_id': device_code_entry.keycloak_user_id, + 'user_code': device_code_entry.user_code, + }, + ) + return _oauth_error( + status.HTTP_500_INTERNAL_SERVER_ERROR, + 'server_error', + 'API key not found', + ) + + # Return the API key as access_token + return DeviceTokenResponse( + access_token=device_api_key, + ) + + # Fallback for unexpected status values + logger.error( + 'Unknown device code status', + extra={'status': device_code_entry.status}, + ) + return _oauth_error( + status.HTTP_500_INTERNAL_SERVER_ERROR, + 'server_error', + 'Unknown device code status', + ) + + except Exception as e: + logger.exception('Error in device token: %s', str(e)) + return _oauth_error( + status.HTTP_500_INTERNAL_SERVER_ERROR, + 'server_error', + 'Internal server error', + ) + + +@oauth_device_router.post('/verify-authenticated') +async def device_verification_authenticated( + user_code: str = Form(...), + user_id: str = Depends(get_user_id), +): + """Process device verification for authenticated users (called by frontend).""" + try: + if not user_id: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail='Authentication required', + ) + + # Validate device code + device_code_entry = device_code_store.get_by_user_code(user_code) + if not device_code_entry: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail='The device code is invalid or has expired.', + ) + + if not device_code_entry.is_pending(): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail='This device code has already been processed.', + ) + + # First, authorize the device code + success = device_code_store.authorize_device_code( + user_code=user_code, + user_id=user_id, + ) + + if not success: + logger.error( + 'Failed to authorize device code', + extra={'user_code': user_code, 'user_id': user_id}, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail='Failed to authorize the device. Please try again.', + ) + + # Only create API key AFTER successful authorization + api_key_store = ApiKeyStore.get_instance() + try: + # Create a unique API key for this device using user_code in the name + device_key_name = f'{API_KEY_NAME} ({user_code})' + api_key_store.create_api_key( + user_id, + name=device_key_name, + expires_at=datetime.now(UTC) + KEY_EXPIRATION_TIME, + ) + logger.info( + 'Created new device API key for user after successful authorization', + extra={'user_id': user_id, 'user_code': user_code}, + ) + except Exception as e: + logger.exception( + 'Failed to create device API key after authorization: %s', str(e) + ) + + # Clean up: revert the device authorization since API key creation failed + # This prevents the device from being in an authorized state without an API key + try: + device_code_store.deny_device_code(user_code) + logger.info( + 'Reverted device authorization due to API key creation failure', + extra={'user_code': user_code, 'user_id': user_id}, + ) + except Exception as cleanup_error: + logger.exception( + 'Failed to revert device authorization during cleanup: %s', + str(cleanup_error), + ) + + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail='Failed to create API key for device access.', + ) + + logger.info( + 'Device code authorized with API key successfully', + extra={'user_code': user_code, 'user_id': user_id}, + ) + return JSONResponse( + status_code=status.HTTP_200_OK, + content={'message': 'Device authorized successfully!'}, + ) + + except HTTPException: + raise + except Exception as e: + logger.exception('Error in device verification: %s', str(e)) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail='An unexpected error occurred. Please try again.', + ) diff --git a/enterprise/storage/api_key_store.py b/enterprise/storage/api_key_store.py index 162ed415c1..693bfdb321 100644 --- a/enterprise/storage/api_key_store.py +++ b/enterprise/storage/api_key_store.py @@ -57,9 +57,15 @@ class ApiKeyStore: return None # Check if the key has expired - if key_record.expires_at and key_record.expires_at < now: - logger.info(f'API key has expired: {key_record.id}') - return None + if key_record.expires_at: + # Handle timezone-naive datetime from database by assuming it's UTC + expires_at = key_record.expires_at + if expires_at.tzinfo is None: + expires_at = expires_at.replace(tzinfo=UTC) + + if expires_at < now: + logger.info(f'API key has expired: {key_record.id}') + return None # Update last_used_at timestamp session.execute( @@ -125,6 +131,33 @@ class ApiKeyStore: return None + def retrieve_api_key_by_name(self, user_id: str, name: str) -> str | None: + """Retrieve an API key by name for a specific user.""" + with self.session_maker() as session: + key_record = ( + session.query(ApiKey) + .filter(ApiKey.user_id == user_id, ApiKey.name == name) + .first() + ) + return key_record.key if key_record else None + + def delete_api_key_by_name(self, user_id: str, name: str) -> bool: + """Delete an API key by name for a specific user.""" + with self.session_maker() as session: + key_record = ( + session.query(ApiKey) + .filter(ApiKey.user_id == user_id, ApiKey.name == name) + .first() + ) + + if not key_record: + return False + + session.delete(key_record) + session.commit() + + return True + @classmethod def get_instance(cls) -> ApiKeyStore: """Get an instance of the ApiKeyStore.""" diff --git a/enterprise/storage/device_code.py b/enterprise/storage/device_code.py new file mode 100644 index 0000000000..47e18b51bc --- /dev/null +++ b/enterprise/storage/device_code.py @@ -0,0 +1,109 @@ +"""Device code storage model for OAuth 2.0 Device Flow.""" + +from datetime import datetime, timezone +from enum import Enum + +from sqlalchemy import Column, DateTime, Integer, String +from storage.base import Base + + +class DeviceCodeStatus(Enum): + """Status of a device code authorization request.""" + + PENDING = 'pending' + AUTHORIZED = 'authorized' + EXPIRED = 'expired' + DENIED = 'denied' + + +class DeviceCode(Base): + """Device code for OAuth 2.0 Device Flow. + + This stores the device codes issued during the device authorization flow, + along with their status and associated user information once authorized. + """ + + __tablename__ = 'device_codes' + + id = Column(Integer, primary_key=True, autoincrement=True) + device_code = Column(String(128), unique=True, nullable=False, index=True) + user_code = Column(String(16), unique=True, nullable=False, index=True) + status = Column(String(32), nullable=False, default=DeviceCodeStatus.PENDING.value) + + # Keycloak user ID who authorized the device (set during verification) + keycloak_user_id = Column(String(255), nullable=True) + + # Timestamps + expires_at = Column(DateTime(timezone=True), nullable=False) + authorized_at = Column(DateTime(timezone=True), nullable=True) + + # Rate limiting fields for RFC 8628 section 3.5 compliance + last_poll_time = Column(DateTime(timezone=True), nullable=True) + current_interval = Column(Integer, nullable=False, default=5) + + def __repr__(self) -> str: + return f"" + + def is_expired(self) -> bool: + """Check if the device code has expired.""" + now = datetime.now(timezone.utc) + return now > self.expires_at + + def is_pending(self) -> bool: + """Check if the device code is still pending authorization.""" + return self.status == DeviceCodeStatus.PENDING.value and not self.is_expired() + + def is_authorized(self) -> bool: + """Check if the device code has been authorized.""" + return self.status == DeviceCodeStatus.AUTHORIZED.value + + def authorize(self, user_id: str) -> None: + """Mark the device code as authorized.""" + self.status = DeviceCodeStatus.AUTHORIZED.value + self.keycloak_user_id = user_id # Set the Keycloak user ID during authorization + self.authorized_at = datetime.now(timezone.utc) + + def deny(self) -> None: + """Mark the device code as denied.""" + self.status = DeviceCodeStatus.DENIED.value + + def expire(self) -> None: + """Mark the device code as expired.""" + self.status = DeviceCodeStatus.EXPIRED.value + + def check_rate_limit(self) -> tuple[bool, int]: + """Check if the client is polling too fast. + + Returns: + tuple: (is_too_fast, current_interval) + - is_too_fast: True if client should receive slow_down error + - current_interval: Current polling interval to use + """ + now = datetime.now(timezone.utc) + + # If this is the first poll, allow it + if self.last_poll_time is None: + return False, self.current_interval + + # Calculate time since last poll + time_since_last_poll = (now - self.last_poll_time).total_seconds() + + # Check if polling too fast + if time_since_last_poll < self.current_interval: + # Increase interval for slow_down (RFC 8628 section 3.5) + new_interval = min(self.current_interval + 5, 60) # Cap at 60 seconds + return True, new_interval + + return False, self.current_interval + + def update_poll_time(self, increase_interval: bool = False) -> None: + """Update the last poll time and optionally increase the interval. + + Args: + increase_interval: If True, increase the current interval for slow_down + """ + self.last_poll_time = datetime.now(timezone.utc) + + if increase_interval: + # Increase interval by 5 seconds, cap at 60 seconds (RFC 8628) + self.current_interval = min(self.current_interval + 5, 60) diff --git a/enterprise/storage/device_code_store.py b/enterprise/storage/device_code_store.py new file mode 100644 index 0000000000..de2fe29cc4 --- /dev/null +++ b/enterprise/storage/device_code_store.py @@ -0,0 +1,167 @@ +"""Device code store for OAuth 2.0 Device Flow.""" + +import secrets +import string +from datetime import datetime, timedelta, timezone + +from sqlalchemy.exc import IntegrityError +from storage.device_code import DeviceCode + + +class DeviceCodeStore: + """Store for managing OAuth 2.0 device codes.""" + + def __init__(self, session_maker): + self.session_maker = session_maker + + def generate_user_code(self) -> str: + """Generate a human-readable user code (8 characters, uppercase letters and digits).""" + # Use a mix of uppercase letters and digits, avoiding confusing characters + alphabet = 'ABCDEFGHJKLMNPQRSTUVWXYZ23456789' # No I, O, 0, 1 + return ''.join(secrets.choice(alphabet) for _ in range(8)) + + def generate_device_code(self) -> str: + """Generate a secure device code (128 characters).""" + alphabet = string.ascii_letters + string.digits + return ''.join(secrets.choice(alphabet) for _ in range(128)) + + def create_device_code( + self, + expires_in: int = 600, # 10 minutes default + max_attempts: int = 10, + ) -> DeviceCode: + """Create a new device code entry. + + Uses database constraints to ensure uniqueness, avoiding TOCTOU race conditions. + Retries on constraint violations until unique codes are generated. + + Args: + expires_in: Expiration time in seconds + max_attempts: Maximum number of attempts to generate unique codes + + Returns: + The created DeviceCode instance + + Raises: + RuntimeError: If unable to generate unique codes after max_attempts + """ + for attempt in range(max_attempts): + user_code = self.generate_user_code() + device_code = self.generate_device_code() + expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in) + + device_code_entry = DeviceCode( + device_code=device_code, + user_code=user_code, + keycloak_user_id=None, # Will be set during authorization + expires_at=expires_at, + ) + + try: + with self.session_maker() as session: + session.add(device_code_entry) + session.commit() + session.refresh(device_code_entry) + session.expunge(device_code_entry) # Detach from session cleanly + return device_code_entry + except IntegrityError: + # Constraint violation - codes already exist, retry with new codes + continue + + raise RuntimeError( + f'Failed to generate unique device codes after {max_attempts} attempts' + ) + + def get_by_device_code(self, device_code: str) -> DeviceCode | None: + """Get device code entry by device code.""" + with self.session_maker() as session: + result = ( + session.query(DeviceCode).filter_by(device_code=device_code).first() + ) + if result: + session.expunge(result) # Detach from session cleanly + return result + + def get_by_user_code(self, user_code: str) -> DeviceCode | None: + """Get device code entry by user code.""" + with self.session_maker() as session: + result = session.query(DeviceCode).filter_by(user_code=user_code).first() + if result: + session.expunge(result) # Detach from session cleanly + return result + + def authorize_device_code(self, user_code: str, user_id: str) -> bool: + """Authorize a device code. + + Args: + user_code: The user code to authorize + user_id: The user ID from Keycloak + + Returns: + True if authorization was successful, False otherwise + """ + with self.session_maker() as session: + device_code_entry = ( + session.query(DeviceCode).filter_by(user_code=user_code).first() + ) + + if not device_code_entry: + return False + + if not device_code_entry.is_pending(): + return False + + device_code_entry.authorize(user_id) + session.commit() + + return True + + def deny_device_code(self, user_code: str) -> bool: + """Deny a device code authorization. + + Args: + user_code: The user code to deny + + Returns: + True if denial was successful, False otherwise + """ + with self.session_maker() as session: + device_code_entry = ( + session.query(DeviceCode).filter_by(user_code=user_code).first() + ) + + if not device_code_entry: + return False + + if not device_code_entry.is_pending(): + return False + + device_code_entry.deny() + session.commit() + + return True + + def update_poll_time( + self, device_code: str, increase_interval: bool = False + ) -> bool: + """Update the poll time for a device code and optionally increase interval. + + Args: + device_code: The device code to update + increase_interval: If True, increase the polling interval for slow_down + + Returns: + True if update was successful, False otherwise + """ + with self.session_maker() as session: + device_code_entry = ( + session.query(DeviceCode).filter_by(device_code=device_code).first() + ) + + if not device_code_entry: + return False + + device_code_entry.update_poll_time(increase_interval) + session.commit() + + return True diff --git a/enterprise/tests/unit/conftest.py b/enterprise/tests/unit/conftest.py index 08516fd813..873f7b775f 100644 --- a/enterprise/tests/unit/conftest.py +++ b/enterprise/tests/unit/conftest.py @@ -12,6 +12,7 @@ from storage.base import Base # Anything not loaded here may not have a table created for it. from storage.billing_session import BillingSession from storage.conversation_work import ConversationWork +from storage.device_code import DeviceCode # noqa: F401 from storage.feedback import Feedback from storage.github_app_installation import GithubAppInstallation from storage.maintenance_task import MaintenanceTask, MaintenanceTaskStatus diff --git a/enterprise/tests/unit/server/routes/test_oauth_device.py b/enterprise/tests/unit/server/routes/test_oauth_device.py new file mode 100644 index 0000000000..53682e65f0 --- /dev/null +++ b/enterprise/tests/unit/server/routes/test_oauth_device.py @@ -0,0 +1,610 @@ +"""Unit tests for OAuth2 Device Flow endpoints.""" + +from datetime import UTC, datetime, timedelta +from unittest.mock import MagicMock, patch + +import pytest +from fastapi import HTTPException, Request +from fastapi.responses import JSONResponse +from server.routes.oauth_device import ( + device_authorization, + device_token, + device_verification_authenticated, +) +from storage.device_code import DeviceCode + + +@pytest.fixture +def mock_device_code_store(): + """Mock device code store.""" + return MagicMock() + + +@pytest.fixture +def mock_api_key_store(): + """Mock API key store.""" + return MagicMock() + + +@pytest.fixture +def mock_token_manager(): + """Mock token manager.""" + return MagicMock() + + +@pytest.fixture +def mock_request(): + """Mock FastAPI request.""" + request = MagicMock(spec=Request) + request.base_url = 'https://test.example.com/' + return request + + +class TestDeviceAuthorization: + """Test device authorization endpoint.""" + + @patch('server.routes.oauth_device.device_code_store') + async def test_device_authorization_success(self, mock_store, mock_request): + """Test successful device authorization.""" + mock_device = DeviceCode( + device_code='test-device-code-123', + user_code='ABC12345', + expires_at=datetime.now(UTC) + timedelta(minutes=10), + current_interval=5, # Default interval + ) + mock_store.create_device_code.return_value = mock_device + + result = await device_authorization(mock_request) + + assert result.device_code == 'test-device-code-123' + assert result.user_code == 'ABC12345' + assert result.expires_in == 600 + assert result.interval == 5 # Should match device's current_interval + assert 'verify' in result.verification_uri + assert 'ABC12345' in result.verification_uri_complete + + @patch('server.routes.oauth_device.device_code_store') + async def test_device_authorization_with_increased_interval( + self, mock_store, mock_request + ): + """Test device authorization returns increased interval from rate limiting.""" + mock_device = DeviceCode( + device_code='test-device-code-456', + user_code='XYZ98765', + expires_at=datetime.now(UTC) + timedelta(minutes=10), + current_interval=15, # Increased interval from previous rate limiting + ) + mock_store.create_device_code.return_value = mock_device + + result = await device_authorization(mock_request) + + assert result.device_code == 'test-device-code-456' + assert result.user_code == 'XYZ98765' + assert result.expires_in == 600 + assert result.interval == 15 # Should match device's increased current_interval + assert 'verify' in result.verification_uri + assert 'XYZ98765' in result.verification_uri_complete + + +class TestDeviceToken: + """Test device token endpoint.""" + + @pytest.mark.parametrize( + 'device_exists,status,expected_error', + [ + (False, None, 'invalid_grant'), + (True, 'expired', 'expired_token'), + (True, 'denied', 'access_denied'), + (True, 'pending', 'authorization_pending'), + ], + ) + @patch('server.routes.oauth_device.device_code_store') + async def test_device_token_error_cases( + self, mock_store, device_exists, status, expected_error + ): + """Test various error cases for device token endpoint.""" + device_code = 'test-device-code' + + if device_exists: + mock_device = MagicMock() + mock_device.is_expired.return_value = status == 'expired' + mock_device.status = status + # Mock rate limiting - return False (not too fast) and default interval + mock_device.check_rate_limit.return_value = (False, 5) + mock_store.get_by_device_code.return_value = mock_device + mock_store.update_poll_time.return_value = True + else: + mock_store.get_by_device_code.return_value = None + + result = await device_token(device_code=device_code) + + assert isinstance(result, JSONResponse) + assert result.status_code == 400 + # Check error in response content + content = result.body.decode() + assert expected_error in content + + @patch('server.routes.oauth_device.ApiKeyStore') + @patch('server.routes.oauth_device.device_code_store') + async def test_device_token_success(self, mock_store, mock_api_key_class): + """Test successful device token retrieval.""" + device_code = 'test-device-code' + + # Mock authorized device + mock_device = MagicMock() + mock_device.is_expired.return_value = False + mock_device.status = 'authorized' + mock_device.keycloak_user_id = 'user-123' + mock_device.user_code = ( + 'ABC12345' # Add user_code for device-specific API key lookup + ) + # Mock rate limiting - return False (not too fast) and default interval + mock_device.check_rate_limit.return_value = (False, 5) + mock_store.get_by_device_code.return_value = mock_device + mock_store.update_poll_time.return_value = True + + # Mock API key retrieval + mock_api_key_store = MagicMock() + mock_api_key_store.retrieve_api_key_by_name.return_value = 'test-api-key' + mock_api_key_class.get_instance.return_value = mock_api_key_store + + result = await device_token(device_code=device_code) + + # Check that result is a DeviceTokenResponse + assert result.access_token == 'test-api-key' + assert result.token_type == 'Bearer' + + # Verify that the correct device-specific API key name was used + mock_api_key_store.retrieve_api_key_by_name.assert_called_once_with( + 'user-123', 'Device Link Access Key (ABC12345)' + ) + + +class TestDeviceVerificationAuthenticated: + """Test device verification authenticated endpoint.""" + + async def test_verification_unauthenticated_user(self): + """Test verification with unauthenticated user.""" + with pytest.raises(HTTPException): + await device_verification_authenticated(user_code='ABC12345', user_id=None) + + @patch('server.routes.oauth_device.ApiKeyStore') + @patch('server.routes.oauth_device.device_code_store') + async def test_verification_invalid_device_code( + self, mock_store, mock_api_key_class + ): + """Test verification with invalid device code.""" + mock_store.get_by_user_code.return_value = None + + with pytest.raises(HTTPException): + await device_verification_authenticated( + user_code='INVALID', user_id='user-123' + ) + + @patch('server.routes.oauth_device.ApiKeyStore') + @patch('server.routes.oauth_device.device_code_store') + async def test_verification_already_processed(self, mock_store, mock_api_key_class): + """Test verification with already processed device code.""" + mock_device = MagicMock() + mock_device.is_pending.return_value = False + mock_store.get_by_user_code.return_value = mock_device + + with pytest.raises(HTTPException): + await device_verification_authenticated( + user_code='ABC12345', user_id='user-123' + ) + + @patch('server.routes.oauth_device.ApiKeyStore') + @patch('server.routes.oauth_device.device_code_store') + async def test_verification_success(self, mock_store, mock_api_key_class): + """Test successful device verification.""" + # Mock device code + mock_device = MagicMock() + mock_device.is_pending.return_value = True + mock_store.get_by_user_code.return_value = mock_device + mock_store.authorize_device_code.return_value = True + + # Mock API key store + mock_api_key_store = MagicMock() + mock_api_key_class.get_instance.return_value = mock_api_key_store + + result = await device_verification_authenticated( + user_code='ABC12345', user_id='user-123' + ) + + assert isinstance(result, JSONResponse) + assert result.status_code == 200 + # Should NOT delete existing API keys (multiple devices allowed) + mock_api_key_store.delete_api_key_by_name.assert_not_called() + # Should create a new API key with device-specific name + mock_api_key_store.create_api_key.assert_called_once() + call_args = mock_api_key_store.create_api_key.call_args + assert call_args[1]['name'] == 'Device Link Access Key (ABC12345)' + mock_store.authorize_device_code.assert_called_once_with( + user_code='ABC12345', user_id='user-123' + ) + + @patch('server.routes.oauth_device.ApiKeyStore') + @patch('server.routes.oauth_device.device_code_store') + async def test_multiple_device_authentication(self, mock_store, mock_api_key_class): + """Test that multiple devices can authenticate simultaneously.""" + # Mock API key store + mock_api_key_store = MagicMock() + mock_api_key_class.get_instance.return_value = mock_api_key_store + + # Simulate two different devices + device1_code = 'ABC12345' + device2_code = 'XYZ67890' + user_id = 'user-123' + + # Mock device codes + mock_device1 = MagicMock() + mock_device1.is_pending.return_value = True + mock_device2 = MagicMock() + mock_device2.is_pending.return_value = True + + # Configure mock store to return appropriate device for each user_code + def get_by_user_code_side_effect(user_code): + if user_code == device1_code: + return mock_device1 + elif user_code == device2_code: + return mock_device2 + return None + + mock_store.get_by_user_code.side_effect = get_by_user_code_side_effect + mock_store.authorize_device_code.return_value = True + + # Authenticate first device + result1 = await device_verification_authenticated( + user_code=device1_code, user_id=user_id + ) + + # Authenticate second device + result2 = await device_verification_authenticated( + user_code=device2_code, user_id=user_id + ) + + # Both should succeed + assert isinstance(result1, JSONResponse) + assert result1.status_code == 200 + assert isinstance(result2, JSONResponse) + assert result2.status_code == 200 + + # Should create two separate API keys with different names + assert mock_api_key_store.create_api_key.call_count == 2 + + # Check that each device got a unique API key name + call_args_list = mock_api_key_store.create_api_key.call_args_list + device1_name = call_args_list[0][1]['name'] + device2_name = call_args_list[1][1]['name'] + + assert device1_name == f'Device Link Access Key ({device1_code})' + assert device2_name == f'Device Link Access Key ({device2_code})' + assert device1_name != device2_name # Ensure they're different + + # Should NOT delete any existing API keys + mock_api_key_store.delete_api_key_by_name.assert_not_called() + + +class TestDeviceTokenRateLimiting: + """Test rate limiting for device token polling (RFC 8628 section 3.5).""" + + @patch('server.routes.oauth_device.device_code_store') + async def test_first_poll_allowed(self, mock_store): + """Test that the first poll is always allowed.""" + # Create a device code with no previous poll time + mock_device = DeviceCode( + device_code='test_device_code', + user_code='ABC123', + status='pending', + expires_at=datetime.now(UTC) + timedelta(minutes=10), + last_poll_time=None, # First poll + current_interval=5, + ) + mock_store.get_by_device_code.return_value = mock_device + mock_store.update_poll_time.return_value = True + + device_code = 'test_device_code' + result = await device_token(device_code=device_code) + + # Should return authorization_pending, not slow_down + assert isinstance(result, JSONResponse) + assert result.status_code == 400 + content = result.body.decode() + assert 'authorization_pending' in content + assert 'slow_down' not in content + + # Should update poll time without increasing interval + mock_store.update_poll_time.assert_called_with( + 'test_device_code', increase_interval=False + ) + + @patch('server.routes.oauth_device.device_code_store') + async def test_normal_polling_allowed(self, mock_store): + """Test that normal polling (respecting interval) is allowed.""" + # Create a device code with last poll time 6 seconds ago (interval is 5) + last_poll = datetime.now(UTC) - timedelta(seconds=6) + mock_device = DeviceCode( + device_code='test_device_code', + user_code='ABC123', + status='pending', + expires_at=datetime.now(UTC) + timedelta(minutes=10), + last_poll_time=last_poll, + current_interval=5, + ) + mock_store.get_by_device_code.return_value = mock_device + mock_store.update_poll_time.return_value = True + + device_code = 'test_device_code' + result = await device_token(device_code=device_code) + + # Should return authorization_pending, not slow_down + assert isinstance(result, JSONResponse) + assert result.status_code == 400 + content = result.body.decode() + assert 'authorization_pending' in content + assert 'slow_down' not in content + + # Should update poll time without increasing interval + mock_store.update_poll_time.assert_called_with( + 'test_device_code', increase_interval=False + ) + + @patch('server.routes.oauth_device.device_code_store') + async def test_fast_polling_returns_slow_down(self, mock_store): + """Test that polling too fast returns slow_down error.""" + # Create a device code with last poll time 2 seconds ago (interval is 5) + last_poll = datetime.now(UTC) - timedelta(seconds=2) + mock_device = DeviceCode( + device_code='test_device_code', + user_code='ABC123', + status='pending', + expires_at=datetime.now(UTC) + timedelta(minutes=10), + last_poll_time=last_poll, + current_interval=5, + ) + mock_store.get_by_device_code.return_value = mock_device + mock_store.update_poll_time.return_value = True + + device_code = 'test_device_code' + result = await device_token(device_code=device_code) + + # Should return slow_down error + assert isinstance(result, JSONResponse) + assert result.status_code == 400 + content = result.body.decode() + assert 'slow_down' in content + assert 'interval' in content + assert '10' in content # New interval should be 5 + 5 = 10 + + # Should update poll time and increase interval + mock_store.update_poll_time.assert_called_with( + 'test_device_code', increase_interval=True + ) + + @patch('server.routes.oauth_device.device_code_store') + async def test_interval_increases_with_repeated_fast_polling(self, mock_store): + """Test that interval increases with repeated fast polling.""" + # Create a device code with higher current interval from previous slow_down + last_poll = datetime.now(UTC) - timedelta(seconds=5) # 5 seconds ago + mock_device = DeviceCode( + device_code='test_device_code', + user_code='ABC123', + status='pending', + expires_at=datetime.now(UTC) + timedelta(minutes=10), + last_poll_time=last_poll, + current_interval=15, # Already increased from previous slow_down + ) + mock_store.get_by_device_code.return_value = mock_device + mock_store.update_poll_time.return_value = True + + device_code = 'test_device_code' + result = await device_token(device_code=device_code) + + # Should return slow_down error with increased interval + assert isinstance(result, JSONResponse) + assert result.status_code == 400 + content = result.body.decode() + assert 'slow_down' in content + assert '20' in content # New interval should be 15 + 5 = 20 + + # Should update poll time and increase interval + mock_store.update_poll_time.assert_called_with( + 'test_device_code', increase_interval=True + ) + + @patch('server.routes.oauth_device.device_code_store') + async def test_interval_caps_at_maximum(self, mock_store): + """Test that interval is capped at maximum value.""" + # Create a device code with interval near maximum + last_poll = datetime.now(UTC) - timedelta(seconds=30) + mock_device = DeviceCode( + device_code='test_device_code', + user_code='ABC123', + status='pending', + expires_at=datetime.now(UTC) + timedelta(minutes=10), + last_poll_time=last_poll, + current_interval=58, # Near maximum of 60 + ) + mock_store.get_by_device_code.return_value = mock_device + mock_store.update_poll_time.return_value = True + + device_code = 'test_device_code' + result = await device_token(device_code=device_code) + + # Should return slow_down error with capped interval + assert isinstance(result, JSONResponse) + assert result.status_code == 400 + content = result.body.decode() + assert 'slow_down' in content + assert '60' in content # Should be capped at 60, not 63 + + @patch('server.routes.oauth_device.device_code_store') + async def test_rate_limiting_with_authorized_device(self, mock_store): + """Test that rate limiting still applies to authorized devices.""" + # Create an authorized device code with recent poll + last_poll = datetime.now(UTC) - timedelta(seconds=2) + mock_device = DeviceCode( + device_code='test_device_code', + user_code='ABC123', + status='authorized', # Device is authorized + keycloak_user_id='user123', + expires_at=datetime.now(UTC) + timedelta(minutes=10), + last_poll_time=last_poll, + current_interval=5, + ) + mock_store.get_by_device_code.return_value = mock_device + mock_store.update_poll_time.return_value = True + + device_code = 'test_device_code' + result = await device_token(device_code=device_code) + + # Should still return slow_down error even for authorized device + assert isinstance(result, JSONResponse) + assert result.status_code == 400 + content = result.body.decode() + assert 'slow_down' in content + + # Should update poll time and increase interval + mock_store.update_poll_time.assert_called_with( + 'test_device_code', increase_interval=True + ) + + +class TestDeviceVerificationTransactionIntegrity: + """Test transaction integrity for device verification to prevent orphaned API keys.""" + + @patch('server.routes.oauth_device.ApiKeyStore') + @patch('server.routes.oauth_device.device_code_store') + async def test_authorization_failure_prevents_api_key_creation( + self, mock_store, mock_api_key_class + ): + """Test that if device authorization fails, no API key is created.""" + # Mock device code + mock_device = MagicMock() + mock_device.is_pending.return_value = True + mock_store.get_by_user_code.return_value = mock_device + mock_store.authorize_device_code.return_value = False # Authorization fails + + # Mock API key store + mock_api_key_store = MagicMock() + mock_api_key_class.get_instance.return_value = mock_api_key_store + + # Should raise HTTPException due to authorization failure + with pytest.raises(HTTPException) as exc_info: + await device_verification_authenticated( + user_code='ABC12345', user_id='user-123' + ) + + assert exc_info.value.status_code == 500 + assert 'Failed to authorize the device' in exc_info.value.detail + + # API key should NOT be created since authorization failed + mock_api_key_store.create_api_key.assert_not_called() + mock_store.authorize_device_code.assert_called_once_with( + user_code='ABC12345', user_id='user-123' + ) + + @patch('server.routes.oauth_device.ApiKeyStore') + @patch('server.routes.oauth_device.device_code_store') + async def test_api_key_creation_failure_reverts_authorization( + self, mock_store, mock_api_key_class + ): + """Test that if API key creation fails after authorization, the authorization is reverted.""" + # Mock device code + mock_device = MagicMock() + mock_device.is_pending.return_value = True + mock_store.get_by_user_code.return_value = mock_device + mock_store.authorize_device_code.return_value = True # Authorization succeeds + mock_store.deny_device_code.return_value = True # Cleanup succeeds + + # Mock API key store to fail on creation + mock_api_key_store = MagicMock() + mock_api_key_store.create_api_key.side_effect = Exception('Database error') + mock_api_key_class.get_instance.return_value = mock_api_key_store + + # Should raise HTTPException due to API key creation failure + with pytest.raises(HTTPException) as exc_info: + await device_verification_authenticated( + user_code='ABC12345', user_id='user-123' + ) + + assert exc_info.value.status_code == 500 + assert 'Failed to create API key for device access' in exc_info.value.detail + + # Authorization should have been attempted first + mock_store.authorize_device_code.assert_called_once_with( + user_code='ABC12345', user_id='user-123' + ) + + # API key creation should have been attempted after authorization + mock_api_key_store.create_api_key.assert_called_once() + + # Authorization should be reverted due to API key creation failure + mock_store.deny_device_code.assert_called_once_with('ABC12345') + + @patch('server.routes.oauth_device.ApiKeyStore') + @patch('server.routes.oauth_device.device_code_store') + async def test_api_key_creation_failure_cleanup_failure_logged( + self, mock_store, mock_api_key_class + ): + """Test that cleanup failure is logged but doesn't prevent the main error from being raised.""" + # Mock device code + mock_device = MagicMock() + mock_device.is_pending.return_value = True + mock_store.get_by_user_code.return_value = mock_device + mock_store.authorize_device_code.return_value = True # Authorization succeeds + mock_store.deny_device_code.side_effect = Exception( + 'Cleanup failed' + ) # Cleanup fails + + # Mock API key store to fail on creation + mock_api_key_store = MagicMock() + mock_api_key_store.create_api_key.side_effect = Exception('Database error') + mock_api_key_class.get_instance.return_value = mock_api_key_store + + # Should still raise HTTPException for the original API key creation failure + with pytest.raises(HTTPException) as exc_info: + await device_verification_authenticated( + user_code='ABC12345', user_id='user-123' + ) + + assert exc_info.value.status_code == 500 + assert 'Failed to create API key for device access' in exc_info.value.detail + + # Both operations should have been attempted + mock_store.authorize_device_code.assert_called_once() + mock_api_key_store.create_api_key.assert_called_once() + mock_store.deny_device_code.assert_called_once_with('ABC12345') + + @patch('server.routes.oauth_device.ApiKeyStore') + @patch('server.routes.oauth_device.device_code_store') + async def test_successful_flow_creates_api_key_after_authorization( + self, mock_store, mock_api_key_class + ): + """Test that in the successful flow, API key is created only after authorization.""" + # Mock device code + mock_device = MagicMock() + mock_device.is_pending.return_value = True + mock_store.get_by_user_code.return_value = mock_device + mock_store.authorize_device_code.return_value = True # Authorization succeeds + + # Mock API key store + mock_api_key_store = MagicMock() + mock_api_key_class.get_instance.return_value = mock_api_key_store + + result = await device_verification_authenticated( + user_code='ABC12345', user_id='user-123' + ) + + assert isinstance(result, JSONResponse) + assert result.status_code == 200 + + # Verify the order: authorization first, then API key creation + mock_store.authorize_device_code.assert_called_once_with( + user_code='ABC12345', user_id='user-123' + ) + mock_api_key_store.create_api_key.assert_called_once() + + # No cleanup should be needed in successful case + mock_store.deny_device_code.assert_not_called() diff --git a/enterprise/tests/unit/storage/test_device_code.py b/enterprise/tests/unit/storage/test_device_code.py new file mode 100644 index 0000000000..0d2193075b --- /dev/null +++ b/enterprise/tests/unit/storage/test_device_code.py @@ -0,0 +1,83 @@ +"""Unit tests for DeviceCode model.""" + +from datetime import datetime, timedelta, timezone + +import pytest +from storage.device_code import DeviceCode, DeviceCodeStatus + + +class TestDeviceCode: + """Test cases for DeviceCode model.""" + + @pytest.fixture + def device_code(self): + """Create a test device code.""" + return DeviceCode( + device_code='test-device-code-123', + user_code='ABC12345', + expires_at=datetime.now(timezone.utc) + timedelta(minutes=10), + ) + + @pytest.mark.parametrize( + 'expires_delta,expected', + [ + (timedelta(minutes=5), False), # Future expiry + (timedelta(minutes=-5), True), # Past expiry + (timedelta(seconds=1), False), # Just future (not expired) + ], + ) + def test_is_expired(self, expires_delta, expected): + """Test expiration check with various time deltas.""" + device_code = DeviceCode( + device_code='test-device-code', + user_code='ABC12345', + expires_at=datetime.now(timezone.utc) + expires_delta, + ) + assert device_code.is_expired() == expected + + @pytest.mark.parametrize( + 'status,expired,expected', + [ + (DeviceCodeStatus.PENDING.value, False, True), + (DeviceCodeStatus.PENDING.value, True, False), + (DeviceCodeStatus.AUTHORIZED.value, False, False), + (DeviceCodeStatus.DENIED.value, False, False), + ], + ) + def test_is_pending(self, status, expired, expected): + """Test pending status check.""" + expires_at = ( + datetime.now(timezone.utc) - timedelta(minutes=1) + if expired + else datetime.now(timezone.utc) + timedelta(minutes=10) + ) + device_code = DeviceCode( + device_code='test-device-code', + user_code='ABC12345', + status=status, + expires_at=expires_at, + ) + assert device_code.is_pending() == expected + + def test_authorize(self, device_code): + """Test device authorization.""" + user_id = 'test-user-123' + + device_code.authorize(user_id) + + assert device_code.status == DeviceCodeStatus.AUTHORIZED.value + assert device_code.keycloak_user_id == user_id + assert device_code.authorized_at is not None + assert isinstance(device_code.authorized_at, datetime) + + @pytest.mark.parametrize( + 'method,expected_status', + [ + ('deny', DeviceCodeStatus.DENIED.value), + ('expire', DeviceCodeStatus.EXPIRED.value), + ], + ) + def test_status_changes(self, device_code, method, expected_status): + """Test status change methods.""" + getattr(device_code, method)() + assert device_code.status == expected_status diff --git a/enterprise/tests/unit/storage/test_device_code_store.py b/enterprise/tests/unit/storage/test_device_code_store.py new file mode 100644 index 0000000000..65a58cda8a --- /dev/null +++ b/enterprise/tests/unit/storage/test_device_code_store.py @@ -0,0 +1,193 @@ +"""Unit tests for DeviceCodeStore.""" + +from unittest.mock import MagicMock + +import pytest +from sqlalchemy.exc import IntegrityError +from storage.device_code import DeviceCode +from storage.device_code_store import DeviceCodeStore + + +@pytest.fixture +def mock_session(): + """Mock database session.""" + session = MagicMock() + return session + + +@pytest.fixture +def mock_session_maker(mock_session): + """Mock session maker.""" + session_maker = MagicMock() + session_maker.return_value.__enter__.return_value = mock_session + session_maker.return_value.__exit__.return_value = None + return session_maker + + +@pytest.fixture +def device_code_store(mock_session_maker): + """Create DeviceCodeStore instance.""" + return DeviceCodeStore(mock_session_maker) + + +class TestDeviceCodeStore: + """Test cases for DeviceCodeStore.""" + + def test_generate_user_code(self, device_code_store): + """Test user code generation.""" + code = device_code_store.generate_user_code() + + assert len(code) == 8 + assert code.isupper() + # Should not contain confusing characters + assert not any(char in code for char in 'IO01') + + def test_generate_device_code(self, device_code_store): + """Test device code generation.""" + code = device_code_store.generate_device_code() + + assert len(code) == 128 + assert code.isalnum() + + def test_create_device_code_success(self, device_code_store, mock_session): + """Test successful device code creation.""" + # Mock successful creation (no IntegrityError) + mock_device_code = MagicMock(spec=DeviceCode) + mock_device_code.device_code = 'test-device-code-123' + mock_device_code.user_code = 'TESTCODE' + + # Mock the session to return our mock device code after refresh + def mock_refresh(obj): + obj.device_code = mock_device_code.device_code + obj.user_code = mock_device_code.user_code + + mock_session.refresh.side_effect = mock_refresh + + result = device_code_store.create_device_code(expires_in=600) + + assert isinstance(result, DeviceCode) + mock_session.add.assert_called_once() + mock_session.commit.assert_called_once() + mock_session.refresh.assert_called_once() + mock_session.expunge.assert_called_once() + + def test_create_device_code_with_retries( + self, device_code_store, mock_session_maker + ): + """Test device code creation with constraint violation retries.""" + mock_session = MagicMock() + mock_session_maker.return_value.__enter__.return_value = mock_session + mock_session_maker.return_value.__exit__.return_value = None + + # First attempt fails with IntegrityError, second succeeds + mock_session.commit.side_effect = [IntegrityError('', '', ''), None] + + mock_device_code = MagicMock(spec=DeviceCode) + mock_device_code.device_code = 'test-device-code-456' + mock_device_code.user_code = 'TESTCD2' + + def mock_refresh(obj): + obj.device_code = mock_device_code.device_code + obj.user_code = mock_device_code.user_code + + mock_session.refresh.side_effect = mock_refresh + + store = DeviceCodeStore(mock_session_maker) + result = store.create_device_code(expires_in=600) + + assert isinstance(result, DeviceCode) + assert mock_session.add.call_count == 2 # Two attempts + assert mock_session.commit.call_count == 2 # Two attempts + + def test_create_device_code_max_attempts_exceeded( + self, device_code_store, mock_session_maker + ): + """Test device code creation failure after max attempts.""" + mock_session = MagicMock() + mock_session_maker.return_value.__enter__.return_value = mock_session + mock_session_maker.return_value.__exit__.return_value = None + + # All attempts fail with IntegrityError + mock_session.commit.side_effect = IntegrityError('', '', '') + + store = DeviceCodeStore(mock_session_maker) + + with pytest.raises( + RuntimeError, + match='Failed to generate unique device codes after 3 attempts', + ): + store.create_device_code(expires_in=600, max_attempts=3) + + @pytest.mark.parametrize( + 'lookup_method,lookup_field', + [ + ('get_by_device_code', 'device_code'), + ('get_by_user_code', 'user_code'), + ], + ) + def test_lookup_methods( + self, device_code_store, mock_session, lookup_method, lookup_field + ): + """Test device code lookup methods.""" + test_code = 'test-code-123' + mock_device_code = MagicMock() + mock_session.query.return_value.filter_by.return_value.first.return_value = ( + mock_device_code + ) + + result = getattr(device_code_store, lookup_method)(test_code) + + assert result == mock_device_code + mock_session.query.assert_called_once_with(DeviceCode) + mock_session.query.return_value.filter_by.assert_called_once_with( + **{lookup_field: test_code} + ) + + @pytest.mark.parametrize( + 'device_exists,is_pending,expected_result', + [ + (True, True, True), # Success case + (False, True, False), # Device not found + (True, False, False), # Device not pending + ], + ) + def test_authorize_device_code( + self, + device_code_store, + mock_session, + device_exists, + is_pending, + expected_result, + ): + """Test device code authorization.""" + user_code = 'ABC12345' + user_id = 'test-user-123' + + if device_exists: + mock_device = MagicMock() + mock_device.is_pending.return_value = is_pending + mock_session.query.return_value.filter_by.return_value.first.return_value = mock_device + else: + mock_session.query.return_value.filter_by.return_value.first.return_value = None + + result = device_code_store.authorize_device_code(user_code, user_id) + + assert result == expected_result + if expected_result: + mock_device.authorize.assert_called_once_with(user_id) + mock_session.commit.assert_called_once() + + def test_deny_device_code(self, device_code_store, mock_session): + """Test device code denial.""" + user_code = 'ABC12345' + mock_device = MagicMock() + mock_device.is_pending.return_value = True + mock_session.query.return_value.filter_by.return_value.first.return_value = ( + mock_device + ) + + result = device_code_store.deny_device_code(user_code) + + assert result is True + mock_device.deny.assert_called_once() + mock_session.commit.assert_called_once() diff --git a/enterprise/tests/unit/test_api_key_store.py b/enterprise/tests/unit/test_api_key_store.py index ea386cb69c..c1c6a98f3d 100644 --- a/enterprise/tests/unit/test_api_key_store.py +++ b/enterprise/tests/unit/test_api_key_store.py @@ -90,6 +90,50 @@ def test_validate_api_key_expired(api_key_store, mock_session): mock_session.commit.assert_not_called() +def test_validate_api_key_expired_timezone_naive(api_key_store, mock_session): + """Test validating an expired API key with timezone-naive datetime from database.""" + # Setup + api_key = 'test-api-key' + mock_key_record = MagicMock() + # Simulate timezone-naive datetime as returned from database + mock_key_record.expires_at = datetime.now() - timedelta(days=1) # No UTC timezone + mock_key_record.id = 1 + mock_session.query.return_value.filter.return_value.first.return_value = ( + mock_key_record + ) + + # Execute + result = api_key_store.validate_api_key(api_key) + + # Verify + assert result is None + mock_session.execute.assert_not_called() + mock_session.commit.assert_not_called() + + +def test_validate_api_key_valid_timezone_naive(api_key_store, mock_session): + """Test validating a valid API key with timezone-naive datetime from database.""" + # Setup + api_key = 'test-api-key' + user_id = 'test-user-123' + mock_key_record = MagicMock() + mock_key_record.user_id = user_id + # Simulate timezone-naive datetime as returned from database (future date) + mock_key_record.expires_at = datetime.now() + timedelta(days=1) # No UTC timezone + mock_key_record.id = 1 + mock_session.query.return_value.filter.return_value.first.return_value = ( + mock_key_record + ) + + # Execute + result = api_key_store.validate_api_key(api_key) + + # Verify + assert result == user_id + mock_session.execute.assert_called_once() + mock_session.commit.assert_called_once() + + def test_validate_api_key_not_found(api_key_store, mock_session): """Test validating a non-existent API key.""" # Setup diff --git a/frontend/src/routes.ts b/frontend/src/routes.ts index 4c3c48adc5..ecee511688 100644 --- a/frontend/src/routes.ts +++ b/frontend/src/routes.ts @@ -21,5 +21,6 @@ export default [ ]), route("conversations/:conversationId", "routes/conversation.tsx"), route("microagent-management", "routes/microagent-management.tsx"), + route("oauth/device/verify", "routes/device-verify.tsx"), ]), ] satisfies RouteConfig; diff --git a/frontend/src/routes/device-verify.tsx b/frontend/src/routes/device-verify.tsx new file mode 100644 index 0000000000..f306d660a5 --- /dev/null +++ b/frontend/src/routes/device-verify.tsx @@ -0,0 +1,274 @@ +/* eslint-disable i18next/no-literal-string */ +import React, { useState } from "react"; +import { useSearchParams } from "react-router"; +import { useIsAuthed } from "#/hooks/query/use-is-authed"; + +export default function DeviceVerify() { + const [searchParams] = useSearchParams(); + const { data: isAuthed, isLoading: isAuthLoading } = useIsAuthed(); + const [verificationResult, setVerificationResult] = useState<{ + success: boolean; + message: string; + } | null>(null); + const [isProcessing, setIsProcessing] = useState(false); + + // Get user_code from URL parameters + const userCode = searchParams.get("user_code"); + + const processDeviceVerification = async (code: string) => { + try { + setIsProcessing(true); + + // Call the backend API endpoint to process device verification + const response = await fetch("/oauth/device/verify-authenticated", { + method: "POST", + headers: { + "Content-Type": "application/x-www-form-urlencoded", + }, + body: `user_code=${encodeURIComponent(code)}`, + credentials: "include", // Include cookies for authentication + }); + + if (response.ok) { + // Show success message + setVerificationResult({ + success: true, + message: + "Device authorized successfully! You can now return to your CLI and close this window.", + }); + } else { + const errorText = await response.text(); + setVerificationResult({ + success: false, + message: errorText || "Failed to authorize device. Please try again.", + }); + } + } catch (error) { + setVerificationResult({ + success: false, + message: + "An error occurred while authorizing the device. Please try again.", + }); + } finally { + setIsProcessing(false); + } + }; + + // Remove automatic verification - require explicit user consent + + const handleManualSubmit = (event: React.FormEvent) => { + event.preventDefault(); + const formData = new FormData(event.currentTarget); + const code = formData.get("user_code") as string; + if (code && isAuthed) { + processDeviceVerification(code); + } + }; + + // Show verification result if we have one + if (verificationResult) { + return ( +
+
+
+
+ {verificationResult.success ? ( + + + + ) : ( + + + + )} +
+

+ {verificationResult.success ? "Success!" : "Error"} +

+

+ {verificationResult.message} +

+ {!verificationResult.success && ( + + )} +
+
+
+ ); + } + + // Show processing state + if (isProcessing) { + return ( +
+
+
+
+

+ Processing device verification... +

+
+
+
+ ); + } + + // Show device authorization confirmation if user is authenticated and code is provided + if (isAuthed && userCode) { + return ( +
+
+

+ Device Authorization Request +

+
+

Device Code:

+

+ {userCode} +

+
+
+
+ + + +
+

+ Security Notice +

+

+ Only authorize this device if you initiated this request from + your CLI or application. +

+
+
+
+

+ Do you want to authorize this device to access your OpenHands + account? +

+
+ + +
+
+
+ ); + } + + // Show manual code entry form if no code in URL but user is authenticated + if (isAuthed && !userCode) { + return ( +
+
+

+ Device Authorization +

+

+ Enter the code displayed on your device: +

+
+
+ + +
+ +
+
+
+ ); + } + + // Show loading state while checking authentication + if (isAuthLoading) { + return ( +
+
+
+

+ Processing device verification... +

+
+
+ ); + } + + // Show authentication required message (this will trigger the auth modal via root layout) + return ( +
+
+

Authentication Required

+

+ Please sign in to authorize your device. +

+
+
+ ); +}