mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
[Refactor]: Modularize settings storage logic (#7868)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
parent
9262babc3b
commit
b2a4b4ed90
@ -134,12 +134,10 @@ async def reset_settings(request: Request) -> JSONResponse:
|
||||
)
|
||||
|
||||
|
||||
@app.post('/settings', response_model=dict[str, str])
|
||||
async def store_settings(
|
||||
request: Request,
|
||||
settings: POSTSettingsModel,
|
||||
) -> JSONResponse:
|
||||
# Check provider tokens are valid
|
||||
|
||||
async def check_provider_tokens(request: Request,
|
||||
settings: POSTSettingsModel) -> str:
|
||||
|
||||
if settings.provider_tokens:
|
||||
# Remove extraneous token types
|
||||
provider_types = [provider.value for provider in ProviderType]
|
||||
@ -154,12 +152,84 @@ async def store_settings(
|
||||
SecretStr(token_value)
|
||||
)
|
||||
if not confirmed_token_type or confirmed_token_type.value != token_type:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={
|
||||
'error': f'Invalid token. Please make sure it is a valid {token_type} token.'
|
||||
},
|
||||
)
|
||||
return f"Invalid token. Please make sure it is a valid {token_type} token."
|
||||
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
|
||||
async def store_provider_tokens(request: Request, settings: POSTSettingsModel):
|
||||
settings_store = await SettingsStoreImpl.get_instance(
|
||||
config, get_user_id(request)
|
||||
)
|
||||
existing_settings = await settings_store.load()
|
||||
if existing_settings:
|
||||
if settings.provider_tokens:
|
||||
if existing_settings.secrets_store:
|
||||
existing_providers = [
|
||||
provider.value
|
||||
for provider in existing_settings.secrets_store.provider_tokens
|
||||
]
|
||||
|
||||
# Merge incoming settings store with the existing one
|
||||
for provider, token_value in list(settings.provider_tokens.items()):
|
||||
if provider in existing_providers and not token_value:
|
||||
provider_type = ProviderType(provider)
|
||||
existing_token = (
|
||||
existing_settings.secrets_store.provider_tokens.get(
|
||||
provider_type
|
||||
)
|
||||
)
|
||||
if existing_token and existing_token.token:
|
||||
settings.provider_tokens[provider] = (
|
||||
existing_token.token.get_secret_value()
|
||||
)
|
||||
else: # nothing passed in means keep current settings
|
||||
provider_tokens = existing_settings.secrets_store.provider_tokens
|
||||
settings.provider_tokens = {
|
||||
provider.value: data.token.get_secret_value()
|
||||
if data.token
|
||||
else None
|
||||
for provider, data in provider_tokens.items()
|
||||
}
|
||||
|
||||
return settings
|
||||
|
||||
|
||||
async def store_llm_settings(request: Request, settings: POSTSettingsModel) -> POSTSettingsModel:
|
||||
settings_store = await SettingsStoreImpl.get_instance(
|
||||
config, get_user_id(request)
|
||||
)
|
||||
existing_settings = await settings_store.load()
|
||||
|
||||
# Convert to Settings model and merge with existing settings
|
||||
if existing_settings:
|
||||
# Keep existing LLM settings if not provided
|
||||
if settings.llm_api_key is None:
|
||||
settings.llm_api_key = existing_settings.llm_api_key
|
||||
if settings.llm_model is None:
|
||||
settings.llm_model = existing_settings.llm_model
|
||||
if settings.llm_base_url is None:
|
||||
settings.llm_base_url = existing_settings.llm_base_url
|
||||
|
||||
return settings
|
||||
|
||||
@app.post('/settings', response_model=dict[str, str])
|
||||
async def store_settings(
|
||||
request: Request,
|
||||
settings: POSTSettingsModel,
|
||||
) -> JSONResponse:
|
||||
# Check provider tokens are valid
|
||||
provider_err_msg = await check_provider_tokens(request, settings)
|
||||
if provider_err_msg:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={
|
||||
'error': provider_err_msg
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
settings_store = await SettingsStoreImpl.get_instance(
|
||||
@ -169,13 +239,7 @@ async def store_settings(
|
||||
|
||||
# Convert to Settings model and merge with existing settings
|
||||
if existing_settings:
|
||||
# Keep existing LLM settings if not provided
|
||||
if settings.llm_api_key is None:
|
||||
settings.llm_api_key = existing_settings.llm_api_key
|
||||
if settings.llm_model is None:
|
||||
settings.llm_model = existing_settings.llm_model
|
||||
if settings.llm_base_url is None:
|
||||
settings.llm_base_url = existing_settings.llm_base_url
|
||||
settings = await store_llm_settings(request, settings)
|
||||
|
||||
# Keep existing analytics consent if not provided
|
||||
if settings.user_consents_to_analytics is None:
|
||||
@ -183,35 +247,8 @@ async def store_settings(
|
||||
existing_settings.user_consents_to_analytics
|
||||
)
|
||||
|
||||
# Only merge if not unsetting tokens
|
||||
if settings.provider_tokens:
|
||||
if existing_settings.secrets_store:
|
||||
existing_providers = [
|
||||
provider.value
|
||||
for provider in existing_settings.secrets_store.provider_tokens
|
||||
]
|
||||
|
||||
# Merge incoming settings store with the existing one
|
||||
for provider, token_value in settings.provider_tokens.items():
|
||||
if provider in existing_providers and not token_value:
|
||||
provider_type = ProviderType(provider)
|
||||
existing_token = (
|
||||
existing_settings.secrets_store.provider_tokens.get(
|
||||
provider_type
|
||||
)
|
||||
)
|
||||
if existing_token and existing_token.token:
|
||||
settings.provider_tokens[provider] = (
|
||||
existing_token.token.get_secret_value()
|
||||
)
|
||||
else: # nothing passed in means keep current settings
|
||||
provider_tokens = existing_settings.secrets_store.provider_tokens
|
||||
settings.provider_tokens = {
|
||||
provider.value: data.token.get_secret_value()
|
||||
if data.token
|
||||
else None
|
||||
for provider, data in provider_tokens.items()
|
||||
}
|
||||
settings = await store_provider_tokens(request, settings)
|
||||
|
||||
|
||||
# Update sandbox config with new settings
|
||||
if settings.remote_runtime_resource_factor is not None:
|
||||
|
||||
261
tests/unit/test_settings_store_functions.py
Normal file
261
tests/unit/test_settings_store_functions.py
Normal file
@ -0,0 +1,261 @@
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.integrations.provider import ProviderToken, SecretStore
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.server.routes.settings import (
|
||||
check_provider_tokens,
|
||||
store_llm_settings,
|
||||
store_provider_tokens,
|
||||
)
|
||||
from openhands.server.settings import POSTSettingsModel, Settings
|
||||
|
||||
|
||||
# Mock functions to simulate the actual functions in settings.py
|
||||
async def get_settings_store(request):
|
||||
"""Mock function to get settings store."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
# Tests for check_provider_tokens
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_provider_tokens_valid():
|
||||
"""Test check_provider_tokens with valid tokens."""
|
||||
mock_request = MagicMock()
|
||||
settings = POSTSettingsModel(provider_tokens={'github': 'valid-token'})
|
||||
|
||||
# Mock the validate_provider_token function to return GITHUB for valid tokens
|
||||
with patch(
|
||||
'openhands.server.routes.settings.validate_provider_token'
|
||||
) as mock_validate:
|
||||
mock_validate.return_value = ProviderType.GITHUB
|
||||
|
||||
result = await check_provider_tokens(mock_request, settings)
|
||||
|
||||
# Should return empty string for valid token
|
||||
assert result == ''
|
||||
mock_validate.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_provider_tokens_invalid():
|
||||
"""Test check_provider_tokens with invalid tokens."""
|
||||
mock_request = MagicMock()
|
||||
settings = POSTSettingsModel(provider_tokens={'github': 'invalid-token'})
|
||||
|
||||
# Mock the validate_provider_token function to return None for invalid tokens
|
||||
with patch(
|
||||
'openhands.server.routes.settings.validate_provider_token'
|
||||
) as mock_validate:
|
||||
mock_validate.return_value = None
|
||||
|
||||
result = await check_provider_tokens(mock_request, settings)
|
||||
|
||||
# Should return error message for invalid token
|
||||
assert 'Invalid token' in result
|
||||
mock_validate.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_provider_tokens_wrong_type():
|
||||
"""Test check_provider_tokens with unsupported provider type."""
|
||||
mock_request = MagicMock()
|
||||
settings = POSTSettingsModel(provider_tokens={'unsupported': 'some-token'})
|
||||
|
||||
result = await check_provider_tokens(mock_request, settings)
|
||||
|
||||
# Should return empty string for unsupported provider
|
||||
assert result == ''
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_provider_tokens_no_tokens():
|
||||
"""Test check_provider_tokens with no tokens."""
|
||||
mock_request = MagicMock()
|
||||
settings = POSTSettingsModel(provider_tokens={})
|
||||
|
||||
result = await check_provider_tokens(mock_request, settings)
|
||||
|
||||
# Should return empty string when no tokens provided
|
||||
assert result == ''
|
||||
|
||||
|
||||
# Tests for store_llm_settings
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_llm_settings_new_settings():
|
||||
"""Test store_llm_settings with new settings."""
|
||||
mock_request = MagicMock()
|
||||
settings = POSTSettingsModel(
|
||||
llm_model='gpt-4',
|
||||
llm_api_key='test-api-key',
|
||||
llm_base_url='https://api.example.com',
|
||||
)
|
||||
|
||||
# Mock the settings store
|
||||
with patch(
|
||||
'openhands.server.routes.settings.SettingsStoreImpl.get_instance'
|
||||
) as mock_get_store:
|
||||
mock_store = MagicMock()
|
||||
mock_store.load = AsyncMock(return_value=None) # No existing settings
|
||||
mock_get_store.return_value = mock_store
|
||||
|
||||
result = await store_llm_settings(mock_request, settings)
|
||||
|
||||
# Should return settings with the provided values
|
||||
assert result.llm_model == 'gpt-4'
|
||||
assert result.llm_api_key.get_secret_value() == 'test-api-key'
|
||||
assert result.llm_base_url == 'https://api.example.com'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_llm_settings_update_existing():
|
||||
"""Test store_llm_settings updates existing settings."""
|
||||
mock_request = MagicMock()
|
||||
settings = POSTSettingsModel(
|
||||
llm_model='gpt-4',
|
||||
llm_api_key='new-api-key',
|
||||
llm_base_url='https://new.example.com',
|
||||
)
|
||||
|
||||
# Mock the settings store
|
||||
with patch(
|
||||
'openhands.server.routes.settings.SettingsStoreImpl.get_instance'
|
||||
) as mock_get_store:
|
||||
mock_store = MagicMock()
|
||||
|
||||
# Create existing settings
|
||||
existing_settings = Settings(
|
||||
llm_model='gpt-3.5',
|
||||
llm_api_key=SecretStr('old-api-key'),
|
||||
llm_base_url='https://old.example.com',
|
||||
)
|
||||
|
||||
mock_store.load = AsyncMock(return_value=existing_settings)
|
||||
mock_get_store.return_value = mock_store
|
||||
|
||||
result = await store_llm_settings(mock_request, settings)
|
||||
|
||||
# Should return settings with the updated values
|
||||
assert result.llm_model == 'gpt-4'
|
||||
assert result.llm_api_key.get_secret_value() == 'new-api-key'
|
||||
assert result.llm_base_url == 'https://new.example.com'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_llm_settings_partial_update():
|
||||
"""Test store_llm_settings with partial update."""
|
||||
mock_request = MagicMock()
|
||||
settings = POSTSettingsModel(
|
||||
llm_model='gpt-4' # Only updating model
|
||||
)
|
||||
|
||||
# Mock the settings store
|
||||
with patch(
|
||||
'openhands.server.routes.settings.SettingsStoreImpl.get_instance'
|
||||
) as mock_get_store:
|
||||
mock_store = MagicMock()
|
||||
|
||||
# Create existing settings
|
||||
existing_settings = Settings(
|
||||
llm_model='gpt-3.5',
|
||||
llm_api_key=SecretStr('existing-api-key'),
|
||||
llm_base_url='https://existing.example.com',
|
||||
)
|
||||
|
||||
mock_store.load = AsyncMock(return_value=existing_settings)
|
||||
mock_get_store.return_value = mock_store
|
||||
|
||||
result = await store_llm_settings(mock_request, settings)
|
||||
|
||||
# Should return settings with updated model but keep other values
|
||||
assert result.llm_model == 'gpt-4'
|
||||
# For SecretStr objects, we need to compare the secret value
|
||||
assert result.llm_api_key.get_secret_value() == 'existing-api-key'
|
||||
assert result.llm_base_url == 'https://existing.example.com'
|
||||
|
||||
|
||||
# Tests for store_provider_tokens
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_provider_tokens_new_tokens():
|
||||
"""Test store_provider_tokens with new tokens."""
|
||||
mock_request = MagicMock()
|
||||
settings = POSTSettingsModel(provider_tokens={'github': 'new-token'})
|
||||
|
||||
# Mock the settings store
|
||||
with patch(
|
||||
'openhands.server.routes.settings.SettingsStoreImpl.get_instance'
|
||||
) as mock_get_store:
|
||||
mock_store = MagicMock()
|
||||
mock_store.load = AsyncMock(return_value=None) # No existing settings
|
||||
mock_get_store.return_value = mock_store
|
||||
|
||||
result = await store_provider_tokens(mock_request, settings)
|
||||
|
||||
# Should return settings with the provided tokens
|
||||
assert result.provider_tokens == {'github': 'new-token'}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_provider_tokens_update_existing():
|
||||
"""Test store_provider_tokens updates existing tokens."""
|
||||
mock_request = MagicMock()
|
||||
settings = POSTSettingsModel(provider_tokens={'github': 'updated-token'})
|
||||
|
||||
# Mock the settings store
|
||||
with patch(
|
||||
'openhands.server.routes.settings.SettingsStoreImpl.get_instance'
|
||||
) as mock_get_store:
|
||||
mock_store = MagicMock()
|
||||
|
||||
# Create existing settings with a GitHub token
|
||||
github_token = ProviderToken(token=SecretStr('old-token'))
|
||||
provider_tokens = {ProviderType.GITHUB: github_token}
|
||||
|
||||
# Create a SecretStore with the provider tokens
|
||||
secrets_store = SecretStore(provider_tokens=provider_tokens)
|
||||
|
||||
# Create existing settings with the secrets store
|
||||
existing_settings = Settings(secrets_store=secrets_store)
|
||||
|
||||
mock_store.load = AsyncMock(return_value=existing_settings)
|
||||
mock_get_store.return_value = mock_store
|
||||
|
||||
result = await store_provider_tokens(mock_request, settings)
|
||||
|
||||
# Should return settings with the updated tokens
|
||||
assert result.provider_tokens == {'github': 'updated-token'}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_provider_tokens_keep_existing():
|
||||
"""Test store_provider_tokens keeps existing tokens when empty string provided."""
|
||||
mock_request = MagicMock()
|
||||
settings = POSTSettingsModel(
|
||||
provider_tokens={'github': ''} # Empty string should keep existing token
|
||||
)
|
||||
|
||||
# Mock the settings store
|
||||
with patch(
|
||||
'openhands.server.routes.settings.SettingsStoreImpl.get_instance'
|
||||
) as mock_get_store:
|
||||
mock_store = MagicMock()
|
||||
|
||||
# Create existing settings with a GitHub token
|
||||
github_token = ProviderToken(token=SecretStr('existing-token'))
|
||||
provider_tokens = {ProviderType.GITHUB: github_token}
|
||||
|
||||
# Create a SecretStore with the provider tokens
|
||||
secrets_store = SecretStore(provider_tokens=provider_tokens)
|
||||
|
||||
# Create existing settings with the secrets store
|
||||
existing_settings = Settings(secrets_store=secrets_store)
|
||||
|
||||
mock_store.load = AsyncMock(return_value=existing_settings)
|
||||
mock_get_store.return_value = mock_store
|
||||
|
||||
result = await store_provider_tokens(mock_request, settings)
|
||||
|
||||
# Should return settings with the existing token preserved
|
||||
assert result.provider_tokens == {'github': 'existing-token'}
|
||||
Loading…
x
Reference in New Issue
Block a user