mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
[Fix]: Make provider tokens immutable (#7317)
This commit is contained in:
parent
dde90fc636
commit
3150af1ad7
@ -1,6 +1,16 @@
|
||||
from enum import Enum
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, SecretStr, SerializationInfo, field_serializer
|
||||
from enum import Enum
|
||||
from types import MappingProxyType
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Field,
|
||||
SecretStr,
|
||||
SerializationInfo,
|
||||
field_serializer,
|
||||
model_validator,
|
||||
)
|
||||
from pydantic.json import pydantic_encoder
|
||||
|
||||
from openhands.integrations.github.github_service import GithubServiceImpl
|
||||
@ -19,43 +29,42 @@ class ProviderType(Enum):
|
||||
|
||||
|
||||
class ProviderToken(BaseModel):
|
||||
token: SecretStr | None
|
||||
user_id: str | None
|
||||
token: SecretStr | None = Field(default=None)
|
||||
user_id: str | None = Field(default=None)
|
||||
|
||||
model_config = {
|
||||
'frozen': True, # Makes the entire model immutable
|
||||
'validate_assignment': True,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_value(cls, token_value: ProviderToken | dict[str, str]) -> ProviderToken:
|
||||
"""Factory method to create a ProviderToken from various input types"""
|
||||
if isinstance(token_value, ProviderToken):
|
||||
return token_value
|
||||
elif isinstance(token_value, dict):
|
||||
token_str = token_value.get('token')
|
||||
user_id = token_value.get('user_id')
|
||||
return cls(token=SecretStr(token_str), user_id=user_id)
|
||||
|
||||
else:
|
||||
raise ValueError('Unsupport Provider token type')
|
||||
|
||||
|
||||
PROVIDER_TOKEN_TYPE = dict[ProviderType, ProviderToken]
|
||||
CUSTOM_SECRETS_TYPE = dict[str, SecretStr]
|
||||
PROVIDER_TOKEN_TYPE = MappingProxyType[ProviderType, ProviderToken]
|
||||
CUSTOM_SECRETS_TYPE = MappingProxyType[str, SecretStr]
|
||||
|
||||
|
||||
class SecretStore(BaseModel):
|
||||
provider_tokens: PROVIDER_TOKEN_TYPE = {}
|
||||
provider_tokens: PROVIDER_TOKEN_TYPE = Field(
|
||||
default_factory=lambda: MappingProxyType({})
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _convert_token(
|
||||
cls, token_value: str | ProviderToken | SecretStr
|
||||
) -> ProviderToken:
|
||||
if isinstance(token_value, ProviderToken):
|
||||
return token_value
|
||||
elif isinstance(token_value, str):
|
||||
return ProviderToken(token=SecretStr(token_value), user_id=None)
|
||||
elif isinstance(token_value, SecretStr):
|
||||
return ProviderToken(token=token_value, user_id=None)
|
||||
else:
|
||||
raise ValueError(f'Invalid token type: {type(token_value)}')
|
||||
|
||||
def model_post_init(self, __context) -> None:
|
||||
# Convert any string tokens to ProviderToken objects
|
||||
converted_tokens = {}
|
||||
for token_type, token_value in self.provider_tokens.items():
|
||||
if token_value: # Only convert non-empty tokens
|
||||
try:
|
||||
if isinstance(token_type, str):
|
||||
token_type = ProviderType(token_type)
|
||||
converted_tokens[token_type] = self._convert_token(token_value)
|
||||
except ValueError:
|
||||
# Skip invalid provider types or tokens
|
||||
continue
|
||||
self.provider_tokens = converted_tokens
|
||||
model_config = {
|
||||
'frozen': True,
|
||||
'validate_assignment': True,
|
||||
'arbitrary_types_allowed': True,
|
||||
}
|
||||
|
||||
@field_serializer('provider_tokens')
|
||||
def provider_tokens_serializer(
|
||||
@ -82,6 +91,40 @@ class SecretStore(BaseModel):
|
||||
|
||||
return tokens
|
||||
|
||||
@model_validator(mode='before')
|
||||
@classmethod
|
||||
def convert_dict_to_mappingproxy(
|
||||
cls, data: dict[str, dict[str, dict[str, str]]] | PROVIDER_TOKEN_TYPE
|
||||
) -> dict[str, MappingProxyType]:
|
||||
"""Custom deserializer to convert dictionary into MappingProxyType"""
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError('SecretStore must be initialized with a dictionary')
|
||||
|
||||
new_data = {}
|
||||
|
||||
if 'provider_tokens' in data:
|
||||
tokens = data['provider_tokens']
|
||||
if isinstance(
|
||||
tokens, dict
|
||||
): # Ensure conversion happens only for dict inputs
|
||||
converted_tokens = {}
|
||||
for key, value in tokens.items():
|
||||
try:
|
||||
provider_type = (
|
||||
ProviderType(key) if isinstance(key, str) else key
|
||||
)
|
||||
converted_tokens[provider_type] = ProviderToken.from_value(
|
||||
value
|
||||
)
|
||||
except ValueError:
|
||||
# Skip invalid provider types or tokens
|
||||
continue
|
||||
|
||||
# Convert to MappingProxyType
|
||||
new_data['provider_tokens'] = MappingProxyType(converted_tokens)
|
||||
|
||||
return new_data
|
||||
|
||||
|
||||
class ProviderHandler:
|
||||
def __init__(
|
||||
@ -94,8 +137,14 @@ class ProviderHandler:
|
||||
ProviderType.GITLAB: GitLabServiceImpl,
|
||||
}
|
||||
|
||||
self.provider_tokens = provider_tokens
|
||||
# Create immutable copy through SecretStore
|
||||
self.external_auth_token = external_auth_token
|
||||
self._provider_tokens = provider_tokens
|
||||
|
||||
@property
|
||||
def provider_tokens(self) -> PROVIDER_TOKEN_TYPE:
|
||||
"""Read-only access to provider tokens."""
|
||||
return self._provider_tokens
|
||||
|
||||
def _get_service(self, provider: ProviderType) -> GitService:
|
||||
"""Helper method to instantiate a service for a given provider"""
|
||||
|
||||
@ -3,7 +3,7 @@ from fastapi.responses import JSONResponse
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.provider import ProviderToken, ProviderType
|
||||
from openhands.integrations.provider import ProviderToken, ProviderType, SecretStore
|
||||
from openhands.integrations.utils import validate_provider_token
|
||||
from openhands.server.auth import get_provider_tokens, get_user_id
|
||||
from openhands.server.settings import GETSettingsModel, POSTSettingsModel, Settings
|
||||
@ -26,12 +26,11 @@ async def load_settings(request: Request) -> GETSettingsModel | JSONResponse:
|
||||
|
||||
github_token_is_set = bool(user_id) or bool(get_provider_tokens(request))
|
||||
settings_with_token_data = GETSettingsModel(
|
||||
**settings.model_dump(),
|
||||
**settings.model_dump(exclude='secrets_store'),
|
||||
github_token_is_set=github_token_is_set,
|
||||
)
|
||||
settings_with_token_data.llm_api_key = settings.llm_api_key
|
||||
|
||||
del settings_with_token_data.secrets_store
|
||||
settings_with_token_data.llm_api_key = settings.llm_api_key
|
||||
return settings_with_token_data
|
||||
except Exception as e:
|
||||
logger.warning(f'Invalid token: {e}')
|
||||
@ -90,9 +89,10 @@ async def store_settings(
|
||||
existing_settings.user_consents_to_analytics
|
||||
)
|
||||
|
||||
# Handle token updates immutably
|
||||
if settings.unset_github_token:
|
||||
settings.secrets_store.provider_tokens = {}
|
||||
settings.provider_tokens = {}
|
||||
settings = settings.model_copy(update={"secrets_store": SecretStore()})
|
||||
|
||||
else: # Only merge if not unsetting tokens
|
||||
if settings.provider_tokens:
|
||||
if existing_settings.secrets_store:
|
||||
@ -156,16 +156,22 @@ def convert_to_settings(settings_with_token_data: POSTSettingsModel) -> Settings
|
||||
# Convert the `llm_api_key` to a `SecretStr` instance
|
||||
filtered_settings_data['llm_api_key'] = settings_with_token_data.llm_api_key
|
||||
|
||||
# Create a new Settings instance without provider tokens
|
||||
# Create a new Settings instance with empty SecretStore
|
||||
settings = Settings(**filtered_settings_data)
|
||||
|
||||
# Update provider tokens if any are provided
|
||||
# Create new provider tokens immutably
|
||||
if settings_with_token_data.provider_tokens:
|
||||
tokens = {}
|
||||
for token_type, token_value in settings_with_token_data.provider_tokens.items():
|
||||
if token_value:
|
||||
provider = ProviderType(token_type)
|
||||
settings.secrets_store.provider_tokens[provider] = ProviderToken(
|
||||
tokens[provider] = ProviderToken(
|
||||
token=SecretStr(token_value), user_id=None
|
||||
)
|
||||
|
||||
# Create new SecretStore with tokens
|
||||
settings = settings.model_copy(update={"secrets_store": SecretStore(
|
||||
provider_tokens=tokens
|
||||
)})
|
||||
|
||||
return settings
|
||||
return settings
|
||||
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Field,
|
||||
SecretStr,
|
||||
SerializationInfo,
|
||||
field_serializer,
|
||||
@ -11,7 +12,7 @@ from pydantic.json import pydantic_encoder
|
||||
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.core.config.utils import load_app_config
|
||||
from openhands.integrations.provider import ProviderToken, ProviderType, SecretStore
|
||||
from openhands.integrations.provider import SecretStore
|
||||
|
||||
|
||||
class Settings(BaseModel):
|
||||
@ -28,11 +29,15 @@ class Settings(BaseModel):
|
||||
llm_api_key: SecretStr | None = None
|
||||
llm_base_url: str | None = None
|
||||
remote_runtime_resource_factor: int | None = None
|
||||
secrets_store: SecretStore = SecretStore()
|
||||
secrets_store: SecretStore = Field(default_factory=SecretStore, frozen=True)
|
||||
enable_default_condenser: bool = False
|
||||
enable_sound_notifications: bool = False
|
||||
user_consents_to_analytics: bool | None = None
|
||||
|
||||
model_config = {
|
||||
'validate_assignment': True,
|
||||
}
|
||||
|
||||
@field_serializer('llm_api_key')
|
||||
def llm_api_key_serializer(self, llm_api_key: SecretStr, info: SerializationInfo):
|
||||
"""Custom serializer for the LLM API key.
|
||||
@ -45,23 +50,6 @@ class Settings(BaseModel):
|
||||
|
||||
return pydantic_encoder(llm_api_key) if llm_api_key else None
|
||||
|
||||
@staticmethod
|
||||
def _convert_token_value(
|
||||
token_type: ProviderType, token_value: str | dict
|
||||
) -> ProviderToken | None:
|
||||
"""Convert a token value to a ProviderToken object."""
|
||||
if isinstance(token_value, dict):
|
||||
token_str = token_value.get('token')
|
||||
if not token_str:
|
||||
return None
|
||||
return ProviderToken(
|
||||
token=SecretStr(token_str),
|
||||
user_id=token_value.get('user_id'),
|
||||
)
|
||||
if isinstance(token_value, str) and token_value:
|
||||
return ProviderToken(token=SecretStr(token_value), user_id=None)
|
||||
return None
|
||||
|
||||
@model_validator(mode='before')
|
||||
@classmethod
|
||||
def convert_provider_tokens(cls, data: dict | object) -> dict | object:
|
||||
@ -77,21 +65,7 @@ class Settings(BaseModel):
|
||||
if not isinstance(tokens, dict):
|
||||
return data
|
||||
|
||||
converted_tokens = {}
|
||||
for token_type_str, token_value in tokens.items():
|
||||
if not token_value:
|
||||
continue
|
||||
|
||||
try:
|
||||
token_type = ProviderType(token_type_str)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
provider_token = cls._convert_token_value(token_type, token_value)
|
||||
if provider_token:
|
||||
converted_tokens[token_type] = provider_token
|
||||
|
||||
data['secrets_store'] = SecretStore(provider_tokens=converted_tokens)
|
||||
data['secrets_store'] = SecretStore(provider_tokens=tokens)
|
||||
return data
|
||||
|
||||
@field_serializer('secrets_store')
|
||||
|
||||
229
tests/unit/test_provider_immutability.py
Normal file
229
tests/unit/test_provider_immutability.py
Normal file
@ -0,0 +1,229 @@
|
||||
from types import MappingProxyType
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr, ValidationError
|
||||
|
||||
from openhands.integrations.provider import (
|
||||
ProviderHandler,
|
||||
ProviderToken,
|
||||
ProviderType,
|
||||
SecretStore,
|
||||
)
|
||||
from openhands.server.routes.settings import convert_to_settings
|
||||
from openhands.server.settings import POSTSettingsModel, Settings
|
||||
|
||||
|
||||
def test_provider_token_immutability():
|
||||
"""Test that ProviderToken is immutable"""
|
||||
token = ProviderToken(token=SecretStr('test'), user_id='user1')
|
||||
|
||||
# Test direct attribute modification
|
||||
with pytest.raises(ValidationError):
|
||||
token.token = SecretStr('new')
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
token.user_id = 'new_user'
|
||||
|
||||
# Test that __setattr__ is blocked
|
||||
with pytest.raises(ValidationError):
|
||||
setattr(token, 'token', SecretStr('new'))
|
||||
|
||||
# Verify original values are unchanged
|
||||
assert token.token.get_secret_value() == 'test'
|
||||
assert token.user_id == 'user1'
|
||||
|
||||
|
||||
def test_secret_store_immutability():
|
||||
"""Test that SecretStore is immutable"""
|
||||
store = SecretStore(
|
||||
provider_tokens={ProviderType.GITHUB: ProviderToken(token=SecretStr('test'))}
|
||||
)
|
||||
|
||||
# Test direct attribute modification
|
||||
with pytest.raises(ValidationError):
|
||||
store.provider_tokens = {}
|
||||
|
||||
# Test dictionary mutation attempts
|
||||
with pytest.raises((TypeError, AttributeError)):
|
||||
store.provider_tokens[ProviderType.GITHUB] = ProviderToken(
|
||||
token=SecretStr('new')
|
||||
)
|
||||
|
||||
with pytest.raises((TypeError, AttributeError)):
|
||||
store.provider_tokens.clear()
|
||||
|
||||
with pytest.raises((TypeError, AttributeError)):
|
||||
store.provider_tokens.update(
|
||||
{ProviderType.GITLAB: ProviderToken(token=SecretStr('test'))}
|
||||
)
|
||||
|
||||
# Test nested immutability
|
||||
github_token = store.provider_tokens[ProviderType.GITHUB]
|
||||
with pytest.raises(ValidationError):
|
||||
github_token.token = SecretStr('new')
|
||||
|
||||
# Verify original values are unchanged
|
||||
assert store.provider_tokens[ProviderType.GITHUB].token.get_secret_value() == 'test'
|
||||
|
||||
|
||||
def test_settings_immutability():
|
||||
"""Test that Settings secrets_store is immutable"""
|
||||
settings = Settings(
|
||||
secrets_store=SecretStore(
|
||||
provider_tokens={
|
||||
ProviderType.GITHUB: ProviderToken(token=SecretStr('test'))
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# Test direct modification of secrets_store
|
||||
with pytest.raises(ValidationError):
|
||||
settings.secrets_store = SecretStore()
|
||||
|
||||
# Test nested modification attempts
|
||||
with pytest.raises((TypeError, AttributeError)):
|
||||
settings.secrets_store.provider_tokens[ProviderType.GITHUB] = ProviderToken(
|
||||
token=SecretStr('new')
|
||||
)
|
||||
|
||||
# Test model_copy creates new instance
|
||||
new_store = SecretStore(
|
||||
provider_tokens={
|
||||
ProviderType.GITHUB: ProviderToken(token=SecretStr('new_token'))
|
||||
}
|
||||
)
|
||||
new_settings = settings.model_copy(update={'secrets_store': new_store})
|
||||
|
||||
# Verify original is unchanged and new has updated values
|
||||
assert (
|
||||
settings.secrets_store.provider_tokens[
|
||||
ProviderType.GITHUB
|
||||
].token.get_secret_value()
|
||||
== 'test'
|
||||
)
|
||||
assert (
|
||||
new_settings.secrets_store.provider_tokens[
|
||||
ProviderType.GITHUB
|
||||
].token.get_secret_value()
|
||||
== 'new_token'
|
||||
)
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
new_settings.secrets_store.provider_tokens[
|
||||
ProviderType.GITHUB
|
||||
].token = SecretStr('')
|
||||
|
||||
|
||||
def test_post_settings_conversion():
|
||||
"""Test that POSTSettingsModel correctly converts to Settings"""
|
||||
# Create POST model with token data
|
||||
post_data = POSTSettingsModel(
|
||||
provider_tokens={'github': 'test_token', 'gitlab': 'gitlab_token'}
|
||||
)
|
||||
|
||||
# Convert to settings using convert_to_settings function
|
||||
settings = convert_to_settings(post_data)
|
||||
|
||||
# Verify tokens were converted correctly
|
||||
assert (
|
||||
settings.secrets_store.provider_tokens[
|
||||
ProviderType.GITHUB
|
||||
].token.get_secret_value()
|
||||
== 'test_token'
|
||||
)
|
||||
assert (
|
||||
settings.secrets_store.provider_tokens[
|
||||
ProviderType.GITLAB
|
||||
].token.get_secret_value()
|
||||
== 'gitlab_token'
|
||||
)
|
||||
assert settings.secrets_store.provider_tokens[ProviderType.GITLAB].user_id is None
|
||||
|
||||
# Verify immutability of converted settings
|
||||
with pytest.raises(ValidationError):
|
||||
settings.secrets_store = SecretStore()
|
||||
|
||||
|
||||
def test_provider_handler_immutability():
|
||||
"""Test that ProviderHandler maintains token immutability"""
|
||||
|
||||
# Create initial tokens
|
||||
tokens = MappingProxyType(
|
||||
{ProviderType.GITHUB: ProviderToken(token=SecretStr('test'))}
|
||||
)
|
||||
|
||||
handler = ProviderHandler(provider_tokens=tokens)
|
||||
|
||||
# Try to modify tokens (should raise TypeError due to frozen dict)
|
||||
with pytest.raises((TypeError, AttributeError)):
|
||||
handler.provider_tokens[ProviderType.GITHUB] = ProviderToken(
|
||||
token=SecretStr('new')
|
||||
)
|
||||
|
||||
# Try to modify the handler's tokens property
|
||||
with pytest.raises((ValidationError, TypeError, AttributeError)):
|
||||
handler.provider_tokens = {}
|
||||
|
||||
# Original token should be unchanged
|
||||
assert (
|
||||
handler.provider_tokens[ProviderType.GITHUB].token.get_secret_value() == 'test'
|
||||
)
|
||||
|
||||
|
||||
def test_token_conversion():
|
||||
"""Test token conversion in SecretStore.create"""
|
||||
# Test with string token
|
||||
store1 = Settings(
|
||||
secrets_store=SecretStore(
|
||||
provider_tokens={
|
||||
ProviderType.GITHUB: ProviderToken(token=SecretStr('test_token'))
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
assert (
|
||||
store1.secrets_store.provider_tokens[
|
||||
ProviderType.GITHUB
|
||||
].token.get_secret_value()
|
||||
== 'test_token'
|
||||
)
|
||||
assert store1.secrets_store.provider_tokens[ProviderType.GITHUB].user_id is None
|
||||
|
||||
# Test with dict token
|
||||
store2 = SecretStore(
|
||||
provider_tokens={'github': {'token': 'test_token', 'user_id': 'user1'}}
|
||||
)
|
||||
assert (
|
||||
store2.provider_tokens[ProviderType.GITHUB].token.get_secret_value()
|
||||
== 'test_token'
|
||||
)
|
||||
assert store2.provider_tokens[ProviderType.GITHUB].user_id == 'user1'
|
||||
|
||||
# Test with ProviderToken
|
||||
token = ProviderToken(token=SecretStr('test_token'), user_id='user2')
|
||||
store3 = SecretStore(provider_tokens={ProviderType.GITHUB: token})
|
||||
assert (
|
||||
store3.provider_tokens[ProviderType.GITHUB].token.get_secret_value()
|
||||
== 'test_token'
|
||||
)
|
||||
assert store3.provider_tokens[ProviderType.GITHUB].user_id == 'user2'
|
||||
|
||||
store4 = SecretStore(
|
||||
provider_tokens={
|
||||
ProviderType.GITHUB: 123 # Invalid type
|
||||
}
|
||||
)
|
||||
|
||||
assert ProviderType.GITHUB not in store4.provider_tokens
|
||||
|
||||
# Test with empty/None token
|
||||
store5 = SecretStore(provider_tokens={ProviderType.GITHUB: None})
|
||||
assert ProviderType.GITHUB not in store5.provider_tokens
|
||||
|
||||
store6 = SecretStore(
|
||||
provider_tokens={
|
||||
'invalid_provider': 'test_token' # Invalid provider type
|
||||
}
|
||||
)
|
||||
|
||||
assert len(store6.provider_tokens.keys()) == 0
|
||||
@ -6,7 +6,7 @@ from openhands.core.config.app_config import AppConfig
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.core.config.sandbox_config import SandboxConfig
|
||||
from openhands.core.config.security_config import SecurityConfig
|
||||
from openhands.integrations.provider import ProviderToken, ProviderType
|
||||
from openhands.integrations.provider import ProviderToken, ProviderType, SecretStore
|
||||
from openhands.server.routes.settings import convert_to_settings
|
||||
from openhands.server.settings import POSTSettingsModel, Settings
|
||||
|
||||
@ -81,10 +81,14 @@ def test_settings_handles_sensitive_data():
|
||||
llm_api_key='test-key',
|
||||
llm_base_url='https://test.example.com',
|
||||
remote_runtime_resource_factor=2,
|
||||
)
|
||||
settings.secrets_store.provider_tokens[ProviderType.GITHUB] = ProviderToken(
|
||||
token=SecretStr('test-token'),
|
||||
user_id=None,
|
||||
secrets_store=SecretStore(
|
||||
provider_tokens={
|
||||
ProviderType.GITHUB: ProviderToken(
|
||||
token=SecretStr('test-token'),
|
||||
user_id=None,
|
||||
)
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
assert str(settings.llm_api_key) == '**********'
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user