diff --git a/openhands/integrations/provider.py b/openhands/integrations/provider.py index 0076ea1c91..b918526875 100644 --- a/openhands/integrations/provider.py +++ b/openhands/integrations/provider.py @@ -1,6 +1,16 @@ -from enum import Enum +from __future__ import annotations -from pydantic import BaseModel, SecretStr, SerializationInfo, field_serializer +from enum import Enum +from types import MappingProxyType + +from pydantic import ( + BaseModel, + Field, + SecretStr, + SerializationInfo, + field_serializer, + model_validator, +) from pydantic.json import pydantic_encoder from openhands.integrations.github.github_service import GithubServiceImpl @@ -19,43 +29,42 @@ class ProviderType(Enum): class ProviderToken(BaseModel): - token: SecretStr | None - user_id: str | None + token: SecretStr | None = Field(default=None) + user_id: str | None = Field(default=None) + + model_config = { + 'frozen': True, # Makes the entire model immutable + 'validate_assignment': True, + } + + @classmethod + def from_value(cls, token_value: ProviderToken | dict[str, str]) -> ProviderToken: + """Factory method to create a ProviderToken from various input types""" + if isinstance(token_value, ProviderToken): + return token_value + elif isinstance(token_value, dict): + token_str = token_value.get('token') + user_id = token_value.get('user_id') + return cls(token=SecretStr(token_str), user_id=user_id) + + else: + raise ValueError('Unsupport Provider token type') -PROVIDER_TOKEN_TYPE = dict[ProviderType, ProviderToken] -CUSTOM_SECRETS_TYPE = dict[str, SecretStr] +PROVIDER_TOKEN_TYPE = MappingProxyType[ProviderType, ProviderToken] +CUSTOM_SECRETS_TYPE = MappingProxyType[str, SecretStr] class SecretStore(BaseModel): - provider_tokens: PROVIDER_TOKEN_TYPE = {} + provider_tokens: PROVIDER_TOKEN_TYPE = Field( + default_factory=lambda: MappingProxyType({}) + ) - @classmethod - def _convert_token( - cls, token_value: str | ProviderToken | SecretStr - ) -> ProviderToken: - if isinstance(token_value, ProviderToken): - return token_value - elif isinstance(token_value, str): - return ProviderToken(token=SecretStr(token_value), user_id=None) - elif isinstance(token_value, SecretStr): - return ProviderToken(token=token_value, user_id=None) - else: - raise ValueError(f'Invalid token type: {type(token_value)}') - - def model_post_init(self, __context) -> None: - # Convert any string tokens to ProviderToken objects - converted_tokens = {} - for token_type, token_value in self.provider_tokens.items(): - if token_value: # Only convert non-empty tokens - try: - if isinstance(token_type, str): - token_type = ProviderType(token_type) - converted_tokens[token_type] = self._convert_token(token_value) - except ValueError: - # Skip invalid provider types or tokens - continue - self.provider_tokens = converted_tokens + model_config = { + 'frozen': True, + 'validate_assignment': True, + 'arbitrary_types_allowed': True, + } @field_serializer('provider_tokens') def provider_tokens_serializer( @@ -82,6 +91,40 @@ class SecretStore(BaseModel): return tokens + @model_validator(mode='before') + @classmethod + def convert_dict_to_mappingproxy( + cls, data: dict[str, dict[str, dict[str, str]]] | PROVIDER_TOKEN_TYPE + ) -> dict[str, MappingProxyType]: + """Custom deserializer to convert dictionary into MappingProxyType""" + if not isinstance(data, dict): + raise ValueError('SecretStore must be initialized with a dictionary') + + new_data = {} + + if 'provider_tokens' in data: + tokens = data['provider_tokens'] + if isinstance( + tokens, dict + ): # Ensure conversion happens only for dict inputs + converted_tokens = {} + for key, value in tokens.items(): + try: + provider_type = ( + ProviderType(key) if isinstance(key, str) else key + ) + converted_tokens[provider_type] = ProviderToken.from_value( + value + ) + except ValueError: + # Skip invalid provider types or tokens + continue + + # Convert to MappingProxyType + new_data['provider_tokens'] = MappingProxyType(converted_tokens) + + return new_data + class ProviderHandler: def __init__( @@ -94,8 +137,14 @@ class ProviderHandler: ProviderType.GITLAB: GitLabServiceImpl, } - self.provider_tokens = provider_tokens + # Create immutable copy through SecretStore self.external_auth_token = external_auth_token + self._provider_tokens = provider_tokens + + @property + def provider_tokens(self) -> PROVIDER_TOKEN_TYPE: + """Read-only access to provider tokens.""" + return self._provider_tokens def _get_service(self, provider: ProviderType) -> GitService: """Helper method to instantiate a service for a given provider""" diff --git a/openhands/server/routes/settings.py b/openhands/server/routes/settings.py index 223fd93800..a953e7a005 100644 --- a/openhands/server/routes/settings.py +++ b/openhands/server/routes/settings.py @@ -3,7 +3,7 @@ from fastapi.responses import JSONResponse from pydantic import SecretStr from openhands.core.logger import openhands_logger as logger -from openhands.integrations.provider import ProviderToken, ProviderType +from openhands.integrations.provider import ProviderToken, ProviderType, SecretStore from openhands.integrations.utils import validate_provider_token from openhands.server.auth import get_provider_tokens, get_user_id from openhands.server.settings import GETSettingsModel, POSTSettingsModel, Settings @@ -26,12 +26,11 @@ async def load_settings(request: Request) -> GETSettingsModel | JSONResponse: github_token_is_set = bool(user_id) or bool(get_provider_tokens(request)) settings_with_token_data = GETSettingsModel( - **settings.model_dump(), + **settings.model_dump(exclude='secrets_store'), github_token_is_set=github_token_is_set, ) - settings_with_token_data.llm_api_key = settings.llm_api_key - del settings_with_token_data.secrets_store + settings_with_token_data.llm_api_key = settings.llm_api_key return settings_with_token_data except Exception as e: logger.warning(f'Invalid token: {e}') @@ -90,9 +89,10 @@ async def store_settings( existing_settings.user_consents_to_analytics ) + # Handle token updates immutably if settings.unset_github_token: - settings.secrets_store.provider_tokens = {} - settings.provider_tokens = {} + settings = settings.model_copy(update={"secrets_store": SecretStore()}) + else: # Only merge if not unsetting tokens if settings.provider_tokens: if existing_settings.secrets_store: @@ -156,16 +156,22 @@ def convert_to_settings(settings_with_token_data: POSTSettingsModel) -> Settings # Convert the `llm_api_key` to a `SecretStr` instance filtered_settings_data['llm_api_key'] = settings_with_token_data.llm_api_key - # Create a new Settings instance without provider tokens + # Create a new Settings instance with empty SecretStore settings = Settings(**filtered_settings_data) - # Update provider tokens if any are provided + # Create new provider tokens immutably if settings_with_token_data.provider_tokens: + tokens = {} for token_type, token_value in settings_with_token_data.provider_tokens.items(): if token_value: provider = ProviderType(token_type) - settings.secrets_store.provider_tokens[provider] = ProviderToken( + tokens[provider] = ProviderToken( token=SecretStr(token_value), user_id=None ) + + # Create new SecretStore with tokens + settings = settings.model_copy(update={"secrets_store": SecretStore( + provider_tokens=tokens + )}) - return settings + return settings \ No newline at end of file diff --git a/openhands/server/settings.py b/openhands/server/settings.py index a16ab3336c..d14a93bcbe 100644 --- a/openhands/server/settings.py +++ b/openhands/server/settings.py @@ -2,6 +2,7 @@ from __future__ import annotations from pydantic import ( BaseModel, + Field, SecretStr, SerializationInfo, field_serializer, @@ -11,7 +12,7 @@ from pydantic.json import pydantic_encoder from openhands.core.config.llm_config import LLMConfig from openhands.core.config.utils import load_app_config -from openhands.integrations.provider import ProviderToken, ProviderType, SecretStore +from openhands.integrations.provider import SecretStore class Settings(BaseModel): @@ -28,11 +29,15 @@ class Settings(BaseModel): llm_api_key: SecretStr | None = None llm_base_url: str | None = None remote_runtime_resource_factor: int | None = None - secrets_store: SecretStore = SecretStore() + secrets_store: SecretStore = Field(default_factory=SecretStore, frozen=True) enable_default_condenser: bool = False enable_sound_notifications: bool = False user_consents_to_analytics: bool | None = None + model_config = { + 'validate_assignment': True, + } + @field_serializer('llm_api_key') def llm_api_key_serializer(self, llm_api_key: SecretStr, info: SerializationInfo): """Custom serializer for the LLM API key. @@ -45,23 +50,6 @@ class Settings(BaseModel): return pydantic_encoder(llm_api_key) if llm_api_key else None - @staticmethod - def _convert_token_value( - token_type: ProviderType, token_value: str | dict - ) -> ProviderToken | None: - """Convert a token value to a ProviderToken object.""" - if isinstance(token_value, dict): - token_str = token_value.get('token') - if not token_str: - return None - return ProviderToken( - token=SecretStr(token_str), - user_id=token_value.get('user_id'), - ) - if isinstance(token_value, str) and token_value: - return ProviderToken(token=SecretStr(token_value), user_id=None) - return None - @model_validator(mode='before') @classmethod def convert_provider_tokens(cls, data: dict | object) -> dict | object: @@ -77,21 +65,7 @@ class Settings(BaseModel): if not isinstance(tokens, dict): return data - converted_tokens = {} - for token_type_str, token_value in tokens.items(): - if not token_value: - continue - - try: - token_type = ProviderType(token_type_str) - except ValueError: - continue - - provider_token = cls._convert_token_value(token_type, token_value) - if provider_token: - converted_tokens[token_type] = provider_token - - data['secrets_store'] = SecretStore(provider_tokens=converted_tokens) + data['secrets_store'] = SecretStore(provider_tokens=tokens) return data @field_serializer('secrets_store') diff --git a/tests/unit/test_provider_immutability.py b/tests/unit/test_provider_immutability.py new file mode 100644 index 0000000000..c8988dbe8e --- /dev/null +++ b/tests/unit/test_provider_immutability.py @@ -0,0 +1,229 @@ +from types import MappingProxyType + +import pytest +from pydantic import SecretStr, ValidationError + +from openhands.integrations.provider import ( + ProviderHandler, + ProviderToken, + ProviderType, + SecretStore, +) +from openhands.server.routes.settings import convert_to_settings +from openhands.server.settings import POSTSettingsModel, Settings + + +def test_provider_token_immutability(): + """Test that ProviderToken is immutable""" + token = ProviderToken(token=SecretStr('test'), user_id='user1') + + # Test direct attribute modification + with pytest.raises(ValidationError): + token.token = SecretStr('new') + + with pytest.raises(ValidationError): + token.user_id = 'new_user' + + # Test that __setattr__ is blocked + with pytest.raises(ValidationError): + setattr(token, 'token', SecretStr('new')) + + # Verify original values are unchanged + assert token.token.get_secret_value() == 'test' + assert token.user_id == 'user1' + + +def test_secret_store_immutability(): + """Test that SecretStore is immutable""" + store = SecretStore( + provider_tokens={ProviderType.GITHUB: ProviderToken(token=SecretStr('test'))} + ) + + # Test direct attribute modification + with pytest.raises(ValidationError): + store.provider_tokens = {} + + # Test dictionary mutation attempts + with pytest.raises((TypeError, AttributeError)): + store.provider_tokens[ProviderType.GITHUB] = ProviderToken( + token=SecretStr('new') + ) + + with pytest.raises((TypeError, AttributeError)): + store.provider_tokens.clear() + + with pytest.raises((TypeError, AttributeError)): + store.provider_tokens.update( + {ProviderType.GITLAB: ProviderToken(token=SecretStr('test'))} + ) + + # Test nested immutability + github_token = store.provider_tokens[ProviderType.GITHUB] + with pytest.raises(ValidationError): + github_token.token = SecretStr('new') + + # Verify original values are unchanged + assert store.provider_tokens[ProviderType.GITHUB].token.get_secret_value() == 'test' + + +def test_settings_immutability(): + """Test that Settings secrets_store is immutable""" + settings = Settings( + secrets_store=SecretStore( + provider_tokens={ + ProviderType.GITHUB: ProviderToken(token=SecretStr('test')) + } + ) + ) + + # Test direct modification of secrets_store + with pytest.raises(ValidationError): + settings.secrets_store = SecretStore() + + # Test nested modification attempts + with pytest.raises((TypeError, AttributeError)): + settings.secrets_store.provider_tokens[ProviderType.GITHUB] = ProviderToken( + token=SecretStr('new') + ) + + # Test model_copy creates new instance + new_store = SecretStore( + provider_tokens={ + ProviderType.GITHUB: ProviderToken(token=SecretStr('new_token')) + } + ) + new_settings = settings.model_copy(update={'secrets_store': new_store}) + + # Verify original is unchanged and new has updated values + assert ( + settings.secrets_store.provider_tokens[ + ProviderType.GITHUB + ].token.get_secret_value() + == 'test' + ) + assert ( + new_settings.secrets_store.provider_tokens[ + ProviderType.GITHUB + ].token.get_secret_value() + == 'new_token' + ) + + with pytest.raises(ValidationError): + new_settings.secrets_store.provider_tokens[ + ProviderType.GITHUB + ].token = SecretStr('') + + +def test_post_settings_conversion(): + """Test that POSTSettingsModel correctly converts to Settings""" + # Create POST model with token data + post_data = POSTSettingsModel( + provider_tokens={'github': 'test_token', 'gitlab': 'gitlab_token'} + ) + + # Convert to settings using convert_to_settings function + settings = convert_to_settings(post_data) + + # Verify tokens were converted correctly + assert ( + settings.secrets_store.provider_tokens[ + ProviderType.GITHUB + ].token.get_secret_value() + == 'test_token' + ) + assert ( + settings.secrets_store.provider_tokens[ + ProviderType.GITLAB + ].token.get_secret_value() + == 'gitlab_token' + ) + assert settings.secrets_store.provider_tokens[ProviderType.GITLAB].user_id is None + + # Verify immutability of converted settings + with pytest.raises(ValidationError): + settings.secrets_store = SecretStore() + + +def test_provider_handler_immutability(): + """Test that ProviderHandler maintains token immutability""" + + # Create initial tokens + tokens = MappingProxyType( + {ProviderType.GITHUB: ProviderToken(token=SecretStr('test'))} + ) + + handler = ProviderHandler(provider_tokens=tokens) + + # Try to modify tokens (should raise TypeError due to frozen dict) + with pytest.raises((TypeError, AttributeError)): + handler.provider_tokens[ProviderType.GITHUB] = ProviderToken( + token=SecretStr('new') + ) + + # Try to modify the handler's tokens property + with pytest.raises((ValidationError, TypeError, AttributeError)): + handler.provider_tokens = {} + + # Original token should be unchanged + assert ( + handler.provider_tokens[ProviderType.GITHUB].token.get_secret_value() == 'test' + ) + + +def test_token_conversion(): + """Test token conversion in SecretStore.create""" + # Test with string token + store1 = Settings( + secrets_store=SecretStore( + provider_tokens={ + ProviderType.GITHUB: ProviderToken(token=SecretStr('test_token')) + } + ) + ) + + assert ( + store1.secrets_store.provider_tokens[ + ProviderType.GITHUB + ].token.get_secret_value() + == 'test_token' + ) + assert store1.secrets_store.provider_tokens[ProviderType.GITHUB].user_id is None + + # Test with dict token + store2 = SecretStore( + provider_tokens={'github': {'token': 'test_token', 'user_id': 'user1'}} + ) + assert ( + store2.provider_tokens[ProviderType.GITHUB].token.get_secret_value() + == 'test_token' + ) + assert store2.provider_tokens[ProviderType.GITHUB].user_id == 'user1' + + # Test with ProviderToken + token = ProviderToken(token=SecretStr('test_token'), user_id='user2') + store3 = SecretStore(provider_tokens={ProviderType.GITHUB: token}) + assert ( + store3.provider_tokens[ProviderType.GITHUB].token.get_secret_value() + == 'test_token' + ) + assert store3.provider_tokens[ProviderType.GITHUB].user_id == 'user2' + + store4 = SecretStore( + provider_tokens={ + ProviderType.GITHUB: 123 # Invalid type + } + ) + + assert ProviderType.GITHUB not in store4.provider_tokens + + # Test with empty/None token + store5 = SecretStore(provider_tokens={ProviderType.GITHUB: None}) + assert ProviderType.GITHUB not in store5.provider_tokens + + store6 = SecretStore( + provider_tokens={ + 'invalid_provider': 'test_token' # Invalid provider type + } + ) + + assert len(store6.provider_tokens.keys()) == 0 diff --git a/tests/unit/test_settings.py b/tests/unit/test_settings.py index 3c5a01027b..64e99049c6 100644 --- a/tests/unit/test_settings.py +++ b/tests/unit/test_settings.py @@ -6,7 +6,7 @@ from openhands.core.config.app_config import AppConfig from openhands.core.config.llm_config import LLMConfig from openhands.core.config.sandbox_config import SandboxConfig from openhands.core.config.security_config import SecurityConfig -from openhands.integrations.provider import ProviderToken, ProviderType +from openhands.integrations.provider import ProviderToken, ProviderType, SecretStore from openhands.server.routes.settings import convert_to_settings from openhands.server.settings import POSTSettingsModel, Settings @@ -81,10 +81,14 @@ def test_settings_handles_sensitive_data(): llm_api_key='test-key', llm_base_url='https://test.example.com', remote_runtime_resource_factor=2, - ) - settings.secrets_store.provider_tokens[ProviderType.GITHUB] = ProviderToken( - token=SecretStr('test-token'), - user_id=None, + secrets_store=SecretStore( + provider_tokens={ + ProviderType.GITHUB: ProviderToken( + token=SecretStr('test-token'), + user_id=None, + ) + } + ), ) assert str(settings.llm_api_key) == '**********'