Tim O'Farrell f292f3a84d
V1 Integration (#11183)
Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: sp.wack <83104063+amanape@users.noreply.github.com>
Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
2025-10-14 02:16:44 +00:00

249 lines
8.1 KiB
Python

import hashlib
import json
from datetime import timedelta
from pathlib import Path
from typing import Any, AsyncGenerator
import jwt
from fastapi import Request
from jose import jwe
from jose.constants import ALGORITHMS
from pydantic import BaseModel, PrivateAttr
from openhands.agent_server.utils import utc_now
from openhands.app_server.services.injector import Injector, InjectorState
from openhands.app_server.utils.encryption_key import (
EncryptionKey,
get_default_encryption_keys,
)
class JwtService:
"""Service for signing/verifying JWS tokens and encrypting/decrypting JWE tokens."""
def __init__(self, keys: list[EncryptionKey]):
"""Initialize the JWT service with a list of keys.
Args:
keys: List of EncryptionKey objects. If None, will try to load from config.
Raises:
ValueError: If no keys are provided and config is not available
"""
active_keys = [key for key in keys if key.active]
if not active_keys:
raise ValueError('At least one active key is required')
# Store keys by ID for quick lookup
self._keys = {key.id: key for key in keys}
# Find the newest key as default
newest_key = max(active_keys, key=lambda k: k.created_at)
self._default_key_id = newest_key.id
@property
def default_key_id(self) -> str:
"""Get the default key ID."""
return self._default_key_id
def create_jws_token(
self,
payload: dict[str, Any],
key_id: str | None = None,
expires_in: timedelta | None = None,
) -> str:
"""Create a JWS (JSON Web Signature) token.
Args:
payload: The JWT payload
key_id: The key ID to use for signing. If None, uses the newest key.
expires_in: Token expiration time. If None, defaults to 1 hour.
Returns:
The signed JWS token
Raises:
ValueError: If key_id is invalid
"""
if key_id is None:
key_id = self._default_key_id
if key_id not in self._keys:
raise ValueError(f"Key ID '{key_id}' not found")
# Add standard JWT claims
now = utc_now()
if expires_in is None:
expires_in = timedelta(hours=1)
jwt_payload = {
**payload,
'iat': int(now.timestamp()),
'exp': int((now + expires_in).timestamp()),
}
# Use the raw key for JWT signing with key_id in header
secret_key = self._keys[key_id].key.get_secret_value()
return jwt.encode(
jwt_payload, secret_key, algorithm='HS256', headers={'kid': key_id}
)
def verify_jws_token(self, token: str, key_id: str | None = None) -> dict[str, Any]:
"""Verify and decode a JWS token.
Args:
token: The JWS token to verify
key_id: The key ID to use for verification. If None, extracts from
token's kid header.
Returns:
The decoded JWT payload
Raises:
ValueError: If token is invalid or key_id is not found
jwt.InvalidTokenError: If token verification fails
"""
if key_id is None:
# Try to extract key_id from the token's kid header
try:
unverified_header = jwt.get_unverified_header(token)
key_id = unverified_header.get('kid')
if not key_id:
raise ValueError("Token does not contain 'kid' header with key ID")
except jwt.DecodeError:
raise ValueError('Invalid JWT token format')
if key_id not in self._keys:
raise ValueError(f"Key ID '{key_id}' not found")
# Use the raw key for JWT verification
secret_key = self._keys[key_id].key.get_secret_value()
try:
payload = jwt.decode(token, secret_key, algorithms=['HS256'])
return payload
except jwt.InvalidTokenError as e:
raise jwt.InvalidTokenError(f'Token verification failed: {str(e)}')
def create_jwe_token(
self,
payload: dict[str, Any],
key_id: str | None = None,
expires_in: timedelta | None = None,
) -> str:
"""Create a JWE (JSON Web Encryption) token.
Args:
payload: The JWT payload to encrypt
key_id: The key ID to use for encryption. If None, uses the newest key.
expires_in: Token expiration time. If None, defaults to 1 hour.
Returns:
The encrypted JWE token
Raises:
ValueError: If key_id is invalid
"""
if key_id is None:
key_id = self._default_key_id
if key_id not in self._keys:
raise ValueError(f"Key ID '{key_id}' not found")
# Add standard JWT claims
now = utc_now()
if expires_in is None:
expires_in = timedelta(hours=1)
jwt_payload = {
**payload,
'iat': int(now.timestamp()),
'exp': int((now + expires_in).timestamp()),
}
# Get the raw key for JWE encryption and derive a 256-bit key
secret_key = self._keys[key_id].key.get_secret_value()
key_bytes = secret_key.encode() if isinstance(secret_key, str) else secret_key
# Derive a 256-bit key using SHA256
key_256 = hashlib.sha256(key_bytes).digest()
# Encrypt the payload (convert to JSON string first)
payload_json = json.dumps(jwt_payload)
encrypted_token = jwe.encrypt(
payload_json,
key_256,
algorithm=ALGORITHMS.DIR,
encryption=ALGORITHMS.A256GCM,
kid=key_id,
)
# Ensure we return a string
return (
encrypted_token.decode('utf-8')
if isinstance(encrypted_token, bytes)
else encrypted_token
)
def decrypt_jwe_token(
self, token: str, key_id: str | None = None
) -> dict[str, Any]:
"""Decrypt and decode a JWE token.
Args:
token: The JWE token to decrypt
key_id: The key ID to use for decryption. If None, extracts
from token header.
Returns:
The decrypted JWT payload
Raises:
ValueError: If token is invalid or key_id is not found
Exception: If token decryption fails
"""
if key_id is None:
# Try to extract key_id from the token's header
try:
header = jwe.get_unverified_header(token)
key_id = header.get('kid')
if not key_id:
raise ValueError("Token does not contain 'kid' header with key ID")
except Exception:
raise ValueError('Invalid JWE token format')
if key_id not in self._keys:
raise ValueError(f"Key ID '{key_id}' not found")
# Get the raw key for JWE decryption and derive a 256-bit key
secret_key = self._keys[key_id].key.get_secret_value()
key_bytes = secret_key.encode() if isinstance(secret_key, str) else secret_key
# Derive a 256-bit key using SHA256
key_256 = hashlib.sha256(key_bytes).digest()
try:
payload_json = jwe.decrypt(token, key_256)
assert payload_json is not None
# Parse the JSON string back to dictionary
payload = json.loads(payload_json)
return payload
except Exception as e:
raise Exception(f'Token decryption failed: {str(e)}')
class JwtServiceInjector(BaseModel, Injector[JwtService]):
persistence_dir: Path
_jwt_service: JwtService | None = PrivateAttr(default=None)
def get_jwt_service(self) -> JwtService:
jwt_service = self._jwt_service
if jwt_service is None:
keys = get_default_encryption_keys(self.persistence_dir)
jwt_service = JwtService(keys=keys)
self._jwt_service = jwt_service
return jwt_service
async def inject(
self, state: InjectorState, request: Request | None = None
) -> AsyncGenerator[JwtService, None]:
yield self.get_jwt_service()