diff --git a/frontend/__tests__/routes/git-settings.test.tsx b/frontend/__tests__/routes/git-settings.test.tsx index b288acc70c..68f00a5ac7 100644 --- a/frontend/__tests__/routes/git-settings.test.tsx +++ b/frontend/__tests__/routes/git-settings.test.tsx @@ -9,6 +9,7 @@ import { MOCK_DEFAULT_USER_SETTINGS } from "#/mocks/handlers"; import { AuthProvider } from "#/context/auth-context"; import { GetConfigResponse } from "#/api/open-hands.types"; import * as ToastHandlers from "#/utils/custom-toast-handlers"; +import { SecretsService } from "#/api/secrets-service"; const VALID_OSS_CONFIG: GetConfigResponse = { APP_MODE: "oss", @@ -230,7 +231,7 @@ describe("Content", () => { describe("Form submission", () => { it("should save the GitHub token", async () => { - const saveSettingsSpy = vi.spyOn(OpenHands, "saveSettings"); + const saveProvidersSpy = vi.spyOn(SecretsService, "addGitProvider"); const getConfigSpy = vi.spyOn(OpenHands, "getConfig"); getConfigSpy.mockResolvedValue(VALID_OSS_CONFIG); @@ -242,27 +243,19 @@ describe("Form submission", () => { await userEvent.type(githubInput, "test-token"); await userEvent.click(submit); - expect(saveSettingsSpy).toHaveBeenCalledWith( - expect.objectContaining({ - provider_tokens: { - github: { token: "test-token" }, - gitlab: { token: "" }, - }, - }), - ); + expect(saveProvidersSpy).toHaveBeenCalledWith({ + github: { token: "test-token" }, + gitlab: { token: "" }, + }); const gitlabInput = await screen.findByTestId("gitlab-token-input"); await userEvent.type(gitlabInput, "test-token"); await userEvent.click(submit); - expect(saveSettingsSpy).toHaveBeenCalledWith( - expect.objectContaining({ - provider_tokens: { - github: { token: "" }, - gitlab: { token: "test-token" }, - }, - }), - ); + expect(saveProvidersSpy).toHaveBeenCalledWith({ + github: { token: "test-token" }, + gitlab: { token: "" }, + }); }); it("should disable the button if there is no input", async () => { @@ -346,7 +339,7 @@ describe("Form submission", () => { // flaky test it.skip("should disable the button when submitting changes", async () => { - const saveSettingsSpy = vi.spyOn(OpenHands, "saveSettings"); + const saveSettingsSpy = vi.spyOn(SecretsService, "addGitProvider"); const getConfigSpy = vi.spyOn(OpenHands, "getConfig"); getConfigSpy.mockResolvedValue(VALID_OSS_CONFIG); @@ -370,7 +363,7 @@ describe("Form submission", () => { }); it("should disable the button after submitting changes", async () => { - const saveSettingsSpy = vi.spyOn(OpenHands, "saveSettings"); + const saveProvidersSpy = vi.spyOn(SecretsService, "addGitProvider"); const getConfigSpy = vi.spyOn(OpenHands, "getConfig"); getConfigSpy.mockResolvedValue(VALID_OSS_CONFIG); @@ -386,7 +379,7 @@ describe("Form submission", () => { // submit the form await userEvent.click(submit); - expect(saveSettingsSpy).toHaveBeenCalled(); + expect(saveProvidersSpy).toHaveBeenCalled(); expect(submit).toBeDisabled(); const gitlabInput = await screen.findByTestId("gitlab-token-input"); @@ -396,7 +389,7 @@ describe("Form submission", () => { // submit the form await userEvent.click(submit); - expect(saveSettingsSpy).toHaveBeenCalled(); + expect(saveProvidersSpy).toHaveBeenCalled(); await waitFor(() => expect(submit).toBeDisabled()); }); @@ -404,7 +397,7 @@ describe("Form submission", () => { describe("Status toasts", () => { it("should call displaySuccessToast when the settings are saved", async () => { - const saveSettingsSpy = vi.spyOn(OpenHands, "saveSettings"); + const saveProvidersSpy = vi.spyOn(SecretsService, "addGitProvider"); const getSettingsSpy = vi.spyOn(OpenHands, "getSettings"); getSettingsSpy.mockResolvedValue(MOCK_DEFAULT_USER_SETTINGS); @@ -422,18 +415,18 @@ describe("Status toasts", () => { const submit = await screen.findByTestId("submit-button"); await userEvent.click(submit); - expect(saveSettingsSpy).toHaveBeenCalled(); + expect(saveProvidersSpy).toHaveBeenCalled(); await waitFor(() => expect(displaySuccessToastSpy).toHaveBeenCalled()); }); it("should call displayErrorToast when the settings fail to save", async () => { - const saveSettingsSpy = vi.spyOn(OpenHands, "saveSettings"); + const saveProvidersSpy = vi.spyOn(SecretsService, "addGitProvider"); const getSettingsSpy = vi.spyOn(OpenHands, "getSettings"); getSettingsSpy.mockResolvedValue(MOCK_DEFAULT_USER_SETTINGS); const displayErrorToastSpy = vi.spyOn(ToastHandlers, "displayErrorToast"); - saveSettingsSpy.mockRejectedValue(new Error("Failed to save settings")); + saveProvidersSpy.mockRejectedValue(new Error("Failed to save settings")); renderGitSettingsScreen(); @@ -444,7 +437,7 @@ describe("Status toasts", () => { const submit = await screen.findByTestId("submit-button"); await userEvent.click(submit); - expect(saveSettingsSpy).toHaveBeenCalled(); + expect(saveProvidersSpy).toHaveBeenCalled(); expect(displayErrorToastSpy).toHaveBeenCalled(); }); }); diff --git a/frontend/src/api/open-hands.ts b/frontend/src/api/open-hands.ts index 1794587a5b..b534001f76 100644 --- a/frontend/src/api/open-hands.ts +++ b/frontend/src/api/open-hands.ts @@ -276,7 +276,7 @@ class OpenHands { static async logout(appMode: GetConfigResponse["APP_MODE"]): Promise { const endpoint = - appMode === "saas" ? "/api/logout" : "/api/unset-settings-tokens"; + appMode === "saas" ? "/api/logout" : "/api/unset-provider-tokens"; await openHands.post(endpoint); } diff --git a/frontend/src/api/secrets-service.ts b/frontend/src/api/secrets-service.ts new file mode 100644 index 0000000000..3116813710 --- /dev/null +++ b/frontend/src/api/secrets-service.ts @@ -0,0 +1,16 @@ +import { Provider, ProviderToken } from "#/types/settings"; +import { openHands } from "./open-hands-axios"; +import { POSTProviderTokens } from "./secrets-service.types"; + +export class SecretsService { + static async addGitProvider(providers: Record) { + const tokens: POSTProviderTokens = { + provider_tokens: providers, + }; + const { data } = await openHands.post( + "/api/add-git-providers", + tokens, + ); + return data; + } +} diff --git a/frontend/src/api/secrets-service.types.ts b/frontend/src/api/secrets-service.types.ts new file mode 100644 index 0000000000..f2426c80f0 --- /dev/null +++ b/frontend/src/api/secrets-service.types.ts @@ -0,0 +1,5 @@ +import { Provider, ProviderToken } from "#/types/settings"; + +export interface POSTProviderTokens { + provider_tokens: Record; +} diff --git a/frontend/src/hooks/mutation/use-add-git-providers.ts b/frontend/src/hooks/mutation/use-add-git-providers.ts new file mode 100644 index 0000000000..323a33b97f --- /dev/null +++ b/frontend/src/hooks/mutation/use-add-git-providers.ts @@ -0,0 +1,21 @@ +import { useMutation, useQueryClient } from "@tanstack/react-query"; +import { SecretsService } from "#/api/secrets-service"; +import { Provider, ProviderToken } from "#/types/settings"; + +export const useAddGitProviders = () => { + const queryClient = useQueryClient(); + + return useMutation({ + mutationFn: ({ + providers, + }: { + providers: Record; + }) => SecretsService.addGitProvider(providers), + onSuccess: async () => { + await queryClient.invalidateQueries({ queryKey: ["settings"] }); + }, + meta: { + disableToast: true, + }, + }); +}; diff --git a/frontend/src/hooks/mutation/use-save-settings.ts b/frontend/src/hooks/mutation/use-save-settings.ts index 436e548afc..f060a5b462 100644 --- a/frontend/src/hooks/mutation/use-save-settings.ts +++ b/frontend/src/hooks/mutation/use-save-settings.ts @@ -20,7 +20,6 @@ const saveSettingsMutationFn = async (settings: Partial) => { enable_default_condenser: settings.ENABLE_DEFAULT_CONDENSER, enable_sound_notifications: settings.ENABLE_SOUND_NOTIFICATIONS, user_consents_to_analytics: settings.user_consents_to_analytics, - provider_tokens: settings.provider_tokens, }; await OpenHands.saveSettings(apiSettings); diff --git a/frontend/src/hooks/query/use-settings.ts b/frontend/src/hooks/query/use-settings.ts index 1e6130de2a..86f2ea626b 100644 --- a/frontend/src/hooks/query/use-settings.ts +++ b/frontend/src/hooks/query/use-settings.ts @@ -23,7 +23,6 @@ const getSettingsQueryFn = async (): Promise => { ENABLE_DEFAULT_CONDENSER: apiSettings.enable_default_condenser, ENABLE_SOUND_NOTIFICATIONS: apiSettings.enable_sound_notifications, USER_CONSENTS_TO_ANALYTICS: apiSettings.user_consents_to_analytics, - PROVIDER_TOKENS: apiSettings.provider_tokens, IS_NEW_USER: false, }; }; diff --git a/frontend/src/mocks/handlers.ts b/frontend/src/mocks/handlers.ts index 24e99f6e56..ee764f6517 100644 --- a/frontend/src/mocks/handlers.ts +++ b/frontend/src/mocks/handlers.ts @@ -6,7 +6,7 @@ import { } from "#/api/open-hands.types"; import { DEFAULT_SETTINGS } from "#/services/settings"; import { STRIPE_BILLING_HANDLERS } from "./billing-handlers"; -import { ApiSettings, PostApiSettings } from "#/types/settings"; +import { ApiSettings, PostApiSettings, Provider } from "#/types/settings"; import { FILE_SERVICE_HANDLERS } from "./file-service-handlers"; import { GitRepository, GitUser } from "#/types/git"; import { TASK_SUGGESTIONS_HANDLERS } from "./task-suggestions-handlers"; @@ -26,7 +26,6 @@ export const MOCK_DEFAULT_USER_SETTINGS: ApiSettings | PostApiSettings = { enable_default_condenser: DEFAULT_SETTINGS.ENABLE_DEFAULT_CONDENSER, enable_sound_notifications: DEFAULT_SETTINGS.ENABLE_SOUND_NOTIFICATIONS, user_consents_to_analytics: DEFAULT_SETTINGS.USER_CONSENTS_TO_ANALYTICS, - provider_tokens: DEFAULT_SETTINGS.PROVIDER_TOKENS, }; const MOCK_USER_PREFERENCES: { @@ -293,4 +292,32 @@ export const handlers = [ MOCK_USER_PREFERENCES.settings = { ...MOCK_DEFAULT_USER_SETTINGS }; return HttpResponse.json(null, { status: 200 }); }), + + http.post("/api/add-git-providers", async ({ request }) => { + const body = await request.json(); + + if (typeof body === "object" && body?.provider_tokens) { + const rawTokens = body.provider_tokens as Record< + string, + { token?: string } + >; + + const providerTokensSet: Partial> = + Object.fromEntries( + Object.entries(rawTokens) + .filter(([, val]) => val && val.token) + .map(([provider]) => [provider as Provider, ""]), + ); + + const newSettings = { + ...(MOCK_USER_PREFERENCES.settings ?? MOCK_DEFAULT_USER_SETTINGS), + provider_tokens_set: providerTokensSet, + }; + MOCK_USER_PREFERENCES.settings = newSettings; + + return HttpResponse.json(true, { status: 200 }); + } + + return HttpResponse.json(null, { status: 400 }); + }), ]; diff --git a/frontend/src/routes/git-settings.tsx b/frontend/src/routes/git-settings.tsx index 8c83ce7071..66c54e12ef 100644 --- a/frontend/src/routes/git-settings.tsx +++ b/frontend/src/routes/git-settings.tsx @@ -1,6 +1,5 @@ import React from "react"; import { useTranslation } from "react-i18next"; -import { useSaveSettings } from "#/hooks/mutation/use-save-settings"; import { useConfig } from "#/hooks/query/use-config"; import { useSettings } from "#/hooks/query/use-settings"; import { BrandButton } from "#/components/features/settings/brand-button"; @@ -16,11 +15,12 @@ import { import { retrieveAxiosErrorMessage } from "#/utils/retrieve-axios-error-message"; import { GitSettingInputsSkeleton } from "#/components/features/settings/git-settings/github-settings-inputs-skeleton"; import { useAuth } from "#/context/auth-context"; +import { useAddGitProviders } from "#/hooks/mutation/use-add-git-providers"; function GitSettingsScreen() { const { t } = useTranslation(); - const { mutate: saveSettings, isPending } = useSaveSettings(); + const { mutate: saveGitProviders, isPending } = useAddGitProviders(); const { mutate: disconnectGitTokens } = useLogout(); const { providerTokensSet } = useAuth(); @@ -48,9 +48,9 @@ function GitSettingsScreen() { const githubToken = formData.get("github-token-input")?.toString() || ""; const gitlabToken = formData.get("gitlab-token-input")?.toString() || ""; - saveSettings( + saveGitProviders( { - provider_tokens: { + providers: { github: { token: githubToken }, gitlab: { token: gitlabToken }, }, diff --git a/frontend/src/services/settings.ts b/frontend/src/services/settings.ts index 7c9592d6cb..1c24b127d1 100644 --- a/frontend/src/services/settings.ts +++ b/frontend/src/services/settings.ts @@ -15,10 +15,6 @@ export const DEFAULT_SETTINGS: Settings = { ENABLE_DEFAULT_CONDENSER: true, ENABLE_SOUND_NOTIFICATIONS: false, USER_CONSENTS_TO_ANALYTICS: false, - PROVIDER_TOKENS: { - github: { token: "" }, - gitlab: { token: "" }, - }, IS_NEW_USER: true, }; diff --git a/frontend/src/types/settings.ts b/frontend/src/types/settings.ts index 0c49c56c4a..96695edb7c 100644 --- a/frontend/src/types/settings.ts +++ b/frontend/src/types/settings.ts @@ -22,7 +22,6 @@ export type Settings = { ENABLE_DEFAULT_CONDENSER: boolean; ENABLE_SOUND_NOTIFICATIONS: boolean; USER_CONSENTS_TO_ANALYTICS: boolean | null; - PROVIDER_TOKENS: Record; IS_NEW_USER?: boolean; }; @@ -39,17 +38,14 @@ export type ApiSettings = { enable_default_condenser: boolean; enable_sound_notifications: boolean; user_consents_to_analytics: boolean | null; - provider_tokens: Record; provider_tokens_set: Partial>; }; export type PostSettings = Settings & { - provider_tokens: Record; user_consents_to_analytics: boolean | null; llm_api_key?: string | null; }; export type PostApiSettings = ApiSettings & { - provider_tokens: Record; user_consents_to_analytics: boolean | null; }; diff --git a/frontend/src/utils/settings-utils.ts b/frontend/src/utils/settings-utils.ts index b536051810..979e4bb1b4 100644 --- a/frontend/src/utils/settings-utils.ts +++ b/frontend/src/utils/settings-utils.ts @@ -1,4 +1,4 @@ -import { Provider, ProviderToken, Settings } from "#/types/settings"; +import { Settings } from "#/types/settings"; const extractBasicFormData = (formData: FormData) => { const provider = formData.get("llm-provider-input")?.toString(); @@ -61,18 +61,6 @@ export const extractSettings = ( ENABLE_DEFAULT_CONDENSER, } = extractAdvancedFormData(formData); - // Extract provider tokens - const githubToken = formData.get("github-token")?.toString(); - const gitlabToken = formData.get("gitlab-token")?.toString(); - const providerTokens: Record = { - github: { - token: githubToken || "", - }, - gitlab: { - token: gitlabToken || "", - }, - }; - return { LLM_MODEL: CUSTOM_LLM_MODEL || LLM_MODEL, LLM_API_KEY_SET: !!LLM_API_KEY, @@ -82,7 +70,6 @@ export const extractSettings = ( CONFIRMATION_MODE, SECURITY_ANALYZER, ENABLE_DEFAULT_CONDENSER, - PROVIDER_TOKENS: providerTokens, llm_api_key: LLM_API_KEY, }; }; diff --git a/openhands/core/setup.py b/openhands/core/setup.py index 3d932905f6..7c9cc4e3e8 100644 --- a/openhands/core/setup.py +++ b/openhands/core/setup.py @@ -15,7 +15,7 @@ from openhands.core.config import ( from openhands.core.logger import openhands_logger as logger from openhands.events import EventStream from openhands.events.event import Event -from openhands.integrations.provider import ProviderToken, ProviderType, SecretStore +from openhands.integrations.provider import ProviderToken, ProviderType from openhands.llm.llm import LLM from openhands.memory.memory import Memory from openhands.microagent.microagent import BaseMicroagent @@ -23,6 +23,7 @@ from openhands.runtime import get_runtime_cls from openhands.runtime.base import Runtime from openhands.security import SecurityAnalyzer, options from openhands.storage import get_file_store +from openhands.storage.data_models.user_secrets import UserSecrets from openhands.utils.async_utils import GENERAL_TIMEOUT, call_async_from_sync @@ -111,7 +112,7 @@ def initialize_repository_for_runtime( ) secret_store = ( - SecretStore(provider_tokens=provider_tokens) if provider_tokens else None + UserSecrets(provider_tokens=provider_tokens) if provider_tokens else None ) immutable_provider_tokens = secret_store.provider_tokens if secret_store else None diff --git a/openhands/integrations/provider.py b/openhands/integrations/provider.py index d651337895..23b84e0f18 100644 --- a/openhands/integrations/provider.py +++ b/openhands/integrations/provider.py @@ -7,12 +7,8 @@ from pydantic import ( BaseModel, Field, SecretStr, - SerializationInfo, WithJsonSchema, - field_serializer, - model_validator, ) -from pydantic.json import pydantic_encoder from openhands.core.logger import openhands_logger as logger from openhands.events.action.action import Action @@ -66,113 +62,6 @@ CUSTOM_SECRETS_TYPE_WITH_JSON_SCHEMA = Annotated[ ] -class SecretStore(BaseModel): - provider_tokens: PROVIDER_TOKEN_TYPE_WITH_JSON_SCHEMA = Field( - default_factory=lambda: MappingProxyType({}) - ) - - custom_secrets: CUSTOM_SECRETS_TYPE_WITH_JSON_SCHEMA = Field( - default_factory=lambda: MappingProxyType({}), - ) - - model_config = { - 'frozen': True, - 'validate_assignment': True, - 'arbitrary_types_allowed': True, - } - - @field_serializer('provider_tokens') - def provider_tokens_serializer( - self, provider_tokens: PROVIDER_TOKEN_TYPE, info: SerializationInfo - ) -> dict[str, dict[str, str | Any]]: - tokens = {} - expose_secrets = info.context and info.context.get('expose_secrets', False) - - for token_type, provider_token in provider_tokens.items(): - if not provider_token or not provider_token.token: - continue - - token_type_str = ( - token_type.value - if isinstance(token_type, ProviderType) - else str(token_type) - ) - tokens[token_type_str] = { - 'token': provider_token.token.get_secret_value() - if expose_secrets - else pydantic_encoder(provider_token.token), - 'user_id': provider_token.user_id, - } - - return tokens - - @field_serializer('custom_secrets') - def custom_secrets_serializer( - self, custom_secrets: CUSTOM_SECRETS_TYPE, info: SerializationInfo - ): - secrets = {} - expose_secrets = info.context and info.context.get('expose_secrets', False) - - if custom_secrets: - for secret_name, secret_key in custom_secrets.items(): - secrets[secret_name] = ( - secret_key.get_secret_value() - if expose_secrets - else pydantic_encoder(secret_key) - ) - return secrets - - @model_validator(mode='before') - @classmethod - def convert_dict_to_mappingproxy( - cls, data: dict[str, dict[str, Any] | MappingProxyType] | PROVIDER_TOKEN_TYPE - ) -> dict[str, MappingProxyType | None]: - """Custom deserializer to convert dictionary into MappingProxyType""" - if not isinstance(data, dict): - raise ValueError('SecretStore must be initialized with a dictionary') - - new_data: dict[str, MappingProxyType | None] = {} - - if 'provider_tokens' in data: - tokens = data['provider_tokens'] - if isinstance( - tokens, dict - ): # Ensure conversion happens only for dict inputs - converted_tokens = {} - for key, value in tokens.items(): - try: - provider_type = ( - ProviderType(key) if isinstance(key, str) else key - ) - converted_tokens[provider_type] = ProviderToken.from_value( - value - ) - except ValueError: - # Skip invalid provider types or tokens - continue - - # Convert to MappingProxyType - new_data['provider_tokens'] = MappingProxyType(converted_tokens) - elif isinstance(tokens, MappingProxyType): - new_data['provider_tokens'] = tokens - - if 'custom_secrets' in data: - secrets = data['custom_secrets'] - if isinstance(secrets, dict): - converted_secrets = {} - for key, value in secrets.items(): - if isinstance(value, str): - converted_secrets[key] = SecretStr(value) - elif isinstance(value, SecretStr): - converted_secrets[key] = value - - new_data['custom_secrets'] = MappingProxyType(converted_secrets) - elif isinstance(secrets, MappingProxyType): - new_data['custom_secrets'] = secrets - - return new_data - - class ProviderHandler: def __init__( self, diff --git a/openhands/server/app.py b/openhands/server/app.py index 02ec3fcf22..c9eea3afb8 100644 --- a/openhands/server/app.py +++ b/openhands/server/app.py @@ -18,6 +18,7 @@ from openhands.server.routes.manage_conversations import ( app as manage_conversation_api_router, ) from openhands.server.routes.public import app as public_api_router +from openhands.server.routes.secrets import app as secrets_router from openhands.server.routes.security import app as security_api_router from openhands.server.routes.settings import app as settings_router from openhands.server.routes.trajectory import app as trajectory_router @@ -50,5 +51,6 @@ app.include_router(feedback_api_router) app.include_router(conversation_api_router) app.include_router(manage_conversation_api_router) app.include_router(settings_router) +app.include_router(secrets_router) app.include_router(git_api_router) app.include_router(trajectory_router) diff --git a/openhands/server/config/server_config.py b/openhands/server/config/server_config.py index 46427dcf35..7dbcaf8235 100644 --- a/openhands/server/config/server_config.py +++ b/openhands/server/config/server_config.py @@ -15,6 +15,9 @@ class ServerConfig(ServerConfigInterface): settings_store_class: str = ( 'openhands.storage.settings.file_settings_store.FileSettingsStore' ) + secret_store_class: str = ( + 'openhands.storage.settings.file_secrets_store.FileSecretsStore' + ) conversation_store_class: str = ( 'openhands.storage.conversation.file_conversation_store.FileConversationStore' ) diff --git a/openhands/server/routes/secrets.py b/openhands/server/routes/secrets.py new file mode 100644 index 0000000000..093b993417 --- /dev/null +++ b/openhands/server/routes/secrets.py @@ -0,0 +1,292 @@ +from fastapi import APIRouter, Depends, status +from fastapi.responses import JSONResponse +from pydantic import SecretStr + +from openhands.integrations.service_types import ProviderType +from openhands.integrations.utils import validate_provider_token +from openhands.server.settings import GETCustomSecrets, POSTCustomSecrets, POSTProviderModel +from openhands.server.user_auth import get_secrets_store, get_user_secrets, get_user_settings_store +from openhands.storage.data_models.settings import Settings +from openhands.storage.data_models.user_secrets import UserSecrets +from openhands.storage.settings.secret_store import SecretsStore +from openhands.storage.settings.settings_store import SettingsStore +from openhands.core.logger import openhands_logger as logger + +app = APIRouter(prefix='/api') + + + + +# ================================================= +# SECTION: Handle git provider tokens +# ================================================= + + +async def invalidate_legacy_secrets_store( + settings: Settings, + settings_store: SettingsStore, + secrets_store: SecretsStore) -> UserSecrets | None: + + """ + We are moving `secrets_store` (a field from `Settings` object) to its own dedicated store + This function moves the values from Settings to UserSecrets, and deletes the values in Settings + While this function in called multiple times, the migration only ever happens once + """ + + if len(settings.secrets_store.provider_tokens.items()) > 0: + user_secrets = UserSecrets(provider_tokens=settings.secrets_store.provider_tokens) + await secrets_store.store(user_secrets) + + # Invalidate old tokens via settings store serializer + invalidated_secrets_settings = settings.model_copy( + update={'secrets_store': UserSecrets()} + ) + await settings_store.store(invalidated_secrets_settings) + + return user_secrets + + return None + + + +async def check_provider_tokens(provider_info: POSTProviderModel) -> str: + print(provider_info) + if provider_info.provider_tokens: + # Determine whether tokens are valid + for token_type, token_value in provider_info.provider_tokens.items(): + if token_value.token: + confirmed_token_type = await validate_provider_token( + token_value.token + ) + if not confirmed_token_type or confirmed_token_type != token_type: + return f'Invalid token. Please make sure it is a valid {token_type.value} token.' + + return '' + + +@app.post('/add-git-providers') +async def store_provider_tokens( + provider_info: POSTProviderModel, + secrets_store: SecretsStore = Depends(get_secrets_store) +) -> JSONResponse: + provider_err_msg = await check_provider_tokens(provider_info) + if provider_err_msg: + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={'error': provider_err_msg}, + ) + + try: + user_secrets = await secrets_store.load() + + + if user_secrets: + if provider_info.provider_tokens: + existing_providers = [ + provider + for provider in user_secrets.provider_tokens + ] + + # Merge incoming settings store with the existing one + for provider, token_value in list(provider_info.provider_tokens.items()): + if provider in existing_providers and not token_value.token: + existing_token = ( + user_secrets.provider_tokens.get(provider) + ) + if existing_token and existing_token.token: + provider_info.provider_tokens[provider] = existing_token + + else: # nothing passed in means keep current settings + provider_info.provider_tokens = dict(user_secrets.provider_tokens) + + + updated_secrets = user_secrets.model_copy(update={"provider_tokens":provider_info.provider_tokens}) + await secrets_store.store(updated_secrets) + + return JSONResponse( + status_code=status.HTTP_200_OK, + content={'message': 'Git providers stored'}, + ) + except Exception as e: + logger.warning(f'Something went wrong storing git providers: {e}') + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={'error': 'Something went wrong storing git providers'}, + ) + + +@app.post('/unset-provider-tokens', response_model=dict[str, str]) +async def unset_provider_tokens( + secrets_store: SecretsStore = Depends(get_secrets_store) +) -> JSONResponse: + try: + user_secrets = await secrets_store.load() + if user_secrets: + updated_secrets = user_secrets.model_copy( + update={'provider_tokens': {}} + ) + await secrets_store.store(updated_secrets) + + return JSONResponse( + status_code=status.HTTP_200_OK, + content={'message': 'Unset Git provider tokens'}, + ) + + except Exception as e: + logger.warning(f'Something went wrong unsetting tokens: {e}') + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={'error': 'Something went wrong unsetting tokens'}, + ) + + + + +# ================================================= +# SECTION: Handle custom secrets +# ================================================= + + + +@app.get('/secrets', response_model=GETCustomSecrets) +async def load_custom_secrets_names( + user_secrets: UserSecrets | None = Depends(get_user_secrets), +) -> GETCustomSecrets | JSONResponse: + try: + if not user_secrets: + return JSONResponse( + status_code=status.HTTP_404_NOT_FOUND, + content={'error': 'User secrets not found'}, + ) + + custom_secrets = list(user_secrets.custom_secrets.keys()) + return GETCustomSecrets(custom_secrets=custom_secrets) + + except Exception as e: + logger.warning(f'Invalid token: {e}') + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={'error': 'Invalid token'}, + ) + + +@app.post('/secrets', response_model=dict[str, str]) +async def create_custom_secret( + incoming_secret: POSTCustomSecrets, + secrets_store: SecretsStore = Depends(get_secrets_store), +) -> JSONResponse: + try: + existing_secrets = await secrets_store.load() + if existing_secrets: + custom_secrets = dict(existing_secrets.custom_secrets) + + for secret_name, secret_value in incoming_secret.custom_secrets.items(): + if secret_name in custom_secrets: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={'message': f'Secret {secret_name} already exists'}, + ) + + custom_secrets[secret_name] = secret_value + + # Create a new UserSecrets that preserves provider tokens + updated_user_secrets = UserSecrets( + custom_secrets=custom_secrets, + provider_tokens=existing_secrets.provider_tokens, + ) + + await secrets_store.store(updated_user_secrets) + + return JSONResponse( + status_code=status.HTTP_200_OK, + content={'message': 'Secret created successfully'}, + ) + except Exception as e: + logger.warning(f'Something went wrong creating secret: {e}') + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={'error': 'Something went wrong creating secret'}, + ) + +@app.put('/secrets/{secret_id}', response_model=dict[str, str]) +async def update_custom_secret( + secret_id: str, + incoming_secret: POSTCustomSecrets, + secrets_store: SecretsStore = Depends(get_secrets_store), +) -> JSONResponse: + try: + existing_secrets = await secrets_store.load() + if existing_secrets: + # Check if the secret to update exists + if secret_id not in existing_secrets.custom_secrets: + return JSONResponse( + status_code=status.HTTP_404_NOT_FOUND, + content={'error': f'Secret with ID {secret_id} not found'}, + ) + + custom_secrets = dict(existing_secrets.custom_secrets) + custom_secrets.pop(secret_id) + + for secret_name, secret_value in incoming_secret.custom_secrets.items(): + custom_secrets[secret_name] = secret_value + + # Create a new UserSecrets that preserves provider tokens + updated_secrets = UserSecrets( + custom_secrets=custom_secrets, + provider_tokens=existing_secrets.provider_tokens, + ) + + await secrets_store.store(updated_secrets) + + return JSONResponse( + status_code=status.HTTP_200_OK, + content={'message': 'Secret updated successfully'}, + ) + except Exception as e: + logger.warning(f'Something went wrong updating secret: {e}') + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={'error': 'Something went wrong updating secret'}, + ) + + +@app.delete('/secrets/{secret_id}') +async def delete_custom_secret( + secret_id: str, + secrets_store: SecretsStore = Depends(get_secrets_store), +) -> JSONResponse: + try: + existing_secrets = await secrets_store.load() + if existing_secrets: + # Get existing custom secrets + custom_secrets = dict(existing_secrets.custom_secrets) + + # Check if the secret to delete exists + if secret_id not in custom_secrets: + return JSONResponse( + status_code=status.HTTP_404_NOT_FOUND, + content={'error': f'Secret with ID {secret_id} not found'}, + ) + + # Remove the secret + custom_secrets.pop(secret_id) + + # Create a new UserSecrets that preserves provider tokens and remaining secrets + updated_secrets = UserSecrets( + custom_secrets=custom_secrets, + provider_tokens=existing_secrets.provider_tokens, + ) + + await secrets_store.store(updated_secrets) + + return JSONResponse( + status_code=status.HTTP_200_OK, + content={'message': 'Secret deleted successfully'}, + ) + except Exception as e: + logger.warning(f'Something went wrong deleting secret: {e}') + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={'error': 'Something went wrong deleting secret'}, + ) + diff --git a/openhands/server/routes/settings.py b/openhands/server/routes/settings.py index 4428739d71..21855acfe1 100644 --- a/openhands/server/routes/settings.py +++ b/openhands/server/routes/settings.py @@ -5,21 +5,20 @@ from openhands.core.logger import openhands_logger as logger from openhands.integrations.provider import ( PROVIDER_TOKEN_TYPE, ProviderType, - SecretStore, ) -from openhands.integrations.utils import validate_provider_token + + +from openhands.server.routes.secrets import invalidate_legacy_secrets_store from openhands.server.settings import ( - GETSettingsCustomSecrets, GETSettingsModel, - POSTSettingsCustomSecrets, - POSTSettingsModel, ) from openhands.server.shared import config from openhands.server.user_auth import ( get_provider_tokens, - get_user_settings, + get_secrets_store, get_user_settings_store, ) +from openhands.storage.settings.secret_store import SecretsStore from openhands.storage.data_models.settings import Settings from openhands.storage.settings.settings_store import SettingsStore @@ -29,18 +28,27 @@ app = APIRouter(prefix='/api') @app.get('/settings', response_model=GETSettingsModel) async def load_settings( provider_tokens: PROVIDER_TOKEN_TYPE | None = Depends(get_provider_tokens), - settings: Settings | None = Depends(get_user_settings), + settings_store: SettingsStore = Depends(get_user_settings_store), + secrets_store: SecretsStore = Depends(get_secrets_store) ) -> GETSettingsModel | JSONResponse: + + settings = await settings_store.load() + try: if not settings: return JSONResponse( status_code=status.HTTP_404_NOT_FOUND, content={'error': 'Settings not found'}, ) + + # On initial load, user secrets may not be populated with values migrated from settings store + user_secrets = await invalidate_legacy_secrets_store(settings, settings_store, secrets_store) + # If invalidation is successful, then the returned user secrets holds the most recent values + git_providers = user_secrets.provider_tokens if user_secrets else provider_tokens - provider_tokens_set: dict[ProviderType, str | None] = {} - if provider_tokens: - for provider_type, provider_token in provider_tokens.items(): + provider_tokens_set: dict[ProviderType, str | None] = {} + if git_providers: + for provider_type, provider_token in git_providers.items(): if provider_token.token or provider_token.user_id: provider_tokens_set[provider_type] = None @@ -60,140 +68,6 @@ async def load_settings( ) -@app.get('/secrets', response_model=GETSettingsCustomSecrets) -async def load_custom_secrets_names( - settings: Settings | None = Depends(get_user_settings), -) -> GETSettingsCustomSecrets | JSONResponse: - try: - if not settings: - return JSONResponse( - status_code=status.HTTP_404_NOT_FOUND, - content={'error': 'Settings not found'}, - ) - - custom_secrets = [] - if settings.secrets_store.custom_secrets: - for secret_name, _ in settings.secrets_store.custom_secrets.items(): - custom_secrets.append(secret_name) - - secret_names = GETSettingsCustomSecrets(custom_secrets=custom_secrets) - return secret_names - - except Exception as e: - logger.warning(f'Invalid token: {e}') - return JSONResponse( - status_code=status.HTTP_401_UNAUTHORIZED, - content={'error': 'Invalid token'}, - ) - - -@app.post('/secrets', response_model=dict[str, str]) -async def add_custom_secret( - incoming_secrets: POSTSettingsCustomSecrets, - settings_store: SettingsStore = Depends(get_user_settings_store), -) -> JSONResponse: - try: - existing_settings = await settings_store.load() - if existing_settings: - for ( - secret_name, - secret_value, - ) in existing_settings.secrets_store.custom_secrets.items(): - if ( - secret_name not in incoming_secrets.custom_secrets - ): # Allow incoming values to override existing ones - incoming_secrets.custom_secrets[secret_name] = secret_value - - # Create a new SecretStore that preserves provider tokens - updated_secret_store = SecretStore( - custom_secrets=incoming_secrets.custom_secrets, - provider_tokens=existing_settings.secrets_store.provider_tokens, - ) - - # Only update SecretStore in Settings - updated_settings = existing_settings.model_copy( - update={'secrets_store': updated_secret_store} - ) - - await settings_store.store(updated_settings) - - return JSONResponse( - status_code=status.HTTP_200_OK, - content={'message': 'Settings stored'}, - ) - except Exception as e: - logger.warning(f'Something went wrong storing settings: {e}') - return JSONResponse( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content={'error': 'Something went wrong storing settings'}, - ) - - -@app.delete('/secrets/{secret_id}') -async def delete_custom_secret( - secret_id: str, - settings_store: SettingsStore = Depends(get_user_settings_store), -) -> JSONResponse: - try: - existing_settings: Settings | None = await settings_store.load() - custom_secrets = {} - if existing_settings: - for ( - secret_name, - secret_value, - ) in existing_settings.secrets_store.custom_secrets.items(): - if secret_name != secret_id: - custom_secrets[secret_name] = secret_value - - # Create a new SecretStore that preserves provider tokens - updated_secret_store = SecretStore( - custom_secrets=custom_secrets, - provider_tokens=existing_settings.secrets_store.provider_tokens, - ) - - updated_settings = existing_settings.model_copy( - update={'secrets_store': updated_secret_store} - ) - - await settings_store.store(updated_settings) - - return JSONResponse( - status_code=status.HTTP_200_OK, - content={'message': 'Settings stored'}, - ) - except Exception as e: - logger.warning(f'Something went wrong storing settings: {e}') - return JSONResponse( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content={'error': 'Something went wrong storing settings'}, - ) - - -@app.post('/unset-settings-tokens', response_model=dict[str, str]) -async def unset_settings_tokens( - settings_store: SettingsStore = Depends(get_user_settings_store), -) -> JSONResponse: - try: - existing_settings = await settings_store.load() - if existing_settings: - settings = existing_settings.model_copy( - update={'secrets_store': SecretStore()} - ) - await settings_store.store(settings) - - return JSONResponse( - status_code=status.HTTP_200_OK, - content={'message': 'Settings stored'}, - ) - - except Exception as e: - logger.warning(f'Something went wrong unsetting tokens: {e}') - return JSONResponse( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content={'error': 'Something went wrong unsetting tokens'}, - ) - - @app.post('/reset-settings', response_model=dict[str, str]) async def reset_settings() -> JSONResponse: """ @@ -206,51 +80,9 @@ async def reset_settings() -> JSONResponse: ) -async def check_provider_tokens(settings: POSTSettingsModel) -> str: - if settings.provider_tokens: - # Determine whether tokens are valid - for provider_type, provider_token in settings.provider_tokens.items(): - if provider_token.token: - confirmed_token_type = await validate_provider_token( - provider_token.token - ) - if not confirmed_token_type or confirmed_token_type != provider_type: - return f'Invalid token. Please make sure it is a valid {provider_type.value} token.' - - return '' - - -async def store_provider_tokens( - settings: POSTSettingsModel, settings_store: SettingsStore -): - existing_settings = await settings_store.load() - if existing_settings: - if existing_settings.secrets_store: - existing_providers = [ - provider for provider in existing_settings.secrets_store.provider_tokens - ] - - # Merge incoming settings store with the existing one - for provider_type, provider_value in list(settings.provider_tokens.items()): - if provider_type in existing_providers and not provider_value.token: - existing_token = ( - existing_settings.secrets_store.provider_tokens.get( - provider_type - ) - ) - if existing_token and existing_token.token: - settings.provider_tokens[provider_type] = existing_token - - else: # nothing passed in means keep current settings - provider_tokens = dict(existing_settings.secrets_store.provider_tokens) - settings.provider_tokens = provider_tokens - - return settings - - async def store_llm_settings( - settings: POSTSettingsModel, settings_store: SettingsStore -) -> POSTSettingsModel: + settings: Settings, settings_store: SettingsStore +) -> Settings: existing_settings = await settings_store.load() # Convert to Settings model and merge with existing settings @@ -268,17 +100,10 @@ async def store_llm_settings( @app.post('/settings', response_model=dict[str, str]) async def store_settings( - settings: POSTSettingsModel, + settings: Settings, settings_store: SettingsStore = Depends(get_user_settings_store), ) -> JSONResponse: # Check provider tokens are valid - provider_err_msg = await check_provider_tokens(settings) - if provider_err_msg: - return JSONResponse( - status_code=status.HTTP_401_UNAUTHORIZED, - content={'error': provider_err_msg}, - ) - try: existing_settings = await settings_store.load() @@ -292,8 +117,6 @@ async def store_settings( existing_settings.user_consents_to_analytics ) - settings = await store_provider_tokens(settings, settings_store) - # Update sandbox config with new settings if settings.remote_runtime_resource_factor is not None: config.sandbox.remote_runtime_resource_factor = ( @@ -314,7 +137,7 @@ async def store_settings( ) -def convert_to_settings(settings_with_token_data: POSTSettingsModel) -> Settings: +def convert_to_settings(settings_with_token_data: Settings) -> Settings: settings_data = settings_with_token_data.model_dump() # Filter out additional fields from `SettingsWithTokenData` @@ -327,17 +150,6 @@ def convert_to_settings(settings_with_token_data: POSTSettingsModel) -> Settings # Convert the `llm_api_key` to a `SecretStr` instance filtered_settings_data['llm_api_key'] = settings_with_token_data.llm_api_key - # Create a new Settings instance with empty SecretStore + # Create a new Settings instance settings = Settings(**filtered_settings_data) - - # Create new provider tokens immutably - if settings_with_token_data.provider_tokens: - settings = settings.model_copy( - update={ - 'secrets_store': SecretStore( - provider_tokens=settings_with_token_data.provider_tokens - ) - } - ) - return settings diff --git a/openhands/server/settings.py b/openhands/server/settings.py index ef86b415e8..8cdd453609 100644 --- a/openhands/server/settings.py +++ b/openhands/server/settings.py @@ -10,7 +10,7 @@ from openhands.integrations.service_types import ProviderType from openhands.storage.data_models.settings import Settings -class POSTSettingsModel(Settings): +class POSTProviderModel(BaseModel): """ Settings for POST requests """ @@ -18,7 +18,7 @@ class POSTSettingsModel(Settings): provider_tokens: dict[ProviderType, ProviderToken] = {} -class POSTSettingsCustomSecrets(BaseModel): +class POSTCustomSecrets(BaseModel): """ Adding new custom secret """ @@ -37,7 +37,7 @@ class GETSettingsModel(Settings): llm_api_key_set: bool -class GETSettingsCustomSecrets(BaseModel): +class GETCustomSecrets(BaseModel): """ Custom secrets names """ diff --git a/openhands/server/shared.py b/openhands/server/shared.py index c53e73fb45..9e7b4be326 100644 --- a/openhands/server/shared.py +++ b/openhands/server/shared.py @@ -11,6 +11,7 @@ from openhands.server.conversation_manager.conversation_manager import ( from openhands.server.monitoring import MonitoringListener from openhands.storage import get_file_store from openhands.storage.conversation.conversation_store import ConversationStore +from openhands.storage.settings.secret_store import SecretsStore from openhands.storage.settings.settings_store import SettingsStore from openhands.utils.import_utils import get_impl @@ -51,6 +52,8 @@ conversation_manager = ConversationManagerImpl.get_instance( # type: ignore SettingsStoreImpl = get_impl(SettingsStore, server_config.settings_store_class) # type: ignore +SecretsStoreImpl = get_impl(SecretsStore, server_config.secret_store_class) + ConversationStoreImpl = get_impl( ConversationStore, # type: ignore server_config.conversation_store_class, diff --git a/openhands/server/user_auth/__init__.py b/openhands/server/user_auth/__init__.py index e574552b66..0ecea4adb9 100644 --- a/openhands/server/user_auth/__init__.py +++ b/openhands/server/user_auth/__init__.py @@ -4,6 +4,8 @@ from pydantic import SecretStr from openhands.integrations.provider import PROVIDER_TOKEN_TYPE from openhands.integrations.service_types import ProviderType from openhands.server.settings import Settings +from openhands.storage.data_models.user_secrets import UserSecrets +from openhands.storage.settings.secret_store import SecretsStore from openhands.server.user_auth.user_auth import AuthType, get_user_auth from openhands.storage.settings.settings_store import SettingsStore @@ -42,6 +44,18 @@ async def get_user_settings(request: Request) -> Settings | None: return user_settings +async def get_secrets_store(request: Request) -> SecretsStore: + user_auth = await get_user_auth(request) + secrets_store = await user_auth.get_secrets_store() + return secrets_store + + +async def get_user_secrets(request: Request) -> UserSecrets | None: + user_auth = await get_user_auth(request) + user_secrets = await user_auth.get_user_secrets() + return user_secrets + + async def get_user_settings_store(request: Request) -> SettingsStore | None: user_auth = await get_user_auth(request) user_settings_store = await user_auth.get_user_settings_store() diff --git a/openhands/server/user_auth/default_user_auth.py b/openhands/server/user_auth/default_user_auth.py index e46880cb34..00be69a287 100644 --- a/openhands/server/user_auth/default_user_auth.py +++ b/openhands/server/user_auth/default_user_auth.py @@ -7,6 +7,8 @@ from openhands.integrations.provider import PROVIDER_TOKEN_TYPE from openhands.server import shared from openhands.server.settings import Settings from openhands.server.user_auth.user_auth import UserAuth +from openhands.storage.data_models.user_secrets import UserSecrets +from openhands.storage.settings.secret_store import SecretsStore from openhands.storage.settings.settings_store import SettingsStore @@ -16,6 +18,8 @@ class DefaultUserAuth(UserAuth): _settings: Settings | None = None _settings_store: SettingsStore | None = None + _secrets_store: SecretsStore | None = None + _user_secrets: UserSecrets | None = None async def get_user_id(self) -> str | None: """The default implementation does not support multi tenancy, so user_id is always None""" @@ -45,9 +49,29 @@ class DefaultUserAuth(UserAuth): self._settings = settings return settings + async def get_secrets_store(self): + secrets_store = self._secrets_store + if secrets_store: + return secrets_store + user_id = await self.get_user_id() + secret_store = await shared.SecretsStoreImpl.get_instance( + shared.config, user_id + ) + self._secrets_store = secret_store + return secret_store + + async def get_user_secrets(self) -> UserSecrets | None: + user_secrets = self._user_secrets + if user_secrets: + return user_secrets + secrets_store = await self.get_secrets_store() + user_secrets = await secrets_store.load() + self._user_secrets = user_secrets + return user_secrets + + async def get_provider_tokens(self) -> PROVIDER_TOKEN_TYPE | None: - settings = await self.get_user_settings() - secrets_store = getattr(settings, 'secrets_store', None) + secrets_store = await self.get_user_secrets() provider_tokens = getattr(secrets_store, 'provider_tokens', None) return provider_tokens diff --git a/openhands/server/user_auth/user_auth.py b/openhands/server/user_auth/user_auth.py index f955a73b15..09feb16fc5 100644 --- a/openhands/server/user_auth/user_auth.py +++ b/openhands/server/user_auth/user_auth.py @@ -9,6 +9,8 @@ from pydantic import SecretStr from openhands.integrations.provider import PROVIDER_TOKEN_TYPE from openhands.server.settings import Settings from openhands.server.shared import server_config +from openhands.storage.data_models.user_secrets import UserSecrets +from openhands.storage.settings.secret_store import SecretsStore from openhands.storage.settings.settings_store import SettingsStore from openhands.utils.import_utils import get_impl @@ -51,6 +53,14 @@ class UserAuth(ABC): self._settings = settings return settings + @abstractmethod + async def get_secrets_store(self) -> SecretsStore: + """Get secrets store""" + + @abstractmethod + async def get_user_secrets(self) -> UserSecrets | None: + """Get the user's secrets""" + def get_auth_type(self) -> AuthType | None: return None diff --git a/openhands/storage/data_models/settings.py b/openhands/storage/data_models/settings.py index d8e8ddbd93..711e6bcd54 100644 --- a/openhands/storage/data_models/settings.py +++ b/openhands/storage/data_models/settings.py @@ -12,7 +12,7 @@ from pydantic.json import pydantic_encoder from openhands.core.config.llm_config import LLMConfig from openhands.core.config.utils import load_app_config -from openhands.integrations.provider import SecretStore +from openhands.storage.data_models.user_secrets import UserSecrets class Settings(BaseModel): @@ -29,7 +29,8 @@ class Settings(BaseModel): llm_api_key: SecretStr | None = None llm_base_url: str | None = None remote_runtime_resource_factor: int | None = None - secrets_store: SecretStore = Field(default_factory=SecretStore, frozen=True) + # Planned to be removed from settings + secrets_store: UserSecrets = Field(default_factory=UserSecrets, frozen=True) enable_default_condenser: bool = True enable_sound_notifications: bool = False user_consents_to_analytics: bool | None = None @@ -55,7 +56,7 @@ class Settings(BaseModel): @model_validator(mode='before') @classmethod def convert_provider_tokens(cls, data: dict | object) -> dict | object: - """Convert provider tokens from JSON format to SecretStore format.""" + """Convert provider tokens from JSON format to UserSecrets format.""" if not isinstance(data, dict): return data @@ -66,10 +67,10 @@ class Settings(BaseModel): custom_secrets = secrets_store.get('custom_secrets') tokens = secrets_store.get('provider_tokens') - secret_store = SecretStore(provider_tokens={}, custom_secrets={}) + secret_store = UserSecrets(provider_tokens={}, custom_secrets={}) if isinstance(tokens, dict): - converted_store = SecretStore(provider_tokens=tokens) + converted_store = UserSecrets(provider_tokens=tokens) secret_store = secret_store.model_copy( update={'provider_tokens': converted_store.provider_tokens} ) @@ -77,7 +78,7 @@ class Settings(BaseModel): secret_store.model_copy(update={'provider_tokens': tokens}) if isinstance(custom_secrets, dict): - converted_store = SecretStore(custom_secrets=custom_secrets) + converted_store = UserSecrets(custom_secrets=custom_secrets) secret_store = secret_store.model_copy( update={'custom_secrets': converted_store.custom_secrets} ) @@ -89,15 +90,12 @@ class Settings(BaseModel): return data @field_serializer('secrets_store') - def secrets_store_serializer(self, secrets: SecretStore, info: SerializationInfo): + def secrets_store_serializer(self, secrets: UserSecrets, info: SerializationInfo): """Custom serializer for secrets store.""" + + """Force invalidate secret store""" return { - 'provider_tokens': secrets.provider_tokens_serializer( - secrets.provider_tokens, info - ), - 'custom_secrets': secrets.custom_secrets_serializer( - secrets.custom_secrets, info - ), + 'provider_tokens': {} } @staticmethod diff --git a/openhands/storage/data_models/user_secrets.py b/openhands/storage/data_models/user_secrets.py new file mode 100644 index 0000000000..964076db95 --- /dev/null +++ b/openhands/storage/data_models/user_secrets.py @@ -0,0 +1,122 @@ +from types import MappingProxyType +from typing import Any +from pydantic import ( + BaseModel, + ConfigDict, + Field, + SecretStr, + SerializationInfo, + field_serializer, + model_validator, +) +from pydantic.json import pydantic_encoder +from openhands.integrations.provider import CUSTOM_SECRETS_TYPE, PROVIDER_TOKEN_TYPE, PROVIDER_TOKEN_TYPE_WITH_JSON_SCHEMA, ProviderToken +from openhands.integrations.service_types import ProviderType + + +class UserSecrets(BaseModel): + provider_tokens: PROVIDER_TOKEN_TYPE_WITH_JSON_SCHEMA = Field( + default_factory=lambda: MappingProxyType({}) + ) + + custom_secrets: CUSTOM_SECRETS_TYPE = Field( + default_factory=lambda: MappingProxyType({}) + ) + + model_config = ConfigDict( + frozen=True, + validate_assignment=True, + arbitrary_types_allowed=True, + ) + + + @field_serializer('provider_tokens') + def provider_tokens_serializer( + self, provider_tokens: PROVIDER_TOKEN_TYPE, info: SerializationInfo + ) -> dict[str, dict[str, str | Any]]: + tokens = {} + expose_secrets = info.context and info.context.get('expose_secrets', False) + + for token_type, provider_token in provider_tokens.items(): + if not provider_token or not provider_token.token: + continue + + token_type_str = ( + token_type.value + if isinstance(token_type, ProviderType) + else str(token_type) + ) + tokens[token_type_str] = { + 'token': provider_token.token.get_secret_value() + if expose_secrets + else pydantic_encoder(provider_token.token), + 'user_id': provider_token.user_id, + } + + return tokens + + @field_serializer('custom_secrets') + def custom_secrets_serializer( + self, custom_secrets: CUSTOM_SECRETS_TYPE, info: SerializationInfo + ): + secrets = {} + expose_secrets = info.context and info.context.get('expose_secrets', False) + + if custom_secrets: + for secret_name, secret_key in custom_secrets.items(): + secrets[secret_name] = ( + secret_key.get_secret_value() + if expose_secrets + else pydantic_encoder(secret_key) + ) + return secrets + + @model_validator(mode='before') + @classmethod + def convert_dict_to_mappingproxy( + cls, data: dict[str, dict[str, Any] | MappingProxyType] | PROVIDER_TOKEN_TYPE + ) -> dict[str, MappingProxyType | None]: + """Custom deserializer to convert dictionary into MappingProxyType""" + if not isinstance(data, dict): + raise ValueError('UserSecrets must be initialized with a dictionary') + + new_data: dict[str, MappingProxyType | None] = {} + + if 'provider_tokens' in data: + tokens = data['provider_tokens'] + if isinstance( + tokens, dict + ): # Ensure conversion happens only for dict inputs + converted_tokens = {} + for key, value in tokens.items(): + try: + provider_type = ( + ProviderType(key) if isinstance(key, str) else key + ) + converted_tokens[provider_type] = ProviderToken.from_value( + value + ) + except ValueError: + # Skip invalid provider types or tokens + continue + + # Convert to MappingProxyType + new_data['provider_tokens'] = MappingProxyType(converted_tokens) + elif isinstance(tokens, MappingProxyType): + new_data['provider_tokens'] = tokens + + if 'custom_secrets' in data: + secrets = data['custom_secrets'] + if isinstance(secrets, dict): + converted_secrets = {} + for key, value in secrets.items(): + if isinstance(value, str): + converted_secrets[key] = SecretStr(value) + elif isinstance(value, SecretStr): + converted_secrets[key] = value + + new_data['custom_secrets'] = MappingProxyType(converted_secrets) + elif isinstance(secrets, MappingProxyType): + new_data['custom_secrets'] = secrets + + return new_data diff --git a/openhands/storage/settings/file_secrets_store.py b/openhands/storage/settings/file_secrets_store.py new file mode 100644 index 0000000000..8710ee3296 --- /dev/null +++ b/openhands/storage/settings/file_secrets_store.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +import json +from dataclasses import dataclass + +from openhands.core.config.app_config import AppConfig +from openhands.storage import get_file_store +from openhands.storage.data_models.user_secrets import UserSecrets +from openhands.storage.files import FileStore +from openhands.storage.settings.secret_store import SecretsStore +from openhands.utils.async_utils import call_sync_from_async + + +@dataclass +class FileSecretsStore(SecretsStore): + file_store: FileStore + path: str = 'secrets.json' + + async def load(self) -> UserSecrets | None: + try: + json_str = await call_sync_from_async(self.file_store.read, self.path) + kwargs = json.loads(json_str) + secrets = UserSecrets(**kwargs) + return secrets + except FileNotFoundError: + return None + + async def store(self, secrets: UserSecrets) -> None: + json_str = secrets.model_dump_json(context={'expose_secrets': True}) + await call_sync_from_async(self.file_store.write, self.path, json_str) + + @classmethod + async def get_instance( + cls, config: AppConfig, user_id: str | None + ) -> FileSecretsStore: + file_store = get_file_store(config.file_store, config.file_store_path) + return FileSecretsStore(file_store) \ No newline at end of file diff --git a/openhands/storage/settings/secret_store.py b/openhands/storage/settings/secret_store.py new file mode 100644 index 0000000000..6d6777628e --- /dev/null +++ b/openhands/storage/settings/secret_store.py @@ -0,0 +1,26 @@ + +from __future__ import annotations + +from abc import ABC, abstractmethod +from openhands.core.config.app_config import AppConfig +from openhands.storage.data_models.user_secrets import UserSecrets + + + +class SecretsStore(ABC): + """Storage for secrets. May or may not support multiple users depending on the environment.""" + + @abstractmethod + async def load(self) -> UserSecrets | None: + """Load secrets.""" + + @abstractmethod + async def store(self, secrets: UserSecrets) -> None: + """Store secrets.""" + + @classmethod + @abstractmethod + async def get_instance( + cls, config: AppConfig, user_id: str | None + ) -> SecretsStore: + """Get a store for the user represented by the token given.""" \ No newline at end of file diff --git a/tests/unit/test_provider_immutability.py b/tests/unit/test_provider_immutability.py index 0b0e626e3e..5e7302d074 100644 --- a/tests/unit/test_provider_immutability.py +++ b/tests/unit/test_provider_immutability.py @@ -8,11 +8,9 @@ from openhands.integrations.provider import ( ProviderHandler, ProviderToken, ProviderType, - SecretStore, ) -from openhands.server.routes.settings import convert_to_settings -from openhands.server.settings import POSTSettingsModel from openhands.storage.data_models.settings import Settings +from openhands.storage.data_models.user_secrets import UserSecrets def test_provider_token_immutability(): @@ -36,8 +34,8 @@ def test_provider_token_immutability(): def test_secret_store_immutability(): - """Test that SecretStore is immutable""" - store = SecretStore( + """Test that UserSecrets is immutable""" + store = UserSecrets( provider_tokens={ProviderType.GITHUB: ProviderToken(token=SecretStr('test'))} ) @@ -71,7 +69,7 @@ def test_secret_store_immutability(): def test_settings_immutability(): """Test that Settings secrets_store is immutable""" settings = Settings( - secrets_store=SecretStore( + secrets_store=UserSecrets( provider_tokens={ ProviderType.GITHUB: ProviderToken(token=SecretStr('test')) } @@ -80,7 +78,7 @@ def test_settings_immutability(): # Test direct modification of secrets_store with pytest.raises(ValidationError): - settings.secrets_store = SecretStore() + settings.secrets_store = UserSecrets() # Test nested modification attempts with pytest.raises((TypeError, AttributeError)): @@ -89,7 +87,7 @@ def test_settings_immutability(): ) # Test model_copy creates new instance - new_store = SecretStore( + new_store = UserSecrets( provider_tokens={ ProviderType.GITHUB: ProviderToken(token=SecretStr('new_token')) } @@ -116,41 +114,6 @@ def test_settings_immutability(): ].token = SecretStr('') -def test_post_settings_conversion(): - """Test that POSTSettingsModel correctly converts to Settings""" - # Create POST model with token data - github_token = ProviderToken(token=SecretStr('test_token')) - gitlab_token = ProviderToken(token=SecretStr('gitlab_token')) - post_data = POSTSettingsModel( - provider_tokens={ - ProviderType.GITHUB: github_token, - ProviderType.GITLAB: gitlab_token, - } - ) - - # Convert to settings using convert_to_settings function - settings = convert_to_settings(post_data) - - # Verify tokens were converted correctly - assert ( - settings.secrets_store.provider_tokens[ - ProviderType.GITHUB - ].token.get_secret_value() - == 'test_token' - ) - assert ( - settings.secrets_store.provider_tokens[ - ProviderType.GITLAB - ].token.get_secret_value() - == 'gitlab_token' - ) - assert settings.secrets_store.provider_tokens[ProviderType.GITLAB].user_id is None - - # Verify immutability of converted settings - with pytest.raises(ValidationError): - settings.secrets_store = SecretStore() - - def test_provider_handler_immutability(): """Test that ProviderHandler maintains token immutability""" @@ -178,10 +141,10 @@ def test_provider_handler_immutability(): def test_token_conversion(): - """Test token conversion in SecretStore.create""" + """Test token conversion in UserSecrets.create""" # Test with string token store1 = Settings( - secrets_store=SecretStore( + secrets_store=UserSecrets( provider_tokens={ ProviderType.GITHUB: ProviderToken(token=SecretStr('test_token')) } @@ -197,7 +160,7 @@ def test_token_conversion(): assert store1.secrets_store.provider_tokens[ProviderType.GITHUB].user_id is None # Test with dict token - store2 = SecretStore( + store2 = UserSecrets( provider_tokens={'github': {'token': 'test_token', 'user_id': 'user1'}} ) assert ( @@ -208,14 +171,14 @@ def test_token_conversion(): # Test with ProviderToken token = ProviderToken(token=SecretStr('test_token'), user_id='user2') - store3 = SecretStore(provider_tokens={ProviderType.GITHUB: token}) + store3 = UserSecrets(provider_tokens={ProviderType.GITHUB: token}) assert ( store3.provider_tokens[ProviderType.GITHUB].token.get_secret_value() == 'test_token' ) assert store3.provider_tokens[ProviderType.GITHUB].user_id == 'user2' - store4 = SecretStore( + store4 = UserSecrets( provider_tokens={ ProviderType.GITHUB: 123 # Invalid type } @@ -224,10 +187,10 @@ def test_token_conversion(): assert ProviderType.GITHUB not in store4.provider_tokens # Test with empty/None token - store5 = SecretStore(provider_tokens={ProviderType.GITHUB: None}) + store5 = UserSecrets(provider_tokens={ProviderType.GITHUB: None}) assert ProviderType.GITHUB not in store5.provider_tokens - store6 = SecretStore( + store6 = UserSecrets( provider_tokens={ 'invalid_provider': 'test_token' # Invalid provider type } diff --git a/tests/unit/test_secret_store.py b/tests/unit/test_secret_store.py index 9b4fa0fe44..a75a47fa8f 100644 --- a/tests/unit/test_secret_store.py +++ b/tests/unit/test_secret_store.py @@ -5,12 +5,13 @@ from typing import Any from pydantic import SecretStr -from openhands.integrations.provider import ProviderToken, ProviderType, SecretStore +from openhands.integrations.provider import ProviderToken, ProviderType +from openhands.storage.data_models.user_secrets import UserSecrets -class TestSecretStore: +class TestUserSecrets: def test_adding_only_provider_tokens(self): - """Test adding only provider tokens to the SecretStore.""" + """Test adding only provider tokens to the UserSecrets.""" # Create provider tokens github_token = ProviderToken( token=SecretStr('github-token-123'), user_id='user1' @@ -26,7 +27,7 @@ class TestSecretStore: } # Initialize the store with a dict that will be converted to MappingProxyType - store = SecretStore(provider_tokens=provider_tokens) + store = UserSecrets(provider_tokens=provider_tokens) # Verify the tokens were added correctly assert isinstance(store.provider_tokens, MappingProxyType) @@ -47,7 +48,7 @@ class TestSecretStore: assert len(store.custom_secrets) == 0 def test_adding_only_custom_secrets(self): - """Test adding only custom secrets to the SecretStore.""" + """Test adding only custom secrets to the UserSecrets.""" # Create custom secrets custom_secrets = { 'API_KEY': 'api-key-123', @@ -55,7 +56,7 @@ class TestSecretStore: } # Initialize the store with custom secrets - store = SecretStore(custom_secrets=custom_secrets) + store = UserSecrets(custom_secrets=custom_secrets) # Verify the custom secrets were added correctly assert isinstance(store.custom_secrets, MappingProxyType) @@ -84,7 +85,7 @@ class TestSecretStore: ) # Test with dict for provider_tokens and MappingProxyType for custom_secrets - store1 = SecretStore( + store1 = UserSecrets( provider_tokens=provider_tokens_dict, custom_secrets=custom_secrets_proxy ) @@ -102,7 +103,7 @@ class TestSecretStore: ) provider_tokens_proxy = MappingProxyType({ProviderType.GITLAB: provider_token}) - store2 = SecretStore( + store2 = UserSecrets( provider_tokens=provider_tokens_proxy, custom_secrets=custom_secrets_dict ) @@ -122,7 +123,7 @@ class TestSecretStore: ) custom_secret = {'API_KEY': SecretStr('api-key-123')} - initial_store = SecretStore( + initial_store = UserSecrets( provider_tokens=MappingProxyType({ProviderType.GITHUB: github_token}), custom_secrets=MappingProxyType(custom_secret), ) @@ -182,14 +183,14 @@ class TestSecretStore: ) def test_serialization_with_expose_secrets(self): - """Test serializing the SecretStore with expose_secrets=True.""" + """Test serializing the UserSecrets with expose_secrets=True.""" # Create a store with both provider tokens and custom secrets github_token = ProviderToken( token=SecretStr('github-token-123'), user_id='user1' ) custom_secrets = {'API_KEY': SecretStr('api-key-123')} - store = SecretStore( + store = UserSecrets( provider_tokens=MappingProxyType({ProviderType.GITHUB: github_token}), custom_secrets=MappingProxyType(custom_secrets), ) @@ -255,7 +256,7 @@ class TestSecretStore: } # Initialize the store - store = SecretStore(provider_tokens=mixed_provider_tokens) + store = UserSecrets(provider_tokens=mixed_provider_tokens) # Verify all tokens are converted to SecretStr assert isinstance(store.provider_tokens, MappingProxyType) @@ -282,7 +283,7 @@ class TestSecretStore: } # Initialize the store - store = SecretStore(custom_secrets=custom_secrets_dict) + store = UserSecrets(custom_secrets=custom_secrets_dict) # Verify all secrets are converted to SecretStr assert isinstance(store.custom_secrets, MappingProxyType) diff --git a/tests/unit/test_secrets_api.py b/tests/unit/test_secrets_api.py index 4f5130536c..cd49a82b3b 100644 --- a/tests/unit/test_secrets_api.py +++ b/tests/unit/test_secrets_api.py @@ -1,7 +1,6 @@ """Tests for the custom secrets API endpoints.""" # flake8: noqa: E501 -from contextlib import contextmanager from unittest.mock import AsyncMock, patch import pytest @@ -9,457 +8,278 @@ from fastapi import FastAPI from fastapi.testclient import TestClient from pydantic import SecretStr -from openhands.integrations.provider import ProviderToken, ProviderType, SecretStore -from openhands.server.routes.settings import app as settings_app -from openhands.server.settings import Settings -from openhands.storage.memory import InMemoryFileStore -from openhands.storage.settings.file_settings_store import FileSettingsStore +from openhands.integrations.provider import ProviderToken, ProviderType +from openhands.server.routes.secrets import app as secrets_app +from openhands.storage import get_file_store +from openhands.storage.data_models.user_secrets import UserSecrets +from openhands.storage.settings.file_secrets_store import FileSecretsStore @pytest.fixture def test_client(): """Create a test client for the settings API.""" app = FastAPI() - app.include_router(settings_app) + app.include_router(secrets_app) return TestClient(app) -@contextmanager -def patch_file_settings_store(): - store = FileSettingsStore(InMemoryFileStore()) +@pytest.fixture +def temp_dir(tmp_path_factory: pytest.TempPathFactory) -> str: + return str(tmp_path_factory.mktemp('secrets_store')) + + +@pytest.fixture +def file_secrets_store(temp_dir): + file_store = get_file_store('local', temp_dir) + store = FileSecretsStore(file_store) with patch( - 'openhands.storage.settings.file_settings_store.FileSettingsStore.get_instance', + 'openhands.storage.settings.file_secrets_store.FileSecretsStore.get_instance', AsyncMock(return_value=store), ): yield store @pytest.mark.asyncio -async def test_load_custom_secrets_names(test_client): +async def test_load_custom_secrets_names(test_client, file_secrets_store): """Test loading custom secrets names.""" - with patch_file_settings_store() as file_settings_store: - # Create initial settings with custom secrets - custom_secrets = { - 'API_KEY': SecretStr('api-key-value'), - 'DB_PASSWORD': SecretStr('db-password-value'), - } - provider_tokens = { - ProviderType.GITHUB: ProviderToken(token=SecretStr('github-token')) - } - secret_store = SecretStore( - custom_secrets=custom_secrets, provider_tokens=provider_tokens - ) - initial_settings = Settings( - language='en', - agent='test-agent', - llm_api_key=SecretStr('test-llm-key'), - secrets_store=secret_store, - ) - # Store the initial settings - await file_settings_store.store(initial_settings) + # Create initial settings with custom secrets + custom_secrets = { + 'API_KEY': SecretStr('api-key-value'), + 'DB_PASSWORD': SecretStr('db-password-value'), + } + provider_tokens = { + ProviderType.GITHUB: ProviderToken(token=SecretStr('github-token')) + } + user_secrets = UserSecrets( + custom_secrets=custom_secrets, provider_tokens=provider_tokens + ) - # Make the GET request - response = test_client.get('/api/secrets') - assert response.status_code == 200 + # Store the initial settings + await file_secrets_store.store(user_secrets) - # Check the response - data = response.json() - assert 'custom_secrets' in data - assert sorted(data['custom_secrets']) == ['API_KEY', 'DB_PASSWORD'] + # Make the GET request + response = test_client.get('/api/secrets') + assert response.status_code == 200 - # Verify that the original settings were not modified - stored_settings = await file_settings_store.load() - assert ( - stored_settings.secrets_store.custom_secrets['API_KEY'].get_secret_value() - == 'api-key-value' - ) - assert ( - stored_settings.secrets_store.custom_secrets[ - 'DB_PASSWORD' - ].get_secret_value() - == 'db-password-value' - ) - assert ProviderType.GITHUB in stored_settings.secrets_store.provider_tokens + # Check the response + data = response.json() + assert 'custom_secrets' in data + assert sorted(data['custom_secrets']) == ['API_KEY', 'DB_PASSWORD'] + + # Verify that the original settings were not modified + stored_settings = await file_secrets_store.load() + assert ( + stored_settings.custom_secrets['API_KEY'].get_secret_value() == 'api-key-value' + ) + assert ( + stored_settings.custom_secrets['DB_PASSWORD'].get_secret_value() + == 'db-password-value' + ) + assert ProviderType.GITHUB in stored_settings.provider_tokens @pytest.mark.asyncio -async def test_load_custom_secrets_names_empty(test_client): +async def test_load_custom_secrets_names_empty(test_client, file_secrets_store): """Test loading custom secrets names when there are no custom secrets.""" - with patch_file_settings_store() as file_settings_store: - # Create initial settings with no custom secrets - provider_tokens = { - ProviderType.GITHUB: ProviderToken(token=SecretStr('github-token')) - } - secret_store = SecretStore(provider_tokens=provider_tokens) - initial_settings = Settings( - language='en', - agent='test-agent', - llm_api_key=SecretStr('test-llm-key'), - secrets_store=secret_store, - ) + # Create initial settings with no custom secrets + provider_tokens = { + ProviderType.GITHUB: ProviderToken(token=SecretStr('github-token')) + } + user_secrets = UserSecrets(provider_tokens=provider_tokens) - # Store the initial settings - await file_settings_store.store(initial_settings) + # Store the initial settings + await file_secrets_store.store(user_secrets) - # Make the GET request - response = test_client.get('/api/secrets') - assert response.status_code == 200 + # Make the GET request + response = test_client.get('/api/secrets') + assert response.status_code == 200 - # Check the response - data = response.json() - assert 'custom_secrets' in data - assert data['custom_secrets'] == [] + # Check the response + data = response.json() + assert 'custom_secrets' in data + assert data['custom_secrets'] == [] @pytest.mark.asyncio -async def test_add_custom_secret(test_client): +async def test_add_custom_secret(test_client, file_secrets_store): """Test adding a new custom secret.""" - with patch_file_settings_store() as file_settings_store: - # Create initial settings with provider tokens but no custom secrets - provider_tokens = { - ProviderType.GITHUB: ProviderToken(token=SecretStr('github-token')) - } - secret_store = SecretStore(provider_tokens=provider_tokens) - initial_settings = Settings( - language='en', - agent='test-agent', - llm_api_key=SecretStr('test-llm-key'), - secrets_store=secret_store, - ) + # Create initial settings with provider tokens but no custom secrets + provider_tokens = { + ProviderType.GITHUB: ProviderToken(token=SecretStr('github-token')) + } + user_secrets = UserSecrets(provider_tokens=provider_tokens) - # Store the initial settings - await file_settings_store.store(initial_settings) + # Store the initial settings + await file_secrets_store.store(user_secrets) - # Make the POST request to add a custom secret - add_secret_data = {'custom_secrets': {'API_KEY': 'api-key-value'}} - response = test_client.post('/api/secrets', json=add_secret_data) - assert response.status_code == 200 + # Make the POST request to add a custom secret + add_secret_data = {'custom_secrets': {'API_KEY': 'api-key-value'}} + response = test_client.post('/api/secrets', json=add_secret_data) + assert response.status_code == 200 - # Verify that the settings were stored with the new secret - stored_settings = await file_settings_store.load() + # Verify that the settings were stored with the new secret + stored_settings = await file_secrets_store.load() - # Check that the secret was added - assert 'API_KEY' in stored_settings.secrets_store.custom_secrets - assert ( - stored_settings.secrets_store.custom_secrets['API_KEY'].get_secret_value() - == 'api-key-value' - ) + # Check that the secret was added + assert 'API_KEY' in stored_settings.custom_secrets + assert ( + stored_settings.custom_secrets['API_KEY'].get_secret_value() == 'api-key-value' + ) + + +@pytest.mark.asyncio +async def test_update_existing_custom_secret(test_client, file_secrets_store): + """Test updating an existing custom secret.""" + + # Create initial settings with a custom secret + custom_secrets = {'API_KEY': SecretStr('old-api-key')} + provider_tokens = { + ProviderType.GITHUB: ProviderToken(token=SecretStr('github-token')) + } + user_secrets = UserSecrets( + custom_secrets=custom_secrets, provider_tokens=provider_tokens + ) + + # Store the initial settings + await file_secrets_store.store(user_secrets) + + # Make the POST request to update the custom secret + update_secret_data = {'custom_secrets': {'API_KEY': 'new-api-key'}} + response = test_client.put('/api/secrets/API_KEY', json=update_secret_data) + assert response.status_code == 200 + + # Verify that the settings were stored with the updated secret + stored_settings = await file_secrets_store.load() + + # Check that the secret was updated + assert 'API_KEY' in stored_settings.custom_secrets + assert stored_settings.custom_secrets['API_KEY'].get_secret_value() == 'new-api-key' # Check that other settings were preserved - assert stored_settings.language == 'en' - assert stored_settings.agent == 'test-agent' - assert stored_settings.llm_api_key.get_secret_value() == 'test-llm-key' + assert ProviderType.GITHUB in stored_settings.provider_tokens @pytest.mark.asyncio -async def test_update_existing_custom_secret(test_client): - """Test updating an existing custom secret.""" - with patch_file_settings_store() as file_settings_store: - # Create initial settings with a custom secret - custom_secrets = {'API_KEY': SecretStr('old-api-key')} - provider_tokens = { - ProviderType.GITHUB: ProviderToken(token=SecretStr('github-token')) - } - secret_store = SecretStore( - custom_secrets=custom_secrets, provider_tokens=provider_tokens - ) - initial_settings = Settings( - language='en', - agent='test-agent', - llm_api_key=SecretStr('test-llm-key'), - secrets_store=secret_store, - ) - - # Store the initial settings - await file_settings_store.store(initial_settings) - - # Make the POST request to update the custom secret - update_secret_data = {'custom_secrets': {'API_KEY': 'new-api-key'}} - response = test_client.post('/api/secrets', json=update_secret_data) - assert response.status_code == 200 - - # Verify that the settings were stored with the updated secret - stored_settings = await file_settings_store.load() - - # Check that the secret was updated - assert 'API_KEY' in stored_settings.secrets_store.custom_secrets - assert ( - stored_settings.secrets_store.custom_secrets['API_KEY'].get_secret_value() - == 'new-api-key' - ) - - # Check that other settings were preserved - assert stored_settings.language == 'en' - assert stored_settings.agent == 'test-agent' - assert stored_settings.llm_api_key.get_secret_value() == 'test-llm-key' - assert ProviderType.GITHUB in stored_settings.secrets_store.provider_tokens - - -@pytest.mark.asyncio -async def test_add_multiple_custom_secrets(test_client): +async def test_add_multiple_custom_secrets(test_client, file_secrets_store): """Test adding multiple custom secrets at once.""" - with patch_file_settings_store() as file_settings_store: - # Create initial settings with one custom secret - custom_secrets = {'EXISTING_SECRET': SecretStr('existing-value')} - provider_tokens = { - ProviderType.GITHUB: ProviderToken(token=SecretStr('github-token')) + + # Create initial settings with one custom secret + custom_secrets = {'EXISTING_SECRET': SecretStr('existing-value')} + provider_tokens = { + ProviderType.GITHUB: ProviderToken(token=SecretStr('github-token')) + } + user_secrets = UserSecrets( + custom_secrets=custom_secrets, provider_tokens=provider_tokens + ) + + # Store the initial settings + await file_secrets_store.store(user_secrets) + + # Make the POST request to add multiple custom secrets + add_secrets_data = { + 'custom_secrets': { + 'API_KEY': 'api-key-value', + 'DB_PASSWORD': 'db-password-value', } - secret_store = SecretStore( - custom_secrets=custom_secrets, provider_tokens=provider_tokens - ) - initial_settings = Settings( - language='en', - agent='test-agent', - llm_api_key=SecretStr('test-llm-key'), - secrets_store=secret_store, - ) + } + response = test_client.post('/api/secrets', json=add_secrets_data) + assert response.status_code == 200 - # Store the initial settings - await file_settings_store.store(initial_settings) + # Verify that the settings were stored with the new secrets + stored_settings = await file_secrets_store.load() - # Make the POST request to add multiple custom secrets - add_secrets_data = { - 'custom_secrets': { - 'API_KEY': 'api-key-value', - 'DB_PASSWORD': 'db-password-value', - } - } - response = test_client.post('/api/secrets', json=add_secrets_data) - assert response.status_code == 200 + # Check that the new secrets were added + assert 'API_KEY' in stored_settings.custom_secrets + assert ( + stored_settings.custom_secrets['API_KEY'].get_secret_value() == 'api-key-value' + ) + assert 'DB_PASSWORD' in stored_settings.custom_secrets + assert ( + stored_settings.custom_secrets['DB_PASSWORD'].get_secret_value() + == 'db-password-value' + ) - # Verify that the settings were stored with the new secrets - stored_settings = await file_settings_store.load() + # Check that existing secrets were preserved + assert 'EXISTING_SECRET' in stored_settings.custom_secrets + assert ( + stored_settings.custom_secrets['EXISTING_SECRET'].get_secret_value() + == 'existing-value' + ) - # Check that the new secrets were added - assert 'API_KEY' in stored_settings.secrets_store.custom_secrets - assert ( - stored_settings.secrets_store.custom_secrets['API_KEY'].get_secret_value() - == 'api-key-value' - ) - assert 'DB_PASSWORD' in stored_settings.secrets_store.custom_secrets - assert ( - stored_settings.secrets_store.custom_secrets[ - 'DB_PASSWORD' - ].get_secret_value() - == 'db-password-value' - ) - - # Check that existing secrets were preserved - assert 'EXISTING_SECRET' in stored_settings.secrets_store.custom_secrets - assert ( - stored_settings.secrets_store.custom_secrets[ - 'EXISTING_SECRET' - ].get_secret_value() - == 'existing-value' - ) - - # Check that other settings were preserved - assert stored_settings.language == 'en' - assert stored_settings.agent == 'test-agent' - assert stored_settings.llm_api_key.get_secret_value() == 'test-llm-key' - assert ProviderType.GITHUB in stored_settings.secrets_store.provider_tokens + # Check that other settings were preserved + assert ProviderType.GITHUB in stored_settings.provider_tokens @pytest.mark.asyncio -async def test_delete_custom_secret(test_client): +async def test_delete_custom_secret(test_client, file_secrets_store): """Test deleting a custom secret.""" - with patch_file_settings_store() as file_settings_store: - # Create initial settings with multiple custom secrets - custom_secrets = { - 'API_KEY': SecretStr('api-key-value'), - 'DB_PASSWORD': SecretStr('db-password-value'), - } - provider_tokens = { - ProviderType.GITHUB: ProviderToken(token=SecretStr('github-token')) - } - secret_store = SecretStore( - custom_secrets=custom_secrets, provider_tokens=provider_tokens - ) - initial_settings = Settings( - language='en', - agent='test-agent', - llm_api_key=SecretStr('test-llm-key'), - secrets_store=secret_store, - ) - # Store the initial settings - await file_settings_store.store(initial_settings) + # Create initial settings with multiple custom secrets + custom_secrets = { + 'API_KEY': SecretStr('api-key-value'), + 'DB_PASSWORD': SecretStr('db-password-value'), + } + provider_tokens = { + ProviderType.GITHUB: ProviderToken(token=SecretStr('github-token')) + } + user_secrets = UserSecrets( + custom_secrets=custom_secrets, provider_tokens=provider_tokens + ) - # Make the DELETE request to delete a custom secret - response = test_client.delete('/api/secrets/API_KEY') - assert response.status_code == 200 + # Store the initial settings + await file_secrets_store.store(user_secrets) - # Verify that the settings were stored without the deleted secret - stored_settings = await file_settings_store.load() + # Make the DELETE request to delete a custom secret + response = test_client.delete('/api/secrets/API_KEY') + assert response.status_code == 200 - # Check that the specified secret was deleted - assert 'API_KEY' not in stored_settings.secrets_store.custom_secrets + # Verify that the settings were stored without the deleted secret + stored_settings = await file_secrets_store.load() - # Check that other secrets were preserved - assert 'DB_PASSWORD' in stored_settings.secrets_store.custom_secrets - assert ( - stored_settings.secrets_store.custom_secrets[ - 'DB_PASSWORD' - ].get_secret_value() - == 'db-password-value' - ) + # Check that the specified secret was deleted + assert 'API_KEY' not in stored_settings.custom_secrets - # Check that other settings were preserved - assert stored_settings.language == 'en' - assert stored_settings.agent == 'test-agent' - assert stored_settings.llm_api_key.get_secret_value() == 'test-llm-key' - assert ProviderType.GITHUB in stored_settings.secrets_store.provider_tokens + # Check that other secrets were preserved + assert 'DB_PASSWORD' in stored_settings.custom_secrets + assert ( + stored_settings.custom_secrets['DB_PASSWORD'].get_secret_value() + == 'db-password-value' + ) + + # Check that other settings were preserved + assert ProviderType.GITHUB in stored_settings.provider_tokens @pytest.mark.asyncio -async def test_delete_nonexistent_custom_secret(test_client): +async def test_delete_nonexistent_custom_secret(test_client, file_secrets_store): """Test deleting a custom secret that doesn't exist.""" - with patch_file_settings_store() as file_settings_store: - # Create initial settings with a custom secret - custom_secrets = {'API_KEY': SecretStr('api-key-value')} - provider_tokens = { - ProviderType.GITHUB: ProviderToken(token=SecretStr('github-token')) - } - secret_store = SecretStore( - custom_secrets=custom_secrets, provider_tokens=provider_tokens - ) - initial_settings = Settings( - language='en', - agent='test-agent', - llm_api_key=SecretStr('test-llm-key'), - secrets_store=secret_store, - ) - # Store the initial settings - await file_settings_store.store(initial_settings) + # Create initial settings with a custom secret + custom_secrets = {'API_KEY': SecretStr('api-key-value')} + provider_tokens = { + ProviderType.GITHUB: ProviderToken(token=SecretStr('github-token')) + } + user_secrets = UserSecrets( + custom_secrets=custom_secrets, provider_tokens=provider_tokens + ) - # Make the DELETE request to delete a nonexistent custom secret - response = test_client.delete('/api/secrets/NONEXISTENT_KEY') - assert response.status_code == 200 + # Store the initial settings + await file_secrets_store.store(user_secrets) - # Verify that the settings were stored without changes to existing secrets - stored_settings = await file_settings_store.load() + # Make the DELETE request to delete a nonexistent custom secret + response = test_client.delete('/api/secrets/NONEXISTENT_KEY') + assert response.status_code == 404 - # Check that the existing secret was preserved - assert 'API_KEY' in stored_settings.secrets_store.custom_secrets - assert ( - stored_settings.secrets_store.custom_secrets['API_KEY'].get_secret_value() - == 'api-key-value' - ) + # Verify that the settings were stored without changes to existing secrets + stored_settings = await file_secrets_store.load() - # Check that other settings were preserved - assert stored_settings.language == 'en' - assert stored_settings.agent == 'test-agent' - assert stored_settings.llm_api_key.get_secret_value() == 'test-llm-key' - assert ProviderType.GITHUB in stored_settings.secrets_store.provider_tokens + # Check that the existing secret was preserved + assert 'API_KEY' in stored_settings.custom_secrets + assert ( + stored_settings.custom_secrets['API_KEY'].get_secret_value() == 'api-key-value' + ) - -@pytest.mark.asyncio -async def test_custom_secrets_operations_preserve_settings(test_client): - """Test that operations on custom secrets preserve all other settings.""" - with patch_file_settings_store() as file_settings_store: - # Create initial settings with comprehensive data - custom_secrets = {'INITIAL_SECRET': SecretStr('initial-value')} - provider_tokens = { - ProviderType.GITHUB: ProviderToken(token=SecretStr('github-token')), - ProviderType.GITLAB: ProviderToken(token=SecretStr('gitlab-token')), - } - secret_store = SecretStore( - custom_secrets=custom_secrets, provider_tokens=provider_tokens - ) - initial_settings = Settings( - language='en', - agent='test-agent', - max_iterations=100, - security_analyzer='default', - confirmation_mode=True, - llm_model='test-model', - llm_api_key=SecretStr('test-llm-key'), - llm_base_url='https://test.com', - remote_runtime_resource_factor=2, - enable_default_condenser=True, - enable_sound_notifications=False, - user_consents_to_analytics=True, - secrets_store=secret_store, - ) - - # Store the initial settings - await file_settings_store.store(initial_settings) - - # 1. Test adding a new custom secret - add_secret_data = {'custom_secrets': {'NEW_SECRET': 'new-value'}} - response = test_client.post('/api/secrets', json=add_secret_data) - assert response.status_code == 200 - - # Verify all settings are preserved - stored_settings = await file_settings_store.load() - assert stored_settings.language == 'en' - assert stored_settings.agent == 'test-agent' - assert stored_settings.max_iterations == 100 - assert stored_settings.security_analyzer == 'default' - assert stored_settings.confirmation_mode is True - assert stored_settings.llm_model == 'test-model' - assert stored_settings.llm_api_key.get_secret_value() == 'test-llm-key' - assert stored_settings.llm_base_url == 'https://test.com' - assert stored_settings.remote_runtime_resource_factor == 2 - assert stored_settings.enable_default_condenser is True - assert stored_settings.enable_sound_notifications is False - assert stored_settings.user_consents_to_analytics is True - assert len(stored_settings.secrets_store.provider_tokens) == 2 - assert ProviderType.GITHUB in stored_settings.secrets_store.provider_tokens - assert ProviderType.GITLAB in stored_settings.secrets_store.provider_tokens - assert ( - stored_settings.secrets_store.custom_secrets[ - 'INITIAL_SECRET' - ].get_secret_value() - == 'initial-value' - ) - assert ( - stored_settings.secrets_store.custom_secrets[ - 'NEW_SECRET' - ].get_secret_value() - == 'new-value' - ) - - # 2. Test updating an existing custom secret - update_secret_data = {'custom_secrets': {'INITIAL_SECRET': 'updated-value'}} - response = test_client.post('/api/secrets', json=update_secret_data) - assert response.status_code == 200 - - # Verify all settings are still preserved - stored_settings = await file_settings_store.load() - assert stored_settings.language == 'en' - assert stored_settings.agent == 'test-agent' - assert stored_settings.max_iterations == 100 - assert stored_settings.security_analyzer == 'default' - assert stored_settings.confirmation_mode is True - assert stored_settings.llm_model == 'test-model' - assert stored_settings.llm_api_key.get_secret_value() == 'test-llm-key' - assert stored_settings.llm_base_url == 'https://test.com' - assert stored_settings.remote_runtime_resource_factor == 2 - assert stored_settings.enable_default_condenser is True - assert stored_settings.enable_sound_notifications is False - assert stored_settings.user_consents_to_analytics is True - assert len(stored_settings.secrets_store.provider_tokens) == 2 - - # 3. Test deleting a custom secret - response = test_client.delete('/api/secrets/NEW_SECRET') - assert response.status_code == 200 - - # Verify all settings are still preserved - stored_settings = await file_settings_store.load() - assert stored_settings.language == 'en' - assert stored_settings.agent == 'test-agent' - assert stored_settings.max_iterations == 100 - assert stored_settings.security_analyzer == 'default' - assert stored_settings.confirmation_mode is True - assert stored_settings.llm_model == 'test-model' - assert stored_settings.llm_api_key.get_secret_value() == 'test-llm-key' - assert stored_settings.llm_base_url == 'https://test.com' - assert stored_settings.remote_runtime_resource_factor == 2 - assert stored_settings.enable_default_condenser is True - assert stored_settings.enable_sound_notifications is False - assert stored_settings.user_consents_to_analytics is True - assert len(stored_settings.secrets_store.provider_tokens) == 2 + # Check that other settings were preserved + assert ProviderType.GITHUB in stored_settings.provider_tokens diff --git a/tests/unit/test_settings.py b/tests/unit/test_settings.py index 84b0e3d08f..c74e3982cb 100644 --- a/tests/unit/test_settings.py +++ b/tests/unit/test_settings.py @@ -6,9 +6,7 @@ from openhands.core.config.app_config import AppConfig from openhands.core.config.llm_config import LLMConfig from openhands.core.config.sandbox_config import SandboxConfig from openhands.core.config.security_config import SecurityConfig -from openhands.integrations.provider import ProviderToken, ProviderType, SecretStore from openhands.server.routes.settings import convert_to_settings -from openhands.server.settings import POSTSettingsModel from openhands.storage.data_models.settings import Settings @@ -84,46 +82,17 @@ def test_settings_handles_sensitive_data(): llm_api_key='test-key', llm_base_url='https://test.example.com', remote_runtime_resource_factor=2, - secrets_store=SecretStore( - provider_tokens={ - ProviderType.GITHUB: ProviderToken( - token=SecretStr('test-token'), - user_id=None, - ) - } - ), ) assert str(settings.llm_api_key) == '**********' - assert ( - str(settings.secrets_store.provider_tokens[ProviderType.GITHUB].token) - == '**********' - ) - assert settings.llm_api_key.get_secret_value() == 'test-key' - assert ( - settings.secrets_store.provider_tokens[ - ProviderType.GITHUB - ].token.get_secret_value() - == 'test-token' - ) def test_convert_to_settings(): - github_token = ProviderToken(token=SecretStr('test-token')) - settings_with_token_data = POSTSettingsModel( + settings_with_token_data = Settings( llm_api_key='test-key', - provider_tokens={ - ProviderType.GITHUB: github_token, - }, ) settings = convert_to_settings(settings_with_token_data) assert settings.llm_api_key.get_secret_value() == 'test-key' - assert ( - settings.secrets_store.provider_tokens[ - ProviderType.GITHUB - ].token.get_secret_value() - == 'test-token' - ) diff --git a/tests/unit/test_settings_api.py b/tests/unit/test_settings_api.py index c9d1ec5c3f..399cf9ea37 100644 --- a/tests/unit/test_settings_api.py +++ b/tests/unit/test_settings_api.py @@ -8,8 +8,10 @@ from pydantic import SecretStr from openhands.integrations.provider import ProviderToken, ProviderType from openhands.server.app import app from openhands.server.user_auth.user_auth import UserAuth +from openhands.storage.data_models.user_secrets import UserSecrets from openhands.storage.memory import InMemoryFileStore from openhands.storage.settings.file_settings_store import FileSettingsStore +from openhands.storage.settings.secret_store import SecretsStore from openhands.storage.settings.settings_store import SettingsStore @@ -34,6 +36,12 @@ class MockUserAuth(UserAuth): async def get_user_settings_store(self) -> SettingsStore | None: return self._settings_store + async def get_secrets_store(self) -> SecretsStore | None: + return None + + async def get_user_secrets(self) -> UserSecrets | None: + return None + @classmethod async def get_instance(cls, request: Request) -> UserAuth: return MockUserAuth() @@ -47,10 +55,6 @@ def test_client(): 'openhands.server.user_auth.user_auth.UserAuth.get_instance', return_value=MockUserAuth(), ), - patch( - 'openhands.server.routes.settings.validate_provider_token', - return_value=ProviderType.GITHUB, - ), patch( 'openhands.storage.settings.file_settings_store.FileSettingsStore.get_instance', AsyncMock(return_value=FileSettingsStore(InMemoryFileStore())), @@ -75,7 +79,6 @@ async def test_settings_api_endpoints(test_client): 'llm_api_key': 'test-key', 'llm_base_url': 'https://test.com', 'remote_runtime_resource_factor': 2, - 'provider_tokens': {'github': {'token': 'test-token'}}, } # Make the POST request to store settings @@ -98,9 +101,6 @@ async def test_settings_api_endpoints(test_client): response = test_client.post('/api/settings', json=partial_settings) assert response.status_code == 200 - # Test the unset-settings-tokens endpoint - response = test_client.post('/api/unset-settings-tokens') + # Test the unset-provider-tokens endpoint + response = test_client.post('/api/unset-provider-tokens') assert response.status_code == 200 - - # We'll skip the secrets endpoints for now as they require more complex mocking # noqa: E501 - # and they're not directly related to the authentication refactoring diff --git a/tests/unit/test_settings_store_functions.py b/tests/unit/test_settings_store_functions.py index 3e2b866045..256339b124 100644 --- a/tests/unit/test_settings_store_functions.py +++ b/tests/unit/test_settings_store_functions.py @@ -1,17 +1,21 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest +from fastapi.testclient import TestClient from pydantic import SecretStr -from openhands.integrations.provider import ProviderToken, SecretStore +from openhands.integrations.provider import ProviderToken from openhands.integrations.service_types import ProviderType -from openhands.server.routes.settings import ( +from openhands.server.routes.secrets import ( + app, check_provider_tokens, - store_llm_settings, - store_provider_tokens, ) -from openhands.server.settings import POSTSettingsModel +from openhands.server.routes.settings import store_llm_settings +from openhands.server.settings import POSTProviderModel +from openhands.storage import get_file_store from openhands.storage.data_models.settings import Settings +from openhands.storage.data_models.user_secrets import UserSecrets +from openhands.storage.settings.file_secrets_store import FileSecretsStore # Mock functions to simulate the actual functions in settings.py @@ -20,20 +24,49 @@ async def get_settings_store(request): return MagicMock() +@pytest.fixture +def test_client(): + # Create a test client + with ( + patch( + 'openhands.server.routes.secrets.check_provider_tokens', + AsyncMock(return_value=''), + ), + ): + client = TestClient(app) + yield client + + +@pytest.fixture +def temp_dir(tmp_path_factory: pytest.TempPathFactory) -> str: + return str(tmp_path_factory.mktemp('secrets_store')) + + +@pytest.fixture +def file_secrets_store(temp_dir): + file_store = get_file_store('local', temp_dir) + store = FileSecretsStore(file_store) + with patch( + 'openhands.storage.settings.file_secrets_store.FileSecretsStore.get_instance', + AsyncMock(return_value=store), + ): + yield store + + # Tests for check_provider_tokens @pytest.mark.asyncio async def test_check_provider_tokens_valid(): """Test check_provider_tokens with valid tokens.""" provider_token = ProviderToken(token=SecretStr('valid-token')) - settings = POSTSettingsModel(provider_tokens={ProviderType.GITHUB: provider_token}) + providers = POSTProviderModel(provider_tokens={ProviderType.GITHUB: provider_token}) # Mock the validate_provider_token function to return GITHUB for valid tokens with patch( - 'openhands.server.routes.settings.validate_provider_token' + 'openhands.server.routes.secrets.validate_provider_token' ) as mock_validate: mock_validate.return_value = ProviderType.GITHUB - result = await check_provider_tokens(settings) + result = await check_provider_tokens(providers) # Should return empty string for valid token assert result == '' @@ -44,15 +77,15 @@ async def test_check_provider_tokens_valid(): async def test_check_provider_tokens_invalid(): """Test check_provider_tokens with invalid tokens.""" provider_token = ProviderToken(token=SecretStr('invalid-token')) - settings = POSTSettingsModel(provider_tokens={ProviderType.GITHUB: provider_token}) + providers = POSTProviderModel(provider_tokens={ProviderType.GITHUB: provider_token}) # Mock the validate_provider_token function to return None for invalid tokens with patch( - 'openhands.server.routes.settings.validate_provider_token' + 'openhands.server.routes.secrets.validate_provider_token' ) as mock_validate: mock_validate.return_value = None - result = await check_provider_tokens(settings) + result = await check_provider_tokens(providers) # Should return error message for invalid token assert 'Invalid token' in result @@ -64,9 +97,8 @@ async def test_check_provider_tokens_wrong_type(): """Test check_provider_tokens with unsupported provider type.""" # We can't test with an unsupported provider type directly since the model enforces valid types # Instead, we'll test with an empty provider_tokens dictionary - settings = POSTSettingsModel(provider_tokens={}) - - result = await check_provider_tokens(settings) + providers = POSTProviderModel(provider_tokens={}) + result = await check_provider_tokens(providers) # Should return empty string for no providers assert result == '' @@ -75,9 +107,9 @@ async def test_check_provider_tokens_wrong_type(): @pytest.mark.asyncio async def test_check_provider_tokens_no_tokens(): """Test check_provider_tokens with no tokens.""" - settings = POSTSettingsModel(provider_tokens={}) + providers = POSTProviderModel(provider_tokens={}) - result = await check_provider_tokens(settings) + result = await check_provider_tokens(providers) # Should return empty string when no tokens provided assert result == '' @@ -87,7 +119,7 @@ async def test_check_provider_tokens_no_tokens(): @pytest.mark.asyncio async def test_store_llm_settings_new_settings(): """Test store_llm_settings with new settings.""" - settings = POSTSettingsModel( + settings = Settings( llm_model='gpt-4', llm_api_key='test-api-key', llm_base_url='https://api.example.com', @@ -108,7 +140,7 @@ async def test_store_llm_settings_new_settings(): @pytest.mark.asyncio async def test_store_llm_settings_update_existing(): """Test store_llm_settings updates existing settings.""" - settings = POSTSettingsModel( + settings = Settings( llm_model='gpt-4', llm_api_key='new-api-key', llm_base_url='https://new.example.com', @@ -137,7 +169,7 @@ async def test_store_llm_settings_update_existing(): @pytest.mark.asyncio async def test_store_llm_settings_partial_update(): """Test store_llm_settings with partial update.""" - settings = POSTSettingsModel( + settings = Settings( llm_model='gpt-4' # Only updating model ) @@ -164,82 +196,77 @@ async def test_store_llm_settings_partial_update(): # Tests for store_provider_tokens @pytest.mark.asyncio -async def test_store_provider_tokens_new_tokens(): +async def test_store_provider_tokens_new_tokens(test_client, file_secrets_store): """Test store_provider_tokens with new tokens.""" - provider_token = ProviderToken(token=SecretStr('new-token')) - settings = POSTSettingsModel(provider_tokens={ProviderType.GITHUB: provider_token}) + provider_tokens = {'provider_tokens': {'github': {'token': 'new-token'}}} # Mock the settings store mock_store = MagicMock() mock_store.load = AsyncMock(return_value=None) # No existing settings - result = await store_provider_tokens(settings, mock_store) + UserSecrets() + + user_secrets = await file_secrets_store.store(UserSecrets()) + + response = test_client.post('/api/add-git-providers', json=provider_tokens) + assert response.status_code == 200 + + user_secrets = await file_secrets_store.load() - # Should return settings with the provided tokens assert ( - result.provider_tokens[ProviderType.GITHUB].token.get_secret_value() + user_secrets.provider_tokens[ProviderType.GITHUB].token.get_secret_value() == 'new-token' ) @pytest.mark.asyncio -async def test_store_provider_tokens_update_existing(): +async def test_store_provider_tokens_update_existing(test_client, file_secrets_store): """Test store_provider_tokens updates existing tokens.""" - provider_token = ProviderToken(token=SecretStr('updated-token')) - settings = POSTSettingsModel(provider_tokens={ProviderType.GITHUB: provider_token}) - - # Mock the settings store - mock_store = MagicMock() # Create existing settings with a GitHub token github_token = ProviderToken(token=SecretStr('old-token')) provider_tokens = {ProviderType.GITHUB: github_token} - # Create a SecretStore with the provider tokens - secrets_store = SecretStore(provider_tokens=provider_tokens) + # Create a UserSecrets with the provider tokens + user_secrets = UserSecrets(provider_tokens=provider_tokens) - # Create existing settings with the secrets store - existing_settings = Settings(secrets_store=secrets_store) + await file_secrets_store.store(user_secrets) - mock_store.load = AsyncMock(return_value=existing_settings) + response = test_client.post( + '/api/add-git-providers', + json={'provider_tokens': {'github': {'token': 'updated-token'}}}, + ) - result = await store_provider_tokens(settings, mock_store) + assert response.status_code == 200 + + user_secrets = await file_secrets_store.load() - # Should return settings with the updated tokens assert ( - result.provider_tokens[ProviderType.GITHUB].token.get_secret_value() + user_secrets.provider_tokens[ProviderType.GITHUB].token.get_secret_value() == 'updated-token' ) @pytest.mark.asyncio -async def test_store_provider_tokens_keep_existing(): +async def test_store_provider_tokens_keep_existing(test_client, file_secrets_store): """Test store_provider_tokens keeps existing tokens when empty string provided.""" - settings = POSTSettingsModel( - provider_tokens={ - 'github': {'token': ''} - } # Empty string should keep existing token - ) - # Mock the settings store - mock_store = MagicMock() - - # Create existing settings with a GitHub token + # Create existing secrets with a GitHub token github_token = ProviderToken(token=SecretStr('existing-token')) provider_tokens = {ProviderType.GITHUB: github_token} + user_secrets = UserSecrets(provider_tokens=provider_tokens) - # Create a SecretStore with the provider tokens - secrets_store = SecretStore(provider_tokens=provider_tokens) + await file_secrets_store.store(user_secrets) - # Create existing settings with the secrets store - existing_settings = Settings(secrets_store=secrets_store) + response = test_client.post( + '/api/add-git-providers', + json={'provider_tokens': {'github': {'token': ''}}}, + ) + assert response.status_code == 200 - mock_store.load = AsyncMock(return_value=existing_settings) + user_secrets = await file_secrets_store.load() - result = await store_provider_tokens(settings, mock_store) - - # Should return settings with the existing token preserved assert ( - result.provider_tokens[ProviderType.GITHUB].token.get_secret_value() + user_secrets.provider_tokens[ProviderType.GITHUB].token.get_secret_value() == 'existing-token' )