Files
OpenHands/enterprise/storage/encrypt_utils.py
chuckbutkus d5e66b4f3a SAAS: Introducing orgs (phase 1) (#11265)
Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: rohitvinodmalhotra@gmail.com <rohitvinodmalhotra@gmail.com>
Co-authored-by: Hiep Le <69354317+hieptl@users.noreply.github.com>
Co-authored-by: Tim O'Farrell <tofarr@gmail.com>
2026-01-15 22:03:31 -05:00

115 lines
3.3 KiB
Python

import binascii
import hashlib
from base64 import b64decode, b64encode
from cryptography.fernet import Fernet, InvalidToken
from pydantic import SecretStr
from server.config import get_config
_jwt_service = None
_fernet = None
def encrypt_model(encrypt_keys: list, model_instance) -> dict:
return encrypt_kwargs(encrypt_keys, model_to_kwargs(model_instance))
def decrypt_model(decrypt_keys: list, model_instance) -> dict:
return decrypt_kwargs(decrypt_keys, model_to_kwargs(model_instance))
def encrypt_kwargs(encrypt_keys: list, kwargs: dict) -> dict:
for key, value in kwargs.items():
if value is None:
continue
if isinstance(value, dict):
encrypt_kwargs(encrypt_keys, value)
continue
if key in encrypt_keys:
value = encrypt_value(value)
kwargs[key] = value
return kwargs
def decrypt_kwargs(encrypt_keys: list, kwargs: dict) -> dict:
for key, value in kwargs.items():
try:
if value is None:
continue
if key in encrypt_keys:
value = decrypt_value(value)
kwargs[key] = value
except binascii.Error:
pass # Key is in legacy format...
return kwargs
def encrypt_value(value: str | SecretStr) -> str:
return get_jwt_service().create_jwe_token(
{'v': value.get_secret_value() if isinstance(value, SecretStr) else value}
)
def decrypt_value(value: str | SecretStr) -> str:
token = get_jwt_service().decrypt_jwe_token(
value.get_secret_value() if isinstance(value, SecretStr) else value
)
return token['v']
def get_jwt_service():
from openhands.app_server.config import get_global_config
global _jwt_service
if _jwt_service is None:
jwt_service_injector = get_global_config().jwt
assert jwt_service_injector is not None
_jwt_service = jwt_service_injector.get_jwt_service()
return _jwt_service
def decrypt_legacy_model(decrypt_keys: list, model_instance) -> dict:
return decrypt_legacy_kwargs(decrypt_keys, model_to_kwargs(model_instance))
def decrypt_legacy_kwargs(encrypt_keys: list, kwargs: dict) -> dict:
for key, value in kwargs.items():
try:
if value is None:
continue
if key in encrypt_keys:
value = decrypt_legacy_value(value)
kwargs[key] = value
except binascii.Error:
pass # Key is in legacy format...
except InvalidToken:
pass # Key not encrypted...
return kwargs
def decrypt_legacy_value(value: str | SecretStr) -> str:
if isinstance(value, SecretStr):
return (
get_fernet().decrypt(b64decode(value.get_secret_value().encode())).decode()
)
else:
return get_fernet().decrypt(b64decode(value.encode())).decode()
def get_fernet():
global _fernet
if _fernet is None:
jwt_secret = get_config().jwt_secret.get_secret_value()
fernet_key = b64encode(hashlib.sha256(jwt_secret.encode()).digest())
_fernet = Fernet(fernet_key)
return _fernet
def model_to_kwargs(model_instance):
return {
column.name: getattr(model_instance, column.name)
for column in model_instance.__table__.columns
}