[Bug fix]: Standardize SecretStr use (#6660)

Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Rohit Malhotra 2025-02-10 08:03:56 -05:00 committed by GitHub
parent 707cb07f4f
commit 4a5891cbea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 166 additions and 71 deletions

View File

@ -24,3 +24,6 @@ inline-quotes = "single"
[format]
quote-style = "single"
[lint.flake8-bugbear]
extend-immutable-calls = ["Depends", "fastapi.Depends", "fastapi.params.Depends"]

View File

@ -12,6 +12,7 @@ from pathlib import Path
from typing import Callable
from zipfile import ZipFile
from pydantic import SecretStr
from requests.exceptions import ConnectionError
from openhands.core.config import AppConfig, SandboxConfig
@ -234,12 +235,12 @@ class Runtime(FileEditRuntimeMixin):
source = event.source if event.source else EventSource.AGENT
self.event_stream.add_event(observation, source) # type: ignore[arg-type]
def clone_repo(self, github_token: str, selected_repository: str) -> str:
def clone_repo(self, github_token: SecretStr, selected_repository: str) -> str:
if not github_token or not selected_repository:
raise ValueError(
'github_token and selected_repository must be provided to clone a repository'
)
url = f'https://{github_token}@github.com/{selected_repository}.git'
url = f'https://{github_token.get_secret_value()}@github.com/{selected_repository}.git'
dir_name = selected_repository.split('/')[1]
# add random branch name to avoid conflicts
random_str = ''.join(

View File

@ -1,7 +1,8 @@
from fastapi import Request
from pydantic import SecretStr
def get_github_token(request: Request) -> str | None:
def get_github_token(request: Request) -> SecretStr | None:
return getattr(request.state, 'github_token', None)

View File

@ -18,7 +18,7 @@ class ServerConfig(ServerConfigInterface):
)
conversation_manager_class: str = 'openhands.server.conversation_manager.standalone_conversation_manager.StandaloneConversationManager'
github_service_class: str = 'openhands.server.services.github_service.GitHubService'
github_service_class: str = 'openhands.services.github.github_service.GitHubService'
def verify_config(self):
if self.config_cls:

View File

@ -196,7 +196,7 @@ class GitHubTokenMiddleware(SessionMiddlewareInterface):
# TODO: To avoid checks like this we should re-add the abilty to have completely different middleware in SAAS as in OSS
if getattr(request.state, 'github_token', None) is None:
if settings and settings.github_token:
request.state.github_token = settings.github_token.get_secret_value()
request.state.github_token = settings.github_token
else:
request.state.github_token = None

View File

@ -1,11 +1,15 @@
from fastapi import APIRouter, Depends
from fastapi.responses import JSONResponse
from pydantic import SecretStr
from openhands.server.auth import get_user_id
from openhands.server.data_models.gh_types import GitHubRepository, GitHubUser
from openhands.server.services.github_service import GitHubService
from openhands.server.auth import get_github_token, get_user_id
from openhands.server.shared import server_config
from openhands.server.types import GhAuthenticationError, GHUnknownException
from openhands.services.github.github_service import (
GhAuthenticationError,
GHUnknownException,
GitHubService,
)
from openhands.services.github.github_types import GitHubRepository, GitHubUser
from openhands.utils.import_utils import get_impl
app = APIRouter(prefix='/api/github')
@ -20,8 +24,9 @@ async def get_github_repositories(
sort: str = 'pushed',
installation_id: int | None = None,
github_user_id: str | None = Depends(get_user_id),
github_user_token: SecretStr | None = Depends(get_github_token),
):
client = GithubServiceImpl(github_user_id)
client = GithubServiceImpl(user_id=github_user_id, token=github_user_token)
try:
repos: list[GitHubRepository] = await client.get_repositories(
page, per_page, sort, installation_id
@ -44,8 +49,9 @@ async def get_github_repositories(
@app.get('/user')
async def get_github_user(
github_user_id: str | None = Depends(get_user_id),
github_user_token: SecretStr | None = Depends(get_github_token),
):
client = GithubServiceImpl(github_user_id)
client = GithubServiceImpl(user_id=github_user_id, token=github_user_token)
try:
user: GitHubUser = await client.get_user()
return user
@ -66,8 +72,9 @@ async def get_github_user(
@app.get('/installations')
async def get_github_installation_ids(
github_user_id: str | None = Depends(get_user_id),
github_user_token: SecretStr | None = Depends(get_github_token),
):
client = GithubServiceImpl(github_user_id)
client = GithubServiceImpl(user_id=github_user_id, token=github_user_token)
try:
installations_ids: list[int] = await client.get_installation_ids()
return installations_ids
@ -92,8 +99,9 @@ async def search_github_repositories(
sort: str = 'stars',
order: str = 'desc',
github_user_id: str | None = Depends(get_user_id),
github_user_token: SecretStr | None = Depends(get_github_token),
):
client = GithubServiceImpl(github_user_id)
client = GithubServiceImpl(user_id=github_user_id, token=github_user_token)
try:
repos: list[GitHubRepository] = await client.search_repositories(
query, per_page, sort, order

View File

@ -4,13 +4,13 @@ from typing import Callable
from fastapi import APIRouter, Body, Request
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from pydantic import BaseModel, SecretStr
from openhands.core.logger import openhands_logger as logger
from openhands.events.action.message import MessageAction
from openhands.events.stream import EventStreamSubscriber
from openhands.runtime import get_runtime_cls
from openhands.server.auth import get_user_id
from openhands.server.auth import get_github_token, get_user_id
from openhands.server.routes.github import GithubServiceImpl
from openhands.server.session.conversation_init_data import ConversationInitData
from openhands.server.shared import (
@ -44,7 +44,7 @@ class InitSessionRequest(BaseModel):
async def _create_new_conversation(
user_id: str | None,
token: str | None,
token: SecretStr | None,
selected_repository: str | None,
initial_user_msg: str | None,
image_urls: list[str] | None,
@ -72,7 +72,7 @@ async def _create_new_conversation(
logger.warn('Settings not present, not starting conversation')
raise MissingSettingsError('Settings not found')
session_init_args['github_token'] = token or ''
session_init_args['github_token'] = token or SecretStr('')
session_init_args['selected_repository'] = selected_repository
conversation_init_data = ConversationInitData(**session_init_args)
logger.info('Loading conversation store')
@ -131,7 +131,9 @@ async def new_conversation(request: Request, data: InitSessionRequest):
"""
logger.info('Initializing new conversation')
user_id = get_user_id(request)
github_token = GithubServiceImpl.get_gh_token(request)
github_service = GithubServiceImpl(user_id=user_id, token=get_github_token(request))
github_token = await github_service.get_latest_token()
selected_repository = data.selected_repository
initial_user_msg = data.initial_user_msg
image_urls = data.image_urls or []

View File

@ -1,11 +1,12 @@
from fastapi import APIRouter, Request, status
from fastapi.responses import JSONResponse
from pydantic import SecretStr
from openhands.core.logger import openhands_logger as logger
from openhands.server.auth import get_user_id
from openhands.server.services.github_service import GitHubService
from openhands.server.auth import get_github_token, get_user_id
from openhands.server.settings import GETSettingsModel, POSTSettingsModel, Settings
from openhands.server.shared import SettingsStoreImpl, config
from openhands.services.github.github_service import GitHubService
app = APIRouter(prefix='/api')
@ -22,7 +23,7 @@ async def load_settings(request: Request) -> GETSettingsModel | None:
content={'error': 'Settings not found'},
)
token_is_set = bool(user_id) or bool(request.state.github_token)
token_is_set = bool(user_id) or bool(get_github_token(request))
settings_with_token_data = GETSettingsModel(
**settings.model_dump(),
github_token_is_set=token_is_set,
@ -50,8 +51,8 @@ async def store_settings(
try:
# We check if the token is valid by getting the user
# If the token is invalid, this will raise an exception
github = GitHubService(None)
await github.validate_user(settings.github_token)
github = GitHubService(user_id=None, token=SecretStr(settings.github_token))
await github.get_user()
except Exception as e:
logger.warning(f'Invalid GitHub token: {e}')

View File

@ -2,6 +2,8 @@ import asyncio
import time
from typing import Callable, Optional
from pydantic import SecretStr
from openhands.controller import AgentController
from openhands.controller.agent import Agent
from openhands.controller.state.state import State
@ -69,7 +71,7 @@ class AgentSession:
max_budget_per_task: float | None = None,
agent_to_llm_config: dict[str, LLMConfig] | None = None,
agent_configs: dict[str, AgentConfig] | None = None,
github_token: str | None = None,
github_token: SecretStr | None = None,
selected_repository: str | None = None,
initial_message: MessageAction | None = None,
):
@ -113,7 +115,7 @@ class AgentSession:
if github_token:
self.event_stream.set_secrets(
{
'github_token': github_token,
'github_token': github_token.get_secret_value(),
}
)
if initial_message:
@ -177,7 +179,7 @@ class AgentSession:
runtime_name: str,
config: AppConfig,
agent: Agent,
github_token: str | None = None,
github_token: SecretStr | None = None,
selected_repository: str | None = None,
):
"""Creates a runtime instance
@ -195,7 +197,7 @@ class AgentSession:
runtime_cls = get_runtime_cls(runtime_name)
env_vars = (
{
'GITHUB_TOKEN': github_token,
'GITHUB_TOKEN': github_token.get_secret_value(),
}
if github_token
else None

View File

@ -1,4 +1,4 @@
from pydantic import Field
from pydantic import Field, SecretStr
from openhands.server.settings import Settings
@ -8,5 +8,5 @@ class ConversationInitData(Settings):
Session initialization data for the web environment - a deep copy of the global config is made and then overridden with this data.
"""
github_token: str | None = Field(default=None)
github_token: SecretStr | None = Field(default=None)
selected_repository: str | None = Field(default=None)

View File

@ -42,15 +42,3 @@ class LLMAuthenticationError(ValueError):
"""Raised when there is an issue with LLM authentication."""
pass
class GhAuthenticationError(ValueError):
"""Raised when there is an issue with LLM authentication."""
pass
class GHUnknownException(ValueError):
"""Raised when there is an issue with LLM authentication."""
pass

View File

@ -1,41 +1,45 @@
from typing import Any
import httpx
from fastapi import Request
from pydantic import SecretStr
from openhands.server.auth import get_github_token
from openhands.server.data_models.gh_types import GitHubRepository, GitHubUser
from openhands.server.shared import SettingsStoreImpl, config, server_config
from openhands.server.types import AppMode, GhAuthenticationError, GHUnknownException
from openhands.services.github.github_types import (
GhAuthenticationError,
GHUnknownException,
GitHubRepository,
GitHubUser,
)
class GitHubService:
BASE_URL = 'https://api.github.com'
token: str = ''
token: SecretStr = SecretStr('')
refresh = False
def __init__(self, user_id: str | None):
def __init__(self, user_id: str | None = None, token: SecretStr | None = None):
self.user_id = user_id
async def _get_github_headers(self):
if token:
self.token = token
async def _get_github_headers(self) -> dict:
"""
Retrieve the GH Token from settings store to construct the headers
"""
settings_store = await SettingsStoreImpl.get_instance(config, self.user_id)
settings = await settings_store.load()
if settings and settings.github_token:
self.token = settings.github_token.get_secret_value()
if self.user_id and not self.token:
self.token = await self.get_latest_token()
return {
'Authorization': f'Bearer {self.token}',
'Authorization': f'Bearer {self.token.get_secret_value()}',
'Accept': 'application/vnd.github.v3+json',
}
def _has_token_expired(self, status_code: int):
def _has_token_expired(self, status_code: int) -> bool:
return status_code == 401
async def _get_latest_token(self):
pass
async def get_latest_token(self) -> SecretStr:
return self.token
async def _fetch_data(
self, url: str, params: dict | None = None
@ -44,10 +48,8 @@ class GitHubService:
async with httpx.AsyncClient() as client:
github_headers = await self._get_github_headers()
response = await client.get(url, headers=github_headers, params=params)
if server_config.app_mode == AppMode.SAAS and self._has_token_expired(
response.status_code
):
await self._get_latest_token()
if self.refresh and self._has_token_expired(response.status_code):
await self.get_latest_token()
github_headers = await self._get_github_headers()
response = await client.get(
url, headers=github_headers, params=params
@ -60,8 +62,10 @@ class GitHubService:
return response.json(), headers
except httpx.HTTPStatusError:
raise GhAuthenticationError('Invalid Github token')
except httpx.HTTPStatusError as e:
if e.response.status_code == 401:
raise GhAuthenticationError('Invalid Github token')
raise GHUnknownException('Unknown error')
except httpx.HTTPError:
raise GHUnknownException('Unknown error')
@ -79,10 +83,6 @@ class GitHubService:
email=response.get('email'),
)
async def validate_user(self, token) -> GitHubUser:
self.token = token
return await self.get_user()
async def get_repositories(
self, page: int, per_page: int, sort: str, installation_id: int | None
) -> list[GitHubRepository]:
@ -133,7 +133,3 @@ class GitHubService:
]
return repos
@classmethod
def get_gh_token(cls, request: Request) -> str | None:
return get_github_token(request)

View File

@ -15,3 +15,15 @@ class GitHubRepository(BaseModel):
full_name: str
stargazers_count: int | None = None
link_header: str | None = None
class GhAuthenticationError(ValueError):
"""Raised when there is an issue with GitHub authentication."""
pass
class GHUnknownException(ValueError):
"""Raised when there is an issue with GitHub communcation."""
pass

View File

@ -0,0 +1,81 @@
from unittest.mock import AsyncMock, Mock, patch
import httpx
import pytest
from pydantic import SecretStr
from openhands.services.github.github_service import GitHubService
from openhands.services.github.github_types import GhAuthenticationError
@pytest.mark.asyncio
async def test_github_service_token_handling():
# Test initialization with SecretStr token
token = SecretStr('test-token')
service = GitHubService(user_id=None, token=token)
assert service.token == token
assert service.token.get_secret_value() == 'test-token'
# Test headers contain the token correctly
headers = await service._get_github_headers()
assert headers['Authorization'] == 'Bearer test-token'
assert headers['Accept'] == 'application/vnd.github.v3+json'
# Test initialization without token
service = GitHubService(user_id='test-user')
assert service.token == SecretStr('')
@pytest.mark.asyncio
async def test_github_service_token_refresh():
# Test that token refresh is only attempted when refresh=True
token = SecretStr('test-token')
service = GitHubService(user_id=None, token=token)
assert not service.refresh
# Test token expiry detection
assert service._has_token_expired(401)
assert not service._has_token_expired(200)
assert not service._has_token_expired(404)
# Test get_latest_token returns a copy of the current token
latest_token = await service.get_latest_token()
assert isinstance(latest_token, SecretStr)
assert latest_token.get_secret_value() == 'test-token' # Compare with known value
@pytest.mark.asyncio
async def test_github_service_fetch_data():
# Mock httpx.AsyncClient for testing API calls
mock_response = AsyncMock()
mock_response.status_code = 200
mock_response.json.return_value = {'login': 'test-user'}
mock_response.raise_for_status = Mock()
mock_client = AsyncMock()
mock_client.get.return_value = mock_response
mock_client.__aenter__.return_value = mock_client
mock_client.__aexit__.return_value = None
with patch('httpx.AsyncClient', return_value=mock_client):
service = GitHubService(user_id=None, token=SecretStr('test-token'))
_ = await service._fetch_data('https://api.github.com/user')
# Verify the request was made with correct headers
mock_client.get.assert_called_once()
call_args = mock_client.get.call_args
headers = call_args[1]['headers']
assert headers['Authorization'] == 'Bearer test-token'
# Test error handling with 401 status code
mock_response.status_code = 401
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
message='401 Unauthorized', request=Mock(), response=mock_response
)
# Reset the mock to test error handling
mock_client.get.reset_mock()
mock_client.get.return_value = mock_response
with pytest.raises(GhAuthenticationError):
_ = await service._fetch_data('https://api.github.com/user')