mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Replace legacy session.query().filter().limit().all() pattern with the modern select().where().limit() + execute().scalars().all() pattern, which is more idiomatic and consistent with other parts of the codebase (e.g., gitlab_webhook_store.py). Co-authored-by: openhands <openhands@all-hands.dev>
207 lines
7.0 KiB
Python
207 lines
7.0 KiB
Python
"""Device code store for OAuth 2.0 Device Flow."""
|
|
|
|
import secrets
|
|
import string
|
|
from datetime import datetime, timedelta, timezone
|
|
|
|
from sqlalchemy import delete, select
|
|
from sqlalchemy.exc import IntegrityError
|
|
from storage.device_code import DeviceCode
|
|
|
|
from openhands.core.logger import openhands_logger as logger
|
|
|
|
|
|
class DeviceCodeStore:
|
|
"""Store for managing OAuth 2.0 device codes."""
|
|
|
|
def __init__(self, session_maker):
|
|
self.session_maker = session_maker
|
|
|
|
def generate_user_code(self) -> str:
|
|
"""Generate a human-readable user code (8 characters, uppercase letters and digits)."""
|
|
# Use a mix of uppercase letters and digits, avoiding confusing characters
|
|
alphabet = 'ABCDEFGHJKLMNPQRSTUVWXYZ23456789' # No I, O, 0, 1
|
|
return ''.join(secrets.choice(alphabet) for _ in range(8))
|
|
|
|
def generate_device_code(self) -> str:
|
|
"""Generate a secure device code (128 characters)."""
|
|
alphabet = string.ascii_letters + string.digits
|
|
return ''.join(secrets.choice(alphabet) for _ in range(128))
|
|
|
|
def create_device_code(
|
|
self,
|
|
expires_in: int = 600, # 10 minutes default
|
|
max_attempts: int = 10,
|
|
) -> DeviceCode:
|
|
"""Create a new device code entry.
|
|
|
|
Uses database constraints to ensure uniqueness, avoiding TOCTOU race conditions.
|
|
Retries on constraint violations until unique codes are generated.
|
|
|
|
Args:
|
|
expires_in: Expiration time in seconds
|
|
max_attempts: Maximum number of attempts to generate unique codes
|
|
|
|
Returns:
|
|
The created DeviceCode instance
|
|
|
|
Raises:
|
|
RuntimeError: If unable to generate unique codes after max_attempts
|
|
"""
|
|
for attempt in range(max_attempts):
|
|
user_code = self.generate_user_code()
|
|
device_code = self.generate_device_code()
|
|
expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in)
|
|
|
|
device_code_entry = DeviceCode(
|
|
device_code=device_code,
|
|
user_code=user_code,
|
|
keycloak_user_id=None, # Will be set during authorization
|
|
expires_at=expires_at,
|
|
)
|
|
|
|
try:
|
|
with self.session_maker() as session:
|
|
session.add(device_code_entry)
|
|
session.commit()
|
|
session.refresh(device_code_entry)
|
|
session.expunge(device_code_entry) # Detach from session cleanly
|
|
return device_code_entry
|
|
except IntegrityError:
|
|
# Constraint violation - codes already exist, retry with new codes
|
|
continue
|
|
|
|
raise RuntimeError(
|
|
f'Failed to generate unique device codes after {max_attempts} attempts'
|
|
)
|
|
|
|
def get_by_device_code(self, device_code: str) -> DeviceCode | None:
|
|
"""Get device code entry by device code."""
|
|
with self.session_maker() as session:
|
|
result = (
|
|
session.query(DeviceCode).filter_by(device_code=device_code).first()
|
|
)
|
|
if result:
|
|
session.expunge(result) # Detach from session cleanly
|
|
return result
|
|
|
|
def get_by_user_code(self, user_code: str) -> DeviceCode | None:
|
|
"""Get device code entry by user code."""
|
|
with self.session_maker() as session:
|
|
result = session.query(DeviceCode).filter_by(user_code=user_code).first()
|
|
if result:
|
|
session.expunge(result) # Detach from session cleanly
|
|
return result
|
|
|
|
def authorize_device_code(self, user_code: str, user_id: str) -> bool:
|
|
"""Authorize a device code.
|
|
|
|
Args:
|
|
user_code: The user code to authorize
|
|
user_id: The user ID from Keycloak
|
|
|
|
Returns:
|
|
True if authorization was successful, False otherwise
|
|
"""
|
|
with self.session_maker() as session:
|
|
device_code_entry = (
|
|
session.query(DeviceCode).filter_by(user_code=user_code).first()
|
|
)
|
|
|
|
if not device_code_entry:
|
|
return False
|
|
|
|
if not device_code_entry.is_pending():
|
|
return False
|
|
|
|
device_code_entry.authorize(user_id)
|
|
session.commit()
|
|
|
|
return True
|
|
|
|
def deny_device_code(self, user_code: str) -> bool:
|
|
"""Deny a device code authorization.
|
|
|
|
Args:
|
|
user_code: The user code to deny
|
|
|
|
Returns:
|
|
True if denial was successful, False otherwise
|
|
"""
|
|
with self.session_maker() as session:
|
|
device_code_entry = (
|
|
session.query(DeviceCode).filter_by(user_code=user_code).first()
|
|
)
|
|
|
|
if not device_code_entry:
|
|
return False
|
|
|
|
if not device_code_entry.is_pending():
|
|
return False
|
|
|
|
device_code_entry.deny()
|
|
session.commit()
|
|
|
|
return True
|
|
|
|
def update_poll_time(
|
|
self, device_code: str, increase_interval: bool = False
|
|
) -> bool:
|
|
"""Update the poll time for a device code and optionally increase interval.
|
|
|
|
Args:
|
|
device_code: The device code to update
|
|
increase_interval: If True, increase the polling interval for slow_down
|
|
|
|
Returns:
|
|
True if update was successful, False otherwise
|
|
"""
|
|
with self.session_maker() as session:
|
|
device_code_entry = (
|
|
session.query(DeviceCode).filter_by(device_code=device_code).first()
|
|
)
|
|
|
|
if not device_code_entry:
|
|
return False
|
|
|
|
device_code_entry.update_poll_time(increase_interval)
|
|
session.commit()
|
|
|
|
return True
|
|
|
|
def cleanup_stale_device_codes(self, limit: int = 100) -> int:
|
|
"""Clean up expired device codes based on oldest creation dates.
|
|
|
|
Removes device codes that are expired (past their expires_at time).
|
|
|
|
Args:
|
|
limit: Maximum number of codes to delete
|
|
|
|
Returns:
|
|
Total number of device codes deleted
|
|
"""
|
|
with self.session_maker() as session:
|
|
# Get expired device codes, ordered by oldest first (using ID as proxy for creation order)
|
|
query = (
|
|
select(DeviceCode)
|
|
.where(DeviceCode.expires_at < datetime.now(timezone.utc))
|
|
.order_by(DeviceCode.id.asc())
|
|
.limit(limit)
|
|
)
|
|
result = session.execute(query)
|
|
expired_codes = result.scalars().all()
|
|
|
|
if not expired_codes:
|
|
logger.info('No expired device codes found')
|
|
return 0
|
|
|
|
# Delete the expired codes
|
|
code_ids = [code.id for code in expired_codes]
|
|
delete_stmt = delete(DeviceCode).where(DeviceCode.id.in_(code_ids))
|
|
result = session.execute(delete_stmt)
|
|
session.commit()
|
|
|
|
deleted_count = result.rowcount
|
|
logger.info(f'Deleted {deleted_count} expired device codes')
|
|
return deleted_count
|