[Feat]: Support Gitlab PAT (#7064)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Rohit Malhotra 2025-03-13 14:44:49 -04:00 committed by GitHub
parent 300bfbdf2d
commit 78d185b102
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 921 additions and 353 deletions

View File

@ -721,7 +721,7 @@ describe("Settings Screen", () => {
expect(saveSettingsSpy).toHaveBeenCalledWith(
expect.objectContaining({
llm_api_key: "", // empty because it's not set previously
github_token: undefined,
provider_tokens: undefined,
language: "no",
}),
);
@ -758,7 +758,7 @@ describe("Settings Screen", () => {
expect(saveSettingsSpy).toHaveBeenCalledWith(
expect.objectContaining({
github_token: undefined,
provider_tokens: undefined,
llm_api_key: "", // empty because it's not set previously
llm_model: "openai/gpt-4o",
}),
@ -801,7 +801,7 @@ describe("Settings Screen", () => {
expect(saveSettingsSpy).toHaveBeenCalledWith({
...mockCopy,
github_token: undefined, // not set
provider_tokens: undefined, // not set
llm_api_key: "", // reset as well
});
expect(screen.queryByTestId("reset-modal")).not.toBeInTheDocument();

View File

@ -8,7 +8,7 @@
* - Please do NOT serve this file on production.
*/
const PACKAGE_VERSION = '2.7.0'
const PACKAGE_VERSION = '2.7.3'
const INTEGRITY_CHECKSUM = '00729d72e3b82faf54ca8b9621dbb96f'
const IS_MOCKED_RESPONSE = Symbol('isMockedResponse')
const activeClientIds = new Set()

View File

@ -17,7 +17,7 @@ const saveSettingsMutationFn = async (settings: Partial<PostSettings>) => {
? ""
: settings.LLM_API_KEY?.trim() || undefined,
remote_runtime_resource_factor: settings.REMOTE_RUNTIME_RESOURCE_FACTOR,
github_token: settings.github_token,
provider_tokens: settings.provider_tokens,
unset_github_token: settings.unset_github_token,
enable_default_condenser: settings.ENABLE_DEFAULT_CONDENSER,
enable_sound_notifications: settings.ENABLE_SOUND_NOTIFICATIONS,

View File

@ -21,6 +21,7 @@ const getSettingsQueryFn = async () => {
ENABLE_DEFAULT_CONDENSER: apiSettings.enable_default_condenser,
ENABLE_SOUND_NOTIFICATIONS: apiSettings.enable_sound_notifications,
USER_CONSENTS_TO_ANALYTICS: apiSettings.user_consents_to_analytics,
PROVIDER_TOKENS: apiSettings.provider_tokens,
IS_NEW_USER: false,
};
};

View File

@ -22,6 +22,7 @@ export const MOCK_DEFAULT_USER_SETTINGS: ApiSettings | PostApiSettings = {
enable_default_condenser: DEFAULT_SETTINGS.ENABLE_DEFAULT_CONDENSER,
enable_sound_notifications: DEFAULT_SETTINGS.ENABLE_SOUND_NOTIFICATIONS,
user_consents_to_analytics: DEFAULT_SETTINGS.USER_CONSENTS_TO_ANALYTICS,
provider_tokens: DEFAULT_SETTINGS.PROVIDER_TOKENS,
};
const MOCK_USER_PREFERENCES: {
@ -190,8 +191,8 @@ export const handlers = [
if (!settings) return HttpResponse.json(null, { status: 404 });
// @ts-expect-error - mock types
if (settings.github_token) settings.github_token_is_set = true;
if (Object.keys(settings.provider_tokens).length > 0)
settings.github_token_is_set = true;
return HttpResponse.json(settings);
}),
@ -203,7 +204,7 @@ export const handlers = [
if (typeof body === "object") {
newSettings = { ...body };
if (newSettings.unset_github_token) {
newSettings.github_token = undefined;
newSettings.provider_tokens = { github: "", gitlab: "" };
newSettings.github_token_is_set = false;
delete newSettings.unset_github_token;
}

View File

@ -61,7 +61,10 @@ function AccountSettings() {
if (isSuccess) {
return (
isCustomModel(resources.models, settings.LLM_MODEL) ||
hasAdvancedSettingsSet(settings)
hasAdvancedSettingsSet({
...settings,
PROVIDER_TOKENS: settings.PROVIDER_TOKENS || {},
})
);
}
@ -128,37 +131,42 @@ function AccountSettings() {
: llmBaseUrl;
const finalLlmApiKey = shouldHandleSpecialSaasCase ? undefined : llmApiKey;
saveSettings(
{
github_token:
formData.get("github-token-input")?.toString() || undefined,
LANGUAGE: languageValue,
user_consents_to_analytics: userConsentsToAnalytics,
ENABLE_DEFAULT_CONDENSER: enableMemoryCondenser,
ENABLE_SOUND_NOTIFICATIONS: enableSoundNotifications,
LLM_MODEL: finalLlmModel,
LLM_BASE_URL: finalLlmBaseUrl,
LLM_API_KEY: finalLlmApiKey,
AGENT: formData.get("agent-input")?.toString(),
SECURITY_ANALYZER:
formData.get("security-analyzer-input")?.toString() || "",
REMOTE_RUNTIME_RESOURCE_FACTOR:
remoteRuntimeResourceFactor ||
DEFAULT_SETTINGS.REMOTE_RUNTIME_RESOURCE_FACTOR,
CONFIRMATION_MODE: confirmationModeIsEnabled,
const githubToken = formData.get("github-token-input")?.toString();
const newSettings = {
github_token: githubToken,
provider_tokens: githubToken
? {
github: githubToken,
gitlab: "",
}
: undefined,
LANGUAGE: languageValue,
user_consents_to_analytics: userConsentsToAnalytics,
ENABLE_DEFAULT_CONDENSER: enableMemoryCondenser,
ENABLE_SOUND_NOTIFICATIONS: enableSoundNotifications,
LLM_MODEL: finalLlmModel,
LLM_BASE_URL: finalLlmBaseUrl,
LLM_API_KEY: finalLlmApiKey,
AGENT: formData.get("agent-input")?.toString(),
SECURITY_ANALYZER:
formData.get("security-analyzer-input")?.toString() || "",
REMOTE_RUNTIME_RESOURCE_FACTOR:
remoteRuntimeResourceFactor ||
DEFAULT_SETTINGS.REMOTE_RUNTIME_RESOURCE_FACTOR,
CONFIRMATION_MODE: confirmationModeIsEnabled,
};
saveSettings(newSettings, {
onSuccess: () => {
handleCaptureConsent(userConsentsToAnalytics);
displaySuccessToast("Settings saved");
setLlmConfigMode(isAdvancedSettingsSet ? "advanced" : "basic");
},
{
onSuccess: () => {
handleCaptureConsent(userConsentsToAnalytics);
displaySuccessToast("Settings saved");
setLlmConfigMode(isAdvancedSettingsSet ? "advanced" : "basic");
},
onError: (error) => {
const errorMessage = retrieveAxiosErrorMessage(error);
displayErrorToast(errorMessage);
},
onError: (error) => {
const errorMessage = retrieveAxiosErrorMessage(error);
displayErrorToast(errorMessage);
},
);
});
};
const handleReset = () => {

View File

@ -15,6 +15,10 @@ export const DEFAULT_SETTINGS: Settings = {
ENABLE_DEFAULT_CONDENSER: true,
ENABLE_SOUND_NOTIFICATIONS: false,
USER_CONSENTS_TO_ANALYTICS: false,
PROVIDER_TOKENS: {
github: "",
gitlab: "",
},
IS_NEW_USER: true,
};

View File

@ -21,7 +21,6 @@ export interface InitConfig {
LLM_MODEL: string;
};
token?: string;
github_token?: string;
latest_event_id?: unknown; // Not sure what this is
}

View File

@ -1,3 +1,5 @@
export type Provider = "github" | "gitlab";
export type Settings = {
LLM_MODEL: string;
LLM_BASE_URL: string;
@ -11,6 +13,7 @@ export type Settings = {
ENABLE_DEFAULT_CONDENSER: boolean;
ENABLE_SOUND_NOTIFICATIONS: boolean;
USER_CONSENTS_TO_ANALYTICS: boolean | null;
PROVIDER_TOKENS: Record<Provider, string>;
IS_NEW_USER?: boolean;
};
@ -27,16 +30,17 @@ export type ApiSettings = {
enable_default_condenser: boolean;
enable_sound_notifications: boolean;
user_consents_to_analytics: boolean | null;
provider_tokens: Record<Provider, string>;
};
export type PostSettings = Settings & {
github_token: string;
provider_tokens: Record<Provider, string>;
unset_github_token: boolean;
user_consents_to_analytics: boolean | null;
};
export type PostApiSettings = ApiSettings & {
github_token: string;
provider_tokens: Record<Provider, string>;
unset_github_token: boolean;
user_consents_to_analytics: boolean | null;
};

View File

@ -59,6 +59,18 @@ export const extractSettings = (formData: FormData): Partial<Settings> => {
ENABLE_DEFAULT_CONDENSER,
} = extractAdvancedFormData(formData);
// Extract provider tokens
const githubToken = formData.get("github-token")?.toString();
const gitlabToken = formData.get("gitlab-token")?.toString();
const providerTokens: Record<string, string> = {};
if (githubToken) {
providerTokens.github = githubToken;
}
if (gitlabToken) {
providerTokens.gitlab = gitlabToken;
}
return {
LLM_MODEL: CUSTOM_LLM_MODEL || LLM_MODEL,
LLM_API_KEY,
@ -68,5 +80,6 @@ export const extractSettings = (formData: FormData): Partial<Settings> => {
CONFIRMATION_MODE,
SECURITY_ANALYZER,
ENABLE_DEFAULT_CONDENSER,
PROVIDER_TOKENS: providerTokens,
};
};

View File

@ -5,43 +5,43 @@ from typing import Any
import httpx
from pydantic import SecretStr
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.github.github_types import (
GhAuthenticationError,
GHUnknownException,
GitHubRepository,
GitHubUser,
from openhands.integrations.service_types import (
AuthenticationError,
GitService,
Repository,
SuggestedTask,
TaskType,
UnknownException,
User,
)
from openhands.utils.import_utils import get_impl
from openhands.core.logger import openhands_logger as logger
class GitHubService:
class GitHubService(GitService):
BASE_URL = 'https://api.github.com'
github_token: SecretStr = SecretStr('')
token: SecretStr = SecretStr('')
refresh = False
def __init__(
self,
user_id: str | None = None,
external_auth_token: SecretStr | None = None,
github_token: SecretStr | None = None,
token: SecretStr | None = None,
external_token_manager: bool = False,
):
self.user_id = user_id
self.external_token_manager = external_token_manager
if github_token:
self.github_token = github_token
if token:
self.token = token
async def _get_github_headers(self) -> dict:
"""Retrieve the GH Token from settings store to construct the headers."""
if self.user_id and not self.github_token:
self.github_token = await self.get_latest_token()
if self.user_id and not self.token:
self.token = await self.get_latest_token()
return {
'Authorization': f'Bearer {self.github_token.get_secret_value() if self.github_token else ""}',
'Authorization': f'Bearer {self.token.get_secret_value() if self.token else ""}',
'Accept': 'application/vnd.github.v3+json',
}
@ -49,7 +49,7 @@ class GitHubService:
return status_code == 401
async def get_latest_token(self) -> SecretStr | None:
return self.github_token
return self.token
async def _fetch_data(
self, url: str, params: dict | None = None
@ -74,20 +74,20 @@ class GitHubService:
except httpx.HTTPStatusError as e:
if e.response.status_code == 401:
raise GhAuthenticationError('Invalid Github token')
raise AuthenticationError('Invalid Github token')
logger.warning(f'Status error on GH API: {e}')
raise GHUnknownException('Unknown error')
raise UnknownException('Unknown error')
except httpx.HTTPError as e:
logger.warning(f'HTTP error on GH API: {e}')
raise GHUnknownException('Unknown error')
raise UnknownException('Unknown error')
async def get_user(self) -> GitHubUser:
async def get_user(self) -> User:
url = f'{self.BASE_URL}/user'
response, _ = await self._fetch_data(url)
return GitHubUser(
return User(
id=response.get('id'),
login=response.get('login'),
avatar_url=response.get('avatar_url'),
@ -98,7 +98,7 @@ class GitHubService:
async def get_repositories(
self, page: int, per_page: int, sort: str, installation_id: int | None
) -> list[GitHubRepository]:
) -> list[Repository]:
params = {'page': str(page), 'per_page': str(per_page)}
if installation_id:
url = f'{self.BASE_URL}/user/installations/{installation_id}/repositories'
@ -111,7 +111,7 @@ class GitHubService:
next_link: str = headers.get('Link', '')
repos = [
GitHubRepository(
Repository(
id=repo.get('id'),
full_name=repo.get('full_name'),
stargazers_count=repo.get('stargazers_count'),
@ -129,7 +129,7 @@ class GitHubService:
async def search_repositories(
self, query: str, per_page: int, sort: str, order: str
) -> list[GitHubRepository]:
) -> list[Repository]:
url = f'{self.BASE_URL}/search/repositories'
params = {'q': query, 'per_page': per_page, 'sort': sort, 'order': order}
@ -137,7 +137,7 @@ class GitHubService:
repos = response.get('items', [])
repos = [
GitHubRepository(
Repository(
id=repo.get('id'),
full_name=repo.get('full_name'),
stargazers_count=repo.get('stargazers_count'),
@ -163,7 +163,7 @@ class GitHubService:
result = response.json()
if 'errors' in result:
raise GHUnknownException(
raise UnknownException(
f"GraphQL query error: {json.dumps(result['errors'])}"
)
@ -171,14 +171,14 @@ class GitHubService:
except httpx.HTTPStatusError as e:
if e.response.status_code == 401:
raise GhAuthenticationError('Invalid Github token')
raise AuthenticationError('Invalid Github token')
logger.warning(f'Status error on GH API: {e}')
raise GHUnknownException('Unknown error')
raise UnknownException('Unknown error')
except httpx.HTTPError as e:
logger.warning(f'HTTP error on GH API: {e}')
raise GHUnknownException('Unknown error')
raise UnknownException('Unknown error')
async def get_suggested_tasks(self) -> list[SuggestedTask]:
"""Get suggested tasks for the authenticated user across all repositories.

View File

@ -1,46 +0,0 @@
from enum import Enum
from pydantic import BaseModel
class TaskType(str, Enum):
MERGE_CONFLICTS = 'MERGE_CONFLICTS'
FAILING_CHECKS = 'FAILING_CHECKS'
UNRESOLVED_COMMENTS = 'UNRESOLVED_COMMENTS'
OPEN_ISSUE = 'OPEN_ISSUE'
OPEN_PR = 'OPEN_PR'
class SuggestedTask(BaseModel):
task_type: TaskType
repo: str
issue_number: int
title: str
class GitHubUser(BaseModel):
id: int
login: str
avatar_url: str
company: str | None = None
name: str | None = None
email: str | None = None
class GitHubRepository(BaseModel):
id: int
full_name: str
stargazers_count: int | None = None
link_header: str | None = None
class GhAuthenticationError(ValueError):
"""Raised when there is an issue with GitHub authentication."""
pass
class GHUnknownException(ValueError):
"""Raised when there is an issue with GitHub communcation."""
pass

View File

@ -0,0 +1,119 @@
import os
from typing import Any
import httpx
from pydantic import SecretStr
from openhands.integrations.service_types import (
AuthenticationError,
GitService,
Repository,
UnknownException,
User,
)
from openhands.utils.import_utils import get_impl
class GitLabService(GitService):
BASE_URL = 'https://gitlab.com/api/v4'
token: SecretStr = SecretStr('')
refresh = False
def __init__(
self,
user_id: str | None = None,
external_auth_token: SecretStr | None = None,
token: SecretStr | None = None,
external_token_manager: bool = False,
):
self.user_id = user_id
self.external_token_manager = external_token_manager
if token:
self.token = token
async def _get_gitlab_headers(self) -> dict:
"""
Retrieve the GitLab Token to construct the headers
"""
if self.user_id and not self.token:
self.token = await self.get_latest_token()
return {
'Authorization': f'Bearer {self.token.get_secret_value()}',
}
def _has_token_expired(self, status_code: int) -> bool:
return status_code == 401
async def get_latest_token(self) -> SecretStr:
return self.token
async def _fetch_data(
self, url: str, params: dict | None = None
) -> tuple[Any, dict]:
try:
async with httpx.AsyncClient() as client:
gitlab_headers = await self._get_gitlab_headers()
response = await client.get(url, headers=gitlab_headers, params=params)
if self.refresh and self._has_token_expired(response.status_code):
await self.get_latest_token()
gitlab_headers = await self._get_gitlab_headers()
response = await client.get(
url, headers=gitlab_headers, params=params
)
response.raise_for_status()
headers = {}
if 'Link' in response.headers:
headers['Link'] = response.headers['Link']
return response.json(), headers
except httpx.HTTPStatusError as e:
if e.response.status_code == 401:
raise AuthenticationError('Invalid GitLab token')
raise UnknownException('Unknown error')
except httpx.HTTPError:
raise UnknownException('Unknown error')
async def get_user(self) -> User:
url = f'{self.BASE_URL}/user'
response, _ = await self._fetch_data(url)
return User(
id=response.get('id'),
username=response.get('username'),
avatar_url=response.get('avatar_url'),
name=response.get('name'),
email=response.get('email'),
company=response.get('organization'),
login=response.get('username'),
)
async def search_repositories(
self, query: str, per_page: int = 30, sort: str = 'updated', order: str = 'desc'
):
url = f'{self.BASE_URL}/search'
params = {
'scope': 'projects',
'search': query,
'per_page': per_page,
'order_by': sort,
'sort': order,
}
response, headers = await self._fetch_data(url, params)
return response, headers
async def get_repositories(
self, page: int, per_page: int, sort: str, installation_id: int | None
) -> list[Repository]:
return []
gitlab_service_cls = os.environ.get(
'OPENHANDS_GITLAB_SERVICE_CLS',
'openhands.integrations.gitlab.gitlab_service.GitLabService',
)
GitLabServiceImpl = get_impl(GitLabService, gitlab_service_cls)

View File

@ -0,0 +1,143 @@
from enum import Enum
from pydantic import BaseModel, SecretStr, SerializationInfo, field_serializer
from pydantic.json import pydantic_encoder
from openhands.integrations.github.github_service import GithubServiceImpl
from openhands.integrations.gitlab.gitlab_service import GitLabServiceImpl
from openhands.integrations.service_types import (
AuthenticationError,
GitService,
Repository,
User,
)
class ProviderType(Enum):
GITHUB = 'github'
GITLAB = 'gitlab'
class ProviderToken(BaseModel):
token: SecretStr | None
user_id: str | None
PROVIDER_TOKEN_TYPE = dict[ProviderType, ProviderToken]
CUSTOM_SECRETS_TYPE = dict[str, SecretStr]
class SecretStore(BaseModel):
provider_tokens: PROVIDER_TOKEN_TYPE = {}
@classmethod
def _convert_token(
cls, token_value: str | ProviderToken | SecretStr
) -> ProviderToken:
if isinstance(token_value, ProviderToken):
return token_value
elif isinstance(token_value, str):
return ProviderToken(token=SecretStr(token_value), user_id=None)
elif isinstance(token_value, SecretStr):
return ProviderToken(token=token_value, user_id=None)
else:
raise ValueError(f'Invalid token type: {type(token_value)}')
def model_post_init(self, __context) -> None:
# Convert any string tokens to ProviderToken objects
converted_tokens = {}
for token_type, token_value in self.provider_tokens.items():
if token_value: # Only convert non-empty tokens
try:
if isinstance(token_type, str):
token_type = ProviderType(token_type)
converted_tokens[token_type] = self._convert_token(token_value)
except ValueError:
# Skip invalid provider types or tokens
continue
self.provider_tokens = converted_tokens
@field_serializer('provider_tokens')
def provider_tokens_serializer(
self, provider_tokens: PROVIDER_TOKEN_TYPE, info: SerializationInfo
):
tokens = {}
expose_secrets = info.context and info.context.get('expose_secrets', False)
for token_type, provider_token in provider_tokens.items():
if not provider_token or not provider_token.token:
continue
token_type_str = (
token_type.value
if isinstance(token_type, ProviderType)
else str(token_type)
)
tokens[token_type_str] = {
'token': provider_token.token.get_secret_value()
if expose_secrets
else pydantic_encoder(provider_token.token),
'user_id': provider_token.user_id,
}
return tokens
class ProviderHandler:
def __init__(
self,
provider_tokens: PROVIDER_TOKEN_TYPE,
external_auth_token: SecretStr | None = None,
):
self.service_class_map: dict[ProviderType, type[GitService]] = {
ProviderType.GITHUB: GithubServiceImpl,
ProviderType.GITLAB: GitLabServiceImpl,
}
self.provider_tokens = provider_tokens
self.external_auth_token = external_auth_token
def _get_service(self, provider: ProviderType) -> GitService:
"""Helper method to instantiate a service for a given provider"""
token = self.provider_tokens[provider]
service_class = self.service_class_map[provider]
return service_class(
user_id=token.user_id,
external_auth_token=self.external_auth_token,
token=token.token,
)
async def get_user(self) -> User:
"""Get user information from the first available provider"""
for provider in self.provider_tokens:
try:
service = self._get_service(provider)
return await service.get_user()
except Exception:
continue
raise AuthenticationError('Need valid provider token')
async def get_latest_provider_tokens(self) -> dict[ProviderType, SecretStr]:
"""Get latest token from services"""
tokens = {}
for provider in self.provider_tokens:
service = self._get_service(provider)
tokens[provider] = await service.get_latest_token()
return tokens
async def get_repositories(
self, page: int, per_page: int, sort: str, installation_id: int | None
) -> list[Repository]:
"""Get repositories from all available providers"""
all_repos = []
for provider in self.provider_tokens:
try:
service = self._get_service(provider)
repos = await service.get_repositories(
page, per_page, sort, installation_id
)
all_repos.extend(repos)
except Exception:
continue
return all_repos

View File

@ -0,0 +1,89 @@
from enum import Enum
from typing import Protocol
from pydantic import BaseModel, SecretStr
class TaskType(str, Enum):
MERGE_CONFLICTS = 'MERGE_CONFLICTS'
FAILING_CHECKS = 'FAILING_CHECKS'
UNRESOLVED_COMMENTS = 'UNRESOLVED_COMMENTS'
OPEN_ISSUE = 'OPEN_ISSUE'
OPEN_PR = 'OPEN_PR'
class SuggestedTask(BaseModel):
task_type: TaskType
repo: str
issue_number: int
title: str
class User(BaseModel):
id: int
login: str
avatar_url: str
company: str | None = None
name: str | None = None
email: str | None = None
class Repository(BaseModel):
id: int
full_name: str
stargazers_count: int | None = None
link_header: str | None = None
class AuthenticationError(ValueError):
"""Raised when there is an issue with GitHub authentication."""
pass
class UnknownException(ValueError):
"""Raised when there is an issue with GitHub communcation."""
pass
class GitService(Protocol):
"""Protocol defining the interface for Git service providers"""
def __init__(
self,
user_id: str | None,
token: SecretStr | None,
external_auth_token: SecretStr | None,
external_token_manager: bool = False,
) -> None:
"""Initialize the service with authentication details"""
...
async def get_latest_token(self) -> SecretStr:
"""Get latest working token of the users"""
...
async def get_user(self) -> User:
"""Get the authenticated user's information"""
...
async def search_repositories(
self,
query: str,
per_page: int,
sort: str,
order: str,
) -> list[Repository]:
"""Search for repositories"""
...
async def get_repositories(
self,
page: int,
per_page: int,
sort: str,
installation_id: int | None,
) -> list[Repository]:
"""Get repositories for the authenticated user"""
...

View File

@ -0,0 +1,37 @@
from pydantic import SecretStr
from openhands.integrations.github.github_service import GitHubService
from openhands.integrations.gitlab.gitlab_service import GitLabService
from openhands.integrations.provider import ProviderType
async def validate_provider_token(token: SecretStr) -> ProviderType | None:
"""
Determine whether a token is for GitHub or GitLab by attempting to get user info
from both services.
Args:
token: The token to check
Returns:
'github' if it's a GitHub token
'gitlab' if it's a GitLab token
None if the token is invalid for both services
"""
# Try GitHub first
try:
github_service = GitHubService(token=token)
await github_service.get_user()
return ProviderType.GITHUB
except Exception:
pass
# Try GitLab next
try:
gitlab_service = GitLabService(token=token)
await gitlab_service.get_user()
return ProviderType.GITLAB
except Exception:
pass
return None

View File

@ -1,6 +1,13 @@
from fastapi import Request
from pydantic import SecretStr
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderType
def get_provider_tokens(request: Request) -> PROVIDER_TOKEN_TYPE | None:
"""Get GitHub token from request state. For backward compatibility."""
return getattr(request.state, 'provider_tokens', None)
def get_access_token(request: Request) -> SecretStr | None:
return getattr(request.state, 'access_token', None)
@ -11,8 +18,18 @@ def get_user_id(request: Request) -> str | None:
def get_github_token(request: Request) -> SecretStr | None:
return getattr(request.state, 'github_token', None)
provider_tokens = get_provider_tokens(request)
if provider_tokens and ProviderType.GITHUB in provider_tokens:
return provider_tokens[ProviderType.GITHUB].token
return None
def get_github_user_id(request: Request) -> str | None:
return getattr(request.state, 'github_user_id', None)
provider_tokens = get_provider_tokens(request)
if provider_tokens and ProviderType.GITHUB in provider_tokens:
return provider_tokens[ProviderType.GITHUB].user_id
return None

View File

@ -194,10 +194,14 @@ class GitHubTokenMiddleware(SessionMiddlewareInterface):
settings = await settings_store.load()
# TODO: To avoid checks like this we should re-add the abilty to have completely different middleware in SAAS as in OSS
if getattr(request.state, 'github_token', None) is None:
if settings and settings.github_token:
request.state.github_token = settings.github_token
if getattr(request.state, 'provider_tokens', None) is None:
if (
settings
and settings.secrets_store
and settings.secrets_store.provider_tokens
):
request.state.provider_tokens = settings.secrets_store.provider_tokens
else:
request.state.github_token = None
request.state.provider_tokens = None
return await call_next(request)

View File

@ -3,147 +3,168 @@ from fastapi.responses import JSONResponse
from pydantic import SecretStr
from openhands.integrations.github.github_service import GithubServiceImpl
from openhands.integrations.github.github_types import (
GhAuthenticationError,
GHUnknownException,
GitHubRepository,
GitHubUser,
SuggestedTask,
from openhands.integrations.provider import (
PROVIDER_TOKEN_TYPE,
ProviderHandler,
ProviderType,
)
from openhands.server.auth import get_access_token, get_github_token, get_github_user_id
from openhands.integrations.service_types import (
AuthenticationError,
Repository,
SuggestedTask,
UnknownException,
User,
)
from openhands.server.auth import get_access_token, get_provider_tokens
app = APIRouter(prefix='/api/github')
@app.get('/repositories', response_model=list[GitHubRepository])
@app.get('/repositories', response_model=list[Repository])
async def get_github_repositories(
page: int = 1,
per_page: int = 10,
sort: str = 'pushed',
installation_id: int | None = None,
github_user_id: str | None = Depends(get_github_user_id),
github_user_token: SecretStr | None = Depends(get_github_token),
provider_tokens: PROVIDER_TOKEN_TYPE | None = Depends(get_provider_tokens),
access_token: SecretStr | None = Depends(get_access_token),
):
client = GithubServiceImpl(
user_id=github_user_id,
external_auth_token=access_token,
github_token=github_user_token,
if provider_tokens and ProviderType.GITHUB in provider_tokens:
token = provider_tokens[ProviderType.GITHUB]
client = GithubServiceImpl(
user_id=token.user_id, external_auth_token=access_token, token=token.token
)
try:
repos: list[Repository] = await client.get_repositories(
page, per_page, sort, installation_id
)
return repos
except AuthenticationError as e:
return JSONResponse(
content=str(e),
status_code=status.HTTP_401_UNAUTHORIZED,
)
except UnknownException as e:
return JSONResponse(
content=str(e),
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
return JSONResponse(
content='GitHub token required.',
status_code=status.HTTP_401_UNAUTHORIZED,
)
try:
repos: list[GitHubRepository] = await client.get_repositories(
page, per_page, sort, installation_id
)
return repos
except GhAuthenticationError as e:
return JSONResponse(
content=str(e),
status_code=status.HTTP_401_UNAUTHORIZED,
)
except GHUnknownException as e:
return JSONResponse(
content=str(e),
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
@app.get('/user', response_model=GitHubUser)
@app.get('/user', response_model=User)
async def get_github_user(
github_user_id: str | None = Depends(get_github_user_id),
github_user_token: SecretStr | None = Depends(get_github_token),
provider_tokens: PROVIDER_TOKEN_TYPE | None = Depends(get_provider_tokens),
access_token: SecretStr | None = Depends(get_access_token),
):
client = GithubServiceImpl(
user_id=github_user_id,
external_auth_token=access_token,
github_token=github_user_token,
if provider_tokens:
client = ProviderHandler(provider_tokens=provider_tokens, external_auth_token=access_token)
try:
user: User = await client.get_user()
return user
except AuthenticationError as e:
return JSONResponse(
content=str(e),
status_code=status.HTTP_401_UNAUTHORIZED,
)
except UnknownException as e:
return JSONResponse(
content=str(e),
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
return JSONResponse(
content='GitHub token required.',
status_code=status.HTTP_401_UNAUTHORIZED,
)
try:
user: GitHubUser = await client.get_user()
return user
except GhAuthenticationError as e:
return JSONResponse(
content=str(e),
status_code=status.HTTP_401_UNAUTHORIZED,
)
except GHUnknownException as e:
return JSONResponse(
content=str(e),
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
@app.get('/installations', response_model=list[int])
async def get_github_installation_ids(
github_user_id: str | None = Depends(get_github_user_id),
github_user_token: SecretStr | None = Depends(get_github_token),
provider_tokens: PROVIDER_TOKEN_TYPE | None = Depends(get_provider_tokens),
access_token: SecretStr | None = Depends(get_access_token),
):
client = GithubServiceImpl(
user_id=github_user_id,
external_auth_token=access_token,
github_token=github_user_token,
if provider_tokens and ProviderType.GITHUB in provider_tokens:
token = provider_tokens[ProviderType.GITHUB]
client = GithubServiceImpl(
user_id=token.user_id, external_auth_token=access_token, token=token.token
)
try:
installations_ids: list[int] = await client.get_installation_ids()
return installations_ids
except AuthenticationError as e:
return JSONResponse(
content=str(e),
status_code=status.HTTP_401_UNAUTHORIZED,
)
except UnknownException as e:
return JSONResponse(
content=str(e),
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
return JSONResponse(
content='GitHub token required.',
status_code=status.HTTP_401_UNAUTHORIZED,
)
try:
installations_ids: list[int] = await client.get_installation_ids()
return installations_ids
except GhAuthenticationError as e:
return JSONResponse(
content=str(e),
status_code=status.HTTP_401_UNAUTHORIZED,
)
except GHUnknownException as e:
return JSONResponse(
content=str(e),
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
@app.get('/search/repositories', response_model=list[GitHubRepository])
@app.get('/search/repositories', response_model=list[Repository])
async def search_github_repositories(
query: str,
per_page: int = 5,
sort: str = 'stars',
order: str = 'desc',
github_user_id: str | None = Depends(get_github_user_id),
github_user_token: SecretStr | None = Depends(get_github_token),
provider_tokens: PROVIDER_TOKEN_TYPE | None = Depends(get_provider_tokens),
access_token: SecretStr | None = Depends(get_access_token),
):
client = GithubServiceImpl(
user_id=github_user_id,
external_auth_token=access_token,
github_token=github_user_token,
if provider_tokens and ProviderType.GITHUB in provider_tokens:
token = provider_tokens[ProviderType.GITHUB]
client = GithubServiceImpl(
user_id=token.user_id, external_auth_token=access_token, token=token.token
)
try:
repos: list[Repository] = await client.search_repositories(
query, per_page, sort, order
)
return repos
except AuthenticationError as e:
return JSONResponse(
content=str(e),
status_code=status.HTTP_401_UNAUTHORIZED,
)
except UnknownException as e:
return JSONResponse(
content=str(e),
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
return JSONResponse(
content='GitHub token required.',
status_code=status.HTTP_401_UNAUTHORIZED,
)
try:
repos: list[GitHubRepository] = await client.search_repositories(
query, per_page, sort, order
)
return repos
except GhAuthenticationError as e:
return JSONResponse(
content=str(e),
status_code=status.HTTP_401_UNAUTHORIZED,
)
except GHUnknownException as e:
return JSONResponse(
content=str(e),
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
@app.get('/suggested-tasks', response_model=list[SuggestedTask])
async def get_suggested_tasks(
github_user_id: str | None = Depends(get_github_user_id),
github_user_token: SecretStr | None = Depends(get_github_token),
access_token: SecretStr | None = Depends(get_access_token),
provider_tokens: PROVIDER_TOKEN_TYPE | None = Depends(get_provider_tokens),
access_token: SecretStr | None = Depends(get_access_token)
):
"""Get suggested tasks for the authenticated user across their most recently pushed repositories.
@ -151,23 +172,30 @@ async def get_suggested_tasks(
- PRs owned by the user
- Issues assigned to the user.
"""
client = GithubServiceImpl(
user_id=github_user_id,
external_auth_token=access_token,
github_token=github_user_token,
if provider_tokens and ProviderType.GITHUB in provider_tokens:
token = provider_tokens[ProviderType.GITHUB]
client = GithubServiceImpl(
user_id=token.user_id, external_auth_token=access_token, token=token.token
)
try:
tasks: list[SuggestedTask] = await client.get_suggested_tasks()
return tasks
except AuthenticationError as e:
return JSONResponse(
content=str(e),
status_code=401,
)
except UnknownException as e:
return JSONResponse(
content=str(e),
status_code=500,
)
return JSONResponse(
content='GitHub token required.',
status_code=status.HTTP_401_UNAUTHORIZED,
)
try:
tasks: list[SuggestedTask] = await client.get_suggested_tasks()
return tasks
except GhAuthenticationError as e:
return JSONResponse(
content=str(e),
status_code=401,
)
except GHUnknownException as e:
return JSONResponse(
content=str(e),
status_code=500,
)

View File

@ -8,8 +8,9 @@ from pydantic import BaseModel, SecretStr
from openhands.core.logger import openhands_logger as logger
from openhands.events.action.message import MessageAction
from openhands.integrations.github.github_service import GithubServiceImpl
from openhands.integrations.provider import ProviderType
from openhands.runtime import get_runtime_cls
from openhands.server.auth import get_access_token, get_github_token, get_github_user_id
from openhands.server.auth import get_provider_tokens, get_access_token, get_github_user_id
from openhands.server.data_models.conversation_info import ConversationInfo
from openhands.server.data_models.conversation_info_result_set import (
ConversationInfoResultSet,
@ -136,13 +137,18 @@ async def new_conversation(request: Request, data: InitSessionRequest):
using the returned conversation ID.
"""
logger.info('Initializing new conversation')
user_id = get_github_user_id(request)
gh_client = GithubServiceImpl(
user_id=user_id,
external_auth_token=get_access_token(request),
github_token=get_github_token(request),
)
github_token = await gh_client.get_latest_token()
user_id = None
github_token = None
provider_tokens = get_provider_tokens(request)
if provider_tokens and ProviderType.GITHUB in provider_tokens:
token = provider_tokens[ProviderType.GITHUB]
user_id = token.user_id
gh_client = GithubServiceImpl(
user_id=user_id,
external_auth_token=get_access_token(request),
token=token.token,
)
github_token = await gh_client.get_latest_token()
selected_repository = data.selected_repository
selected_branch = data.selected_branch

View File

@ -3,8 +3,9 @@ from fastapi.responses import JSONResponse
from pydantic import SecretStr
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.github.github_service import GithubServiceImpl
from openhands.server.auth import get_github_token, get_user_id
from openhands.integrations.provider import ProviderToken, ProviderType
from openhands.integrations.utils import validate_provider_token
from openhands.server.auth import get_provider_tokens, get_user_id
from openhands.server.settings import GETSettingsModel, POSTSettingsModel, Settings
from openhands.server.shared import SettingsStoreImpl, config
@ -23,14 +24,14 @@ async def load_settings(request: Request) -> GETSettingsModel | JSONResponse:
content={'error': 'Settings not found'},
)
token_is_set = bool(user_id) or bool(get_github_token(request))
github_token_is_set = bool(user_id) or bool(get_provider_tokens(request))
settings_with_token_data = GETSettingsModel(
**settings.model_dump(),
github_token_is_set=token_is_set,
github_token_is_set=github_token_is_set,
)
settings_with_token_data.llm_api_key = settings.llm_api_key
del settings_with_token_data.github_token
del settings_with_token_data.secrets_store
return settings_with_token_data
except Exception as e:
logger.warning(f'Invalid token: {e}')
@ -45,26 +46,27 @@ async def store_settings(
request: Request,
settings: POSTSettingsModel,
) -> JSONResponse:
# Check if token is valid
if settings.github_token:
try:
# We check if the token is valid by getting the user
# If the token is invalid, this will raise an exception
github = GithubServiceImpl(
user_id=None,
external_auth_token=None,
github_token=SecretStr(settings.github_token),
)
await github.get_user()
# Check provider tokens are valid
if settings.provider_tokens:
# Remove extraneous token types
provider_types = [provider.value for provider in ProviderType]
settings.provider_tokens = {
k: v for k, v in settings.provider_tokens.items() if k in provider_types
}
except Exception as e:
logger.warning(f'Invalid GitHub token: {e}')
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={
'error': 'Invalid GitHub token. Please make sure it is valid.'
},
)
# Determine whether tokens are valid
for token_type, token_value in settings.provider_tokens.items():
if token_value:
confirmed_token_type = await validate_provider_token(
SecretStr(token_value)
)
if not confirmed_token_type or confirmed_token_type.value != token_type:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={
'error': f'Invalid token. Please make sure it is a valid {token_type} token.'
},
)
try:
settings_store = await SettingsStoreImpl.get_instance(
@ -72,32 +74,46 @@ async def store_settings(
)
existing_settings = await settings_store.load()
# Convert to Settings model and merge with existing settings
if existing_settings:
# LLM key isn't on the frontend, so we need to keep it if unset
# Keep existing LLM settings if not provided
if settings.llm_api_key is None:
settings.llm_api_key = existing_settings.llm_api_key
if settings.llm_model is None:
settings.llm_model = existing_settings.llm_model
if settings.llm_base_url is None:
settings.llm_base_url = existing_settings.llm_base_url
if settings.github_token is None:
settings.github_token = existing_settings.github_token
# Keep existing analytics consent if not provided
if settings.user_consents_to_analytics is None:
settings.user_consents_to_analytics = (
existing_settings.user_consents_to_analytics
)
if settings.llm_model is None:
settings.llm_model = existing_settings.llm_model
if existing_settings.secrets_store:
existing_providers = [
provider.value
for provider in existing_settings.secrets_store.provider_tokens
]
if settings.llm_base_url is None:
settings.llm_base_url = existing_settings.llm_base_url
# Merge incoming settings store with the existing one
for provider, token_value in settings.provider_tokens.items():
if provider in existing_providers and not token_value:
provider_type = ProviderType(provider)
existing_token = (
existing_settings.secrets_store.provider_tokens.get(
provider_type
)
)
if existing_token and existing_token.token:
settings.provider_tokens[provider] = (
existing_token.token.get_secret_value()
)
response = JSONResponse(
status_code=status.HTTP_200_OK,
content={'message': 'Settings stored'},
)
if settings.unset_github_token:
settings.github_token = None
# Merge provider tokens with existing ones
if settings.unset_github_token: # Only merge if not unsetting tokens
settings.secrets_store.provider_tokens = {}
settings.provider_tokens = {}
# Update sandbox config with new settings
if settings.remote_runtime_resource_factor is not None:
@ -106,9 +122,11 @@ async def store_settings(
)
settings = convert_to_settings(settings)
await settings_store.store(settings)
return response
return JSONResponse(
status_code=status.HTTP_200_OK,
content={'message': 'Settings stored'},
)
except Exception as e:
logger.warning(f'Something went wrong storing settings: {e}')
return JSONResponse(
@ -127,8 +145,19 @@ def convert_to_settings(settings_with_token_data: POSTSettingsModel) -> Settings
if key in Settings.model_fields # Ensures only `Settings` fields are included
}
# Convert the `llm_api_key` and `github_token` to a `SecretStr` instance
# Convert the `llm_api_key` to a `SecretStr` instance
filtered_settings_data['llm_api_key'] = settings_with_token_data.llm_api_key
filtered_settings_data['github_token'] = settings_with_token_data.github_token
return Settings(**filtered_settings_data)
# Create a new Settings instance without provider tokens
settings = Settings(**filtered_settings_data)
# Update provider tokens if any are provided
if settings_with_token_data.provider_tokens:
for token_type, token_value in settings_with_token_data.provider_tokens.items():
if token_value:
provider = ProviderType(token_type)
settings.secrets_store.provider_tokens[provider] = ProviderToken(
token=SecretStr(token_value), user_id=None
)
return settings

View File

@ -1,10 +1,17 @@
from __future__ import annotations
from pydantic import BaseModel, SecretStr, SerializationInfo, field_serializer
from pydantic import (
BaseModel,
SecretStr,
SerializationInfo,
field_serializer,
model_validator,
)
from pydantic.json import pydantic_encoder
from openhands.core.config.llm_config import LLMConfig
from openhands.core.config.utils import load_app_config
from openhands.integrations.provider import ProviderToken, ProviderType, SecretStore
class Settings(BaseModel):
@ -21,7 +28,7 @@ class Settings(BaseModel):
llm_api_key: SecretStr | None = None
llm_base_url: str | None = None
remote_runtime_resource_factor: int | None = None
github_token: SecretStr | None = None
secrets_store: SecretStore = SecretStore()
enable_default_condenser: bool = False
enable_sound_notifications: bool = False
user_consents_to_analytics: bool | None = None
@ -38,22 +45,63 @@ class Settings(BaseModel):
return pydantic_encoder(llm_api_key)
@field_serializer('github_token')
def github_token_serializer(
self, github_token: SecretStr | None, info: SerializationInfo
):
"""Custom serializer for the GitHub token.
@staticmethod
def _convert_token_value(
token_type: ProviderType, token_value: str | dict
) -> ProviderToken | None:
"""Convert a token value to a ProviderToken object."""
if isinstance(token_value, dict):
token_str = token_value.get('token')
if not token_str:
return None
return ProviderToken(
token=SecretStr(token_str),
user_id=token_value.get('user_id'),
)
if isinstance(token_value, str) and token_value:
return ProviderToken(token=SecretStr(token_value), user_id=None)
return None
To serialize the token instead of ********, set expose_secrets to True in the serialization context.
"""
if github_token is None:
return None
@model_validator(mode='before')
@classmethod
def convert_provider_tokens(cls, data: dict | object) -> dict | object:
"""Convert provider tokens from JSON format to SecretStore format."""
if not isinstance(data, dict):
return data
context = info.context
if context and context.get('expose_secrets', False):
return github_token.get_secret_value()
secrets_store = data.get('secrets_store')
if not isinstance(secrets_store, dict):
return data
return pydantic_encoder(github_token)
tokens = secrets_store.get('provider_tokens')
if not isinstance(tokens, dict):
return data
converted_tokens = {}
for token_type_str, token_value in tokens.items():
if not token_value:
continue
try:
token_type = ProviderType(token_type_str)
except ValueError:
continue
provider_token = cls._convert_token_value(token_type, token_value)
if provider_token:
converted_tokens[token_type] = provider_token
data['secrets_store'] = SecretStore(provider_tokens=converted_tokens)
return data
@field_serializer('secrets_store')
def secrets_store_serializer(self, secrets: SecretStore, info: SerializationInfo):
"""Custom serializer for secrets store."""
return {
'provider_tokens': secrets.provider_tokens_serializer(
secrets.provider_tokens, info
)
}
@staticmethod
def from_config() -> Settings | None:
@ -73,7 +121,7 @@ class Settings(BaseModel):
llm_api_key=llm_config.api_key,
llm_base_url=llm_config.base_url,
remote_runtime_resource_factor=app_config.sandbox.remote_runtime_resource_factor,
github_token=None,
provider_tokens={},
)
return settings
@ -84,14 +132,12 @@ class POSTSettingsModel(Settings):
"""
unset_github_token: bool | None = None
github_token: str | None = (
None # This is a string because it's coming from the frontend
)
# Override provider_tokens to accept string tokens from frontend
provider_tokens: dict[str, str] = {}
# Override the serializer for the GitHub token to handle the string input
@field_serializer('github_token')
def github_token_serializer(self, github_token: str | None):
return github_token
@field_serializer('provider_tokens')
def provider_tokens_serializer(self, provider_tokens: dict[str, str]):
return provider_tokens
class GETSettingsModel(Settings):

View File

@ -5,16 +5,16 @@ import pytest
from pydantic import SecretStr
from openhands.integrations.github.github_service import GitHubService
from openhands.integrations.github.github_types import GhAuthenticationError
from openhands.integrations.service_types import AuthenticationError
@pytest.mark.asyncio
async def test_github_service_token_handling():
# Test initialization with SecretStr token
token = SecretStr('test-token')
service = GitHubService(user_id=None, github_token=token)
assert service.github_token == token
assert service.github_token.get_secret_value() == 'test-token'
service = GitHubService(user_id=None, token=token)
assert service.token == token
assert service.token.get_secret_value() == 'test-token'
# Test headers contain the token correctly
headers = await service._get_github_headers()
@ -23,14 +23,14 @@ async def test_github_service_token_handling():
# Test initialization without token
service = GitHubService(user_id='test-user')
assert service.github_token == SecretStr('')
assert service.token == SecretStr('')
@pytest.mark.asyncio
async def test_github_service_token_refresh():
# Test that token refresh is only attempted when refresh=True
token = SecretStr('test-token')
service = GitHubService(user_id=None, github_token=token)
service = GitHubService(user_id=None, token=token)
assert not service.refresh
# Test token expiry detection
@ -58,7 +58,7 @@ async def test_github_service_fetch_data():
mock_client.__aexit__.return_value = None
with patch('httpx.AsyncClient', return_value=mock_client):
service = GitHubService(user_id=None, github_token=SecretStr('test-token'))
service = GitHubService(user_id=None, token=SecretStr('test-token'))
_ = await service._fetch_data('https://api.github.com/user')
# Verify the request was made with correct headers
@ -77,5 +77,5 @@ async def test_github_service_fetch_data():
mock_client.get.reset_mock()
mock_client.get.return_value = mock_response
with pytest.raises(GhAuthenticationError):
with pytest.raises(AuthenticationError):
_ = await service._fetch_data('https://api.github.com/user')

View File

@ -6,6 +6,7 @@ from openhands.core.config.app_config import AppConfig
from openhands.core.config.llm_config import LLMConfig
from openhands.core.config.sandbox_config import SandboxConfig
from openhands.core.config.security_config import SecurityConfig
from openhands.integrations.provider import ProviderToken, ProviderType
from openhands.server.routes.settings import convert_to_settings
from openhands.server.settings import POSTSettingsModel, Settings
@ -43,7 +44,7 @@ def test_settings_from_config():
assert settings.llm_api_key.get_secret_value() == 'test-key'
assert settings.llm_base_url == 'https://test.example.com'
assert settings.remote_runtime_resource_factor == 2
assert settings.github_token is None
assert not settings.secrets_store.provider_tokens
def test_settings_from_config_no_api_key():
@ -80,23 +81,41 @@ def test_settings_handles_sensitive_data():
llm_api_key='test-key',
llm_base_url='https://test.example.com',
remote_runtime_resource_factor=2,
github_token='test-token',
)
settings.secrets_store.provider_tokens[ProviderType.GITHUB] = ProviderToken(
token=SecretStr('test-token'),
user_id=None,
)
assert str(settings.llm_api_key) == '**********'
assert str(settings.github_token) == '**********'
assert (
str(settings.secrets_store.provider_tokens[ProviderType.GITHUB].token)
== '**********'
)
assert settings.llm_api_key.get_secret_value() == 'test-key'
assert settings.github_token.get_secret_value() == 'test-token'
assert (
settings.secrets_store.provider_tokens[
ProviderType.GITHUB
].token.get_secret_value()
== 'test-token'
)
def test_convert_to_settings():
settings_with_token_data = POSTSettingsModel(
llm_api_key='test-key',
github_token='test-token',
provider_tokens={
'github': 'test-token',
},
)
settings = convert_to_settings(settings_with_token_data)
assert settings.llm_api_key.get_secret_value() == 'test-key'
assert settings.github_token.get_secret_value() == 'test-token'
assert (
settings.secrets_store.provider_tokens[
ProviderType.GITHUB
].token.get_secret_value()
== 'test-token'
)

View File

@ -5,6 +5,7 @@ from fastapi.testclient import TestClient
from pydantic import SecretStr
from openhands.core.config.sandbox_config import SandboxConfig
from openhands.integrations.provider import ProviderType, SecretStore
from openhands.server.app import app
from openhands.server.settings import Settings
@ -19,6 +20,24 @@ def mock_settings_store():
yield store_instance
@pytest.fixture
def mock_get_user_id():
with patch('openhands.server.routes.settings.get_user_id') as mock:
mock.return_value = 'test-user'
yield mock
@pytest.fixture
def mock_validate_provider_token():
with patch('openhands.server.routes.settings.validate_provider_token') as mock:
async def mock_determine(*args, **kwargs):
return ProviderType.GITHUB
mock.side_effect = mock_determine
yield mock
@pytest.fixture
def test_client(mock_settings_store):
# Mock the middleware that adds github_token
@ -28,9 +47,15 @@ def test_client(mock_settings_store):
async def __call__(self, scope, receive, send):
settings = mock_settings_store.load.return_value
token = settings.github_token if settings else None
token = None
if settings and settings.secrets_store.provider_tokens.get(
ProviderType.GITHUB
):
token = settings.secrets_store.provider_tokens[
ProviderType.GITHUB
].token
if scope['type'] == 'http':
scope['state'] = {'github_token': token}
scope['state'] = {'token': token}
await self.app(scope, receive, send)
# Replace the middleware
@ -47,7 +72,9 @@ def mock_github_service():
@pytest.mark.asyncio
async def test_settings_api_runtime_factor(test_client, mock_settings_store):
async def test_settings_api_runtime_factor(
test_client, mock_settings_store, mock_get_user_id, mock_validate_provider_token
):
# Mock the settings store to return None initially (no existing settings)
mock_settings_store.load.return_value = None
@ -62,6 +89,7 @@ async def test_settings_api_runtime_factor(test_client, mock_settings_store):
'llm_api_key': 'test-key',
'llm_base_url': 'https://test.com',
'remote_runtime_resource_factor': 2,
'provider_tokens': {'github': 'test-token'},
}
# The test_client fixture already handles authentication
@ -98,12 +126,17 @@ async def test_settings_api_runtime_factor(test_client, mock_settings_store):
@pytest.mark.asyncio
async def test_settings_llm_api_key(test_client, mock_settings_store):
async def test_settings_llm_api_key(
test_client, mock_settings_store, mock_get_user_id, mock_validate_provider_token
):
# Mock the settings store to return None initially (no existing settings)
mock_settings_store.load.return_value = None
# Test data with remote_runtime_resource_factor
settings_data = {'llm_api_key': 'test-key'}
settings_data = {
'llm_api_key': 'test-key',
'provider_tokens': {'github': 'test-token'},
}
# The test_client fixture already handles authentication
@ -132,9 +165,13 @@ async def test_settings_llm_api_key(test_client, mock_settings_store):
)
@pytest.mark.asyncio
async def test_settings_api_set_github_token(
mock_github_service, test_client, mock_settings_store
mock_github_service,
test_client,
mock_settings_store,
mock_get_user_id,
mock_validate_provider_token,
):
# Test data with github_token set
# Test data with provider token set
settings_data = {
'language': 'en',
'agent': 'test-agent',
@ -144,16 +181,21 @@ async def test_settings_api_set_github_token(
'llm_model': 'test-model',
'llm_api_key': 'test-key',
'llm_base_url': 'https://test.com',
'github_token': 'test-token',
'provider_tokens': {'github': 'test-token'},
}
# Make the POST request to store settings
response = test_client.post('/api/settings', json=settings_data)
assert response.status_code == 200
# Verify the settings were stored with the github_token
# Verify the settings were stored with the provider token
stored_settings = mock_settings_store.store.call_args[0][0]
assert stored_settings.github_token == 'test-token'
assert (
stored_settings.secrets_store.provider_tokens[
ProviderType.GITHUB
].token.get_secret_value()
== 'test-token'
)
# Mock settings store to return our settings for the GET request
mock_settings_store.load.return_value = Settings(**settings_data)
@ -163,17 +205,21 @@ async def test_settings_api_set_github_token(
data = response.json()
assert response.status_code == 200
assert data.get('github_token') is None
assert data['github_token_is_set'] is True
assert data.get('token') is None
assert data['token_is_set'] is True
@pytest.mark.skip(
reason='Mock middleware does not seem to properly set the github_token'
)
async def test_settings_unset_github_token(
mock_github_service, test_client, mock_settings_store
mock_github_service,
test_client,
mock_settings_store,
mock_get_user_id,
mock_validate_provider_token,
):
# Test data with unset_github_token set to True
# Test data with unset_token set to True
settings_data = {
'language': 'en',
'agent': 'test-agent',
@ -183,7 +229,7 @@ async def test_settings_unset_github_token(
'llm_model': 'test-model',
'llm_api_key': 'test-key',
'llm_base_url': 'https://test.com',
'github_token': 'test-token',
'provider_tokens': {'github': 'test-token'},
}
# Mock settings store to return our settings for the GET request
@ -191,23 +237,23 @@ async def test_settings_unset_github_token(
response = test_client.get('/api/settings')
assert response.status_code == 200
assert response.json()['github_token_is_set'] is True
assert response.json()['token_is_set'] is True
settings_data['unset_github_token'] = True
settings_data['unset_token'] = True
# Make the POST request to store settings
response = test_client.post('/api/settings', json=settings_data)
assert response.status_code == 200
# Verify the settings were stored with the github_token unset
# Verify the settings were stored with the provider token unset
stored_settings = mock_settings_store.store.call_args[0][0]
assert stored_settings.github_token is None
assert not stored_settings.secrets_store.provider_tokens
mock_settings_store.load.return_value = Settings(**stored_settings.dict())
# Make a GET request to retrieve settings
response = test_client.get('/api/settings')
assert response.status_code == 200
assert response.json()['github_token_is_set'] is False
assert response.json()['token_is_set'] is False
@pytest.mark.asyncio
@ -222,6 +268,7 @@ async def test_settings_preserve_llm_fields_when_none(test_client, mock_settings
llm_model='existing-model',
llm_api_key=SecretStr('existing-key'),
llm_base_url='https://existing.com',
secrets_store=SecretStore(),
)
# Mock the settings store to return our initial settings

View File

@ -3,13 +3,13 @@ from unittest.mock import AsyncMock
import pytest
from openhands.integrations.github.github_service import GitHubService
from openhands.integrations.github.github_types import GitHubUser, TaskType
from openhands.integrations.service_types import TaskType, User
@pytest.mark.asyncio
async def test_get_suggested_tasks():
# Mock responses
mock_user = GitHubUser(
mock_user = User(
id=1,
login='test-user',
avatar_url='https://example.com/avatar.jpg',