mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
[Feat]: Support Gitlab PAT (#7064)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
parent
300bfbdf2d
commit
78d185b102
@ -721,7 +721,7 @@ describe("Settings Screen", () => {
|
||||
expect(saveSettingsSpy).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
llm_api_key: "", // empty because it's not set previously
|
||||
github_token: undefined,
|
||||
provider_tokens: undefined,
|
||||
language: "no",
|
||||
}),
|
||||
);
|
||||
@ -758,7 +758,7 @@ describe("Settings Screen", () => {
|
||||
|
||||
expect(saveSettingsSpy).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
github_token: undefined,
|
||||
provider_tokens: undefined,
|
||||
llm_api_key: "", // empty because it's not set previously
|
||||
llm_model: "openai/gpt-4o",
|
||||
}),
|
||||
@ -801,7 +801,7 @@ describe("Settings Screen", () => {
|
||||
|
||||
expect(saveSettingsSpy).toHaveBeenCalledWith({
|
||||
...mockCopy,
|
||||
github_token: undefined, // not set
|
||||
provider_tokens: undefined, // not set
|
||||
llm_api_key: "", // reset as well
|
||||
});
|
||||
expect(screen.queryByTestId("reset-modal")).not.toBeInTheDocument();
|
||||
|
||||
@ -8,7 +8,7 @@
|
||||
* - Please do NOT serve this file on production.
|
||||
*/
|
||||
|
||||
const PACKAGE_VERSION = '2.7.0'
|
||||
const PACKAGE_VERSION = '2.7.3'
|
||||
const INTEGRITY_CHECKSUM = '00729d72e3b82faf54ca8b9621dbb96f'
|
||||
const IS_MOCKED_RESPONSE = Symbol('isMockedResponse')
|
||||
const activeClientIds = new Set()
|
||||
|
||||
@ -17,7 +17,7 @@ const saveSettingsMutationFn = async (settings: Partial<PostSettings>) => {
|
||||
? ""
|
||||
: settings.LLM_API_KEY?.trim() || undefined,
|
||||
remote_runtime_resource_factor: settings.REMOTE_RUNTIME_RESOURCE_FACTOR,
|
||||
github_token: settings.github_token,
|
||||
provider_tokens: settings.provider_tokens,
|
||||
unset_github_token: settings.unset_github_token,
|
||||
enable_default_condenser: settings.ENABLE_DEFAULT_CONDENSER,
|
||||
enable_sound_notifications: settings.ENABLE_SOUND_NOTIFICATIONS,
|
||||
|
||||
@ -21,6 +21,7 @@ const getSettingsQueryFn = async () => {
|
||||
ENABLE_DEFAULT_CONDENSER: apiSettings.enable_default_condenser,
|
||||
ENABLE_SOUND_NOTIFICATIONS: apiSettings.enable_sound_notifications,
|
||||
USER_CONSENTS_TO_ANALYTICS: apiSettings.user_consents_to_analytics,
|
||||
PROVIDER_TOKENS: apiSettings.provider_tokens,
|
||||
IS_NEW_USER: false,
|
||||
};
|
||||
};
|
||||
|
||||
@ -22,6 +22,7 @@ export const MOCK_DEFAULT_USER_SETTINGS: ApiSettings | PostApiSettings = {
|
||||
enable_default_condenser: DEFAULT_SETTINGS.ENABLE_DEFAULT_CONDENSER,
|
||||
enable_sound_notifications: DEFAULT_SETTINGS.ENABLE_SOUND_NOTIFICATIONS,
|
||||
user_consents_to_analytics: DEFAULT_SETTINGS.USER_CONSENTS_TO_ANALYTICS,
|
||||
provider_tokens: DEFAULT_SETTINGS.PROVIDER_TOKENS,
|
||||
};
|
||||
|
||||
const MOCK_USER_PREFERENCES: {
|
||||
@ -190,8 +191,8 @@ export const handlers = [
|
||||
|
||||
if (!settings) return HttpResponse.json(null, { status: 404 });
|
||||
|
||||
// @ts-expect-error - mock types
|
||||
if (settings.github_token) settings.github_token_is_set = true;
|
||||
if (Object.keys(settings.provider_tokens).length > 0)
|
||||
settings.github_token_is_set = true;
|
||||
|
||||
return HttpResponse.json(settings);
|
||||
}),
|
||||
@ -203,7 +204,7 @@ export const handlers = [
|
||||
if (typeof body === "object") {
|
||||
newSettings = { ...body };
|
||||
if (newSettings.unset_github_token) {
|
||||
newSettings.github_token = undefined;
|
||||
newSettings.provider_tokens = { github: "", gitlab: "" };
|
||||
newSettings.github_token_is_set = false;
|
||||
delete newSettings.unset_github_token;
|
||||
}
|
||||
|
||||
@ -61,7 +61,10 @@ function AccountSettings() {
|
||||
if (isSuccess) {
|
||||
return (
|
||||
isCustomModel(resources.models, settings.LLM_MODEL) ||
|
||||
hasAdvancedSettingsSet(settings)
|
||||
hasAdvancedSettingsSet({
|
||||
...settings,
|
||||
PROVIDER_TOKENS: settings.PROVIDER_TOKENS || {},
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
@ -128,37 +131,42 @@ function AccountSettings() {
|
||||
: llmBaseUrl;
|
||||
const finalLlmApiKey = shouldHandleSpecialSaasCase ? undefined : llmApiKey;
|
||||
|
||||
saveSettings(
|
||||
{
|
||||
github_token:
|
||||
formData.get("github-token-input")?.toString() || undefined,
|
||||
LANGUAGE: languageValue,
|
||||
user_consents_to_analytics: userConsentsToAnalytics,
|
||||
ENABLE_DEFAULT_CONDENSER: enableMemoryCondenser,
|
||||
ENABLE_SOUND_NOTIFICATIONS: enableSoundNotifications,
|
||||
LLM_MODEL: finalLlmModel,
|
||||
LLM_BASE_URL: finalLlmBaseUrl,
|
||||
LLM_API_KEY: finalLlmApiKey,
|
||||
AGENT: formData.get("agent-input")?.toString(),
|
||||
SECURITY_ANALYZER:
|
||||
formData.get("security-analyzer-input")?.toString() || "",
|
||||
REMOTE_RUNTIME_RESOURCE_FACTOR:
|
||||
remoteRuntimeResourceFactor ||
|
||||
DEFAULT_SETTINGS.REMOTE_RUNTIME_RESOURCE_FACTOR,
|
||||
CONFIRMATION_MODE: confirmationModeIsEnabled,
|
||||
const githubToken = formData.get("github-token-input")?.toString();
|
||||
const newSettings = {
|
||||
github_token: githubToken,
|
||||
provider_tokens: githubToken
|
||||
? {
|
||||
github: githubToken,
|
||||
gitlab: "",
|
||||
}
|
||||
: undefined,
|
||||
LANGUAGE: languageValue,
|
||||
user_consents_to_analytics: userConsentsToAnalytics,
|
||||
ENABLE_DEFAULT_CONDENSER: enableMemoryCondenser,
|
||||
ENABLE_SOUND_NOTIFICATIONS: enableSoundNotifications,
|
||||
LLM_MODEL: finalLlmModel,
|
||||
LLM_BASE_URL: finalLlmBaseUrl,
|
||||
LLM_API_KEY: finalLlmApiKey,
|
||||
AGENT: formData.get("agent-input")?.toString(),
|
||||
SECURITY_ANALYZER:
|
||||
formData.get("security-analyzer-input")?.toString() || "",
|
||||
REMOTE_RUNTIME_RESOURCE_FACTOR:
|
||||
remoteRuntimeResourceFactor ||
|
||||
DEFAULT_SETTINGS.REMOTE_RUNTIME_RESOURCE_FACTOR,
|
||||
CONFIRMATION_MODE: confirmationModeIsEnabled,
|
||||
};
|
||||
|
||||
saveSettings(newSettings, {
|
||||
onSuccess: () => {
|
||||
handleCaptureConsent(userConsentsToAnalytics);
|
||||
displaySuccessToast("Settings saved");
|
||||
setLlmConfigMode(isAdvancedSettingsSet ? "advanced" : "basic");
|
||||
},
|
||||
{
|
||||
onSuccess: () => {
|
||||
handleCaptureConsent(userConsentsToAnalytics);
|
||||
displaySuccessToast("Settings saved");
|
||||
setLlmConfigMode(isAdvancedSettingsSet ? "advanced" : "basic");
|
||||
},
|
||||
onError: (error) => {
|
||||
const errorMessage = retrieveAxiosErrorMessage(error);
|
||||
displayErrorToast(errorMessage);
|
||||
},
|
||||
onError: (error) => {
|
||||
const errorMessage = retrieveAxiosErrorMessage(error);
|
||||
displayErrorToast(errorMessage);
|
||||
},
|
||||
);
|
||||
});
|
||||
};
|
||||
|
||||
const handleReset = () => {
|
||||
|
||||
@ -15,6 +15,10 @@ export const DEFAULT_SETTINGS: Settings = {
|
||||
ENABLE_DEFAULT_CONDENSER: true,
|
||||
ENABLE_SOUND_NOTIFICATIONS: false,
|
||||
USER_CONSENTS_TO_ANALYTICS: false,
|
||||
PROVIDER_TOKENS: {
|
||||
github: "",
|
||||
gitlab: "",
|
||||
},
|
||||
IS_NEW_USER: true,
|
||||
};
|
||||
|
||||
|
||||
@ -21,7 +21,6 @@ export interface InitConfig {
|
||||
LLM_MODEL: string;
|
||||
};
|
||||
token?: string;
|
||||
github_token?: string;
|
||||
latest_event_id?: unknown; // Not sure what this is
|
||||
}
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
export type Provider = "github" | "gitlab";
|
||||
|
||||
export type Settings = {
|
||||
LLM_MODEL: string;
|
||||
LLM_BASE_URL: string;
|
||||
@ -11,6 +13,7 @@ export type Settings = {
|
||||
ENABLE_DEFAULT_CONDENSER: boolean;
|
||||
ENABLE_SOUND_NOTIFICATIONS: boolean;
|
||||
USER_CONSENTS_TO_ANALYTICS: boolean | null;
|
||||
PROVIDER_TOKENS: Record<Provider, string>;
|
||||
IS_NEW_USER?: boolean;
|
||||
};
|
||||
|
||||
@ -27,16 +30,17 @@ export type ApiSettings = {
|
||||
enable_default_condenser: boolean;
|
||||
enable_sound_notifications: boolean;
|
||||
user_consents_to_analytics: boolean | null;
|
||||
provider_tokens: Record<Provider, string>;
|
||||
};
|
||||
|
||||
export type PostSettings = Settings & {
|
||||
github_token: string;
|
||||
provider_tokens: Record<Provider, string>;
|
||||
unset_github_token: boolean;
|
||||
user_consents_to_analytics: boolean | null;
|
||||
};
|
||||
|
||||
export type PostApiSettings = ApiSettings & {
|
||||
github_token: string;
|
||||
provider_tokens: Record<Provider, string>;
|
||||
unset_github_token: boolean;
|
||||
user_consents_to_analytics: boolean | null;
|
||||
};
|
||||
|
||||
@ -59,6 +59,18 @@ export const extractSettings = (formData: FormData): Partial<Settings> => {
|
||||
ENABLE_DEFAULT_CONDENSER,
|
||||
} = extractAdvancedFormData(formData);
|
||||
|
||||
// Extract provider tokens
|
||||
const githubToken = formData.get("github-token")?.toString();
|
||||
const gitlabToken = formData.get("gitlab-token")?.toString();
|
||||
const providerTokens: Record<string, string> = {};
|
||||
|
||||
if (githubToken) {
|
||||
providerTokens.github = githubToken;
|
||||
}
|
||||
if (gitlabToken) {
|
||||
providerTokens.gitlab = gitlabToken;
|
||||
}
|
||||
|
||||
return {
|
||||
LLM_MODEL: CUSTOM_LLM_MODEL || LLM_MODEL,
|
||||
LLM_API_KEY,
|
||||
@ -68,5 +80,6 @@ export const extractSettings = (formData: FormData): Partial<Settings> => {
|
||||
CONFIRMATION_MODE,
|
||||
SECURITY_ANALYZER,
|
||||
ENABLE_DEFAULT_CONDENSER,
|
||||
PROVIDER_TOKENS: providerTokens,
|
||||
};
|
||||
};
|
||||
|
||||
@ -5,43 +5,43 @@ from typing import Any
|
||||
import httpx
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.github.github_types import (
|
||||
GhAuthenticationError,
|
||||
GHUnknownException,
|
||||
GitHubRepository,
|
||||
GitHubUser,
|
||||
from openhands.integrations.service_types import (
|
||||
AuthenticationError,
|
||||
GitService,
|
||||
Repository,
|
||||
SuggestedTask,
|
||||
TaskType,
|
||||
UnknownException,
|
||||
User,
|
||||
)
|
||||
from openhands.utils.import_utils import get_impl
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
class GitHubService:
|
||||
class GitHubService(GitService):
|
||||
BASE_URL = 'https://api.github.com'
|
||||
github_token: SecretStr = SecretStr('')
|
||||
token: SecretStr = SecretStr('')
|
||||
refresh = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str | None = None,
|
||||
external_auth_token: SecretStr | None = None,
|
||||
github_token: SecretStr | None = None,
|
||||
token: SecretStr | None = None,
|
||||
external_token_manager: bool = False,
|
||||
):
|
||||
self.user_id = user_id
|
||||
self.external_token_manager = external_token_manager
|
||||
|
||||
if github_token:
|
||||
self.github_token = github_token
|
||||
if token:
|
||||
self.token = token
|
||||
|
||||
async def _get_github_headers(self) -> dict:
|
||||
"""Retrieve the GH Token from settings store to construct the headers."""
|
||||
if self.user_id and not self.github_token:
|
||||
self.github_token = await self.get_latest_token()
|
||||
if self.user_id and not self.token:
|
||||
self.token = await self.get_latest_token()
|
||||
|
||||
return {
|
||||
'Authorization': f'Bearer {self.github_token.get_secret_value() if self.github_token else ""}',
|
||||
'Authorization': f'Bearer {self.token.get_secret_value() if self.token else ""}',
|
||||
'Accept': 'application/vnd.github.v3+json',
|
||||
}
|
||||
|
||||
@ -49,7 +49,7 @@ class GitHubService:
|
||||
return status_code == 401
|
||||
|
||||
async def get_latest_token(self) -> SecretStr | None:
|
||||
return self.github_token
|
||||
return self.token
|
||||
|
||||
async def _fetch_data(
|
||||
self, url: str, params: dict | None = None
|
||||
@ -74,20 +74,20 @@ class GitHubService:
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise GhAuthenticationError('Invalid Github token')
|
||||
raise AuthenticationError('Invalid Github token')
|
||||
|
||||
logger.warning(f'Status error on GH API: {e}')
|
||||
raise GHUnknownException('Unknown error')
|
||||
raise UnknownException('Unknown error')
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
logger.warning(f'HTTP error on GH API: {e}')
|
||||
raise GHUnknownException('Unknown error')
|
||||
raise UnknownException('Unknown error')
|
||||
|
||||
async def get_user(self) -> GitHubUser:
|
||||
async def get_user(self) -> User:
|
||||
url = f'{self.BASE_URL}/user'
|
||||
response, _ = await self._fetch_data(url)
|
||||
|
||||
return GitHubUser(
|
||||
return User(
|
||||
id=response.get('id'),
|
||||
login=response.get('login'),
|
||||
avatar_url=response.get('avatar_url'),
|
||||
@ -98,7 +98,7 @@ class GitHubService:
|
||||
|
||||
async def get_repositories(
|
||||
self, page: int, per_page: int, sort: str, installation_id: int | None
|
||||
) -> list[GitHubRepository]:
|
||||
) -> list[Repository]:
|
||||
params = {'page': str(page), 'per_page': str(per_page)}
|
||||
if installation_id:
|
||||
url = f'{self.BASE_URL}/user/installations/{installation_id}/repositories'
|
||||
@ -111,7 +111,7 @@ class GitHubService:
|
||||
|
||||
next_link: str = headers.get('Link', '')
|
||||
repos = [
|
||||
GitHubRepository(
|
||||
Repository(
|
||||
id=repo.get('id'),
|
||||
full_name=repo.get('full_name'),
|
||||
stargazers_count=repo.get('stargazers_count'),
|
||||
@ -129,7 +129,7 @@ class GitHubService:
|
||||
|
||||
async def search_repositories(
|
||||
self, query: str, per_page: int, sort: str, order: str
|
||||
) -> list[GitHubRepository]:
|
||||
) -> list[Repository]:
|
||||
url = f'{self.BASE_URL}/search/repositories'
|
||||
params = {'q': query, 'per_page': per_page, 'sort': sort, 'order': order}
|
||||
|
||||
@ -137,7 +137,7 @@ class GitHubService:
|
||||
repos = response.get('items', [])
|
||||
|
||||
repos = [
|
||||
GitHubRepository(
|
||||
Repository(
|
||||
id=repo.get('id'),
|
||||
full_name=repo.get('full_name'),
|
||||
stargazers_count=repo.get('stargazers_count'),
|
||||
@ -163,7 +163,7 @@ class GitHubService:
|
||||
|
||||
result = response.json()
|
||||
if 'errors' in result:
|
||||
raise GHUnknownException(
|
||||
raise UnknownException(
|
||||
f"GraphQL query error: {json.dumps(result['errors'])}"
|
||||
)
|
||||
|
||||
@ -171,14 +171,14 @@ class GitHubService:
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise GhAuthenticationError('Invalid Github token')
|
||||
raise AuthenticationError('Invalid Github token')
|
||||
|
||||
logger.warning(f'Status error on GH API: {e}')
|
||||
raise GHUnknownException('Unknown error')
|
||||
raise UnknownException('Unknown error')
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
logger.warning(f'HTTP error on GH API: {e}')
|
||||
raise GHUnknownException('Unknown error')
|
||||
raise UnknownException('Unknown error')
|
||||
|
||||
async def get_suggested_tasks(self) -> list[SuggestedTask]:
|
||||
"""Get suggested tasks for the authenticated user across all repositories.
|
||||
|
||||
@ -1,46 +0,0 @@
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class TaskType(str, Enum):
|
||||
MERGE_CONFLICTS = 'MERGE_CONFLICTS'
|
||||
FAILING_CHECKS = 'FAILING_CHECKS'
|
||||
UNRESOLVED_COMMENTS = 'UNRESOLVED_COMMENTS'
|
||||
OPEN_ISSUE = 'OPEN_ISSUE'
|
||||
OPEN_PR = 'OPEN_PR'
|
||||
|
||||
|
||||
class SuggestedTask(BaseModel):
|
||||
task_type: TaskType
|
||||
repo: str
|
||||
issue_number: int
|
||||
title: str
|
||||
|
||||
|
||||
class GitHubUser(BaseModel):
|
||||
id: int
|
||||
login: str
|
||||
avatar_url: str
|
||||
company: str | None = None
|
||||
name: str | None = None
|
||||
email: str | None = None
|
||||
|
||||
|
||||
class GitHubRepository(BaseModel):
|
||||
id: int
|
||||
full_name: str
|
||||
stargazers_count: int | None = None
|
||||
link_header: str | None = None
|
||||
|
||||
|
||||
class GhAuthenticationError(ValueError):
|
||||
"""Raised when there is an issue with GitHub authentication."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class GHUnknownException(ValueError):
|
||||
"""Raised when there is an issue with GitHub communcation."""
|
||||
|
||||
pass
|
||||
119
openhands/integrations/gitlab/gitlab_service.py
Normal file
119
openhands/integrations/gitlab/gitlab_service.py
Normal file
@ -0,0 +1,119 @@
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.integrations.service_types import (
|
||||
AuthenticationError,
|
||||
GitService,
|
||||
Repository,
|
||||
UnknownException,
|
||||
User,
|
||||
)
|
||||
from openhands.utils.import_utils import get_impl
|
||||
|
||||
|
||||
class GitLabService(GitService):
|
||||
BASE_URL = 'https://gitlab.com/api/v4'
|
||||
token: SecretStr = SecretStr('')
|
||||
refresh = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str | None = None,
|
||||
external_auth_token: SecretStr | None = None,
|
||||
token: SecretStr | None = None,
|
||||
external_token_manager: bool = False,
|
||||
):
|
||||
self.user_id = user_id
|
||||
self.external_token_manager = external_token_manager
|
||||
|
||||
if token:
|
||||
self.token = token
|
||||
|
||||
async def _get_gitlab_headers(self) -> dict:
|
||||
"""
|
||||
Retrieve the GitLab Token to construct the headers
|
||||
"""
|
||||
if self.user_id and not self.token:
|
||||
self.token = await self.get_latest_token()
|
||||
|
||||
return {
|
||||
'Authorization': f'Bearer {self.token.get_secret_value()}',
|
||||
}
|
||||
|
||||
def _has_token_expired(self, status_code: int) -> bool:
|
||||
return status_code == 401
|
||||
|
||||
async def get_latest_token(self) -> SecretStr:
|
||||
return self.token
|
||||
|
||||
async def _fetch_data(
|
||||
self, url: str, params: dict | None = None
|
||||
) -> tuple[Any, dict]:
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
gitlab_headers = await self._get_gitlab_headers()
|
||||
response = await client.get(url, headers=gitlab_headers, params=params)
|
||||
if self.refresh and self._has_token_expired(response.status_code):
|
||||
await self.get_latest_token()
|
||||
gitlab_headers = await self._get_gitlab_headers()
|
||||
response = await client.get(
|
||||
url, headers=gitlab_headers, params=params
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
headers = {}
|
||||
if 'Link' in response.headers:
|
||||
headers['Link'] = response.headers['Link']
|
||||
|
||||
return response.json(), headers
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise AuthenticationError('Invalid GitLab token')
|
||||
raise UnknownException('Unknown error')
|
||||
|
||||
except httpx.HTTPError:
|
||||
raise UnknownException('Unknown error')
|
||||
|
||||
async def get_user(self) -> User:
|
||||
url = f'{self.BASE_URL}/user'
|
||||
response, _ = await self._fetch_data(url)
|
||||
|
||||
return User(
|
||||
id=response.get('id'),
|
||||
username=response.get('username'),
|
||||
avatar_url=response.get('avatar_url'),
|
||||
name=response.get('name'),
|
||||
email=response.get('email'),
|
||||
company=response.get('organization'),
|
||||
login=response.get('username'),
|
||||
)
|
||||
|
||||
async def search_repositories(
|
||||
self, query: str, per_page: int = 30, sort: str = 'updated', order: str = 'desc'
|
||||
):
|
||||
url = f'{self.BASE_URL}/search'
|
||||
params = {
|
||||
'scope': 'projects',
|
||||
'search': query,
|
||||
'per_page': per_page,
|
||||
'order_by': sort,
|
||||
'sort': order,
|
||||
}
|
||||
response, headers = await self._fetch_data(url, params)
|
||||
return response, headers
|
||||
|
||||
async def get_repositories(
|
||||
self, page: int, per_page: int, sort: str, installation_id: int | None
|
||||
) -> list[Repository]:
|
||||
return []
|
||||
|
||||
|
||||
gitlab_service_cls = os.environ.get(
|
||||
'OPENHANDS_GITLAB_SERVICE_CLS',
|
||||
'openhands.integrations.gitlab.gitlab_service.GitLabService',
|
||||
)
|
||||
GitLabServiceImpl = get_impl(GitLabService, gitlab_service_cls)
|
||||
143
openhands/integrations/provider.py
Normal file
143
openhands/integrations/provider.py
Normal file
@ -0,0 +1,143 @@
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, SecretStr, SerializationInfo, field_serializer
|
||||
from pydantic.json import pydantic_encoder
|
||||
|
||||
from openhands.integrations.github.github_service import GithubServiceImpl
|
||||
from openhands.integrations.gitlab.gitlab_service import GitLabServiceImpl
|
||||
from openhands.integrations.service_types import (
|
||||
AuthenticationError,
|
||||
GitService,
|
||||
Repository,
|
||||
User,
|
||||
)
|
||||
|
||||
|
||||
class ProviderType(Enum):
|
||||
GITHUB = 'github'
|
||||
GITLAB = 'gitlab'
|
||||
|
||||
|
||||
class ProviderToken(BaseModel):
|
||||
token: SecretStr | None
|
||||
user_id: str | None
|
||||
|
||||
|
||||
PROVIDER_TOKEN_TYPE = dict[ProviderType, ProviderToken]
|
||||
CUSTOM_SECRETS_TYPE = dict[str, SecretStr]
|
||||
|
||||
|
||||
class SecretStore(BaseModel):
|
||||
provider_tokens: PROVIDER_TOKEN_TYPE = {}
|
||||
|
||||
@classmethod
|
||||
def _convert_token(
|
||||
cls, token_value: str | ProviderToken | SecretStr
|
||||
) -> ProviderToken:
|
||||
if isinstance(token_value, ProviderToken):
|
||||
return token_value
|
||||
elif isinstance(token_value, str):
|
||||
return ProviderToken(token=SecretStr(token_value), user_id=None)
|
||||
elif isinstance(token_value, SecretStr):
|
||||
return ProviderToken(token=token_value, user_id=None)
|
||||
else:
|
||||
raise ValueError(f'Invalid token type: {type(token_value)}')
|
||||
|
||||
def model_post_init(self, __context) -> None:
|
||||
# Convert any string tokens to ProviderToken objects
|
||||
converted_tokens = {}
|
||||
for token_type, token_value in self.provider_tokens.items():
|
||||
if token_value: # Only convert non-empty tokens
|
||||
try:
|
||||
if isinstance(token_type, str):
|
||||
token_type = ProviderType(token_type)
|
||||
converted_tokens[token_type] = self._convert_token(token_value)
|
||||
except ValueError:
|
||||
# Skip invalid provider types or tokens
|
||||
continue
|
||||
self.provider_tokens = converted_tokens
|
||||
|
||||
@field_serializer('provider_tokens')
|
||||
def provider_tokens_serializer(
|
||||
self, provider_tokens: PROVIDER_TOKEN_TYPE, info: SerializationInfo
|
||||
):
|
||||
tokens = {}
|
||||
expose_secrets = info.context and info.context.get('expose_secrets', False)
|
||||
|
||||
for token_type, provider_token in provider_tokens.items():
|
||||
if not provider_token or not provider_token.token:
|
||||
continue
|
||||
|
||||
token_type_str = (
|
||||
token_type.value
|
||||
if isinstance(token_type, ProviderType)
|
||||
else str(token_type)
|
||||
)
|
||||
tokens[token_type_str] = {
|
||||
'token': provider_token.token.get_secret_value()
|
||||
if expose_secrets
|
||||
else pydantic_encoder(provider_token.token),
|
||||
'user_id': provider_token.user_id,
|
||||
}
|
||||
|
||||
return tokens
|
||||
|
||||
|
||||
class ProviderHandler:
|
||||
def __init__(
|
||||
self,
|
||||
provider_tokens: PROVIDER_TOKEN_TYPE,
|
||||
external_auth_token: SecretStr | None = None,
|
||||
):
|
||||
self.service_class_map: dict[ProviderType, type[GitService]] = {
|
||||
ProviderType.GITHUB: GithubServiceImpl,
|
||||
ProviderType.GITLAB: GitLabServiceImpl,
|
||||
}
|
||||
|
||||
self.provider_tokens = provider_tokens
|
||||
self.external_auth_token = external_auth_token
|
||||
|
||||
def _get_service(self, provider: ProviderType) -> GitService:
|
||||
"""Helper method to instantiate a service for a given provider"""
|
||||
token = self.provider_tokens[provider]
|
||||
service_class = self.service_class_map[provider]
|
||||
return service_class(
|
||||
user_id=token.user_id,
|
||||
external_auth_token=self.external_auth_token,
|
||||
token=token.token,
|
||||
)
|
||||
|
||||
async def get_user(self) -> User:
|
||||
"""Get user information from the first available provider"""
|
||||
for provider in self.provider_tokens:
|
||||
try:
|
||||
service = self._get_service(provider)
|
||||
return await service.get_user()
|
||||
except Exception:
|
||||
continue
|
||||
raise AuthenticationError('Need valid provider token')
|
||||
|
||||
async def get_latest_provider_tokens(self) -> dict[ProviderType, SecretStr]:
|
||||
"""Get latest token from services"""
|
||||
tokens = {}
|
||||
for provider in self.provider_tokens:
|
||||
service = self._get_service(provider)
|
||||
tokens[provider] = await service.get_latest_token()
|
||||
|
||||
return tokens
|
||||
|
||||
async def get_repositories(
|
||||
self, page: int, per_page: int, sort: str, installation_id: int | None
|
||||
) -> list[Repository]:
|
||||
"""Get repositories from all available providers"""
|
||||
all_repos = []
|
||||
for provider in self.provider_tokens:
|
||||
try:
|
||||
service = self._get_service(provider)
|
||||
repos = await service.get_repositories(
|
||||
page, per_page, sort, installation_id
|
||||
)
|
||||
all_repos.extend(repos)
|
||||
except Exception:
|
||||
continue
|
||||
return all_repos
|
||||
89
openhands/integrations/service_types.py
Normal file
89
openhands/integrations/service_types.py
Normal file
@ -0,0 +1,89 @@
|
||||
from enum import Enum
|
||||
from typing import Protocol
|
||||
|
||||
from pydantic import BaseModel, SecretStr
|
||||
|
||||
|
||||
class TaskType(str, Enum):
|
||||
MERGE_CONFLICTS = 'MERGE_CONFLICTS'
|
||||
FAILING_CHECKS = 'FAILING_CHECKS'
|
||||
UNRESOLVED_COMMENTS = 'UNRESOLVED_COMMENTS'
|
||||
OPEN_ISSUE = 'OPEN_ISSUE'
|
||||
OPEN_PR = 'OPEN_PR'
|
||||
|
||||
|
||||
class SuggestedTask(BaseModel):
|
||||
task_type: TaskType
|
||||
repo: str
|
||||
issue_number: int
|
||||
title: str
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
id: int
|
||||
login: str
|
||||
avatar_url: str
|
||||
company: str | None = None
|
||||
name: str | None = None
|
||||
email: str | None = None
|
||||
|
||||
|
||||
class Repository(BaseModel):
|
||||
id: int
|
||||
full_name: str
|
||||
stargazers_count: int | None = None
|
||||
link_header: str | None = None
|
||||
|
||||
|
||||
class AuthenticationError(ValueError):
|
||||
"""Raised when there is an issue with GitHub authentication."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class UnknownException(ValueError):
|
||||
"""Raised when there is an issue with GitHub communcation."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class GitService(Protocol):
|
||||
"""Protocol defining the interface for Git service providers"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str | None,
|
||||
token: SecretStr | None,
|
||||
external_auth_token: SecretStr | None,
|
||||
external_token_manager: bool = False,
|
||||
) -> None:
|
||||
"""Initialize the service with authentication details"""
|
||||
...
|
||||
|
||||
async def get_latest_token(self) -> SecretStr:
|
||||
"""Get latest working token of the users"""
|
||||
...
|
||||
|
||||
async def get_user(self) -> User:
|
||||
"""Get the authenticated user's information"""
|
||||
...
|
||||
|
||||
async def search_repositories(
|
||||
self,
|
||||
query: str,
|
||||
per_page: int,
|
||||
sort: str,
|
||||
order: str,
|
||||
) -> list[Repository]:
|
||||
"""Search for repositories"""
|
||||
...
|
||||
|
||||
async def get_repositories(
|
||||
self,
|
||||
page: int,
|
||||
per_page: int,
|
||||
sort: str,
|
||||
installation_id: int | None,
|
||||
) -> list[Repository]:
|
||||
"""Get repositories for the authenticated user"""
|
||||
...
|
||||
37
openhands/integrations/utils.py
Normal file
37
openhands/integrations/utils.py
Normal file
@ -0,0 +1,37 @@
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.integrations.github.github_service import GitHubService
|
||||
from openhands.integrations.gitlab.gitlab_service import GitLabService
|
||||
from openhands.integrations.provider import ProviderType
|
||||
|
||||
|
||||
async def validate_provider_token(token: SecretStr) -> ProviderType | None:
|
||||
"""
|
||||
Determine whether a token is for GitHub or GitLab by attempting to get user info
|
||||
from both services.
|
||||
|
||||
Args:
|
||||
token: The token to check
|
||||
|
||||
Returns:
|
||||
'github' if it's a GitHub token
|
||||
'gitlab' if it's a GitLab token
|
||||
None if the token is invalid for both services
|
||||
"""
|
||||
# Try GitHub first
|
||||
try:
|
||||
github_service = GitHubService(token=token)
|
||||
await github_service.get_user()
|
||||
return ProviderType.GITHUB
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Try GitLab next
|
||||
try:
|
||||
gitlab_service = GitLabService(token=token)
|
||||
await gitlab_service.get_user()
|
||||
return ProviderType.GITLAB
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
@ -1,6 +1,13 @@
|
||||
from fastapi import Request
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderType
|
||||
|
||||
|
||||
def get_provider_tokens(request: Request) -> PROVIDER_TOKEN_TYPE | None:
|
||||
"""Get GitHub token from request state. For backward compatibility."""
|
||||
return getattr(request.state, 'provider_tokens', None)
|
||||
|
||||
|
||||
def get_access_token(request: Request) -> SecretStr | None:
|
||||
return getattr(request.state, 'access_token', None)
|
||||
@ -11,8 +18,18 @@ def get_user_id(request: Request) -> str | None:
|
||||
|
||||
|
||||
def get_github_token(request: Request) -> SecretStr | None:
|
||||
return getattr(request.state, 'github_token', None)
|
||||
provider_tokens = get_provider_tokens(request)
|
||||
|
||||
if provider_tokens and ProviderType.GITHUB in provider_tokens:
|
||||
return provider_tokens[ProviderType.GITHUB].token
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_github_user_id(request: Request) -> str | None:
|
||||
return getattr(request.state, 'github_user_id', None)
|
||||
provider_tokens = get_provider_tokens(request)
|
||||
|
||||
if provider_tokens and ProviderType.GITHUB in provider_tokens:
|
||||
return provider_tokens[ProviderType.GITHUB].user_id
|
||||
|
||||
return None
|
||||
|
||||
@ -194,10 +194,14 @@ class GitHubTokenMiddleware(SessionMiddlewareInterface):
|
||||
settings = await settings_store.load()
|
||||
|
||||
# TODO: To avoid checks like this we should re-add the abilty to have completely different middleware in SAAS as in OSS
|
||||
if getattr(request.state, 'github_token', None) is None:
|
||||
if settings and settings.github_token:
|
||||
request.state.github_token = settings.github_token
|
||||
if getattr(request.state, 'provider_tokens', None) is None:
|
||||
if (
|
||||
settings
|
||||
and settings.secrets_store
|
||||
and settings.secrets_store.provider_tokens
|
||||
):
|
||||
request.state.provider_tokens = settings.secrets_store.provider_tokens
|
||||
else:
|
||||
request.state.github_token = None
|
||||
request.state.provider_tokens = None
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
@ -3,147 +3,168 @@ from fastapi.responses import JSONResponse
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.integrations.github.github_service import GithubServiceImpl
|
||||
from openhands.integrations.github.github_types import (
|
||||
GhAuthenticationError,
|
||||
GHUnknownException,
|
||||
GitHubRepository,
|
||||
GitHubUser,
|
||||
SuggestedTask,
|
||||
from openhands.integrations.provider import (
|
||||
PROVIDER_TOKEN_TYPE,
|
||||
ProviderHandler,
|
||||
ProviderType,
|
||||
)
|
||||
from openhands.server.auth import get_access_token, get_github_token, get_github_user_id
|
||||
from openhands.integrations.service_types import (
|
||||
AuthenticationError,
|
||||
Repository,
|
||||
SuggestedTask,
|
||||
UnknownException,
|
||||
User,
|
||||
)
|
||||
from openhands.server.auth import get_access_token, get_provider_tokens
|
||||
|
||||
app = APIRouter(prefix='/api/github')
|
||||
|
||||
|
||||
@app.get('/repositories', response_model=list[GitHubRepository])
|
||||
@app.get('/repositories', response_model=list[Repository])
|
||||
async def get_github_repositories(
|
||||
page: int = 1,
|
||||
per_page: int = 10,
|
||||
sort: str = 'pushed',
|
||||
installation_id: int | None = None,
|
||||
github_user_id: str | None = Depends(get_github_user_id),
|
||||
github_user_token: SecretStr | None = Depends(get_github_token),
|
||||
provider_tokens: PROVIDER_TOKEN_TYPE | None = Depends(get_provider_tokens),
|
||||
access_token: SecretStr | None = Depends(get_access_token),
|
||||
):
|
||||
client = GithubServiceImpl(
|
||||
user_id=github_user_id,
|
||||
external_auth_token=access_token,
|
||||
github_token=github_user_token,
|
||||
if provider_tokens and ProviderType.GITHUB in provider_tokens:
|
||||
token = provider_tokens[ProviderType.GITHUB]
|
||||
client = GithubServiceImpl(
|
||||
user_id=token.user_id, external_auth_token=access_token, token=token.token
|
||||
)
|
||||
|
||||
try:
|
||||
repos: list[Repository] = await client.get_repositories(
|
||||
page, per_page, sort, installation_id
|
||||
)
|
||||
return repos
|
||||
|
||||
except AuthenticationError as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
except UnknownException as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
content='GitHub token required.',
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
try:
|
||||
repos: list[GitHubRepository] = await client.get_repositories(
|
||||
page, per_page, sort, installation_id
|
||||
)
|
||||
return repos
|
||||
|
||||
except GhAuthenticationError as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
except GHUnknownException as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
|
||||
@app.get('/user', response_model=GitHubUser)
|
||||
@app.get('/user', response_model=User)
|
||||
async def get_github_user(
|
||||
github_user_id: str | None = Depends(get_github_user_id),
|
||||
github_user_token: SecretStr | None = Depends(get_github_token),
|
||||
provider_tokens: PROVIDER_TOKEN_TYPE | None = Depends(get_provider_tokens),
|
||||
access_token: SecretStr | None = Depends(get_access_token),
|
||||
):
|
||||
client = GithubServiceImpl(
|
||||
user_id=github_user_id,
|
||||
external_auth_token=access_token,
|
||||
github_token=github_user_token,
|
||||
if provider_tokens:
|
||||
client = ProviderHandler(provider_tokens=provider_tokens, external_auth_token=access_token)
|
||||
|
||||
try:
|
||||
user: User = await client.get_user()
|
||||
return user
|
||||
|
||||
except AuthenticationError as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
except UnknownException as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
content='GitHub token required.',
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
try:
|
||||
user: GitHubUser = await client.get_user()
|
||||
return user
|
||||
|
||||
except GhAuthenticationError as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
except GHUnknownException as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
|
||||
@app.get('/installations', response_model=list[int])
|
||||
async def get_github_installation_ids(
|
||||
github_user_id: str | None = Depends(get_github_user_id),
|
||||
github_user_token: SecretStr | None = Depends(get_github_token),
|
||||
provider_tokens: PROVIDER_TOKEN_TYPE | None = Depends(get_provider_tokens),
|
||||
access_token: SecretStr | None = Depends(get_access_token),
|
||||
):
|
||||
client = GithubServiceImpl(
|
||||
user_id=github_user_id,
|
||||
external_auth_token=access_token,
|
||||
github_token=github_user_token,
|
||||
if provider_tokens and ProviderType.GITHUB in provider_tokens:
|
||||
token = provider_tokens[ProviderType.GITHUB]
|
||||
|
||||
client = GithubServiceImpl(
|
||||
user_id=token.user_id, external_auth_token=access_token, token=token.token
|
||||
)
|
||||
try:
|
||||
installations_ids: list[int] = await client.get_installation_ids()
|
||||
return installations_ids
|
||||
|
||||
except AuthenticationError as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
except UnknownException as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
content='GitHub token required.',
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
try:
|
||||
installations_ids: list[int] = await client.get_installation_ids()
|
||||
return installations_ids
|
||||
|
||||
except GhAuthenticationError as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
except GHUnknownException as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
|
||||
@app.get('/search/repositories', response_model=list[GitHubRepository])
|
||||
@app.get('/search/repositories', response_model=list[Repository])
|
||||
async def search_github_repositories(
|
||||
query: str,
|
||||
per_page: int = 5,
|
||||
sort: str = 'stars',
|
||||
order: str = 'desc',
|
||||
github_user_id: str | None = Depends(get_github_user_id),
|
||||
github_user_token: SecretStr | None = Depends(get_github_token),
|
||||
provider_tokens: PROVIDER_TOKEN_TYPE | None = Depends(get_provider_tokens),
|
||||
access_token: SecretStr | None = Depends(get_access_token),
|
||||
):
|
||||
client = GithubServiceImpl(
|
||||
user_id=github_user_id,
|
||||
external_auth_token=access_token,
|
||||
github_token=github_user_token,
|
||||
if provider_tokens and ProviderType.GITHUB in provider_tokens:
|
||||
token = provider_tokens[ProviderType.GITHUB]
|
||||
|
||||
client = GithubServiceImpl(
|
||||
user_id=token.user_id, external_auth_token=access_token, token=token.token
|
||||
)
|
||||
try:
|
||||
repos: list[Repository] = await client.search_repositories(
|
||||
query, per_page, sort, order
|
||||
)
|
||||
return repos
|
||||
|
||||
except AuthenticationError as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
except UnknownException as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
content='GitHub token required.',
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
try:
|
||||
repos: list[GitHubRepository] = await client.search_repositories(
|
||||
query, per_page, sort, order
|
||||
)
|
||||
return repos
|
||||
|
||||
except GhAuthenticationError as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
except GHUnknownException as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
|
||||
@app.get('/suggested-tasks', response_model=list[SuggestedTask])
|
||||
async def get_suggested_tasks(
|
||||
github_user_id: str | None = Depends(get_github_user_id),
|
||||
github_user_token: SecretStr | None = Depends(get_github_token),
|
||||
access_token: SecretStr | None = Depends(get_access_token),
|
||||
provider_tokens: PROVIDER_TOKEN_TYPE | None = Depends(get_provider_tokens),
|
||||
access_token: SecretStr | None = Depends(get_access_token)
|
||||
):
|
||||
"""Get suggested tasks for the authenticated user across their most recently pushed repositories.
|
||||
|
||||
@ -151,23 +172,30 @@ async def get_suggested_tasks(
|
||||
- PRs owned by the user
|
||||
- Issues assigned to the user.
|
||||
"""
|
||||
client = GithubServiceImpl(
|
||||
user_id=github_user_id,
|
||||
external_auth_token=access_token,
|
||||
github_token=github_user_token,
|
||||
|
||||
if provider_tokens and ProviderType.GITHUB in provider_tokens:
|
||||
token = provider_tokens[ProviderType.GITHUB]
|
||||
|
||||
client = GithubServiceImpl(
|
||||
user_id=token.user_id, external_auth_token=access_token, token=token.token
|
||||
)
|
||||
try:
|
||||
tasks: list[SuggestedTask] = await client.get_suggested_tasks()
|
||||
return tasks
|
||||
|
||||
except AuthenticationError as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=401,
|
||||
)
|
||||
|
||||
except UnknownException as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
content='GitHub token required.',
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
try:
|
||||
tasks: list[SuggestedTask] = await client.get_suggested_tasks()
|
||||
return tasks
|
||||
|
||||
except GhAuthenticationError as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=401,
|
||||
)
|
||||
|
||||
except GHUnknownException as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
@ -8,8 +8,9 @@ from pydantic import BaseModel, SecretStr
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action.message import MessageAction
|
||||
from openhands.integrations.github.github_service import GithubServiceImpl
|
||||
from openhands.integrations.provider import ProviderType
|
||||
from openhands.runtime import get_runtime_cls
|
||||
from openhands.server.auth import get_access_token, get_github_token, get_github_user_id
|
||||
from openhands.server.auth import get_provider_tokens, get_access_token, get_github_user_id
|
||||
from openhands.server.data_models.conversation_info import ConversationInfo
|
||||
from openhands.server.data_models.conversation_info_result_set import (
|
||||
ConversationInfoResultSet,
|
||||
@ -136,13 +137,18 @@ async def new_conversation(request: Request, data: InitSessionRequest):
|
||||
using the returned conversation ID.
|
||||
"""
|
||||
logger.info('Initializing new conversation')
|
||||
user_id = get_github_user_id(request)
|
||||
gh_client = GithubServiceImpl(
|
||||
user_id=user_id,
|
||||
external_auth_token=get_access_token(request),
|
||||
github_token=get_github_token(request),
|
||||
)
|
||||
github_token = await gh_client.get_latest_token()
|
||||
user_id = None
|
||||
github_token = None
|
||||
provider_tokens = get_provider_tokens(request)
|
||||
if provider_tokens and ProviderType.GITHUB in provider_tokens:
|
||||
token = provider_tokens[ProviderType.GITHUB]
|
||||
user_id = token.user_id
|
||||
gh_client = GithubServiceImpl(
|
||||
user_id=user_id,
|
||||
external_auth_token=get_access_token(request),
|
||||
token=token.token,
|
||||
)
|
||||
github_token = await gh_client.get_latest_token()
|
||||
|
||||
selected_repository = data.selected_repository
|
||||
selected_branch = data.selected_branch
|
||||
|
||||
@ -3,8 +3,9 @@ from fastapi.responses import JSONResponse
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.github.github_service import GithubServiceImpl
|
||||
from openhands.server.auth import get_github_token, get_user_id
|
||||
from openhands.integrations.provider import ProviderToken, ProviderType
|
||||
from openhands.integrations.utils import validate_provider_token
|
||||
from openhands.server.auth import get_provider_tokens, get_user_id
|
||||
from openhands.server.settings import GETSettingsModel, POSTSettingsModel, Settings
|
||||
from openhands.server.shared import SettingsStoreImpl, config
|
||||
|
||||
@ -23,14 +24,14 @@ async def load_settings(request: Request) -> GETSettingsModel | JSONResponse:
|
||||
content={'error': 'Settings not found'},
|
||||
)
|
||||
|
||||
token_is_set = bool(user_id) or bool(get_github_token(request))
|
||||
github_token_is_set = bool(user_id) or bool(get_provider_tokens(request))
|
||||
settings_with_token_data = GETSettingsModel(
|
||||
**settings.model_dump(),
|
||||
github_token_is_set=token_is_set,
|
||||
github_token_is_set=github_token_is_set,
|
||||
)
|
||||
settings_with_token_data.llm_api_key = settings.llm_api_key
|
||||
|
||||
del settings_with_token_data.github_token
|
||||
del settings_with_token_data.secrets_store
|
||||
return settings_with_token_data
|
||||
except Exception as e:
|
||||
logger.warning(f'Invalid token: {e}')
|
||||
@ -45,26 +46,27 @@ async def store_settings(
|
||||
request: Request,
|
||||
settings: POSTSettingsModel,
|
||||
) -> JSONResponse:
|
||||
# Check if token is valid
|
||||
if settings.github_token:
|
||||
try:
|
||||
# We check if the token is valid by getting the user
|
||||
# If the token is invalid, this will raise an exception
|
||||
github = GithubServiceImpl(
|
||||
user_id=None,
|
||||
external_auth_token=None,
|
||||
github_token=SecretStr(settings.github_token),
|
||||
)
|
||||
await github.get_user()
|
||||
# Check provider tokens are valid
|
||||
if settings.provider_tokens:
|
||||
# Remove extraneous token types
|
||||
provider_types = [provider.value for provider in ProviderType]
|
||||
settings.provider_tokens = {
|
||||
k: v for k, v in settings.provider_tokens.items() if k in provider_types
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f'Invalid GitHub token: {e}')
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={
|
||||
'error': 'Invalid GitHub token. Please make sure it is valid.'
|
||||
},
|
||||
)
|
||||
# Determine whether tokens are valid
|
||||
for token_type, token_value in settings.provider_tokens.items():
|
||||
if token_value:
|
||||
confirmed_token_type = await validate_provider_token(
|
||||
SecretStr(token_value)
|
||||
)
|
||||
if not confirmed_token_type or confirmed_token_type.value != token_type:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={
|
||||
'error': f'Invalid token. Please make sure it is a valid {token_type} token.'
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
settings_store = await SettingsStoreImpl.get_instance(
|
||||
@ -72,32 +74,46 @@ async def store_settings(
|
||||
)
|
||||
existing_settings = await settings_store.load()
|
||||
|
||||
# Convert to Settings model and merge with existing settings
|
||||
if existing_settings:
|
||||
# LLM key isn't on the frontend, so we need to keep it if unset
|
||||
# Keep existing LLM settings if not provided
|
||||
if settings.llm_api_key is None:
|
||||
settings.llm_api_key = existing_settings.llm_api_key
|
||||
if settings.llm_model is None:
|
||||
settings.llm_model = existing_settings.llm_model
|
||||
if settings.llm_base_url is None:
|
||||
settings.llm_base_url = existing_settings.llm_base_url
|
||||
|
||||
if settings.github_token is None:
|
||||
settings.github_token = existing_settings.github_token
|
||||
|
||||
# Keep existing analytics consent if not provided
|
||||
if settings.user_consents_to_analytics is None:
|
||||
settings.user_consents_to_analytics = (
|
||||
existing_settings.user_consents_to_analytics
|
||||
)
|
||||
|
||||
if settings.llm_model is None:
|
||||
settings.llm_model = existing_settings.llm_model
|
||||
if existing_settings.secrets_store:
|
||||
existing_providers = [
|
||||
provider.value
|
||||
for provider in existing_settings.secrets_store.provider_tokens
|
||||
]
|
||||
|
||||
if settings.llm_base_url is None:
|
||||
settings.llm_base_url = existing_settings.llm_base_url
|
||||
# Merge incoming settings store with the existing one
|
||||
for provider, token_value in settings.provider_tokens.items():
|
||||
if provider in existing_providers and not token_value:
|
||||
provider_type = ProviderType(provider)
|
||||
existing_token = (
|
||||
existing_settings.secrets_store.provider_tokens.get(
|
||||
provider_type
|
||||
)
|
||||
)
|
||||
if existing_token and existing_token.token:
|
||||
settings.provider_tokens[provider] = (
|
||||
existing_token.token.get_secret_value()
|
||||
)
|
||||
|
||||
response = JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content={'message': 'Settings stored'},
|
||||
)
|
||||
|
||||
if settings.unset_github_token:
|
||||
settings.github_token = None
|
||||
# Merge provider tokens with existing ones
|
||||
if settings.unset_github_token: # Only merge if not unsetting tokens
|
||||
settings.secrets_store.provider_tokens = {}
|
||||
settings.provider_tokens = {}
|
||||
|
||||
# Update sandbox config with new settings
|
||||
if settings.remote_runtime_resource_factor is not None:
|
||||
@ -106,9 +122,11 @@ async def store_settings(
|
||||
)
|
||||
|
||||
settings = convert_to_settings(settings)
|
||||
|
||||
await settings_store.store(settings)
|
||||
return response
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content={'message': 'Settings stored'},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f'Something went wrong storing settings: {e}')
|
||||
return JSONResponse(
|
||||
@ -127,8 +145,19 @@ def convert_to_settings(settings_with_token_data: POSTSettingsModel) -> Settings
|
||||
if key in Settings.model_fields # Ensures only `Settings` fields are included
|
||||
}
|
||||
|
||||
# Convert the `llm_api_key` and `github_token` to a `SecretStr` instance
|
||||
# Convert the `llm_api_key` to a `SecretStr` instance
|
||||
filtered_settings_data['llm_api_key'] = settings_with_token_data.llm_api_key
|
||||
filtered_settings_data['github_token'] = settings_with_token_data.github_token
|
||||
|
||||
return Settings(**filtered_settings_data)
|
||||
# Create a new Settings instance without provider tokens
|
||||
settings = Settings(**filtered_settings_data)
|
||||
|
||||
# Update provider tokens if any are provided
|
||||
if settings_with_token_data.provider_tokens:
|
||||
for token_type, token_value in settings_with_token_data.provider_tokens.items():
|
||||
if token_value:
|
||||
provider = ProviderType(token_type)
|
||||
settings.secrets_store.provider_tokens[provider] = ProviderToken(
|
||||
token=SecretStr(token_value), user_id=None
|
||||
)
|
||||
|
||||
return settings
|
||||
|
||||
@ -1,10 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, SecretStr, SerializationInfo, field_serializer
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
SecretStr,
|
||||
SerializationInfo,
|
||||
field_serializer,
|
||||
model_validator,
|
||||
)
|
||||
from pydantic.json import pydantic_encoder
|
||||
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.core.config.utils import load_app_config
|
||||
from openhands.integrations.provider import ProviderToken, ProviderType, SecretStore
|
||||
|
||||
|
||||
class Settings(BaseModel):
|
||||
@ -21,7 +28,7 @@ class Settings(BaseModel):
|
||||
llm_api_key: SecretStr | None = None
|
||||
llm_base_url: str | None = None
|
||||
remote_runtime_resource_factor: int | None = None
|
||||
github_token: SecretStr | None = None
|
||||
secrets_store: SecretStore = SecretStore()
|
||||
enable_default_condenser: bool = False
|
||||
enable_sound_notifications: bool = False
|
||||
user_consents_to_analytics: bool | None = None
|
||||
@ -38,22 +45,63 @@ class Settings(BaseModel):
|
||||
|
||||
return pydantic_encoder(llm_api_key)
|
||||
|
||||
@field_serializer('github_token')
|
||||
def github_token_serializer(
|
||||
self, github_token: SecretStr | None, info: SerializationInfo
|
||||
):
|
||||
"""Custom serializer for the GitHub token.
|
||||
@staticmethod
|
||||
def _convert_token_value(
|
||||
token_type: ProviderType, token_value: str | dict
|
||||
) -> ProviderToken | None:
|
||||
"""Convert a token value to a ProviderToken object."""
|
||||
if isinstance(token_value, dict):
|
||||
token_str = token_value.get('token')
|
||||
if not token_str:
|
||||
return None
|
||||
return ProviderToken(
|
||||
token=SecretStr(token_str),
|
||||
user_id=token_value.get('user_id'),
|
||||
)
|
||||
if isinstance(token_value, str) and token_value:
|
||||
return ProviderToken(token=SecretStr(token_value), user_id=None)
|
||||
return None
|
||||
|
||||
To serialize the token instead of ********, set expose_secrets to True in the serialization context.
|
||||
"""
|
||||
if github_token is None:
|
||||
return None
|
||||
@model_validator(mode='before')
|
||||
@classmethod
|
||||
def convert_provider_tokens(cls, data: dict | object) -> dict | object:
|
||||
"""Convert provider tokens from JSON format to SecretStore format."""
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
|
||||
context = info.context
|
||||
if context and context.get('expose_secrets', False):
|
||||
return github_token.get_secret_value()
|
||||
secrets_store = data.get('secrets_store')
|
||||
if not isinstance(secrets_store, dict):
|
||||
return data
|
||||
|
||||
return pydantic_encoder(github_token)
|
||||
tokens = secrets_store.get('provider_tokens')
|
||||
if not isinstance(tokens, dict):
|
||||
return data
|
||||
|
||||
converted_tokens = {}
|
||||
for token_type_str, token_value in tokens.items():
|
||||
if not token_value:
|
||||
continue
|
||||
|
||||
try:
|
||||
token_type = ProviderType(token_type_str)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
provider_token = cls._convert_token_value(token_type, token_value)
|
||||
if provider_token:
|
||||
converted_tokens[token_type] = provider_token
|
||||
|
||||
data['secrets_store'] = SecretStore(provider_tokens=converted_tokens)
|
||||
return data
|
||||
|
||||
@field_serializer('secrets_store')
|
||||
def secrets_store_serializer(self, secrets: SecretStore, info: SerializationInfo):
|
||||
"""Custom serializer for secrets store."""
|
||||
return {
|
||||
'provider_tokens': secrets.provider_tokens_serializer(
|
||||
secrets.provider_tokens, info
|
||||
)
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def from_config() -> Settings | None:
|
||||
@ -73,7 +121,7 @@ class Settings(BaseModel):
|
||||
llm_api_key=llm_config.api_key,
|
||||
llm_base_url=llm_config.base_url,
|
||||
remote_runtime_resource_factor=app_config.sandbox.remote_runtime_resource_factor,
|
||||
github_token=None,
|
||||
provider_tokens={},
|
||||
)
|
||||
return settings
|
||||
|
||||
@ -84,14 +132,12 @@ class POSTSettingsModel(Settings):
|
||||
"""
|
||||
|
||||
unset_github_token: bool | None = None
|
||||
github_token: str | None = (
|
||||
None # This is a string because it's coming from the frontend
|
||||
)
|
||||
# Override provider_tokens to accept string tokens from frontend
|
||||
provider_tokens: dict[str, str] = {}
|
||||
|
||||
# Override the serializer for the GitHub token to handle the string input
|
||||
@field_serializer('github_token')
|
||||
def github_token_serializer(self, github_token: str | None):
|
||||
return github_token
|
||||
@field_serializer('provider_tokens')
|
||||
def provider_tokens_serializer(self, provider_tokens: dict[str, str]):
|
||||
return provider_tokens
|
||||
|
||||
|
||||
class GETSettingsModel(Settings):
|
||||
|
||||
@ -5,16 +5,16 @@ import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.integrations.github.github_service import GitHubService
|
||||
from openhands.integrations.github.github_types import GhAuthenticationError
|
||||
from openhands.integrations.service_types import AuthenticationError
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_github_service_token_handling():
|
||||
# Test initialization with SecretStr token
|
||||
token = SecretStr('test-token')
|
||||
service = GitHubService(user_id=None, github_token=token)
|
||||
assert service.github_token == token
|
||||
assert service.github_token.get_secret_value() == 'test-token'
|
||||
service = GitHubService(user_id=None, token=token)
|
||||
assert service.token == token
|
||||
assert service.token.get_secret_value() == 'test-token'
|
||||
|
||||
# Test headers contain the token correctly
|
||||
headers = await service._get_github_headers()
|
||||
@ -23,14 +23,14 @@ async def test_github_service_token_handling():
|
||||
|
||||
# Test initialization without token
|
||||
service = GitHubService(user_id='test-user')
|
||||
assert service.github_token == SecretStr('')
|
||||
assert service.token == SecretStr('')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_github_service_token_refresh():
|
||||
# Test that token refresh is only attempted when refresh=True
|
||||
token = SecretStr('test-token')
|
||||
service = GitHubService(user_id=None, github_token=token)
|
||||
service = GitHubService(user_id=None, token=token)
|
||||
assert not service.refresh
|
||||
|
||||
# Test token expiry detection
|
||||
@ -58,7 +58,7 @@ async def test_github_service_fetch_data():
|
||||
mock_client.__aexit__.return_value = None
|
||||
|
||||
with patch('httpx.AsyncClient', return_value=mock_client):
|
||||
service = GitHubService(user_id=None, github_token=SecretStr('test-token'))
|
||||
service = GitHubService(user_id=None, token=SecretStr('test-token'))
|
||||
_ = await service._fetch_data('https://api.github.com/user')
|
||||
|
||||
# Verify the request was made with correct headers
|
||||
@ -77,5 +77,5 @@ async def test_github_service_fetch_data():
|
||||
mock_client.get.reset_mock()
|
||||
mock_client.get.return_value = mock_response
|
||||
|
||||
with pytest.raises(GhAuthenticationError):
|
||||
with pytest.raises(AuthenticationError):
|
||||
_ = await service._fetch_data('https://api.github.com/user')
|
||||
|
||||
@ -6,6 +6,7 @@ from openhands.core.config.app_config import AppConfig
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.core.config.sandbox_config import SandboxConfig
|
||||
from openhands.core.config.security_config import SecurityConfig
|
||||
from openhands.integrations.provider import ProviderToken, ProviderType
|
||||
from openhands.server.routes.settings import convert_to_settings
|
||||
from openhands.server.settings import POSTSettingsModel, Settings
|
||||
|
||||
@ -43,7 +44,7 @@ def test_settings_from_config():
|
||||
assert settings.llm_api_key.get_secret_value() == 'test-key'
|
||||
assert settings.llm_base_url == 'https://test.example.com'
|
||||
assert settings.remote_runtime_resource_factor == 2
|
||||
assert settings.github_token is None
|
||||
assert not settings.secrets_store.provider_tokens
|
||||
|
||||
|
||||
def test_settings_from_config_no_api_key():
|
||||
@ -80,23 +81,41 @@ def test_settings_handles_sensitive_data():
|
||||
llm_api_key='test-key',
|
||||
llm_base_url='https://test.example.com',
|
||||
remote_runtime_resource_factor=2,
|
||||
github_token='test-token',
|
||||
)
|
||||
settings.secrets_store.provider_tokens[ProviderType.GITHUB] = ProviderToken(
|
||||
token=SecretStr('test-token'),
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
assert str(settings.llm_api_key) == '**********'
|
||||
assert str(settings.github_token) == '**********'
|
||||
assert (
|
||||
str(settings.secrets_store.provider_tokens[ProviderType.GITHUB].token)
|
||||
== '**********'
|
||||
)
|
||||
|
||||
assert settings.llm_api_key.get_secret_value() == 'test-key'
|
||||
assert settings.github_token.get_secret_value() == 'test-token'
|
||||
assert (
|
||||
settings.secrets_store.provider_tokens[
|
||||
ProviderType.GITHUB
|
||||
].token.get_secret_value()
|
||||
== 'test-token'
|
||||
)
|
||||
|
||||
|
||||
def test_convert_to_settings():
|
||||
settings_with_token_data = POSTSettingsModel(
|
||||
llm_api_key='test-key',
|
||||
github_token='test-token',
|
||||
provider_tokens={
|
||||
'github': 'test-token',
|
||||
},
|
||||
)
|
||||
|
||||
settings = convert_to_settings(settings_with_token_data)
|
||||
|
||||
assert settings.llm_api_key.get_secret_value() == 'test-key'
|
||||
assert settings.github_token.get_secret_value() == 'test-token'
|
||||
assert (
|
||||
settings.secrets_store.provider_tokens[
|
||||
ProviderType.GITHUB
|
||||
].token.get_secret_value()
|
||||
== 'test-token'
|
||||
)
|
||||
|
||||
@ -5,6 +5,7 @@ from fastapi.testclient import TestClient
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.core.config.sandbox_config import SandboxConfig
|
||||
from openhands.integrations.provider import ProviderType, SecretStore
|
||||
from openhands.server.app import app
|
||||
from openhands.server.settings import Settings
|
||||
|
||||
@ -19,6 +20,24 @@ def mock_settings_store():
|
||||
yield store_instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_user_id():
|
||||
with patch('openhands.server.routes.settings.get_user_id') as mock:
|
||||
mock.return_value = 'test-user'
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_validate_provider_token():
|
||||
with patch('openhands.server.routes.settings.validate_provider_token') as mock:
|
||||
|
||||
async def mock_determine(*args, **kwargs):
|
||||
return ProviderType.GITHUB
|
||||
|
||||
mock.side_effect = mock_determine
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_client(mock_settings_store):
|
||||
# Mock the middleware that adds github_token
|
||||
@ -28,9 +47,15 @@ def test_client(mock_settings_store):
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
settings = mock_settings_store.load.return_value
|
||||
token = settings.github_token if settings else None
|
||||
token = None
|
||||
if settings and settings.secrets_store.provider_tokens.get(
|
||||
ProviderType.GITHUB
|
||||
):
|
||||
token = settings.secrets_store.provider_tokens[
|
||||
ProviderType.GITHUB
|
||||
].token
|
||||
if scope['type'] == 'http':
|
||||
scope['state'] = {'github_token': token}
|
||||
scope['state'] = {'token': token}
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
# Replace the middleware
|
||||
@ -47,7 +72,9 @@ def mock_github_service():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_settings_api_runtime_factor(test_client, mock_settings_store):
|
||||
async def test_settings_api_runtime_factor(
|
||||
test_client, mock_settings_store, mock_get_user_id, mock_validate_provider_token
|
||||
):
|
||||
# Mock the settings store to return None initially (no existing settings)
|
||||
mock_settings_store.load.return_value = None
|
||||
|
||||
@ -62,6 +89,7 @@ async def test_settings_api_runtime_factor(test_client, mock_settings_store):
|
||||
'llm_api_key': 'test-key',
|
||||
'llm_base_url': 'https://test.com',
|
||||
'remote_runtime_resource_factor': 2,
|
||||
'provider_tokens': {'github': 'test-token'},
|
||||
}
|
||||
|
||||
# The test_client fixture already handles authentication
|
||||
@ -98,12 +126,17 @@ async def test_settings_api_runtime_factor(test_client, mock_settings_store):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_settings_llm_api_key(test_client, mock_settings_store):
|
||||
async def test_settings_llm_api_key(
|
||||
test_client, mock_settings_store, mock_get_user_id, mock_validate_provider_token
|
||||
):
|
||||
# Mock the settings store to return None initially (no existing settings)
|
||||
mock_settings_store.load.return_value = None
|
||||
|
||||
# Test data with remote_runtime_resource_factor
|
||||
settings_data = {'llm_api_key': 'test-key'}
|
||||
settings_data = {
|
||||
'llm_api_key': 'test-key',
|
||||
'provider_tokens': {'github': 'test-token'},
|
||||
}
|
||||
|
||||
# The test_client fixture already handles authentication
|
||||
|
||||
@ -132,9 +165,13 @@ async def test_settings_llm_api_key(test_client, mock_settings_store):
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_settings_api_set_github_token(
|
||||
mock_github_service, test_client, mock_settings_store
|
||||
mock_github_service,
|
||||
test_client,
|
||||
mock_settings_store,
|
||||
mock_get_user_id,
|
||||
mock_validate_provider_token,
|
||||
):
|
||||
# Test data with github_token set
|
||||
# Test data with provider token set
|
||||
settings_data = {
|
||||
'language': 'en',
|
||||
'agent': 'test-agent',
|
||||
@ -144,16 +181,21 @@ async def test_settings_api_set_github_token(
|
||||
'llm_model': 'test-model',
|
||||
'llm_api_key': 'test-key',
|
||||
'llm_base_url': 'https://test.com',
|
||||
'github_token': 'test-token',
|
||||
'provider_tokens': {'github': 'test-token'},
|
||||
}
|
||||
|
||||
# Make the POST request to store settings
|
||||
response = test_client.post('/api/settings', json=settings_data)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify the settings were stored with the github_token
|
||||
# Verify the settings were stored with the provider token
|
||||
stored_settings = mock_settings_store.store.call_args[0][0]
|
||||
assert stored_settings.github_token == 'test-token'
|
||||
assert (
|
||||
stored_settings.secrets_store.provider_tokens[
|
||||
ProviderType.GITHUB
|
||||
].token.get_secret_value()
|
||||
== 'test-token'
|
||||
)
|
||||
|
||||
# Mock settings store to return our settings for the GET request
|
||||
mock_settings_store.load.return_value = Settings(**settings_data)
|
||||
@ -163,17 +205,21 @@ async def test_settings_api_set_github_token(
|
||||
data = response.json()
|
||||
|
||||
assert response.status_code == 200
|
||||
assert data.get('github_token') is None
|
||||
assert data['github_token_is_set'] is True
|
||||
assert data.get('token') is None
|
||||
assert data['token_is_set'] is True
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason='Mock middleware does not seem to properly set the github_token'
|
||||
)
|
||||
async def test_settings_unset_github_token(
|
||||
mock_github_service, test_client, mock_settings_store
|
||||
mock_github_service,
|
||||
test_client,
|
||||
mock_settings_store,
|
||||
mock_get_user_id,
|
||||
mock_validate_provider_token,
|
||||
):
|
||||
# Test data with unset_github_token set to True
|
||||
# Test data with unset_token set to True
|
||||
settings_data = {
|
||||
'language': 'en',
|
||||
'agent': 'test-agent',
|
||||
@ -183,7 +229,7 @@ async def test_settings_unset_github_token(
|
||||
'llm_model': 'test-model',
|
||||
'llm_api_key': 'test-key',
|
||||
'llm_base_url': 'https://test.com',
|
||||
'github_token': 'test-token',
|
||||
'provider_tokens': {'github': 'test-token'},
|
||||
}
|
||||
|
||||
# Mock settings store to return our settings for the GET request
|
||||
@ -191,23 +237,23 @@ async def test_settings_unset_github_token(
|
||||
|
||||
response = test_client.get('/api/settings')
|
||||
assert response.status_code == 200
|
||||
assert response.json()['github_token_is_set'] is True
|
||||
assert response.json()['token_is_set'] is True
|
||||
|
||||
settings_data['unset_github_token'] = True
|
||||
settings_data['unset_token'] = True
|
||||
|
||||
# Make the POST request to store settings
|
||||
response = test_client.post('/api/settings', json=settings_data)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify the settings were stored with the github_token unset
|
||||
# Verify the settings were stored with the provider token unset
|
||||
stored_settings = mock_settings_store.store.call_args[0][0]
|
||||
assert stored_settings.github_token is None
|
||||
assert not stored_settings.secrets_store.provider_tokens
|
||||
mock_settings_store.load.return_value = Settings(**stored_settings.dict())
|
||||
|
||||
# Make a GET request to retrieve settings
|
||||
response = test_client.get('/api/settings')
|
||||
assert response.status_code == 200
|
||||
assert response.json()['github_token_is_set'] is False
|
||||
assert response.json()['token_is_set'] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -222,6 +268,7 @@ async def test_settings_preserve_llm_fields_when_none(test_client, mock_settings
|
||||
llm_model='existing-model',
|
||||
llm_api_key=SecretStr('existing-key'),
|
||||
llm_base_url='https://existing.com',
|
||||
secrets_store=SecretStore(),
|
||||
)
|
||||
|
||||
# Mock the settings store to return our initial settings
|
||||
|
||||
@ -3,13 +3,13 @@ from unittest.mock import AsyncMock
|
||||
import pytest
|
||||
|
||||
from openhands.integrations.github.github_service import GitHubService
|
||||
from openhands.integrations.github.github_types import GitHubUser, TaskType
|
||||
from openhands.integrations.service_types import TaskType, User
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_suggested_tasks():
|
||||
# Mock responses
|
||||
mock_user = GitHubUser(
|
||||
mock_user = User(
|
||||
id=1,
|
||||
login='test-user',
|
||||
avatar_url='https://example.com/avatar.jpg',
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user