feat: Add BYOR export flag to org for LLM key access control (#12753)

Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: hieptl <hieptl.developer@gmail.com>
This commit is contained in:
Tim O'Farrell
2026-02-06 09:30:12 -07:00
committed by GitHub
parent 8cd8c011b2
commit d43ff82534
9 changed files with 262 additions and 10 deletions

View File

@@ -0,0 +1,46 @@
"""Add byor_export_enabled flag to org table.
Revision ID: 091
Revises: 090
Create Date: 2025-01-15 00:00:00.000000
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '091'
down_revision: Union[str, None] = '090'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Add byor_export_enabled column to org table with default false
op.add_column(
'org',
sa.Column(
'byor_export_enabled',
sa.Boolean,
nullable=False,
server_default=sa.text('false'),
),
)
# Set byor_export_enabled to true for orgs that have completed billing sessions
op.execute(
sa.text("""
UPDATE org SET byor_export_enabled = TRUE
WHERE id IN (
SELECT DISTINCT org_id FROM billing_sessions
WHERE status = 'completed' AND org_id IS NOT NULL
)
""")
)
def downgrade() -> None:
op.drop_column('org', 'byor_export_enabled')

View File

@@ -6,12 +6,30 @@ from storage.api_key_store import ApiKeyStore
from storage.lite_llm_manager import LiteLlmManager
from storage.org_member import OrgMember
from storage.org_member_store import OrgMemberStore
from storage.org_store import OrgStore
from storage.user_store import UserStore
from openhands.core.logger import openhands_logger as logger
from openhands.server.user_auth import get_user_id
async def check_byor_export_enabled(user_id: str) -> bool:
"""Check if BYOR export is enabled for the user's current org.
Returns True if the user's current org has byor_export_enabled set to True.
Returns False if the user is not found, has no current org, or the flag is False.
"""
user = await UserStore.get_user_by_id_async(user_id)
if not user or not user.current_org_id:
return False
org = OrgStore.get_org_by_id(user.current_org_id)
if not org:
return False
return org.byor_export_enabled
# Helper functions for BYOR API key management
async def get_byor_key_from_db(user_id: str) -> str | None:
"""Get the BYOR key from the database for a user."""
@@ -52,7 +70,6 @@ async def store_byor_key_in_db(user_id: str, key: str) -> None:
async def generate_byor_key(user_id: str) -> str | None:
"""Generate a new BYOR key for a user."""
try:
user = await UserStore.get_user_by_id_async(user_id)
if not user:
@@ -148,6 +165,26 @@ class LlmApiKeyResponse(BaseModel):
key: str | None
class ByorPermittedResponse(BaseModel):
permitted: bool
@api_router.get('/llm/byor/permitted', response_model=ByorPermittedResponse)
async def check_byor_permitted(user_id: str = Depends(get_user_id)):
"""Check if BYOR key export is permitted for the user's current org."""
try:
permitted = await check_byor_export_enabled(user_id)
return {'permitted': permitted}
except Exception as e:
logger.exception(
'Error checking BYOR export permission', extra={'error': str(e)}
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Failed to check BYOR export permission',
)
@api_router.post('', response_model=ApiKeyCreateResponse)
async def create_api_key(key_data: ApiKeyCreate, user_id: str = Depends(get_user_id)):
"""Create a new API key for the authenticated user."""
@@ -253,8 +290,17 @@ async def get_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
This endpoint validates that the key exists in LiteLLM before returning it.
If validation fails, it automatically generates a new key to ensure users
always receive a working key.
Returns 402 Payment Required if BYOR export is not enabled for the user's org.
"""
try:
# Check if BYOR export is enabled for the user's org
if not await check_byor_export_enabled(user_id):
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail='BYOR key export is not enabled. Purchase credits to enable this feature.',
)
# Check if the BYOR key exists in the database
byor_key = await get_byor_key_from_db(user_id)
if byor_key:
@@ -310,10 +356,20 @@ async def get_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
@api_router.post('/llm/byor/refresh', response_model=LlmApiKeyResponse)
async def refresh_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
"""Refresh the LLM API key for BYOR (Bring Your Own Runtime) for the authenticated user."""
"""Refresh the LLM API key for BYOR (Bring Your Own Runtime) for the authenticated user.
Returns 402 Payment Required if BYOR export is not enabled for the user's org.
"""
logger.info('Starting BYOR LLM API key refresh', extra={'user_id': user_id})
try:
# Check if BYOR export is enabled for the user's org
if not await check_byor_export_enabled(user_id):
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail='BYOR key export is not enabled. Purchase credits to enable this feature.',
)
# Get the existing BYOR key from the database
existing_byor_key = await get_byor_key_from_db(user_id)

View File

@@ -46,6 +46,7 @@ class Org(Base): # type: ignore
v1_enabled = Column(Boolean, nullable=True)
conversation_expiration = Column(Integer, nullable=True)
condenser_max_size = Column(Integer, nullable=True)
byor_export_enabled = Column(Boolean, nullable=False, default=False)
# Relationships
org_members = relationship('OrgMember', back_populates='org')

View File

@@ -172,6 +172,19 @@ class UserStore:
)
decrypted_user_settings = UserSettings(**kwargs)
with session_maker() as session:
# Check if user has completed billing sessions to enable BYOR export
from storage.billing_session import BillingSession
has_completed_billing = (
session.query(BillingSession)
.filter(
BillingSession.user_id == user_id,
BillingSession.status == 'completed',
)
.first()
is not None
)
# create personal org
org = Org(
id=uuid.UUID(user_id),
@@ -180,6 +193,7 @@ class UserStore:
contact_name=resolve_display_name(user_info)
or user_info.get('username', ''),
contact_email=user_info['email'],
byor_export_enabled=has_completed_billing,
)
session.add(org)

View File

@@ -182,16 +182,18 @@ class TestGetLlmApiKeyForByor:
"""Test the get_llm_api_key_for_byor endpoint."""
@pytest.mark.asyncio
@patch('server.routes.api_keys.check_byor_export_enabled')
@patch('server.routes.api_keys.store_byor_key_in_db')
@patch('server.routes.api_keys.generate_byor_key')
@patch('server.routes.api_keys.get_byor_key_from_db')
async def test_no_key_in_database_generates_new(
self, mock_get_key, mock_generate_key, mock_store_key
self, mock_get_key, mock_generate_key, mock_store_key, mock_check_enabled
):
"""Test that when no key exists in database, a new one is generated."""
# Arrange
user_id = 'user-123'
new_key = 'sk-new-generated-key'
mock_check_enabled.return_value = True
mock_get_key.return_value = None
mock_generate_key.return_value = new_key
mock_store_key.return_value = None
@@ -201,20 +203,23 @@ class TestGetLlmApiKeyForByor:
# Assert
assert result == {'key': new_key}
mock_check_enabled.assert_called_once_with(user_id)
mock_get_key.assert_called_once_with(user_id)
mock_generate_key.assert_called_once_with(user_id)
mock_store_key.assert_called_once_with(user_id, new_key)
@pytest.mark.asyncio
@patch('server.routes.api_keys.check_byor_export_enabled')
@patch('storage.lite_llm_manager.LiteLlmManager.verify_key')
@patch('server.routes.api_keys.get_byor_key_from_db')
async def test_valid_key_in_database_returns_key(
self, mock_get_key, mock_verify_key
self, mock_get_key, mock_verify_key, mock_check_enabled
):
"""Test that when a valid key exists in database, it is returned."""
# Arrange
user_id = 'user-123'
existing_key = 'sk-existing-valid-key'
mock_check_enabled.return_value = True
mock_get_key.return_value = existing_key
mock_verify_key.return_value = True
@@ -223,10 +228,12 @@ class TestGetLlmApiKeyForByor:
# Assert
assert result == {'key': existing_key}
mock_check_enabled.assert_called_once_with(user_id)
mock_get_key.assert_called_once_with(user_id)
mock_verify_key.assert_called_once_with(existing_key, user_id)
@pytest.mark.asyncio
@patch('server.routes.api_keys.check_byor_export_enabled')
@patch('server.routes.api_keys.store_byor_key_in_db')
@patch('server.routes.api_keys.generate_byor_key')
@patch('server.routes.api_keys.delete_byor_key_from_litellm')
@@ -239,12 +246,14 @@ class TestGetLlmApiKeyForByor:
mock_delete_key,
mock_generate_key,
mock_store_key,
mock_check_enabled,
):
"""Test that when an invalid key exists in database, it is regenerated."""
# Arrange
user_id = 'user-123'
invalid_key = 'sk-invalid-key'
new_key = 'sk-new-generated-key'
mock_check_enabled.return_value = True
mock_get_key.return_value = invalid_key
mock_verify_key.return_value = False
mock_delete_key.return_value = True
@@ -256,6 +265,7 @@ class TestGetLlmApiKeyForByor:
# Assert
assert result == {'key': new_key}
mock_check_enabled.assert_called_once_with(user_id)
mock_get_key.assert_called_once_with(user_id)
mock_verify_key.assert_called_once_with(invalid_key, user_id)
mock_delete_key.assert_called_once_with(user_id, invalid_key)
@@ -263,6 +273,7 @@ class TestGetLlmApiKeyForByor:
mock_store_key.assert_called_once_with(user_id, new_key)
@pytest.mark.asyncio
@patch('server.routes.api_keys.check_byor_export_enabled')
@patch('server.routes.api_keys.store_byor_key_in_db')
@patch('server.routes.api_keys.generate_byor_key')
@patch('server.routes.api_keys.delete_byor_key_from_litellm')
@@ -275,12 +286,14 @@ class TestGetLlmApiKeyForByor:
mock_delete_key,
mock_generate_key,
mock_store_key,
mock_check_enabled,
):
"""Test that even if deletion fails, regeneration still proceeds."""
# Arrange
user_id = 'user-123'
invalid_key = 'sk-invalid-key'
new_key = 'sk-new-generated-key'
mock_check_enabled.return_value = True
mock_get_key.return_value = invalid_key
mock_verify_key.return_value = False
mock_delete_key.return_value = False # Deletion fails
@@ -292,19 +305,22 @@ class TestGetLlmApiKeyForByor:
# Assert
assert result == {'key': new_key}
mock_check_enabled.assert_called_once_with(user_id)
mock_delete_key.assert_called_once_with(user_id, invalid_key)
mock_generate_key.assert_called_once_with(user_id)
mock_store_key.assert_called_once_with(user_id, new_key)
@pytest.mark.asyncio
@patch('server.routes.api_keys.check_byor_export_enabled')
@patch('server.routes.api_keys.generate_byor_key')
@patch('server.routes.api_keys.get_byor_key_from_db')
async def test_key_generation_failure_raises_exception(
self, mock_get_key, mock_generate_key
self, mock_get_key, mock_generate_key, mock_check_enabled
):
"""Test that when key generation fails, an HTTPException is raised."""
# Arrange
user_id = 'user-123'
mock_check_enabled.return_value = True
mock_get_key.return_value = None
mock_generate_key.return_value = None
@@ -316,11 +332,15 @@ class TestGetLlmApiKeyForByor:
assert 'Failed to generate new BYOR LLM API key' in exc_info.value.detail
@pytest.mark.asyncio
@patch('server.routes.api_keys.check_byor_export_enabled')
@patch('server.routes.api_keys.get_byor_key_from_db')
async def test_database_error_raises_exception(self, mock_get_key):
async def test_database_error_raises_exception(
self, mock_get_key, mock_check_enabled
):
"""Test that database errors are properly handled."""
# Arrange
user_id = 'user-123'
mock_check_enabled.return_value = True
mock_get_key.side_effect = Exception('Database connection error')
# Act & Assert
@@ -330,6 +350,21 @@ class TestGetLlmApiKeyForByor:
assert exc_info.value.status_code == 500
assert 'Failed to retrieve BYOR LLM API key' in exc_info.value.detail
@pytest.mark.asyncio
@patch('server.routes.api_keys.check_byor_export_enabled')
async def test_byor_export_disabled_returns_402(self, mock_check_enabled):
"""Test that when BYOR export is disabled, 402 is returned."""
# Arrange
user_id = 'user-123'
mock_check_enabled.return_value = False
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
await get_llm_api_key_for_byor(user_id=user_id)
assert exc_info.value.status_code == 402
assert 'BYOR key export is not enabled' in exc_info.value.detail
class TestDeleteByorKeyFromLitellm:
"""Test the delete_byor_key_from_litellm function with alias cleanup."""