[Fix]: Make provider tokens immutable (#7317)

This commit is contained in:
Rohit Malhotra 2025-03-18 10:50:13 -04:00 committed by GitHub
parent dde90fc636
commit 3150af1ad7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 345 additions and 83 deletions

View File

@ -1,6 +1,16 @@
from enum import Enum
from __future__ import annotations
from pydantic import BaseModel, SecretStr, SerializationInfo, field_serializer
from enum import Enum
from types import MappingProxyType
from pydantic import (
BaseModel,
Field,
SecretStr,
SerializationInfo,
field_serializer,
model_validator,
)
from pydantic.json import pydantic_encoder
from openhands.integrations.github.github_service import GithubServiceImpl
@ -19,43 +29,42 @@ class ProviderType(Enum):
class ProviderToken(BaseModel):
token: SecretStr | None
user_id: str | None
token: SecretStr | None = Field(default=None)
user_id: str | None = Field(default=None)
model_config = {
'frozen': True, # Makes the entire model immutable
'validate_assignment': True,
}
@classmethod
def from_value(cls, token_value: ProviderToken | dict[str, str]) -> ProviderToken:
"""Factory method to create a ProviderToken from various input types"""
if isinstance(token_value, ProviderToken):
return token_value
elif isinstance(token_value, dict):
token_str = token_value.get('token')
user_id = token_value.get('user_id')
return cls(token=SecretStr(token_str), user_id=user_id)
else:
raise ValueError('Unsupport Provider token type')
PROVIDER_TOKEN_TYPE = dict[ProviderType, ProviderToken]
CUSTOM_SECRETS_TYPE = dict[str, SecretStr]
PROVIDER_TOKEN_TYPE = MappingProxyType[ProviderType, ProviderToken]
CUSTOM_SECRETS_TYPE = MappingProxyType[str, SecretStr]
class SecretStore(BaseModel):
provider_tokens: PROVIDER_TOKEN_TYPE = {}
provider_tokens: PROVIDER_TOKEN_TYPE = Field(
default_factory=lambda: MappingProxyType({})
)
@classmethod
def _convert_token(
cls, token_value: str | ProviderToken | SecretStr
) -> ProviderToken:
if isinstance(token_value, ProviderToken):
return token_value
elif isinstance(token_value, str):
return ProviderToken(token=SecretStr(token_value), user_id=None)
elif isinstance(token_value, SecretStr):
return ProviderToken(token=token_value, user_id=None)
else:
raise ValueError(f'Invalid token type: {type(token_value)}')
def model_post_init(self, __context) -> None:
# Convert any string tokens to ProviderToken objects
converted_tokens = {}
for token_type, token_value in self.provider_tokens.items():
if token_value: # Only convert non-empty tokens
try:
if isinstance(token_type, str):
token_type = ProviderType(token_type)
converted_tokens[token_type] = self._convert_token(token_value)
except ValueError:
# Skip invalid provider types or tokens
continue
self.provider_tokens = converted_tokens
model_config = {
'frozen': True,
'validate_assignment': True,
'arbitrary_types_allowed': True,
}
@field_serializer('provider_tokens')
def provider_tokens_serializer(
@ -82,6 +91,40 @@ class SecretStore(BaseModel):
return tokens
@model_validator(mode='before')
@classmethod
def convert_dict_to_mappingproxy(
cls, data: dict[str, dict[str, dict[str, str]]] | PROVIDER_TOKEN_TYPE
) -> dict[str, MappingProxyType]:
"""Custom deserializer to convert dictionary into MappingProxyType"""
if not isinstance(data, dict):
raise ValueError('SecretStore must be initialized with a dictionary')
new_data = {}
if 'provider_tokens' in data:
tokens = data['provider_tokens']
if isinstance(
tokens, dict
): # Ensure conversion happens only for dict inputs
converted_tokens = {}
for key, value in tokens.items():
try:
provider_type = (
ProviderType(key) if isinstance(key, str) else key
)
converted_tokens[provider_type] = ProviderToken.from_value(
value
)
except ValueError:
# Skip invalid provider types or tokens
continue
# Convert to MappingProxyType
new_data['provider_tokens'] = MappingProxyType(converted_tokens)
return new_data
class ProviderHandler:
def __init__(
@ -94,8 +137,14 @@ class ProviderHandler:
ProviderType.GITLAB: GitLabServiceImpl,
}
self.provider_tokens = provider_tokens
# Create immutable copy through SecretStore
self.external_auth_token = external_auth_token
self._provider_tokens = provider_tokens
@property
def provider_tokens(self) -> PROVIDER_TOKEN_TYPE:
"""Read-only access to provider tokens."""
return self._provider_tokens
def _get_service(self, provider: ProviderType) -> GitService:
"""Helper method to instantiate a service for a given provider"""

View File

@ -3,7 +3,7 @@ from fastapi.responses import JSONResponse
from pydantic import SecretStr
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.provider import ProviderToken, ProviderType
from openhands.integrations.provider import ProviderToken, ProviderType, SecretStore
from openhands.integrations.utils import validate_provider_token
from openhands.server.auth import get_provider_tokens, get_user_id
from openhands.server.settings import GETSettingsModel, POSTSettingsModel, Settings
@ -26,12 +26,11 @@ async def load_settings(request: Request) -> GETSettingsModel | JSONResponse:
github_token_is_set = bool(user_id) or bool(get_provider_tokens(request))
settings_with_token_data = GETSettingsModel(
**settings.model_dump(),
**settings.model_dump(exclude='secrets_store'),
github_token_is_set=github_token_is_set,
)
settings_with_token_data.llm_api_key = settings.llm_api_key
del settings_with_token_data.secrets_store
settings_with_token_data.llm_api_key = settings.llm_api_key
return settings_with_token_data
except Exception as e:
logger.warning(f'Invalid token: {e}')
@ -90,9 +89,10 @@ async def store_settings(
existing_settings.user_consents_to_analytics
)
# Handle token updates immutably
if settings.unset_github_token:
settings.secrets_store.provider_tokens = {}
settings.provider_tokens = {}
settings = settings.model_copy(update={"secrets_store": SecretStore()})
else: # Only merge if not unsetting tokens
if settings.provider_tokens:
if existing_settings.secrets_store:
@ -156,16 +156,22 @@ def convert_to_settings(settings_with_token_data: POSTSettingsModel) -> Settings
# Convert the `llm_api_key` to a `SecretStr` instance
filtered_settings_data['llm_api_key'] = settings_with_token_data.llm_api_key
# Create a new Settings instance without provider tokens
# Create a new Settings instance with empty SecretStore
settings = Settings(**filtered_settings_data)
# Update provider tokens if any are provided
# Create new provider tokens immutably
if settings_with_token_data.provider_tokens:
tokens = {}
for token_type, token_value in settings_with_token_data.provider_tokens.items():
if token_value:
provider = ProviderType(token_type)
settings.secrets_store.provider_tokens[provider] = ProviderToken(
tokens[provider] = ProviderToken(
token=SecretStr(token_value), user_id=None
)
# Create new SecretStore with tokens
settings = settings.model_copy(update={"secrets_store": SecretStore(
provider_tokens=tokens
)})
return settings
return settings

View File

@ -2,6 +2,7 @@ from __future__ import annotations
from pydantic import (
BaseModel,
Field,
SecretStr,
SerializationInfo,
field_serializer,
@ -11,7 +12,7 @@ from pydantic.json import pydantic_encoder
from openhands.core.config.llm_config import LLMConfig
from openhands.core.config.utils import load_app_config
from openhands.integrations.provider import ProviderToken, ProviderType, SecretStore
from openhands.integrations.provider import SecretStore
class Settings(BaseModel):
@ -28,11 +29,15 @@ class Settings(BaseModel):
llm_api_key: SecretStr | None = None
llm_base_url: str | None = None
remote_runtime_resource_factor: int | None = None
secrets_store: SecretStore = SecretStore()
secrets_store: SecretStore = Field(default_factory=SecretStore, frozen=True)
enable_default_condenser: bool = False
enable_sound_notifications: bool = False
user_consents_to_analytics: bool | None = None
model_config = {
'validate_assignment': True,
}
@field_serializer('llm_api_key')
def llm_api_key_serializer(self, llm_api_key: SecretStr, info: SerializationInfo):
"""Custom serializer for the LLM API key.
@ -45,23 +50,6 @@ class Settings(BaseModel):
return pydantic_encoder(llm_api_key) if llm_api_key else None
@staticmethod
def _convert_token_value(
token_type: ProviderType, token_value: str | dict
) -> ProviderToken | None:
"""Convert a token value to a ProviderToken object."""
if isinstance(token_value, dict):
token_str = token_value.get('token')
if not token_str:
return None
return ProviderToken(
token=SecretStr(token_str),
user_id=token_value.get('user_id'),
)
if isinstance(token_value, str) and token_value:
return ProviderToken(token=SecretStr(token_value), user_id=None)
return None
@model_validator(mode='before')
@classmethod
def convert_provider_tokens(cls, data: dict | object) -> dict | object:
@ -77,21 +65,7 @@ class Settings(BaseModel):
if not isinstance(tokens, dict):
return data
converted_tokens = {}
for token_type_str, token_value in tokens.items():
if not token_value:
continue
try:
token_type = ProviderType(token_type_str)
except ValueError:
continue
provider_token = cls._convert_token_value(token_type, token_value)
if provider_token:
converted_tokens[token_type] = provider_token
data['secrets_store'] = SecretStore(provider_tokens=converted_tokens)
data['secrets_store'] = SecretStore(provider_tokens=tokens)
return data
@field_serializer('secrets_store')

View File

@ -0,0 +1,229 @@
from types import MappingProxyType
import pytest
from pydantic import SecretStr, ValidationError
from openhands.integrations.provider import (
ProviderHandler,
ProviderToken,
ProviderType,
SecretStore,
)
from openhands.server.routes.settings import convert_to_settings
from openhands.server.settings import POSTSettingsModel, Settings
def test_provider_token_immutability():
"""Test that ProviderToken is immutable"""
token = ProviderToken(token=SecretStr('test'), user_id='user1')
# Test direct attribute modification
with pytest.raises(ValidationError):
token.token = SecretStr('new')
with pytest.raises(ValidationError):
token.user_id = 'new_user'
# Test that __setattr__ is blocked
with pytest.raises(ValidationError):
setattr(token, 'token', SecretStr('new'))
# Verify original values are unchanged
assert token.token.get_secret_value() == 'test'
assert token.user_id == 'user1'
def test_secret_store_immutability():
"""Test that SecretStore is immutable"""
store = SecretStore(
provider_tokens={ProviderType.GITHUB: ProviderToken(token=SecretStr('test'))}
)
# Test direct attribute modification
with pytest.raises(ValidationError):
store.provider_tokens = {}
# Test dictionary mutation attempts
with pytest.raises((TypeError, AttributeError)):
store.provider_tokens[ProviderType.GITHUB] = ProviderToken(
token=SecretStr('new')
)
with pytest.raises((TypeError, AttributeError)):
store.provider_tokens.clear()
with pytest.raises((TypeError, AttributeError)):
store.provider_tokens.update(
{ProviderType.GITLAB: ProviderToken(token=SecretStr('test'))}
)
# Test nested immutability
github_token = store.provider_tokens[ProviderType.GITHUB]
with pytest.raises(ValidationError):
github_token.token = SecretStr('new')
# Verify original values are unchanged
assert store.provider_tokens[ProviderType.GITHUB].token.get_secret_value() == 'test'
def test_settings_immutability():
"""Test that Settings secrets_store is immutable"""
settings = Settings(
secrets_store=SecretStore(
provider_tokens={
ProviderType.GITHUB: ProviderToken(token=SecretStr('test'))
}
)
)
# Test direct modification of secrets_store
with pytest.raises(ValidationError):
settings.secrets_store = SecretStore()
# Test nested modification attempts
with pytest.raises((TypeError, AttributeError)):
settings.secrets_store.provider_tokens[ProviderType.GITHUB] = ProviderToken(
token=SecretStr('new')
)
# Test model_copy creates new instance
new_store = SecretStore(
provider_tokens={
ProviderType.GITHUB: ProviderToken(token=SecretStr('new_token'))
}
)
new_settings = settings.model_copy(update={'secrets_store': new_store})
# Verify original is unchanged and new has updated values
assert (
settings.secrets_store.provider_tokens[
ProviderType.GITHUB
].token.get_secret_value()
== 'test'
)
assert (
new_settings.secrets_store.provider_tokens[
ProviderType.GITHUB
].token.get_secret_value()
== 'new_token'
)
with pytest.raises(ValidationError):
new_settings.secrets_store.provider_tokens[
ProviderType.GITHUB
].token = SecretStr('')
def test_post_settings_conversion():
"""Test that POSTSettingsModel correctly converts to Settings"""
# Create POST model with token data
post_data = POSTSettingsModel(
provider_tokens={'github': 'test_token', 'gitlab': 'gitlab_token'}
)
# Convert to settings using convert_to_settings function
settings = convert_to_settings(post_data)
# Verify tokens were converted correctly
assert (
settings.secrets_store.provider_tokens[
ProviderType.GITHUB
].token.get_secret_value()
== 'test_token'
)
assert (
settings.secrets_store.provider_tokens[
ProviderType.GITLAB
].token.get_secret_value()
== 'gitlab_token'
)
assert settings.secrets_store.provider_tokens[ProviderType.GITLAB].user_id is None
# Verify immutability of converted settings
with pytest.raises(ValidationError):
settings.secrets_store = SecretStore()
def test_provider_handler_immutability():
"""Test that ProviderHandler maintains token immutability"""
# Create initial tokens
tokens = MappingProxyType(
{ProviderType.GITHUB: ProviderToken(token=SecretStr('test'))}
)
handler = ProviderHandler(provider_tokens=tokens)
# Try to modify tokens (should raise TypeError due to frozen dict)
with pytest.raises((TypeError, AttributeError)):
handler.provider_tokens[ProviderType.GITHUB] = ProviderToken(
token=SecretStr('new')
)
# Try to modify the handler's tokens property
with pytest.raises((ValidationError, TypeError, AttributeError)):
handler.provider_tokens = {}
# Original token should be unchanged
assert (
handler.provider_tokens[ProviderType.GITHUB].token.get_secret_value() == 'test'
)
def test_token_conversion():
"""Test token conversion in SecretStore.create"""
# Test with string token
store1 = Settings(
secrets_store=SecretStore(
provider_tokens={
ProviderType.GITHUB: ProviderToken(token=SecretStr('test_token'))
}
)
)
assert (
store1.secrets_store.provider_tokens[
ProviderType.GITHUB
].token.get_secret_value()
== 'test_token'
)
assert store1.secrets_store.provider_tokens[ProviderType.GITHUB].user_id is None
# Test with dict token
store2 = SecretStore(
provider_tokens={'github': {'token': 'test_token', 'user_id': 'user1'}}
)
assert (
store2.provider_tokens[ProviderType.GITHUB].token.get_secret_value()
== 'test_token'
)
assert store2.provider_tokens[ProviderType.GITHUB].user_id == 'user1'
# Test with ProviderToken
token = ProviderToken(token=SecretStr('test_token'), user_id='user2')
store3 = SecretStore(provider_tokens={ProviderType.GITHUB: token})
assert (
store3.provider_tokens[ProviderType.GITHUB].token.get_secret_value()
== 'test_token'
)
assert store3.provider_tokens[ProviderType.GITHUB].user_id == 'user2'
store4 = SecretStore(
provider_tokens={
ProviderType.GITHUB: 123 # Invalid type
}
)
assert ProviderType.GITHUB not in store4.provider_tokens
# Test with empty/None token
store5 = SecretStore(provider_tokens={ProviderType.GITHUB: None})
assert ProviderType.GITHUB not in store5.provider_tokens
store6 = SecretStore(
provider_tokens={
'invalid_provider': 'test_token' # Invalid provider type
}
)
assert len(store6.provider_tokens.keys()) == 0

View File

@ -6,7 +6,7 @@ from openhands.core.config.app_config import AppConfig
from openhands.core.config.llm_config import LLMConfig
from openhands.core.config.sandbox_config import SandboxConfig
from openhands.core.config.security_config import SecurityConfig
from openhands.integrations.provider import ProviderToken, ProviderType
from openhands.integrations.provider import ProviderToken, ProviderType, SecretStore
from openhands.server.routes.settings import convert_to_settings
from openhands.server.settings import POSTSettingsModel, Settings
@ -81,10 +81,14 @@ def test_settings_handles_sensitive_data():
llm_api_key='test-key',
llm_base_url='https://test.example.com',
remote_runtime_resource_factor=2,
)
settings.secrets_store.provider_tokens[ProviderType.GITHUB] = ProviderToken(
token=SecretStr('test-token'),
user_id=None,
secrets_store=SecretStore(
provider_tokens={
ProviderType.GITHUB: ProviderToken(
token=SecretStr('test-token'),
user_id=None,
)
}
),
)
assert str(settings.llm_api_key) == '**********'