diff --git a/frontend/__tests__/routes/settings.test.tsx b/frontend/__tests__/routes/settings.test.tsx index dbbcadd8a1..f06c228390 100644 --- a/frontend/__tests__/routes/settings.test.tsx +++ b/frontend/__tests__/routes/settings.test.tsx @@ -721,7 +721,7 @@ describe("Settings Screen", () => { expect(saveSettingsSpy).toHaveBeenCalledWith( expect.objectContaining({ llm_api_key: "", // empty because it's not set previously - github_token: undefined, + provider_tokens: undefined, language: "no", }), ); @@ -758,7 +758,7 @@ describe("Settings Screen", () => { expect(saveSettingsSpy).toHaveBeenCalledWith( expect.objectContaining({ - github_token: undefined, + provider_tokens: undefined, llm_api_key: "", // empty because it's not set previously llm_model: "openai/gpt-4o", }), @@ -801,7 +801,7 @@ describe("Settings Screen", () => { expect(saveSettingsSpy).toHaveBeenCalledWith({ ...mockCopy, - github_token: undefined, // not set + provider_tokens: undefined, // not set llm_api_key: "", // reset as well }); expect(screen.queryByTestId("reset-modal")).not.toBeInTheDocument(); diff --git a/frontend/public/mockServiceWorker.js b/frontend/public/mockServiceWorker.js index ec47a9a50a..34057e898f 100644 --- a/frontend/public/mockServiceWorker.js +++ b/frontend/public/mockServiceWorker.js @@ -8,7 +8,7 @@ * - Please do NOT serve this file on production. */ -const PACKAGE_VERSION = '2.7.0' +const PACKAGE_VERSION = '2.7.3' const INTEGRITY_CHECKSUM = '00729d72e3b82faf54ca8b9621dbb96f' const IS_MOCKED_RESPONSE = Symbol('isMockedResponse') const activeClientIds = new Set() diff --git a/frontend/src/hooks/mutation/use-save-settings.ts b/frontend/src/hooks/mutation/use-save-settings.ts index 8bf7c1da40..50791977aa 100644 --- a/frontend/src/hooks/mutation/use-save-settings.ts +++ b/frontend/src/hooks/mutation/use-save-settings.ts @@ -17,7 +17,7 @@ const saveSettingsMutationFn = async (settings: Partial) => { ? "" : settings.LLM_API_KEY?.trim() || undefined, remote_runtime_resource_factor: settings.REMOTE_RUNTIME_RESOURCE_FACTOR, - github_token: settings.github_token, + provider_tokens: settings.provider_tokens, unset_github_token: settings.unset_github_token, enable_default_condenser: settings.ENABLE_DEFAULT_CONDENSER, enable_sound_notifications: settings.ENABLE_SOUND_NOTIFICATIONS, diff --git a/frontend/src/hooks/query/use-settings.ts b/frontend/src/hooks/query/use-settings.ts index 348cf45514..055aa99559 100644 --- a/frontend/src/hooks/query/use-settings.ts +++ b/frontend/src/hooks/query/use-settings.ts @@ -21,6 +21,7 @@ const getSettingsQueryFn = async () => { ENABLE_DEFAULT_CONDENSER: apiSettings.enable_default_condenser, ENABLE_SOUND_NOTIFICATIONS: apiSettings.enable_sound_notifications, USER_CONSENTS_TO_ANALYTICS: apiSettings.user_consents_to_analytics, + PROVIDER_TOKENS: apiSettings.provider_tokens, IS_NEW_USER: false, }; }; diff --git a/frontend/src/mocks/handlers.ts b/frontend/src/mocks/handlers.ts index f1ef5026d1..b3caacd210 100644 --- a/frontend/src/mocks/handlers.ts +++ b/frontend/src/mocks/handlers.ts @@ -22,6 +22,7 @@ export const MOCK_DEFAULT_USER_SETTINGS: ApiSettings | PostApiSettings = { enable_default_condenser: DEFAULT_SETTINGS.ENABLE_DEFAULT_CONDENSER, enable_sound_notifications: DEFAULT_SETTINGS.ENABLE_SOUND_NOTIFICATIONS, user_consents_to_analytics: DEFAULT_SETTINGS.USER_CONSENTS_TO_ANALYTICS, + provider_tokens: DEFAULT_SETTINGS.PROVIDER_TOKENS, }; const MOCK_USER_PREFERENCES: { @@ -190,8 +191,8 @@ export const handlers = [ if (!settings) return HttpResponse.json(null, { status: 404 }); - // @ts-expect-error - mock types - if (settings.github_token) settings.github_token_is_set = true; + if (Object.keys(settings.provider_tokens).length > 0) + settings.github_token_is_set = true; return HttpResponse.json(settings); }), @@ -203,7 +204,7 @@ export const handlers = [ if (typeof body === "object") { newSettings = { ...body }; if (newSettings.unset_github_token) { - newSettings.github_token = undefined; + newSettings.provider_tokens = { github: "", gitlab: "" }; newSettings.github_token_is_set = false; delete newSettings.unset_github_token; } diff --git a/frontend/src/routes/account-settings.tsx b/frontend/src/routes/account-settings.tsx index e7c374fa38..1f814e7ab4 100644 --- a/frontend/src/routes/account-settings.tsx +++ b/frontend/src/routes/account-settings.tsx @@ -61,7 +61,10 @@ function AccountSettings() { if (isSuccess) { return ( isCustomModel(resources.models, settings.LLM_MODEL) || - hasAdvancedSettingsSet(settings) + hasAdvancedSettingsSet({ + ...settings, + PROVIDER_TOKENS: settings.PROVIDER_TOKENS || {}, + }) ); } @@ -128,37 +131,42 @@ function AccountSettings() { : llmBaseUrl; const finalLlmApiKey = shouldHandleSpecialSaasCase ? undefined : llmApiKey; - saveSettings( - { - github_token: - formData.get("github-token-input")?.toString() || undefined, - LANGUAGE: languageValue, - user_consents_to_analytics: userConsentsToAnalytics, - ENABLE_DEFAULT_CONDENSER: enableMemoryCondenser, - ENABLE_SOUND_NOTIFICATIONS: enableSoundNotifications, - LLM_MODEL: finalLlmModel, - LLM_BASE_URL: finalLlmBaseUrl, - LLM_API_KEY: finalLlmApiKey, - AGENT: formData.get("agent-input")?.toString(), - SECURITY_ANALYZER: - formData.get("security-analyzer-input")?.toString() || "", - REMOTE_RUNTIME_RESOURCE_FACTOR: - remoteRuntimeResourceFactor || - DEFAULT_SETTINGS.REMOTE_RUNTIME_RESOURCE_FACTOR, - CONFIRMATION_MODE: confirmationModeIsEnabled, + const githubToken = formData.get("github-token-input")?.toString(); + const newSettings = { + github_token: githubToken, + provider_tokens: githubToken + ? { + github: githubToken, + gitlab: "", + } + : undefined, + LANGUAGE: languageValue, + user_consents_to_analytics: userConsentsToAnalytics, + ENABLE_DEFAULT_CONDENSER: enableMemoryCondenser, + ENABLE_SOUND_NOTIFICATIONS: enableSoundNotifications, + LLM_MODEL: finalLlmModel, + LLM_BASE_URL: finalLlmBaseUrl, + LLM_API_KEY: finalLlmApiKey, + AGENT: formData.get("agent-input")?.toString(), + SECURITY_ANALYZER: + formData.get("security-analyzer-input")?.toString() || "", + REMOTE_RUNTIME_RESOURCE_FACTOR: + remoteRuntimeResourceFactor || + DEFAULT_SETTINGS.REMOTE_RUNTIME_RESOURCE_FACTOR, + CONFIRMATION_MODE: confirmationModeIsEnabled, + }; + + saveSettings(newSettings, { + onSuccess: () => { + handleCaptureConsent(userConsentsToAnalytics); + displaySuccessToast("Settings saved"); + setLlmConfigMode(isAdvancedSettingsSet ? "advanced" : "basic"); }, - { - onSuccess: () => { - handleCaptureConsent(userConsentsToAnalytics); - displaySuccessToast("Settings saved"); - setLlmConfigMode(isAdvancedSettingsSet ? "advanced" : "basic"); - }, - onError: (error) => { - const errorMessage = retrieveAxiosErrorMessage(error); - displayErrorToast(errorMessage); - }, + onError: (error) => { + const errorMessage = retrieveAxiosErrorMessage(error); + displayErrorToast(errorMessage); }, - ); + }); }; const handleReset = () => { diff --git a/frontend/src/services/settings.ts b/frontend/src/services/settings.ts index 7931298b4b..91b95d2021 100644 --- a/frontend/src/services/settings.ts +++ b/frontend/src/services/settings.ts @@ -15,6 +15,10 @@ export const DEFAULT_SETTINGS: Settings = { ENABLE_DEFAULT_CONDENSER: true, ENABLE_SOUND_NOTIFICATIONS: false, USER_CONSENTS_TO_ANALYTICS: false, + PROVIDER_TOKENS: { + github: "", + gitlab: "", + }, IS_NEW_USER: true, }; diff --git a/frontend/src/types/core/variances.ts b/frontend/src/types/core/variances.ts index 5aca6ccd30..0492588c5c 100644 --- a/frontend/src/types/core/variances.ts +++ b/frontend/src/types/core/variances.ts @@ -21,7 +21,6 @@ export interface InitConfig { LLM_MODEL: string; }; token?: string; - github_token?: string; latest_event_id?: unknown; // Not sure what this is } diff --git a/frontend/src/types/settings.ts b/frontend/src/types/settings.ts index 3c7dc4915e..6e13a2ec50 100644 --- a/frontend/src/types/settings.ts +++ b/frontend/src/types/settings.ts @@ -1,3 +1,5 @@ +export type Provider = "github" | "gitlab"; + export type Settings = { LLM_MODEL: string; LLM_BASE_URL: string; @@ -11,6 +13,7 @@ export type Settings = { ENABLE_DEFAULT_CONDENSER: boolean; ENABLE_SOUND_NOTIFICATIONS: boolean; USER_CONSENTS_TO_ANALYTICS: boolean | null; + PROVIDER_TOKENS: Record; IS_NEW_USER?: boolean; }; @@ -27,16 +30,17 @@ export type ApiSettings = { enable_default_condenser: boolean; enable_sound_notifications: boolean; user_consents_to_analytics: boolean | null; + provider_tokens: Record; }; export type PostSettings = Settings & { - github_token: string; + provider_tokens: Record; unset_github_token: boolean; user_consents_to_analytics: boolean | null; }; export type PostApiSettings = ApiSettings & { - github_token: string; + provider_tokens: Record; unset_github_token: boolean; user_consents_to_analytics: boolean | null; }; diff --git a/frontend/src/utils/settings-utils.ts b/frontend/src/utils/settings-utils.ts index f92959abae..bcf0ec2160 100644 --- a/frontend/src/utils/settings-utils.ts +++ b/frontend/src/utils/settings-utils.ts @@ -59,6 +59,18 @@ export const extractSettings = (formData: FormData): Partial => { ENABLE_DEFAULT_CONDENSER, } = extractAdvancedFormData(formData); + // Extract provider tokens + const githubToken = formData.get("github-token")?.toString(); + const gitlabToken = formData.get("gitlab-token")?.toString(); + const providerTokens: Record = {}; + + if (githubToken) { + providerTokens.github = githubToken; + } + if (gitlabToken) { + providerTokens.gitlab = gitlabToken; + } + return { LLM_MODEL: CUSTOM_LLM_MODEL || LLM_MODEL, LLM_API_KEY, @@ -68,5 +80,6 @@ export const extractSettings = (formData: FormData): Partial => { CONFIRMATION_MODE, SECURITY_ANALYZER, ENABLE_DEFAULT_CONDENSER, + PROVIDER_TOKENS: providerTokens, }; }; diff --git a/openhands/integrations/github/github_service.py b/openhands/integrations/github/github_service.py index bb85cfdee5..1ccdc641d6 100644 --- a/openhands/integrations/github/github_service.py +++ b/openhands/integrations/github/github_service.py @@ -5,43 +5,43 @@ from typing import Any import httpx from pydantic import SecretStr -from openhands.core.logger import openhands_logger as logger -from openhands.integrations.github.github_types import ( - GhAuthenticationError, - GHUnknownException, - GitHubRepository, - GitHubUser, +from openhands.integrations.service_types import ( + AuthenticationError, + GitService, + Repository, SuggestedTask, TaskType, + UnknownException, + User, ) from openhands.utils.import_utils import get_impl +from openhands.core.logger import openhands_logger as logger - -class GitHubService: +class GitHubService(GitService): BASE_URL = 'https://api.github.com' - github_token: SecretStr = SecretStr('') + token: SecretStr = SecretStr('') refresh = False def __init__( self, user_id: str | None = None, external_auth_token: SecretStr | None = None, - github_token: SecretStr | None = None, + token: SecretStr | None = None, external_token_manager: bool = False, ): self.user_id = user_id self.external_token_manager = external_token_manager - if github_token: - self.github_token = github_token + if token: + self.token = token async def _get_github_headers(self) -> dict: """Retrieve the GH Token from settings store to construct the headers.""" - if self.user_id and not self.github_token: - self.github_token = await self.get_latest_token() + if self.user_id and not self.token: + self.token = await self.get_latest_token() return { - 'Authorization': f'Bearer {self.github_token.get_secret_value() if self.github_token else ""}', + 'Authorization': f'Bearer {self.token.get_secret_value() if self.token else ""}', 'Accept': 'application/vnd.github.v3+json', } @@ -49,7 +49,7 @@ class GitHubService: return status_code == 401 async def get_latest_token(self) -> SecretStr | None: - return self.github_token + return self.token async def _fetch_data( self, url: str, params: dict | None = None @@ -74,20 +74,20 @@ class GitHubService: except httpx.HTTPStatusError as e: if e.response.status_code == 401: - raise GhAuthenticationError('Invalid Github token') + raise AuthenticationError('Invalid Github token') logger.warning(f'Status error on GH API: {e}') - raise GHUnknownException('Unknown error') + raise UnknownException('Unknown error') except httpx.HTTPError as e: logger.warning(f'HTTP error on GH API: {e}') - raise GHUnknownException('Unknown error') + raise UnknownException('Unknown error') - async def get_user(self) -> GitHubUser: + async def get_user(self) -> User: url = f'{self.BASE_URL}/user' response, _ = await self._fetch_data(url) - return GitHubUser( + return User( id=response.get('id'), login=response.get('login'), avatar_url=response.get('avatar_url'), @@ -98,7 +98,7 @@ class GitHubService: async def get_repositories( self, page: int, per_page: int, sort: str, installation_id: int | None - ) -> list[GitHubRepository]: + ) -> list[Repository]: params = {'page': str(page), 'per_page': str(per_page)} if installation_id: url = f'{self.BASE_URL}/user/installations/{installation_id}/repositories' @@ -111,7 +111,7 @@ class GitHubService: next_link: str = headers.get('Link', '') repos = [ - GitHubRepository( + Repository( id=repo.get('id'), full_name=repo.get('full_name'), stargazers_count=repo.get('stargazers_count'), @@ -129,7 +129,7 @@ class GitHubService: async def search_repositories( self, query: str, per_page: int, sort: str, order: str - ) -> list[GitHubRepository]: + ) -> list[Repository]: url = f'{self.BASE_URL}/search/repositories' params = {'q': query, 'per_page': per_page, 'sort': sort, 'order': order} @@ -137,7 +137,7 @@ class GitHubService: repos = response.get('items', []) repos = [ - GitHubRepository( + Repository( id=repo.get('id'), full_name=repo.get('full_name'), stargazers_count=repo.get('stargazers_count'), @@ -163,7 +163,7 @@ class GitHubService: result = response.json() if 'errors' in result: - raise GHUnknownException( + raise UnknownException( f"GraphQL query error: {json.dumps(result['errors'])}" ) @@ -171,14 +171,14 @@ class GitHubService: except httpx.HTTPStatusError as e: if e.response.status_code == 401: - raise GhAuthenticationError('Invalid Github token') + raise AuthenticationError('Invalid Github token') logger.warning(f'Status error on GH API: {e}') - raise GHUnknownException('Unknown error') + raise UnknownException('Unknown error') except httpx.HTTPError as e: logger.warning(f'HTTP error on GH API: {e}') - raise GHUnknownException('Unknown error') + raise UnknownException('Unknown error') async def get_suggested_tasks(self) -> list[SuggestedTask]: """Get suggested tasks for the authenticated user across all repositories. diff --git a/openhands/integrations/github/github_types.py b/openhands/integrations/github/github_types.py deleted file mode 100644 index 1856b5e9ea..0000000000 --- a/openhands/integrations/github/github_types.py +++ /dev/null @@ -1,46 +0,0 @@ -from enum import Enum - -from pydantic import BaseModel - - -class TaskType(str, Enum): - MERGE_CONFLICTS = 'MERGE_CONFLICTS' - FAILING_CHECKS = 'FAILING_CHECKS' - UNRESOLVED_COMMENTS = 'UNRESOLVED_COMMENTS' - OPEN_ISSUE = 'OPEN_ISSUE' - OPEN_PR = 'OPEN_PR' - - -class SuggestedTask(BaseModel): - task_type: TaskType - repo: str - issue_number: int - title: str - - -class GitHubUser(BaseModel): - id: int - login: str - avatar_url: str - company: str | None = None - name: str | None = None - email: str | None = None - - -class GitHubRepository(BaseModel): - id: int - full_name: str - stargazers_count: int | None = None - link_header: str | None = None - - -class GhAuthenticationError(ValueError): - """Raised when there is an issue with GitHub authentication.""" - - pass - - -class GHUnknownException(ValueError): - """Raised when there is an issue with GitHub communcation.""" - - pass diff --git a/openhands/integrations/gitlab/gitlab_service.py b/openhands/integrations/gitlab/gitlab_service.py new file mode 100644 index 0000000000..9e3dae38e7 --- /dev/null +++ b/openhands/integrations/gitlab/gitlab_service.py @@ -0,0 +1,119 @@ +import os +from typing import Any + +import httpx +from pydantic import SecretStr + +from openhands.integrations.service_types import ( + AuthenticationError, + GitService, + Repository, + UnknownException, + User, +) +from openhands.utils.import_utils import get_impl + + +class GitLabService(GitService): + BASE_URL = 'https://gitlab.com/api/v4' + token: SecretStr = SecretStr('') + refresh = False + + def __init__( + self, + user_id: str | None = None, + external_auth_token: SecretStr | None = None, + token: SecretStr | None = None, + external_token_manager: bool = False, + ): + self.user_id = user_id + self.external_token_manager = external_token_manager + + if token: + self.token = token + + async def _get_gitlab_headers(self) -> dict: + """ + Retrieve the GitLab Token to construct the headers + """ + if self.user_id and not self.token: + self.token = await self.get_latest_token() + + return { + 'Authorization': f'Bearer {self.token.get_secret_value()}', + } + + def _has_token_expired(self, status_code: int) -> bool: + return status_code == 401 + + async def get_latest_token(self) -> SecretStr: + return self.token + + async def _fetch_data( + self, url: str, params: dict | None = None + ) -> tuple[Any, dict]: + try: + async with httpx.AsyncClient() as client: + gitlab_headers = await self._get_gitlab_headers() + response = await client.get(url, headers=gitlab_headers, params=params) + if self.refresh and self._has_token_expired(response.status_code): + await self.get_latest_token() + gitlab_headers = await self._get_gitlab_headers() + response = await client.get( + url, headers=gitlab_headers, params=params + ) + + response.raise_for_status() + headers = {} + if 'Link' in response.headers: + headers['Link'] = response.headers['Link'] + + return response.json(), headers + + except httpx.HTTPStatusError as e: + if e.response.status_code == 401: + raise AuthenticationError('Invalid GitLab token') + raise UnknownException('Unknown error') + + except httpx.HTTPError: + raise UnknownException('Unknown error') + + async def get_user(self) -> User: + url = f'{self.BASE_URL}/user' + response, _ = await self._fetch_data(url) + + return User( + id=response.get('id'), + username=response.get('username'), + avatar_url=response.get('avatar_url'), + name=response.get('name'), + email=response.get('email'), + company=response.get('organization'), + login=response.get('username'), + ) + + async def search_repositories( + self, query: str, per_page: int = 30, sort: str = 'updated', order: str = 'desc' + ): + url = f'{self.BASE_URL}/search' + params = { + 'scope': 'projects', + 'search': query, + 'per_page': per_page, + 'order_by': sort, + 'sort': order, + } + response, headers = await self._fetch_data(url, params) + return response, headers + + async def get_repositories( + self, page: int, per_page: int, sort: str, installation_id: int | None + ) -> list[Repository]: + return [] + + +gitlab_service_cls = os.environ.get( + 'OPENHANDS_GITLAB_SERVICE_CLS', + 'openhands.integrations.gitlab.gitlab_service.GitLabService', +) +GitLabServiceImpl = get_impl(GitLabService, gitlab_service_cls) diff --git a/openhands/integrations/provider.py b/openhands/integrations/provider.py new file mode 100644 index 0000000000..0076ea1c91 --- /dev/null +++ b/openhands/integrations/provider.py @@ -0,0 +1,143 @@ +from enum import Enum + +from pydantic import BaseModel, SecretStr, SerializationInfo, field_serializer +from pydantic.json import pydantic_encoder + +from openhands.integrations.github.github_service import GithubServiceImpl +from openhands.integrations.gitlab.gitlab_service import GitLabServiceImpl +from openhands.integrations.service_types import ( + AuthenticationError, + GitService, + Repository, + User, +) + + +class ProviderType(Enum): + GITHUB = 'github' + GITLAB = 'gitlab' + + +class ProviderToken(BaseModel): + token: SecretStr | None + user_id: str | None + + +PROVIDER_TOKEN_TYPE = dict[ProviderType, ProviderToken] +CUSTOM_SECRETS_TYPE = dict[str, SecretStr] + + +class SecretStore(BaseModel): + provider_tokens: PROVIDER_TOKEN_TYPE = {} + + @classmethod + def _convert_token( + cls, token_value: str | ProviderToken | SecretStr + ) -> ProviderToken: + if isinstance(token_value, ProviderToken): + return token_value + elif isinstance(token_value, str): + return ProviderToken(token=SecretStr(token_value), user_id=None) + elif isinstance(token_value, SecretStr): + return ProviderToken(token=token_value, user_id=None) + else: + raise ValueError(f'Invalid token type: {type(token_value)}') + + def model_post_init(self, __context) -> None: + # Convert any string tokens to ProviderToken objects + converted_tokens = {} + for token_type, token_value in self.provider_tokens.items(): + if token_value: # Only convert non-empty tokens + try: + if isinstance(token_type, str): + token_type = ProviderType(token_type) + converted_tokens[token_type] = self._convert_token(token_value) + except ValueError: + # Skip invalid provider types or tokens + continue + self.provider_tokens = converted_tokens + + @field_serializer('provider_tokens') + def provider_tokens_serializer( + self, provider_tokens: PROVIDER_TOKEN_TYPE, info: SerializationInfo + ): + tokens = {} + expose_secrets = info.context and info.context.get('expose_secrets', False) + + for token_type, provider_token in provider_tokens.items(): + if not provider_token or not provider_token.token: + continue + + token_type_str = ( + token_type.value + if isinstance(token_type, ProviderType) + else str(token_type) + ) + tokens[token_type_str] = { + 'token': provider_token.token.get_secret_value() + if expose_secrets + else pydantic_encoder(provider_token.token), + 'user_id': provider_token.user_id, + } + + return tokens + + +class ProviderHandler: + def __init__( + self, + provider_tokens: PROVIDER_TOKEN_TYPE, + external_auth_token: SecretStr | None = None, + ): + self.service_class_map: dict[ProviderType, type[GitService]] = { + ProviderType.GITHUB: GithubServiceImpl, + ProviderType.GITLAB: GitLabServiceImpl, + } + + self.provider_tokens = provider_tokens + self.external_auth_token = external_auth_token + + def _get_service(self, provider: ProviderType) -> GitService: + """Helper method to instantiate a service for a given provider""" + token = self.provider_tokens[provider] + service_class = self.service_class_map[provider] + return service_class( + user_id=token.user_id, + external_auth_token=self.external_auth_token, + token=token.token, + ) + + async def get_user(self) -> User: + """Get user information from the first available provider""" + for provider in self.provider_tokens: + try: + service = self._get_service(provider) + return await service.get_user() + except Exception: + continue + raise AuthenticationError('Need valid provider token') + + async def get_latest_provider_tokens(self) -> dict[ProviderType, SecretStr]: + """Get latest token from services""" + tokens = {} + for provider in self.provider_tokens: + service = self._get_service(provider) + tokens[provider] = await service.get_latest_token() + + return tokens + + async def get_repositories( + self, page: int, per_page: int, sort: str, installation_id: int | None + ) -> list[Repository]: + """Get repositories from all available providers""" + all_repos = [] + for provider in self.provider_tokens: + try: + service = self._get_service(provider) + repos = await service.get_repositories( + page, per_page, sort, installation_id + ) + all_repos.extend(repos) + except Exception: + continue + return all_repos diff --git a/openhands/integrations/service_types.py b/openhands/integrations/service_types.py new file mode 100644 index 0000000000..916acbf35a --- /dev/null +++ b/openhands/integrations/service_types.py @@ -0,0 +1,89 @@ +from enum import Enum +from typing import Protocol + +from pydantic import BaseModel, SecretStr + + +class TaskType(str, Enum): + MERGE_CONFLICTS = 'MERGE_CONFLICTS' + FAILING_CHECKS = 'FAILING_CHECKS' + UNRESOLVED_COMMENTS = 'UNRESOLVED_COMMENTS' + OPEN_ISSUE = 'OPEN_ISSUE' + OPEN_PR = 'OPEN_PR' + + +class SuggestedTask(BaseModel): + task_type: TaskType + repo: str + issue_number: int + title: str + + +class User(BaseModel): + id: int + login: str + avatar_url: str + company: str | None = None + name: str | None = None + email: str | None = None + + +class Repository(BaseModel): + id: int + full_name: str + stargazers_count: int | None = None + link_header: str | None = None + + +class AuthenticationError(ValueError): + """Raised when there is an issue with GitHub authentication.""" + + pass + + +class UnknownException(ValueError): + """Raised when there is an issue with GitHub communcation.""" + + pass + + +class GitService(Protocol): + """Protocol defining the interface for Git service providers""" + + def __init__( + self, + user_id: str | None, + token: SecretStr | None, + external_auth_token: SecretStr | None, + external_token_manager: bool = False, + ) -> None: + """Initialize the service with authentication details""" + ... + + async def get_latest_token(self) -> SecretStr: + """Get latest working token of the users""" + ... + + async def get_user(self) -> User: + """Get the authenticated user's information""" + ... + + async def search_repositories( + self, + query: str, + per_page: int, + sort: str, + order: str, + ) -> list[Repository]: + """Search for repositories""" + ... + + async def get_repositories( + self, + page: int, + per_page: int, + sort: str, + installation_id: int | None, + ) -> list[Repository]: + """Get repositories for the authenticated user""" + ... diff --git a/openhands/integrations/utils.py b/openhands/integrations/utils.py new file mode 100644 index 0000000000..959ea2314f --- /dev/null +++ b/openhands/integrations/utils.py @@ -0,0 +1,37 @@ +from pydantic import SecretStr + +from openhands.integrations.github.github_service import GitHubService +from openhands.integrations.gitlab.gitlab_service import GitLabService +from openhands.integrations.provider import ProviderType + + +async def validate_provider_token(token: SecretStr) -> ProviderType | None: + """ + Determine whether a token is for GitHub or GitLab by attempting to get user info + from both services. + + Args: + token: The token to check + + Returns: + 'github' if it's a GitHub token + 'gitlab' if it's a GitLab token + None if the token is invalid for both services + """ + # Try GitHub first + try: + github_service = GitHubService(token=token) + await github_service.get_user() + return ProviderType.GITHUB + except Exception: + pass + + # Try GitLab next + try: + gitlab_service = GitLabService(token=token) + await gitlab_service.get_user() + return ProviderType.GITLAB + except Exception: + pass + + return None diff --git a/openhands/server/auth.py b/openhands/server/auth.py index 55ded747a0..eeaee04baf 100644 --- a/openhands/server/auth.py +++ b/openhands/server/auth.py @@ -1,6 +1,13 @@ from fastapi import Request from pydantic import SecretStr +from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderType + + +def get_provider_tokens(request: Request) -> PROVIDER_TOKEN_TYPE | None: + """Get GitHub token from request state. For backward compatibility.""" + return getattr(request.state, 'provider_tokens', None) + def get_access_token(request: Request) -> SecretStr | None: return getattr(request.state, 'access_token', None) @@ -11,8 +18,18 @@ def get_user_id(request: Request) -> str | None: def get_github_token(request: Request) -> SecretStr | None: - return getattr(request.state, 'github_token', None) + provider_tokens = get_provider_tokens(request) + + if provider_tokens and ProviderType.GITHUB in provider_tokens: + return provider_tokens[ProviderType.GITHUB].token + + return None def get_github_user_id(request: Request) -> str | None: - return getattr(request.state, 'github_user_id', None) + provider_tokens = get_provider_tokens(request) + + if provider_tokens and ProviderType.GITHUB in provider_tokens: + return provider_tokens[ProviderType.GITHUB].user_id + + return None diff --git a/openhands/server/middleware.py b/openhands/server/middleware.py index 734d52004b..3d690306f7 100644 --- a/openhands/server/middleware.py +++ b/openhands/server/middleware.py @@ -194,10 +194,14 @@ class GitHubTokenMiddleware(SessionMiddlewareInterface): settings = await settings_store.load() # TODO: To avoid checks like this we should re-add the abilty to have completely different middleware in SAAS as in OSS - if getattr(request.state, 'github_token', None) is None: - if settings and settings.github_token: - request.state.github_token = settings.github_token + if getattr(request.state, 'provider_tokens', None) is None: + if ( + settings + and settings.secrets_store + and settings.secrets_store.provider_tokens + ): + request.state.provider_tokens = settings.secrets_store.provider_tokens else: - request.state.github_token = None + request.state.provider_tokens = None return await call_next(request) diff --git a/openhands/server/routes/github.py b/openhands/server/routes/github.py index 1435987cbc..08705e4763 100644 --- a/openhands/server/routes/github.py +++ b/openhands/server/routes/github.py @@ -3,147 +3,168 @@ from fastapi.responses import JSONResponse from pydantic import SecretStr from openhands.integrations.github.github_service import GithubServiceImpl -from openhands.integrations.github.github_types import ( - GhAuthenticationError, - GHUnknownException, - GitHubRepository, - GitHubUser, - SuggestedTask, +from openhands.integrations.provider import ( + PROVIDER_TOKEN_TYPE, + ProviderHandler, + ProviderType, ) -from openhands.server.auth import get_access_token, get_github_token, get_github_user_id +from openhands.integrations.service_types import ( + AuthenticationError, + Repository, + SuggestedTask, + UnknownException, + User, +) +from openhands.server.auth import get_access_token, get_provider_tokens app = APIRouter(prefix='/api/github') -@app.get('/repositories', response_model=list[GitHubRepository]) +@app.get('/repositories', response_model=list[Repository]) async def get_github_repositories( page: int = 1, per_page: int = 10, sort: str = 'pushed', installation_id: int | None = None, - github_user_id: str | None = Depends(get_github_user_id), - github_user_token: SecretStr | None = Depends(get_github_token), + provider_tokens: PROVIDER_TOKEN_TYPE | None = Depends(get_provider_tokens), access_token: SecretStr | None = Depends(get_access_token), ): - client = GithubServiceImpl( - user_id=github_user_id, - external_auth_token=access_token, - github_token=github_user_token, + if provider_tokens and ProviderType.GITHUB in provider_tokens: + token = provider_tokens[ProviderType.GITHUB] + client = GithubServiceImpl( + user_id=token.user_id, external_auth_token=access_token, token=token.token + ) + + try: + repos: list[Repository] = await client.get_repositories( + page, per_page, sort, installation_id + ) + return repos + + except AuthenticationError as e: + return JSONResponse( + content=str(e), + status_code=status.HTTP_401_UNAUTHORIZED, + ) + + except UnknownException as e: + return JSONResponse( + content=str(e), + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + return JSONResponse( + content='GitHub token required.', + status_code=status.HTTP_401_UNAUTHORIZED, ) - try: - repos: list[GitHubRepository] = await client.get_repositories( - page, per_page, sort, installation_id - ) - return repos - - except GhAuthenticationError as e: - return JSONResponse( - content=str(e), - status_code=status.HTTP_401_UNAUTHORIZED, - ) - - except GHUnknownException as e: - return JSONResponse( - content=str(e), - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) -@app.get('/user', response_model=GitHubUser) +@app.get('/user', response_model=User) async def get_github_user( - github_user_id: str | None = Depends(get_github_user_id), - github_user_token: SecretStr | None = Depends(get_github_token), + provider_tokens: PROVIDER_TOKEN_TYPE | None = Depends(get_provider_tokens), access_token: SecretStr | None = Depends(get_access_token), ): - client = GithubServiceImpl( - user_id=github_user_id, - external_auth_token=access_token, - github_token=github_user_token, + if provider_tokens: + client = ProviderHandler(provider_tokens=provider_tokens, external_auth_token=access_token) + + try: + user: User = await client.get_user() + return user + + except AuthenticationError as e: + return JSONResponse( + content=str(e), + status_code=status.HTTP_401_UNAUTHORIZED, + ) + + except UnknownException as e: + return JSONResponse( + content=str(e), + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + return JSONResponse( + content='GitHub token required.', + status_code=status.HTTP_401_UNAUTHORIZED, ) - try: - user: GitHubUser = await client.get_user() - return user - - except GhAuthenticationError as e: - return JSONResponse( - content=str(e), - status_code=status.HTTP_401_UNAUTHORIZED, - ) - - except GHUnknownException as e: - return JSONResponse( - content=str(e), - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) @app.get('/installations', response_model=list[int]) async def get_github_installation_ids( - github_user_id: str | None = Depends(get_github_user_id), - github_user_token: SecretStr | None = Depends(get_github_token), + provider_tokens: PROVIDER_TOKEN_TYPE | None = Depends(get_provider_tokens), access_token: SecretStr | None = Depends(get_access_token), ): - client = GithubServiceImpl( - user_id=github_user_id, - external_auth_token=access_token, - github_token=github_user_token, + if provider_tokens and ProviderType.GITHUB in provider_tokens: + token = provider_tokens[ProviderType.GITHUB] + + client = GithubServiceImpl( + user_id=token.user_id, external_auth_token=access_token, token=token.token + ) + try: + installations_ids: list[int] = await client.get_installation_ids() + return installations_ids + + except AuthenticationError as e: + return JSONResponse( + content=str(e), + status_code=status.HTTP_401_UNAUTHORIZED, + ) + + except UnknownException as e: + return JSONResponse( + content=str(e), + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + return JSONResponse( + content='GitHub token required.', + status_code=status.HTTP_401_UNAUTHORIZED, ) - try: - installations_ids: list[int] = await client.get_installation_ids() - return installations_ids - - except GhAuthenticationError as e: - return JSONResponse( - content=str(e), - status_code=status.HTTP_401_UNAUTHORIZED, - ) - - except GHUnknownException as e: - return JSONResponse( - content=str(e), - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) -@app.get('/search/repositories', response_model=list[GitHubRepository]) +@app.get('/search/repositories', response_model=list[Repository]) async def search_github_repositories( query: str, per_page: int = 5, sort: str = 'stars', order: str = 'desc', - github_user_id: str | None = Depends(get_github_user_id), - github_user_token: SecretStr | None = Depends(get_github_token), + provider_tokens: PROVIDER_TOKEN_TYPE | None = Depends(get_provider_tokens), access_token: SecretStr | None = Depends(get_access_token), ): - client = GithubServiceImpl( - user_id=github_user_id, - external_auth_token=access_token, - github_token=github_user_token, + if provider_tokens and ProviderType.GITHUB in provider_tokens: + token = provider_tokens[ProviderType.GITHUB] + + client = GithubServiceImpl( + user_id=token.user_id, external_auth_token=access_token, token=token.token + ) + try: + repos: list[Repository] = await client.search_repositories( + query, per_page, sort, order + ) + return repos + + except AuthenticationError as e: + return JSONResponse( + content=str(e), + status_code=status.HTTP_401_UNAUTHORIZED, + ) + + except UnknownException as e: + return JSONResponse( + content=str(e), + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + return JSONResponse( + content='GitHub token required.', + status_code=status.HTTP_401_UNAUTHORIZED, ) - try: - repos: list[GitHubRepository] = await client.search_repositories( - query, per_page, sort, order - ) - return repos - - except GhAuthenticationError as e: - return JSONResponse( - content=str(e), - status_code=status.HTTP_401_UNAUTHORIZED, - ) - - except GHUnknownException as e: - return JSONResponse( - content=str(e), - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) @app.get('/suggested-tasks', response_model=list[SuggestedTask]) async def get_suggested_tasks( - github_user_id: str | None = Depends(get_github_user_id), - github_user_token: SecretStr | None = Depends(get_github_token), - access_token: SecretStr | None = Depends(get_access_token), + provider_tokens: PROVIDER_TOKEN_TYPE | None = Depends(get_provider_tokens), + access_token: SecretStr | None = Depends(get_access_token) ): """Get suggested tasks for the authenticated user across their most recently pushed repositories. @@ -151,23 +172,30 @@ async def get_suggested_tasks( - PRs owned by the user - Issues assigned to the user. """ - client = GithubServiceImpl( - user_id=github_user_id, - external_auth_token=access_token, - github_token=github_user_token, + + if provider_tokens and ProviderType.GITHUB in provider_tokens: + token = provider_tokens[ProviderType.GITHUB] + + client = GithubServiceImpl( + user_id=token.user_id, external_auth_token=access_token, token=token.token + ) + try: + tasks: list[SuggestedTask] = await client.get_suggested_tasks() + return tasks + + except AuthenticationError as e: + return JSONResponse( + content=str(e), + status_code=401, + ) + + except UnknownException as e: + return JSONResponse( + content=str(e), + status_code=500, + ) + + return JSONResponse( + content='GitHub token required.', + status_code=status.HTTP_401_UNAUTHORIZED, ) - try: - tasks: list[SuggestedTask] = await client.get_suggested_tasks() - return tasks - - except GhAuthenticationError as e: - return JSONResponse( - content=str(e), - status_code=401, - ) - - except GHUnknownException as e: - return JSONResponse( - content=str(e), - status_code=500, - ) diff --git a/openhands/server/routes/manage_conversations.py b/openhands/server/routes/manage_conversations.py index 4d991c13a4..ae5c714bed 100644 --- a/openhands/server/routes/manage_conversations.py +++ b/openhands/server/routes/manage_conversations.py @@ -8,8 +8,9 @@ from pydantic import BaseModel, SecretStr from openhands.core.logger import openhands_logger as logger from openhands.events.action.message import MessageAction from openhands.integrations.github.github_service import GithubServiceImpl +from openhands.integrations.provider import ProviderType from openhands.runtime import get_runtime_cls -from openhands.server.auth import get_access_token, get_github_token, get_github_user_id +from openhands.server.auth import get_provider_tokens, get_access_token, get_github_user_id from openhands.server.data_models.conversation_info import ConversationInfo from openhands.server.data_models.conversation_info_result_set import ( ConversationInfoResultSet, @@ -136,13 +137,18 @@ async def new_conversation(request: Request, data: InitSessionRequest): using the returned conversation ID. """ logger.info('Initializing new conversation') - user_id = get_github_user_id(request) - gh_client = GithubServiceImpl( - user_id=user_id, - external_auth_token=get_access_token(request), - github_token=get_github_token(request), - ) - github_token = await gh_client.get_latest_token() + user_id = None + github_token = None + provider_tokens = get_provider_tokens(request) + if provider_tokens and ProviderType.GITHUB in provider_tokens: + token = provider_tokens[ProviderType.GITHUB] + user_id = token.user_id + gh_client = GithubServiceImpl( + user_id=user_id, + external_auth_token=get_access_token(request), + token=token.token, + ) + github_token = await gh_client.get_latest_token() selected_repository = data.selected_repository selected_branch = data.selected_branch diff --git a/openhands/server/routes/settings.py b/openhands/server/routes/settings.py index 6418f18c2e..63fc03f6bb 100644 --- a/openhands/server/routes/settings.py +++ b/openhands/server/routes/settings.py @@ -3,8 +3,9 @@ from fastapi.responses import JSONResponse from pydantic import SecretStr from openhands.core.logger import openhands_logger as logger -from openhands.integrations.github.github_service import GithubServiceImpl -from openhands.server.auth import get_github_token, get_user_id +from openhands.integrations.provider import ProviderToken, ProviderType +from openhands.integrations.utils import validate_provider_token +from openhands.server.auth import get_provider_tokens, get_user_id from openhands.server.settings import GETSettingsModel, POSTSettingsModel, Settings from openhands.server.shared import SettingsStoreImpl, config @@ -23,14 +24,14 @@ async def load_settings(request: Request) -> GETSettingsModel | JSONResponse: content={'error': 'Settings not found'}, ) - token_is_set = bool(user_id) or bool(get_github_token(request)) + github_token_is_set = bool(user_id) or bool(get_provider_tokens(request)) settings_with_token_data = GETSettingsModel( **settings.model_dump(), - github_token_is_set=token_is_set, + github_token_is_set=github_token_is_set, ) settings_with_token_data.llm_api_key = settings.llm_api_key - del settings_with_token_data.github_token + del settings_with_token_data.secrets_store return settings_with_token_data except Exception as e: logger.warning(f'Invalid token: {e}') @@ -45,26 +46,27 @@ async def store_settings( request: Request, settings: POSTSettingsModel, ) -> JSONResponse: - # Check if token is valid - if settings.github_token: - try: - # We check if the token is valid by getting the user - # If the token is invalid, this will raise an exception - github = GithubServiceImpl( - user_id=None, - external_auth_token=None, - github_token=SecretStr(settings.github_token), - ) - await github.get_user() + # Check provider tokens are valid + if settings.provider_tokens: + # Remove extraneous token types + provider_types = [provider.value for provider in ProviderType] + settings.provider_tokens = { + k: v for k, v in settings.provider_tokens.items() if k in provider_types + } - except Exception as e: - logger.warning(f'Invalid GitHub token: {e}') - return JSONResponse( - status_code=status.HTTP_401_UNAUTHORIZED, - content={ - 'error': 'Invalid GitHub token. Please make sure it is valid.' - }, - ) + # Determine whether tokens are valid + for token_type, token_value in settings.provider_tokens.items(): + if token_value: + confirmed_token_type = await validate_provider_token( + SecretStr(token_value) + ) + if not confirmed_token_type or confirmed_token_type.value != token_type: + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={ + 'error': f'Invalid token. Please make sure it is a valid {token_type} token.' + }, + ) try: settings_store = await SettingsStoreImpl.get_instance( @@ -72,32 +74,46 @@ async def store_settings( ) existing_settings = await settings_store.load() + # Convert to Settings model and merge with existing settings if existing_settings: - # LLM key isn't on the frontend, so we need to keep it if unset + # Keep existing LLM settings if not provided if settings.llm_api_key is None: settings.llm_api_key = existing_settings.llm_api_key + if settings.llm_model is None: + settings.llm_model = existing_settings.llm_model + if settings.llm_base_url is None: + settings.llm_base_url = existing_settings.llm_base_url - if settings.github_token is None: - settings.github_token = existing_settings.github_token - + # Keep existing analytics consent if not provided if settings.user_consents_to_analytics is None: settings.user_consents_to_analytics = ( existing_settings.user_consents_to_analytics ) - if settings.llm_model is None: - settings.llm_model = existing_settings.llm_model + if existing_settings.secrets_store: + existing_providers = [ + provider.value + for provider in existing_settings.secrets_store.provider_tokens + ] - if settings.llm_base_url is None: - settings.llm_base_url = existing_settings.llm_base_url + # Merge incoming settings store with the existing one + for provider, token_value in settings.provider_tokens.items(): + if provider in existing_providers and not token_value: + provider_type = ProviderType(provider) + existing_token = ( + existing_settings.secrets_store.provider_tokens.get( + provider_type + ) + ) + if existing_token and existing_token.token: + settings.provider_tokens[provider] = ( + existing_token.token.get_secret_value() + ) - response = JSONResponse( - status_code=status.HTTP_200_OK, - content={'message': 'Settings stored'}, - ) - - if settings.unset_github_token: - settings.github_token = None + # Merge provider tokens with existing ones + if settings.unset_github_token: # Only merge if not unsetting tokens + settings.secrets_store.provider_tokens = {} + settings.provider_tokens = {} # Update sandbox config with new settings if settings.remote_runtime_resource_factor is not None: @@ -106,9 +122,11 @@ async def store_settings( ) settings = convert_to_settings(settings) - await settings_store.store(settings) - return response + return JSONResponse( + status_code=status.HTTP_200_OK, + content={'message': 'Settings stored'}, + ) except Exception as e: logger.warning(f'Something went wrong storing settings: {e}') return JSONResponse( @@ -127,8 +145,19 @@ def convert_to_settings(settings_with_token_data: POSTSettingsModel) -> Settings if key in Settings.model_fields # Ensures only `Settings` fields are included } - # Convert the `llm_api_key` and `github_token` to a `SecretStr` instance + # Convert the `llm_api_key` to a `SecretStr` instance filtered_settings_data['llm_api_key'] = settings_with_token_data.llm_api_key - filtered_settings_data['github_token'] = settings_with_token_data.github_token - return Settings(**filtered_settings_data) + # Create a new Settings instance without provider tokens + settings = Settings(**filtered_settings_data) + + # Update provider tokens if any are provided + if settings_with_token_data.provider_tokens: + for token_type, token_value in settings_with_token_data.provider_tokens.items(): + if token_value: + provider = ProviderType(token_type) + settings.secrets_store.provider_tokens[provider] = ProviderToken( + token=SecretStr(token_value), user_id=None + ) + + return settings diff --git a/openhands/server/settings.py b/openhands/server/settings.py index 96e0b775f9..4059bf4886 100644 --- a/openhands/server/settings.py +++ b/openhands/server/settings.py @@ -1,10 +1,17 @@ from __future__ import annotations -from pydantic import BaseModel, SecretStr, SerializationInfo, field_serializer +from pydantic import ( + BaseModel, + SecretStr, + SerializationInfo, + field_serializer, + model_validator, +) from pydantic.json import pydantic_encoder from openhands.core.config.llm_config import LLMConfig from openhands.core.config.utils import load_app_config +from openhands.integrations.provider import ProviderToken, ProviderType, SecretStore class Settings(BaseModel): @@ -21,7 +28,7 @@ class Settings(BaseModel): llm_api_key: SecretStr | None = None llm_base_url: str | None = None remote_runtime_resource_factor: int | None = None - github_token: SecretStr | None = None + secrets_store: SecretStore = SecretStore() enable_default_condenser: bool = False enable_sound_notifications: bool = False user_consents_to_analytics: bool | None = None @@ -38,22 +45,63 @@ class Settings(BaseModel): return pydantic_encoder(llm_api_key) - @field_serializer('github_token') - def github_token_serializer( - self, github_token: SecretStr | None, info: SerializationInfo - ): - """Custom serializer for the GitHub token. + @staticmethod + def _convert_token_value( + token_type: ProviderType, token_value: str | dict + ) -> ProviderToken | None: + """Convert a token value to a ProviderToken object.""" + if isinstance(token_value, dict): + token_str = token_value.get('token') + if not token_str: + return None + return ProviderToken( + token=SecretStr(token_str), + user_id=token_value.get('user_id'), + ) + if isinstance(token_value, str) and token_value: + return ProviderToken(token=SecretStr(token_value), user_id=None) + return None - To serialize the token instead of ********, set expose_secrets to True in the serialization context. - """ - if github_token is None: - return None + @model_validator(mode='before') + @classmethod + def convert_provider_tokens(cls, data: dict | object) -> dict | object: + """Convert provider tokens from JSON format to SecretStore format.""" + if not isinstance(data, dict): + return data - context = info.context - if context and context.get('expose_secrets', False): - return github_token.get_secret_value() + secrets_store = data.get('secrets_store') + if not isinstance(secrets_store, dict): + return data - return pydantic_encoder(github_token) + tokens = secrets_store.get('provider_tokens') + if not isinstance(tokens, dict): + return data + + converted_tokens = {} + for token_type_str, token_value in tokens.items(): + if not token_value: + continue + + try: + token_type = ProviderType(token_type_str) + except ValueError: + continue + + provider_token = cls._convert_token_value(token_type, token_value) + if provider_token: + converted_tokens[token_type] = provider_token + + data['secrets_store'] = SecretStore(provider_tokens=converted_tokens) + return data + + @field_serializer('secrets_store') + def secrets_store_serializer(self, secrets: SecretStore, info: SerializationInfo): + """Custom serializer for secrets store.""" + return { + 'provider_tokens': secrets.provider_tokens_serializer( + secrets.provider_tokens, info + ) + } @staticmethod def from_config() -> Settings | None: @@ -73,7 +121,7 @@ class Settings(BaseModel): llm_api_key=llm_config.api_key, llm_base_url=llm_config.base_url, remote_runtime_resource_factor=app_config.sandbox.remote_runtime_resource_factor, - github_token=None, + provider_tokens={}, ) return settings @@ -84,14 +132,12 @@ class POSTSettingsModel(Settings): """ unset_github_token: bool | None = None - github_token: str | None = ( - None # This is a string because it's coming from the frontend - ) + # Override provider_tokens to accept string tokens from frontend + provider_tokens: dict[str, str] = {} - # Override the serializer for the GitHub token to handle the string input - @field_serializer('github_token') - def github_token_serializer(self, github_token: str | None): - return github_token + @field_serializer('provider_tokens') + def provider_tokens_serializer(self, provider_tokens: dict[str, str]): + return provider_tokens class GETSettingsModel(Settings): diff --git a/tests/unit/test_github_service.py b/tests/unit/test_github_service.py index 627c60121f..5446d9a88a 100644 --- a/tests/unit/test_github_service.py +++ b/tests/unit/test_github_service.py @@ -5,16 +5,16 @@ import pytest from pydantic import SecretStr from openhands.integrations.github.github_service import GitHubService -from openhands.integrations.github.github_types import GhAuthenticationError +from openhands.integrations.service_types import AuthenticationError @pytest.mark.asyncio async def test_github_service_token_handling(): # Test initialization with SecretStr token token = SecretStr('test-token') - service = GitHubService(user_id=None, github_token=token) - assert service.github_token == token - assert service.github_token.get_secret_value() == 'test-token' + service = GitHubService(user_id=None, token=token) + assert service.token == token + assert service.token.get_secret_value() == 'test-token' # Test headers contain the token correctly headers = await service._get_github_headers() @@ -23,14 +23,14 @@ async def test_github_service_token_handling(): # Test initialization without token service = GitHubService(user_id='test-user') - assert service.github_token == SecretStr('') + assert service.token == SecretStr('') @pytest.mark.asyncio async def test_github_service_token_refresh(): # Test that token refresh is only attempted when refresh=True token = SecretStr('test-token') - service = GitHubService(user_id=None, github_token=token) + service = GitHubService(user_id=None, token=token) assert not service.refresh # Test token expiry detection @@ -58,7 +58,7 @@ async def test_github_service_fetch_data(): mock_client.__aexit__.return_value = None with patch('httpx.AsyncClient', return_value=mock_client): - service = GitHubService(user_id=None, github_token=SecretStr('test-token')) + service = GitHubService(user_id=None, token=SecretStr('test-token')) _ = await service._fetch_data('https://api.github.com/user') # Verify the request was made with correct headers @@ -77,5 +77,5 @@ async def test_github_service_fetch_data(): mock_client.get.reset_mock() mock_client.get.return_value = mock_response - with pytest.raises(GhAuthenticationError): + with pytest.raises(AuthenticationError): _ = await service._fetch_data('https://api.github.com/user') diff --git a/tests/unit/test_settings.py b/tests/unit/test_settings.py index 4c37798682..3c5a01027b 100644 --- a/tests/unit/test_settings.py +++ b/tests/unit/test_settings.py @@ -6,6 +6,7 @@ from openhands.core.config.app_config import AppConfig from openhands.core.config.llm_config import LLMConfig from openhands.core.config.sandbox_config import SandboxConfig from openhands.core.config.security_config import SecurityConfig +from openhands.integrations.provider import ProviderToken, ProviderType from openhands.server.routes.settings import convert_to_settings from openhands.server.settings import POSTSettingsModel, Settings @@ -43,7 +44,7 @@ def test_settings_from_config(): assert settings.llm_api_key.get_secret_value() == 'test-key' assert settings.llm_base_url == 'https://test.example.com' assert settings.remote_runtime_resource_factor == 2 - assert settings.github_token is None + assert not settings.secrets_store.provider_tokens def test_settings_from_config_no_api_key(): @@ -80,23 +81,41 @@ def test_settings_handles_sensitive_data(): llm_api_key='test-key', llm_base_url='https://test.example.com', remote_runtime_resource_factor=2, - github_token='test-token', + ) + settings.secrets_store.provider_tokens[ProviderType.GITHUB] = ProviderToken( + token=SecretStr('test-token'), + user_id=None, ) assert str(settings.llm_api_key) == '**********' - assert str(settings.github_token) == '**********' + assert ( + str(settings.secrets_store.provider_tokens[ProviderType.GITHUB].token) + == '**********' + ) assert settings.llm_api_key.get_secret_value() == 'test-key' - assert settings.github_token.get_secret_value() == 'test-token' + assert ( + settings.secrets_store.provider_tokens[ + ProviderType.GITHUB + ].token.get_secret_value() + == 'test-token' + ) def test_convert_to_settings(): settings_with_token_data = POSTSettingsModel( llm_api_key='test-key', - github_token='test-token', + provider_tokens={ + 'github': 'test-token', + }, ) settings = convert_to_settings(settings_with_token_data) assert settings.llm_api_key.get_secret_value() == 'test-key' - assert settings.github_token.get_secret_value() == 'test-token' + assert ( + settings.secrets_store.provider_tokens[ + ProviderType.GITHUB + ].token.get_secret_value() + == 'test-token' + ) diff --git a/tests/unit/test_settings_api.py b/tests/unit/test_settings_api.py index bd06a4876e..cda91dcf48 100644 --- a/tests/unit/test_settings_api.py +++ b/tests/unit/test_settings_api.py @@ -5,6 +5,7 @@ from fastapi.testclient import TestClient from pydantic import SecretStr from openhands.core.config.sandbox_config import SandboxConfig +from openhands.integrations.provider import ProviderType, SecretStore from openhands.server.app import app from openhands.server.settings import Settings @@ -19,6 +20,24 @@ def mock_settings_store(): yield store_instance +@pytest.fixture +def mock_get_user_id(): + with patch('openhands.server.routes.settings.get_user_id') as mock: + mock.return_value = 'test-user' + yield mock + + +@pytest.fixture +def mock_validate_provider_token(): + with patch('openhands.server.routes.settings.validate_provider_token') as mock: + + async def mock_determine(*args, **kwargs): + return ProviderType.GITHUB + + mock.side_effect = mock_determine + yield mock + + @pytest.fixture def test_client(mock_settings_store): # Mock the middleware that adds github_token @@ -28,9 +47,15 @@ def test_client(mock_settings_store): async def __call__(self, scope, receive, send): settings = mock_settings_store.load.return_value - token = settings.github_token if settings else None + token = None + if settings and settings.secrets_store.provider_tokens.get( + ProviderType.GITHUB + ): + token = settings.secrets_store.provider_tokens[ + ProviderType.GITHUB + ].token if scope['type'] == 'http': - scope['state'] = {'github_token': token} + scope['state'] = {'token': token} await self.app(scope, receive, send) # Replace the middleware @@ -47,7 +72,9 @@ def mock_github_service(): @pytest.mark.asyncio -async def test_settings_api_runtime_factor(test_client, mock_settings_store): +async def test_settings_api_runtime_factor( + test_client, mock_settings_store, mock_get_user_id, mock_validate_provider_token +): # Mock the settings store to return None initially (no existing settings) mock_settings_store.load.return_value = None @@ -62,6 +89,7 @@ async def test_settings_api_runtime_factor(test_client, mock_settings_store): 'llm_api_key': 'test-key', 'llm_base_url': 'https://test.com', 'remote_runtime_resource_factor': 2, + 'provider_tokens': {'github': 'test-token'}, } # The test_client fixture already handles authentication @@ -98,12 +126,17 @@ async def test_settings_api_runtime_factor(test_client, mock_settings_store): @pytest.mark.asyncio -async def test_settings_llm_api_key(test_client, mock_settings_store): +async def test_settings_llm_api_key( + test_client, mock_settings_store, mock_get_user_id, mock_validate_provider_token +): # Mock the settings store to return None initially (no existing settings) mock_settings_store.load.return_value = None # Test data with remote_runtime_resource_factor - settings_data = {'llm_api_key': 'test-key'} + settings_data = { + 'llm_api_key': 'test-key', + 'provider_tokens': {'github': 'test-token'}, + } # The test_client fixture already handles authentication @@ -132,9 +165,13 @@ async def test_settings_llm_api_key(test_client, mock_settings_store): ) @pytest.mark.asyncio async def test_settings_api_set_github_token( - mock_github_service, test_client, mock_settings_store + mock_github_service, + test_client, + mock_settings_store, + mock_get_user_id, + mock_validate_provider_token, ): - # Test data with github_token set + # Test data with provider token set settings_data = { 'language': 'en', 'agent': 'test-agent', @@ -144,16 +181,21 @@ async def test_settings_api_set_github_token( 'llm_model': 'test-model', 'llm_api_key': 'test-key', 'llm_base_url': 'https://test.com', - 'github_token': 'test-token', + 'provider_tokens': {'github': 'test-token'}, } # Make the POST request to store settings response = test_client.post('/api/settings', json=settings_data) assert response.status_code == 200 - # Verify the settings were stored with the github_token + # Verify the settings were stored with the provider token stored_settings = mock_settings_store.store.call_args[0][0] - assert stored_settings.github_token == 'test-token' + assert ( + stored_settings.secrets_store.provider_tokens[ + ProviderType.GITHUB + ].token.get_secret_value() + == 'test-token' + ) # Mock settings store to return our settings for the GET request mock_settings_store.load.return_value = Settings(**settings_data) @@ -163,17 +205,21 @@ async def test_settings_api_set_github_token( data = response.json() assert response.status_code == 200 - assert data.get('github_token') is None - assert data['github_token_is_set'] is True + assert data.get('token') is None + assert data['token_is_set'] is True @pytest.mark.skip( reason='Mock middleware does not seem to properly set the github_token' ) async def test_settings_unset_github_token( - mock_github_service, test_client, mock_settings_store + mock_github_service, + test_client, + mock_settings_store, + mock_get_user_id, + mock_validate_provider_token, ): - # Test data with unset_github_token set to True + # Test data with unset_token set to True settings_data = { 'language': 'en', 'agent': 'test-agent', @@ -183,7 +229,7 @@ async def test_settings_unset_github_token( 'llm_model': 'test-model', 'llm_api_key': 'test-key', 'llm_base_url': 'https://test.com', - 'github_token': 'test-token', + 'provider_tokens': {'github': 'test-token'}, } # Mock settings store to return our settings for the GET request @@ -191,23 +237,23 @@ async def test_settings_unset_github_token( response = test_client.get('/api/settings') assert response.status_code == 200 - assert response.json()['github_token_is_set'] is True + assert response.json()['token_is_set'] is True - settings_data['unset_github_token'] = True + settings_data['unset_token'] = True # Make the POST request to store settings response = test_client.post('/api/settings', json=settings_data) assert response.status_code == 200 - # Verify the settings were stored with the github_token unset + # Verify the settings were stored with the provider token unset stored_settings = mock_settings_store.store.call_args[0][0] - assert stored_settings.github_token is None + assert not stored_settings.secrets_store.provider_tokens mock_settings_store.load.return_value = Settings(**stored_settings.dict()) # Make a GET request to retrieve settings response = test_client.get('/api/settings') assert response.status_code == 200 - assert response.json()['github_token_is_set'] is False + assert response.json()['token_is_set'] is False @pytest.mark.asyncio @@ -222,6 +268,7 @@ async def test_settings_preserve_llm_fields_when_none(test_client, mock_settings llm_model='existing-model', llm_api_key=SecretStr('existing-key'), llm_base_url='https://existing.com', + secrets_store=SecretStore(), ) # Mock the settings store to return our initial settings diff --git a/tests/unit/test_suggested_tasks.py b/tests/unit/test_suggested_tasks.py index 80c74c2da6..c28619ad34 100644 --- a/tests/unit/test_suggested_tasks.py +++ b/tests/unit/test_suggested_tasks.py @@ -3,13 +3,13 @@ from unittest.mock import AsyncMock import pytest from openhands.integrations.github.github_service import GitHubService -from openhands.integrations.github.github_types import GitHubUser, TaskType +from openhands.integrations.service_types import TaskType, User @pytest.mark.asyncio async def test_get_suggested_tasks(): # Mock responses - mock_user = GitHubUser( + mock_user = User( id=1, login='test-user', avatar_url='https://example.com/avatar.jpg',