mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
[Bug fix]: Standardize SecretStr use (#6660)
Co-authored-by: Engel Nyst <enyst@users.noreply.github.com> Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
parent
707cb07f4f
commit
4a5891cbea
@ -24,3 +24,6 @@ inline-quotes = "single"
|
||||
|
||||
[format]
|
||||
quote-style = "single"
|
||||
|
||||
[lint.flake8-bugbear]
|
||||
extend-immutable-calls = ["Depends", "fastapi.Depends", "fastapi.params.Depends"]
|
||||
|
||||
@ -12,6 +12,7 @@ from pathlib import Path
|
||||
from typing import Callable
|
||||
from zipfile import ZipFile
|
||||
|
||||
from pydantic import SecretStr
|
||||
from requests.exceptions import ConnectionError
|
||||
|
||||
from openhands.core.config import AppConfig, SandboxConfig
|
||||
@ -234,12 +235,12 @@ class Runtime(FileEditRuntimeMixin):
|
||||
source = event.source if event.source else EventSource.AGENT
|
||||
self.event_stream.add_event(observation, source) # type: ignore[arg-type]
|
||||
|
||||
def clone_repo(self, github_token: str, selected_repository: str) -> str:
|
||||
def clone_repo(self, github_token: SecretStr, selected_repository: str) -> str:
|
||||
if not github_token or not selected_repository:
|
||||
raise ValueError(
|
||||
'github_token and selected_repository must be provided to clone a repository'
|
||||
)
|
||||
url = f'https://{github_token}@github.com/{selected_repository}.git'
|
||||
url = f'https://{github_token.get_secret_value()}@github.com/{selected_repository}.git'
|
||||
dir_name = selected_repository.split('/')[1]
|
||||
# add random branch name to avoid conflicts
|
||||
random_str = ''.join(
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
from fastapi import Request
|
||||
from pydantic import SecretStr
|
||||
|
||||
|
||||
def get_github_token(request: Request) -> str | None:
|
||||
def get_github_token(request: Request) -> SecretStr | None:
|
||||
return getattr(request.state, 'github_token', None)
|
||||
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ class ServerConfig(ServerConfigInterface):
|
||||
)
|
||||
conversation_manager_class: str = 'openhands.server.conversation_manager.standalone_conversation_manager.StandaloneConversationManager'
|
||||
|
||||
github_service_class: str = 'openhands.server.services.github_service.GitHubService'
|
||||
github_service_class: str = 'openhands.services.github.github_service.GitHubService'
|
||||
|
||||
def verify_config(self):
|
||||
if self.config_cls:
|
||||
|
||||
@ -196,7 +196,7 @@ class GitHubTokenMiddleware(SessionMiddlewareInterface):
|
||||
# TODO: To avoid checks like this we should re-add the abilty to have completely different middleware in SAAS as in OSS
|
||||
if getattr(request.state, 'github_token', None) is None:
|
||||
if settings and settings.github_token:
|
||||
request.state.github_token = settings.github_token.get_secret_value()
|
||||
request.state.github_token = settings.github_token
|
||||
else:
|
||||
request.state.github_token = None
|
||||
|
||||
|
||||
@ -1,11 +1,15 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.server.auth import get_user_id
|
||||
from openhands.server.data_models.gh_types import GitHubRepository, GitHubUser
|
||||
from openhands.server.services.github_service import GitHubService
|
||||
from openhands.server.auth import get_github_token, get_user_id
|
||||
from openhands.server.shared import server_config
|
||||
from openhands.server.types import GhAuthenticationError, GHUnknownException
|
||||
from openhands.services.github.github_service import (
|
||||
GhAuthenticationError,
|
||||
GHUnknownException,
|
||||
GitHubService,
|
||||
)
|
||||
from openhands.services.github.github_types import GitHubRepository, GitHubUser
|
||||
from openhands.utils.import_utils import get_impl
|
||||
|
||||
app = APIRouter(prefix='/api/github')
|
||||
@ -20,8 +24,9 @@ async def get_github_repositories(
|
||||
sort: str = 'pushed',
|
||||
installation_id: int | None = None,
|
||||
github_user_id: str | None = Depends(get_user_id),
|
||||
github_user_token: SecretStr | None = Depends(get_github_token),
|
||||
):
|
||||
client = GithubServiceImpl(github_user_id)
|
||||
client = GithubServiceImpl(user_id=github_user_id, token=github_user_token)
|
||||
try:
|
||||
repos: list[GitHubRepository] = await client.get_repositories(
|
||||
page, per_page, sort, installation_id
|
||||
@ -44,8 +49,9 @@ async def get_github_repositories(
|
||||
@app.get('/user')
|
||||
async def get_github_user(
|
||||
github_user_id: str | None = Depends(get_user_id),
|
||||
github_user_token: SecretStr | None = Depends(get_github_token),
|
||||
):
|
||||
client = GithubServiceImpl(github_user_id)
|
||||
client = GithubServiceImpl(user_id=github_user_id, token=github_user_token)
|
||||
try:
|
||||
user: GitHubUser = await client.get_user()
|
||||
return user
|
||||
@ -66,8 +72,9 @@ async def get_github_user(
|
||||
@app.get('/installations')
|
||||
async def get_github_installation_ids(
|
||||
github_user_id: str | None = Depends(get_user_id),
|
||||
github_user_token: SecretStr | None = Depends(get_github_token),
|
||||
):
|
||||
client = GithubServiceImpl(github_user_id)
|
||||
client = GithubServiceImpl(user_id=github_user_id, token=github_user_token)
|
||||
try:
|
||||
installations_ids: list[int] = await client.get_installation_ids()
|
||||
return installations_ids
|
||||
@ -92,8 +99,9 @@ async def search_github_repositories(
|
||||
sort: str = 'stars',
|
||||
order: str = 'desc',
|
||||
github_user_id: str | None = Depends(get_user_id),
|
||||
github_user_token: SecretStr | None = Depends(get_github_token),
|
||||
):
|
||||
client = GithubServiceImpl(github_user_id)
|
||||
client = GithubServiceImpl(user_id=github_user_id, token=github_user_token)
|
||||
try:
|
||||
repos: list[GitHubRepository] = await client.search_repositories(
|
||||
query, per_page, sort, order
|
||||
|
||||
@ -4,13 +4,13 @@ from typing import Callable
|
||||
|
||||
from fastapi import APIRouter, Body, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, SecretStr
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action.message import MessageAction
|
||||
from openhands.events.stream import EventStreamSubscriber
|
||||
from openhands.runtime import get_runtime_cls
|
||||
from openhands.server.auth import get_user_id
|
||||
from openhands.server.auth import get_github_token, get_user_id
|
||||
from openhands.server.routes.github import GithubServiceImpl
|
||||
from openhands.server.session.conversation_init_data import ConversationInitData
|
||||
from openhands.server.shared import (
|
||||
@ -44,7 +44,7 @@ class InitSessionRequest(BaseModel):
|
||||
|
||||
async def _create_new_conversation(
|
||||
user_id: str | None,
|
||||
token: str | None,
|
||||
token: SecretStr | None,
|
||||
selected_repository: str | None,
|
||||
initial_user_msg: str | None,
|
||||
image_urls: list[str] | None,
|
||||
@ -72,7 +72,7 @@ async def _create_new_conversation(
|
||||
logger.warn('Settings not present, not starting conversation')
|
||||
raise MissingSettingsError('Settings not found')
|
||||
|
||||
session_init_args['github_token'] = token or ''
|
||||
session_init_args['github_token'] = token or SecretStr('')
|
||||
session_init_args['selected_repository'] = selected_repository
|
||||
conversation_init_data = ConversationInitData(**session_init_args)
|
||||
logger.info('Loading conversation store')
|
||||
@ -131,7 +131,9 @@ async def new_conversation(request: Request, data: InitSessionRequest):
|
||||
"""
|
||||
logger.info('Initializing new conversation')
|
||||
user_id = get_user_id(request)
|
||||
github_token = GithubServiceImpl.get_gh_token(request)
|
||||
github_service = GithubServiceImpl(user_id=user_id, token=get_github_token(request))
|
||||
github_token = await github_service.get_latest_token()
|
||||
|
||||
selected_repository = data.selected_repository
|
||||
initial_user_msg = data.initial_user_msg
|
||||
image_urls = data.image_urls or []
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
from fastapi import APIRouter, Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.auth import get_user_id
|
||||
from openhands.server.services.github_service import GitHubService
|
||||
from openhands.server.auth import get_github_token, get_user_id
|
||||
from openhands.server.settings import GETSettingsModel, POSTSettingsModel, Settings
|
||||
from openhands.server.shared import SettingsStoreImpl, config
|
||||
from openhands.services.github.github_service import GitHubService
|
||||
|
||||
app = APIRouter(prefix='/api')
|
||||
|
||||
@ -22,7 +23,7 @@ async def load_settings(request: Request) -> GETSettingsModel | None:
|
||||
content={'error': 'Settings not found'},
|
||||
)
|
||||
|
||||
token_is_set = bool(user_id) or bool(request.state.github_token)
|
||||
token_is_set = bool(user_id) or bool(get_github_token(request))
|
||||
settings_with_token_data = GETSettingsModel(
|
||||
**settings.model_dump(),
|
||||
github_token_is_set=token_is_set,
|
||||
@ -50,8 +51,8 @@ async def store_settings(
|
||||
try:
|
||||
# We check if the token is valid by getting the user
|
||||
# If the token is invalid, this will raise an exception
|
||||
github = GitHubService(None)
|
||||
await github.validate_user(settings.github_token)
|
||||
github = GitHubService(user_id=None, token=SecretStr(settings.github_token))
|
||||
await github.get_user()
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f'Invalid GitHub token: {e}')
|
||||
|
||||
@ -2,6 +2,8 @@ import asyncio
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.controller import AgentController
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.controller.state.state import State
|
||||
@ -69,7 +71,7 @@ class AgentSession:
|
||||
max_budget_per_task: float | None = None,
|
||||
agent_to_llm_config: dict[str, LLMConfig] | None = None,
|
||||
agent_configs: dict[str, AgentConfig] | None = None,
|
||||
github_token: str | None = None,
|
||||
github_token: SecretStr | None = None,
|
||||
selected_repository: str | None = None,
|
||||
initial_message: MessageAction | None = None,
|
||||
):
|
||||
@ -113,7 +115,7 @@ class AgentSession:
|
||||
if github_token:
|
||||
self.event_stream.set_secrets(
|
||||
{
|
||||
'github_token': github_token,
|
||||
'github_token': github_token.get_secret_value(),
|
||||
}
|
||||
)
|
||||
if initial_message:
|
||||
@ -177,7 +179,7 @@ class AgentSession:
|
||||
runtime_name: str,
|
||||
config: AppConfig,
|
||||
agent: Agent,
|
||||
github_token: str | None = None,
|
||||
github_token: SecretStr | None = None,
|
||||
selected_repository: str | None = None,
|
||||
):
|
||||
"""Creates a runtime instance
|
||||
@ -195,7 +197,7 @@ class AgentSession:
|
||||
runtime_cls = get_runtime_cls(runtime_name)
|
||||
env_vars = (
|
||||
{
|
||||
'GITHUB_TOKEN': github_token,
|
||||
'GITHUB_TOKEN': github_token.get_secret_value(),
|
||||
}
|
||||
if github_token
|
||||
else None
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from pydantic import Field
|
||||
from pydantic import Field, SecretStr
|
||||
|
||||
from openhands.server.settings import Settings
|
||||
|
||||
@ -8,5 +8,5 @@ class ConversationInitData(Settings):
|
||||
Session initialization data for the web environment - a deep copy of the global config is made and then overridden with this data.
|
||||
"""
|
||||
|
||||
github_token: str | None = Field(default=None)
|
||||
github_token: SecretStr | None = Field(default=None)
|
||||
selected_repository: str | None = Field(default=None)
|
||||
|
||||
@ -42,15 +42,3 @@ class LLMAuthenticationError(ValueError):
|
||||
"""Raised when there is an issue with LLM authentication."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class GhAuthenticationError(ValueError):
|
||||
"""Raised when there is an issue with LLM authentication."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class GHUnknownException(ValueError):
|
||||
"""Raised when there is an issue with LLM authentication."""
|
||||
|
||||
pass
|
||||
|
||||
@ -1,41 +1,45 @@
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from fastapi import Request
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.server.auth import get_github_token
|
||||
from openhands.server.data_models.gh_types import GitHubRepository, GitHubUser
|
||||
from openhands.server.shared import SettingsStoreImpl, config, server_config
|
||||
from openhands.server.types import AppMode, GhAuthenticationError, GHUnknownException
|
||||
from openhands.services.github.github_types import (
|
||||
GhAuthenticationError,
|
||||
GHUnknownException,
|
||||
GitHubRepository,
|
||||
GitHubUser,
|
||||
)
|
||||
|
||||
|
||||
class GitHubService:
|
||||
BASE_URL = 'https://api.github.com'
|
||||
token: str = ''
|
||||
token: SecretStr = SecretStr('')
|
||||
refresh = False
|
||||
|
||||
def __init__(self, user_id: str | None):
|
||||
def __init__(self, user_id: str | None = None, token: SecretStr | None = None):
|
||||
self.user_id = user_id
|
||||
|
||||
async def _get_github_headers(self):
|
||||
if token:
|
||||
self.token = token
|
||||
|
||||
async def _get_github_headers(self) -> dict:
|
||||
"""
|
||||
Retrieve the GH Token from settings store to construct the headers
|
||||
"""
|
||||
|
||||
settings_store = await SettingsStoreImpl.get_instance(config, self.user_id)
|
||||
settings = await settings_store.load()
|
||||
if settings and settings.github_token:
|
||||
self.token = settings.github_token.get_secret_value()
|
||||
if self.user_id and not self.token:
|
||||
self.token = await self.get_latest_token()
|
||||
|
||||
return {
|
||||
'Authorization': f'Bearer {self.token}',
|
||||
'Authorization': f'Bearer {self.token.get_secret_value()}',
|
||||
'Accept': 'application/vnd.github.v3+json',
|
||||
}
|
||||
|
||||
def _has_token_expired(self, status_code: int):
|
||||
def _has_token_expired(self, status_code: int) -> bool:
|
||||
return status_code == 401
|
||||
|
||||
async def _get_latest_token(self):
|
||||
pass
|
||||
async def get_latest_token(self) -> SecretStr:
|
||||
return self.token
|
||||
|
||||
async def _fetch_data(
|
||||
self, url: str, params: dict | None = None
|
||||
@ -44,10 +48,8 @@ class GitHubService:
|
||||
async with httpx.AsyncClient() as client:
|
||||
github_headers = await self._get_github_headers()
|
||||
response = await client.get(url, headers=github_headers, params=params)
|
||||
if server_config.app_mode == AppMode.SAAS and self._has_token_expired(
|
||||
response.status_code
|
||||
):
|
||||
await self._get_latest_token()
|
||||
if self.refresh and self._has_token_expired(response.status_code):
|
||||
await self.get_latest_token()
|
||||
github_headers = await self._get_github_headers()
|
||||
response = await client.get(
|
||||
url, headers=github_headers, params=params
|
||||
@ -60,8 +62,10 @@ class GitHubService:
|
||||
|
||||
return response.json(), headers
|
||||
|
||||
except httpx.HTTPStatusError:
|
||||
raise GhAuthenticationError('Invalid Github token')
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise GhAuthenticationError('Invalid Github token')
|
||||
raise GHUnknownException('Unknown error')
|
||||
|
||||
except httpx.HTTPError:
|
||||
raise GHUnknownException('Unknown error')
|
||||
@ -79,10 +83,6 @@ class GitHubService:
|
||||
email=response.get('email'),
|
||||
)
|
||||
|
||||
async def validate_user(self, token) -> GitHubUser:
|
||||
self.token = token
|
||||
return await self.get_user()
|
||||
|
||||
async def get_repositories(
|
||||
self, page: int, per_page: int, sort: str, installation_id: int | None
|
||||
) -> list[GitHubRepository]:
|
||||
@ -133,7 +133,3 @@ class GitHubService:
|
||||
]
|
||||
|
||||
return repos
|
||||
|
||||
@classmethod
|
||||
def get_gh_token(cls, request: Request) -> str | None:
|
||||
return get_github_token(request)
|
||||
@ -15,3 +15,15 @@ class GitHubRepository(BaseModel):
|
||||
full_name: str
|
||||
stargazers_count: int | None = None
|
||||
link_header: str | None = None
|
||||
|
||||
|
||||
class GhAuthenticationError(ValueError):
|
||||
"""Raised when there is an issue with GitHub authentication."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class GHUnknownException(ValueError):
|
||||
"""Raised when there is an issue with GitHub communcation."""
|
||||
|
||||
pass
|
||||
81
tests/unit/test_github_service.py
Normal file
81
tests/unit/test_github_service.py
Normal file
@ -0,0 +1,81 @@
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.services.github.github_service import GitHubService
|
||||
from openhands.services.github.github_types import GhAuthenticationError
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_github_service_token_handling():
|
||||
# Test initialization with SecretStr token
|
||||
token = SecretStr('test-token')
|
||||
service = GitHubService(user_id=None, token=token)
|
||||
assert service.token == token
|
||||
assert service.token.get_secret_value() == 'test-token'
|
||||
|
||||
# Test headers contain the token correctly
|
||||
headers = await service._get_github_headers()
|
||||
assert headers['Authorization'] == 'Bearer test-token'
|
||||
assert headers['Accept'] == 'application/vnd.github.v3+json'
|
||||
|
||||
# Test initialization without token
|
||||
service = GitHubService(user_id='test-user')
|
||||
assert service.token == SecretStr('')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_github_service_token_refresh():
|
||||
# Test that token refresh is only attempted when refresh=True
|
||||
token = SecretStr('test-token')
|
||||
service = GitHubService(user_id=None, token=token)
|
||||
assert not service.refresh
|
||||
|
||||
# Test token expiry detection
|
||||
assert service._has_token_expired(401)
|
||||
assert not service._has_token_expired(200)
|
||||
assert not service._has_token_expired(404)
|
||||
|
||||
# Test get_latest_token returns a copy of the current token
|
||||
latest_token = await service.get_latest_token()
|
||||
assert isinstance(latest_token, SecretStr)
|
||||
assert latest_token.get_secret_value() == 'test-token' # Compare with known value
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_github_service_fetch_data():
|
||||
# Mock httpx.AsyncClient for testing API calls
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {'login': 'test-user'}
|
||||
mock_response.raise_for_status = Mock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_response
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.__aexit__.return_value = None
|
||||
|
||||
with patch('httpx.AsyncClient', return_value=mock_client):
|
||||
service = GitHubService(user_id=None, token=SecretStr('test-token'))
|
||||
_ = await service._fetch_data('https://api.github.com/user')
|
||||
|
||||
# Verify the request was made with correct headers
|
||||
mock_client.get.assert_called_once()
|
||||
call_args = mock_client.get.call_args
|
||||
headers = call_args[1]['headers']
|
||||
assert headers['Authorization'] == 'Bearer test-token'
|
||||
|
||||
# Test error handling with 401 status code
|
||||
mock_response.status_code = 401
|
||||
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
message='401 Unauthorized', request=Mock(), response=mock_response
|
||||
)
|
||||
|
||||
# Reset the mock to test error handling
|
||||
mock_client.get.reset_mock()
|
||||
mock_client.get.return_value = mock_response
|
||||
|
||||
with pytest.raises(GhAuthenticationError):
|
||||
_ = await service._fetch_data('https://api.github.com/user')
|
||||
Loading…
x
Reference in New Issue
Block a user