From 767d092f8ffb09451f722875ec4d15c5484d1058 Mon Sep 17 00:00:00 2001 From: Rohit Malhotra Date: Fri, 2 May 2025 11:17:04 -0400 Subject: [PATCH] [Fix]: Use `str` in place of `Repository` for repository param when creating new conversation (#8159) Co-authored-by: openhands Co-authored-by: Engel Nyst --- docs/static/openapi.json | 15 +- .../features/home/home-header.test.tsx | 1 + .../features/home/repo-connector.test.tsx | 8 +- .../features/home/task-card.test.tsx | 3 +- frontend/src/api/open-hands.ts | 8 +- .../hooks/mutation/use-create-conversation.ts | 8 +- openhands/core/setup.py | 29 ++- .../integrations/github/github_service.py | 12 + .../integrations/gitlab/gitlab_service.py | 17 ++ openhands/integrations/provider.py | 19 ++ openhands/integrations/service_types.py | 5 + openhands/runtime/base.py | 53 ++-- .../server/routes/manage_conversations.py | 40 ++- openhands/server/session/agent_session.py | 12 +- .../server/session/conversation_init_data.py | 3 +- openhands/server/user_auth/__init__.py | 7 +- .../server/user_auth/default_user_auth.py | 1 + openhands/server/user_auth/user_auth.py | 9 + .../data_models/conversation_metadata.py | 1 + tests/runtime/test_setup.py | 19 +- tests/unit/test_conversation.py | 205 +++++++++++---- tests/unit/test_runtime_git_tokens.py | 239 +++++++++++++++++- 22 files changed, 579 insertions(+), 135 deletions(-) diff --git a/docs/static/openapi.json b/docs/static/openapi.json index 3ecf5e25e0..a3e6c0ccda 100644 --- a/docs/static/openapi.json +++ b/docs/static/openapi.json @@ -858,14 +858,15 @@ "schema": { "type": "object", "properties": { - "selected_repository": { - "type": "object", + "repository": { + "type": "string", "nullable": true, - "properties": { - "full_name": { - "type": "string" - } - } + "description": "Full name of the repository (e.g., owner/repo)" + }, + "git_provider": { + "type": "string", + "nullable": true, + "description": "The Git provider (e.g., github or gitlab). If omitted, all configured providers are checked for the repository." }, "selected_branch": { "type": "string", diff --git a/frontend/__tests__/components/features/home/home-header.test.tsx b/frontend/__tests__/components/features/home/home-header.test.tsx index bedb28b94a..9ed0872767 100644 --- a/frontend/__tests__/components/features/home/home-header.test.tsx +++ b/frontend/__tests__/components/features/home/home-header.test.tsx @@ -49,6 +49,7 @@ describe("HomeHeader", () => { "gui", undefined, undefined, + undefined, [], undefined, undefined, diff --git a/frontend/__tests__/components/features/home/repo-connector.test.tsx b/frontend/__tests__/components/features/home/repo-connector.test.tsx index 59bb0d19ac..cc3d39b5a7 100644 --- a/frontend/__tests__/components/features/home/repo-connector.test.tsx +++ b/frontend/__tests__/components/features/home/repo-connector.test.tsx @@ -165,12 +165,8 @@ describe("RepoConnector", () => { expect(createConversationSpy).toHaveBeenCalledExactlyOnceWith( "gui", - { - full_name: "rbren/polaris", - git_provider: "github", - id: 1, - is_public: true, - }, + "rbren/polaris", + "github", undefined, [], undefined, diff --git a/frontend/__tests__/components/features/home/task-card.test.tsx b/frontend/__tests__/components/features/home/task-card.test.tsx index d0cf98fa72..2523e2f3c9 100644 --- a/frontend/__tests__/components/features/home/task-card.test.tsx +++ b/frontend/__tests__/components/features/home/task-card.test.tsx @@ -89,7 +89,8 @@ describe("TaskCard", () => { expect(createConversationSpy).toHaveBeenCalledWith( "suggested_task", - MOCK_RESPOSITORIES[0], + MOCK_RESPOSITORIES[0].full_name, + MOCK_RESPOSITORIES[0].git_provider, undefined, [], undefined, diff --git a/frontend/src/api/open-hands.ts b/frontend/src/api/open-hands.ts index 827ac70621..1794587a5b 100644 --- a/frontend/src/api/open-hands.ts +++ b/frontend/src/api/open-hands.ts @@ -13,7 +13,7 @@ import { ConversationTrigger, } from "./open-hands.types"; import { openHands } from "./open-hands-axios"; -import { ApiSettings, PostApiSettings } from "#/types/settings"; +import { ApiSettings, PostApiSettings, Provider } from "#/types/settings"; import { GitUser, GitRepository } from "#/types/git"; import { SuggestedTask } from "#/components/features/home/tasks/task.types"; @@ -152,7 +152,8 @@ class OpenHands { static async createConversation( conversation_trigger: ConversationTrigger = "gui", - selectedRepository?: GitRepository, + selectedRepository?: string, + git_provider?: Provider, initialUserMsg?: string, imageUrls?: string[], replayJson?: string, @@ -160,7 +161,8 @@ class OpenHands { ): Promise { const body = { conversation_trigger, - selected_repository: selectedRepository, + repository: selectedRepository, + git_provider, selected_branch: undefined, initial_user_msg: initialUserMsg, image_urls: imageUrls, diff --git a/frontend/src/hooks/mutation/use-create-conversation.ts b/frontend/src/hooks/mutation/use-create-conversation.ts index 072a16e020..9225282573 100644 --- a/frontend/src/hooks/mutation/use-create-conversation.ts +++ b/frontend/src/hooks/mutation/use-create-conversation.ts @@ -24,13 +24,19 @@ export const useCreateConversation = () => { conversation_trigger: ConversationTrigger; q?: string; selectedRepository?: GitRepository | null; + suggested_task?: SuggestedTask; }) => { if (variables.q) dispatch(setInitialPrompt(variables.q)); return OpenHands.createConversation( variables.conversation_trigger, - variables.selectedRepository || undefined, + variables.selectedRepository + ? variables.selectedRepository.full_name + : undefined, + variables.selectedRepository + ? variables.selectedRepository.git_provider + : undefined, variables.q, files, replayJson || undefined, diff --git a/openhands/core/setup.py b/openhands/core/setup.py index 2ebd2ba524..3d932905f6 100644 --- a/openhands/core/setup.py +++ b/openhands/core/setup.py @@ -85,40 +85,41 @@ def create_runtime( def initialize_repository_for_runtime( - runtime: Runtime, - selected_repository: str | None = None, - github_token: SecretStr | None = None, + runtime: Runtime, selected_repository: str | None = None ) -> str | None: """Initialize the repository for the runtime. Args: runtime: The runtime to initialize the repository for. selected_repository: (optional) The GitHub repository to use. - github_token: (optional) The GitHub token to use. Returns: The repository directory path if a repository was cloned, None otherwise. """ # clone selected repository if provided - if github_token is None and 'GITHUB_TOKEN' in os.environ: + provider_tokens = {} + if 'GITHUB_TOKEN' in os.environ: github_token = SecretStr(os.environ['GITHUB_TOKEN']) + provider_tokens[ProviderType.GITHUB] = ProviderToken( + token=SecretStr(github_token) + ) + + if 'GITLAB_TOKEN' in os.environ: + gitlab_token = SecretStr(os.environ['GITLAB_TOKEN']) + provider_tokens[ProviderType.GITLAB] = ProviderToken( + token=SecretStr(gitlab_token) + ) secret_store = ( - SecretStore( - provider_tokens={ - ProviderType.GITHUB: ProviderToken(token=SecretStr(github_token)) - } - ) - if github_token - else None + SecretStore(provider_tokens=provider_tokens) if provider_tokens else None ) - provider_tokens = secret_store.provider_tokens if secret_store else None + immutable_provider_tokens = secret_store.provider_tokens if secret_store else None logger.debug(f'Selected repository {selected_repository}.') repo_directory = call_async_from_sync( runtime.clone_or_init_repo, GENERAL_TIMEOUT, - provider_tokens, + immutable_provider_tokens, selected_repository, None, ) diff --git a/openhands/integrations/github/github_service.py b/openhands/integrations/github/github_service.py index fd09d86b4e..cc60094e27 100644 --- a/openhands/integrations/github/github_service.py +++ b/openhands/integrations/github/github_service.py @@ -390,6 +390,18 @@ class GitHubService(BaseGitService, GitService): except Exception: return [] + async def get_repository_details_from_repo_name(self, repository: str) -> Repository: + url = f'{self.BASE_URL}/repos/{repository}' + repo, _ = await self._make_request(url) + + return Repository( + id=repo.get('id'), + full_name=repo.get('full_name'), + stargazers_count=repo.get('stargazers_count'), + git_provider=ProviderType.GITHUB, + is_public=not repo.get('private', True), + ) + github_service_cls = os.environ.get( 'OPENHANDS_GITHUB_SERVICE_CLS', diff --git a/openhands/integrations/gitlab/gitlab_service.py b/openhands/integrations/gitlab/gitlab_service.py index 875c3006d6..5d28c970c6 100644 --- a/openhands/integrations/gitlab/gitlab_service.py +++ b/openhands/integrations/gitlab/gitlab_service.py @@ -383,6 +383,23 @@ class GitLabService(BaseGitService, GitService): return [] + async def get_repository_details_from_repo_name(self, repository: str) -> Repository: + encoded_name = repository.replace("/", "%2F") + + url = f'{self.BASE_URL}/projects/{encoded_name}' + repo, _ = await self._make_request(url) + + return Repository( + id=repo.get('id'), + full_name=repo.get('path_with_namespace'), + stargazers_count=repo.get('star_count'), + git_provider=ProviderType.GITLAB, + is_public=repo.get('visibility') == 'public', + ) + + + + gitlab_service_cls = os.environ.get( 'OPENHANDS_GITLAB_SERVICE_CLS', 'openhands.integrations.gitlab.gitlab_service.GitLabService', diff --git a/openhands/integrations/provider.py b/openhands/integrations/provider.py index 7639919c66..d651337895 100644 --- a/openhands/integrations/provider.py +++ b/openhands/integrations/provider.py @@ -397,3 +397,22 @@ class ProviderHandler: Map ProviderType value to the environment variable name in the runtime """ return f'{provider.value}_token'.lower() + + async def verify_repo_provider( + self, repository: str, specified_provider: ProviderType | None = None + ): + if specified_provider: + try: + service = self._get_service(specified_provider) + return await service.get_repository_details_from_repo_name(repository) + except Exception: + pass + + for provider in self.provider_tokens: + try: + service = self._get_service(provider) + return await service.get_repository_details_from_repo_name(repository) + except Exception: + pass + + raise AuthenticationError(f'Unable to access repo {repository}') diff --git a/openhands/integrations/service_types.py b/openhands/integrations/service_types.py index 39e7bb4352..6e5cc8dfa4 100644 --- a/openhands/integrations/service_types.py +++ b/openhands/integrations/service_types.py @@ -206,3 +206,8 @@ class GitService(Protocol): async def get_suggested_tasks(self) -> list[SuggestedTask]: """Get suggested tasks for the authenticated user across all repositories""" ... + + async def get_repository_details_from_repo_name( + self, repository: str + ) -> Repository: + """Gets all repository details from repository name""" diff --git a/openhands/runtime/base.py b/openhands/runtime/base.py index 00ca450aa5..2d8f03b4b9 100644 --- a/openhands/runtime/base.py +++ b/openhands/runtime/base.py @@ -47,7 +47,7 @@ from openhands.integrations.provider import ( ProviderHandler, ProviderType, ) -from openhands.integrations.service_types import Repository +from openhands.integrations.service_types import AuthenticationError from openhands.microagent import ( BaseMicroagent, load_microagents_from_dir, @@ -311,10 +311,23 @@ class Runtime(FileEditRuntimeMixin): async def clone_or_init_repo( self, git_provider_tokens: PROVIDER_TOKEN_TYPE | None, - selected_repository: str | Repository | None, + selected_repository: str | None, selected_branch: str | None, - repository_provider: ProviderType = ProviderType.GITHUB, ) -> str: + repository = None + if selected_repository: # Determine provider from repo name + try: + provider_handler = ProviderHandler( + git_provider_tokens or MappingProxyType({}) + ) + repository = await provider_handler.verify_repo_provider( + selected_repository + ) + except AuthenticationError: + raise RuntimeError( + 'Git provider authentication issue when cloning repo' + ) + if not selected_repository: # In SaaS mode (indicated by user_id being set), always run git init # In OSS mode, only run git init if workspace_base is not set @@ -332,36 +345,30 @@ class Runtime(FileEditRuntimeMixin): ) return '' + # This satisfies mypy because param is optional, but `verify_repo_provider` guarentees this gets populated + if not repository: + return '' + + provider = repository.git_provider provider_domains = { ProviderType.GITHUB: 'github.com', ProviderType.GITLAB: 'gitlab.com', } - chosen_provider = ( - repository_provider - if isinstance(selected_repository, str) - else selected_repository.git_provider - ) - - domain = provider_domains[chosen_provider] - repository = ( - selected_repository - if isinstance(selected_repository, str) - else selected_repository.full_name - ) + domain = provider_domains[provider] # Try to use token if available, otherwise use public URL - if git_provider_tokens and chosen_provider in git_provider_tokens: - git_token = git_provider_tokens[chosen_provider].token + if git_provider_tokens and provider in git_provider_tokens: + git_token = git_provider_tokens[provider].token if git_token: - if chosen_provider == ProviderType.GITLAB: - remote_repo_url = f'https://oauth2:{git_token.get_secret_value()}@{domain}/{repository}.git' + if provider == ProviderType.GITLAB: + remote_repo_url = f'https://oauth2:{git_token.get_secret_value()}@{domain}/{selected_repository}.git' else: - remote_repo_url = f'https://{git_token.get_secret_value()}@{domain}/{repository}.git' + remote_repo_url = f'https://{git_token.get_secret_value()}@{domain}/{selected_repository}.git' else: - remote_repo_url = f'https://{domain}/{repository}.git' + remote_repo_url = f'https://{domain}/{selected_repository}.git' else: - remote_repo_url = f'https://{domain}/{repository}.git' + remote_repo_url = f'https://{domain}/{selected_repository}.git' if not remote_repo_url: raise ValueError('Missing either Git token or valid repository') @@ -371,7 +378,7 @@ class Runtime(FileEditRuntimeMixin): 'info', 'STATUS$SETTING_UP_WORKSPACE', 'Setting up workspace...' ) - dir_name = repository.split('/')[-1] + dir_name = selected_repository.split('/')[-1] # Generate a random branch name to avoid conflicts random_str = ''.join( diff --git a/openhands/server/routes/manage_conversations.py b/openhands/server/routes/manage_conversations.py index a678c75a02..baa5fb6dbd 100644 --- a/openhands/server/routes/manage_conversations.py +++ b/openhands/server/routes/manage_conversations.py @@ -12,8 +12,9 @@ from openhands.events.event import EventSource from openhands.events.stream import EventStream from openhands.integrations.provider import ( PROVIDER_TOKEN_TYPE, + ProviderHandler, ) -from openhands.integrations.service_types import Repository, SuggestedTask +from openhands.integrations.service_types import AuthenticationError, ProviderType, Repository, SuggestedTask from openhands.runtime import get_runtime_cls from openhands.server.data_models.conversation_info import ConversationInfo from openhands.server.data_models.conversation_info_result_set import ( @@ -29,9 +30,11 @@ from openhands.server.shared import ( ) from openhands.server.types import LLMAuthenticationError, MissingSettingsError from openhands.server.user_auth import ( + get_auth_type, get_provider_tokens, get_user_id, ) +from openhands.server.user_auth.user_auth import AuthType from openhands.server.utils import get_conversation_store from openhands.storage.conversation.conversation_store import ConversationStore from openhands.storage.data_models.conversation_metadata import ( @@ -48,7 +51,8 @@ app = APIRouter(prefix='/api') class InitSessionRequest(BaseModel): conversation_trigger: ConversationTrigger = ConversationTrigger.GUI - selected_repository: Repository | None = None + repository: str | None = None + git_provider: ProviderType | None = None selected_branch: str | None = None initial_user_msg: str | None = None image_urls: list[str] | None = None @@ -59,7 +63,7 @@ class InitSessionRequest(BaseModel): async def _create_new_conversation( user_id: str | None, git_provider_tokens: PROVIDER_TOKEN_TYPE | None, - selected_repository: Repository | None, + selected_repository: str | None, selected_branch: str | None, initial_user_msg: str | None, image_urls: list[str] | None, @@ -67,7 +71,7 @@ async def _create_new_conversation( conversation_trigger: ConversationTrigger = ConversationTrigger.GUI, attach_convo_id: bool = False, ): - print("trigger", conversation_trigger) + logger.info( 'Creating conversation', extra={'signal': 'create_conversation', 'user_id': user_id, 'trigger': conversation_trigger.value}, @@ -122,9 +126,7 @@ async def _create_new_conversation( title=conversation_title, user_id=user_id, github_user_id=None, - selected_repository=selected_repository.full_name - if selected_repository - else selected_repository, + selected_repository=selected_repository, selected_branch=selected_branch, ) ) @@ -161,6 +163,7 @@ async def new_conversation( data: InitSessionRequest, user_id: str = Depends(get_user_id), provider_tokens: PROVIDER_TOKEN_TYPE = Depends(get_provider_tokens), + auth_type: AuthType | None = Depends(get_auth_type) ): """Initialize a new session or join an existing one. @@ -168,24 +171,33 @@ async def new_conversation( using the returned conversation ID. """ logger.info('Initializing new conversation') - selected_repository = data.selected_repository + repository = data.repository selected_branch = data.selected_branch initial_user_msg = data.initial_user_msg image_urls = data.image_urls or [] replay_json = data.replay_json suggested_task = data.suggested_task conversation_trigger = data.conversation_trigger + git_provider = data.git_provider if suggested_task: initial_user_msg = suggested_task.get_prompt_for_task() conversation_trigger = ConversationTrigger.SUGGESTED_TASK + if auth_type == AuthType.BEARER: + conversation_trigger = ConversationTrigger.REMOTE_API_KEY + try: + if repository: + provider_handler = ProviderHandler(provider_tokens) + # Check against git_provider, otherwise check all provider apis + await provider_handler.verify_repo_provider(repository, git_provider) + # Create conversation with initial message conversation_id = await _create_new_conversation( user_id=user_id, git_provider_tokens=provider_tokens, - selected_repository=selected_repository, + selected_repository=repository, selected_branch=selected_branch, initial_user_msg=initial_user_msg, image_urls=image_urls, @@ -215,6 +227,16 @@ async def new_conversation( }, status_code=status.HTTP_400_BAD_REQUEST, ) + + except AuthenticationError as e: + return JSONResponse( + content={ + 'status': 'error', + 'message': str(e), + 'msg_id': 'STATUS$GIT_PROVIDER_AUTHENTICATION_ERROR' + }, + status_code=status.HTTP_400_BAD_REQUEST, + ) @app.get('/conversations') diff --git a/openhands/server/session/agent_session.py b/openhands/server/session/agent_session.py index bd58a9b7c4..0cdd7d7d5e 100644 --- a/openhands/server/session/agent_session.py +++ b/openhands/server/session/agent_session.py @@ -86,7 +86,7 @@ class AgentSession: max_budget_per_task: float | None = None, agent_to_llm_config: dict[str, LLMConfig] | None = None, agent_configs: dict[str, AgentConfig] | None = None, - selected_repository: Repository | None = None, + selected_repository: str | None = None, selected_branch: str | None = None, initial_message: MessageAction | None = None, replay_json: str | None = None, @@ -153,7 +153,7 @@ class AgentSession: repo_directory = None if self.runtime and runtime_connected and selected_repository: - repo_directory = selected_repository.full_name.split('/')[-1] + repo_directory = selected_repository.split('/')[-1] self.memory = await self._create_memory( selected_repository=selected_repository, @@ -265,7 +265,7 @@ class AgentSession: config: AppConfig, agent: Agent, git_provider_tokens: PROVIDER_TOKEN_TYPE | None = None, - selected_repository: Repository | None = None, + selected_repository: str | None = None, selected_branch: str | None = None, ) -> bool: """Creates a runtime instance @@ -400,7 +400,7 @@ class AgentSession: return controller async def _create_memory( - self, selected_repository: Repository | None, repo_directory: str | None + self, selected_repository: str | None, repo_directory: str | None ) -> Memory: memory = Memory( event_stream=self.event_stream, @@ -415,13 +415,13 @@ class AgentSession: # loads microagents from repo/.openhands/microagents microagents: list[BaseMicroagent] = await call_sync_from_async( self.runtime.get_microagents_from_selected_repo, - selected_repository.full_name if selected_repository else None, + selected_repository or None, ) memory.load_user_workspace_microagents(microagents) if selected_repository and repo_directory: memory.set_repository_info( - selected_repository.full_name, repo_directory + selected_repository, repo_directory ) return memory diff --git a/openhands/server/session/conversation_init_data.py b/openhands/server/session/conversation_init_data.py index 12bde488fc..f8237f7ef0 100644 --- a/openhands/server/session/conversation_init_data.py +++ b/openhands/server/session/conversation_init_data.py @@ -1,7 +1,6 @@ from pydantic import Field from openhands.integrations.provider import PROVIDER_TOKEN_TYPE -from openhands.integrations.service_types import Repository from openhands.storage.data_models.settings import Settings @@ -11,7 +10,7 @@ class ConversationInitData(Settings): """ git_provider_tokens: PROVIDER_TOKEN_TYPE | None = Field(default=None, frozen=True) - selected_repository: Repository | None = Field(default=None) + selected_repository: str | None = Field(default=None) replay_json: str | None = Field(default=None) selected_branch: str | None = Field(default=None) diff --git a/openhands/server/user_auth/__init__.py b/openhands/server/user_auth/__init__.py index 2b02c51af7..091f8ad26b 100644 --- a/openhands/server/user_auth/__init__.py +++ b/openhands/server/user_auth/__init__.py @@ -4,7 +4,7 @@ from pydantic import SecretStr from openhands.integrations.provider import PROVIDER_TOKEN_TYPE from openhands.integrations.service_types import ProviderType from openhands.server.settings import Settings -from openhands.server.user_auth.user_auth import get_user_auth +from openhands.server.user_auth.user_auth import AuthType, get_user_auth from openhands.storage.settings.settings_store import SettingsStore @@ -46,3 +46,8 @@ async def get_user_settings_store(request: Request) -> SettingsStore | None: user_auth = await get_user_auth(request) user_settings_store = await user_auth.get_user_settings_store() return user_settings_store + + +async def get_auth_type(request: Request) -> AuthType | None: + user_auth = await get_user_auth(request) + return user_auth.get_auth_type() \ No newline at end of file diff --git a/openhands/server/user_auth/default_user_auth.py b/openhands/server/user_auth/default_user_auth.py index e46880cb34..9ac2acebf9 100644 --- a/openhands/server/user_auth/default_user_auth.py +++ b/openhands/server/user_auth/default_user_auth.py @@ -51,6 +51,7 @@ class DefaultUserAuth(UserAuth): provider_tokens = getattr(secrets_store, 'provider_tokens', None) return provider_tokens + @classmethod async def get_instance(cls, request: Request) -> UserAuth: user_auth = DefaultUserAuth() diff --git a/openhands/server/user_auth/user_auth.py b/openhands/server/user_auth/user_auth.py index a565d65d84..2902cfc4ae 100644 --- a/openhands/server/user_auth/user_auth.py +++ b/openhands/server/user_auth/user_auth.py @@ -1,6 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod +from enum import Enum from fastapi import Request from pydantic import SecretStr @@ -12,6 +13,11 @@ from openhands.storage.settings.settings_store import SettingsStore from openhands.utils.import_utils import get_impl +class AuthType(Enum): + COOKIE = "cookie" + BEARER = "bearer" + + class UserAuth(ABC): """Extensible class encapsulating user Authentication""" @@ -45,6 +51,9 @@ class UserAuth(ABC): self._settings = settings return settings + def get_auth_type(self) -> AuthType | None: + return None + @classmethod @abstractmethod async def get_instance(cls, request: Request) -> UserAuth: diff --git a/openhands/storage/data_models/conversation_metadata.py b/openhands/storage/data_models/conversation_metadata.py index bfe476369e..041e25c570 100644 --- a/openhands/storage/data_models/conversation_metadata.py +++ b/openhands/storage/data_models/conversation_metadata.py @@ -7,6 +7,7 @@ class ConversationTrigger(Enum): RESOLVER = 'resolver' GUI = 'gui' SUGGESTED_TASK = 'suggested_task' + REMOTE_API_KEY = 'openhands_api' @dataclass diff --git a/tests/runtime/test_setup.py b/tests/runtime/test_setup.py index 6639ceeb09..52dfd9ff36 100644 --- a/tests/runtime/test_setup.py +++ b/tests/runtime/test_setup.py @@ -1,5 +1,7 @@ """Tests for the setup script.""" +from unittest.mock import patch + from conftest import ( _load_runtime, ) @@ -7,14 +9,27 @@ from conftest import ( from openhands.core.setup import initialize_repository_for_runtime from openhands.events.action import FileReadAction, FileWriteAction from openhands.events.observation import FileReadObservation, FileWriteObservation +from openhands.integrations.service_types import ProviderType, Repository def test_initialize_repository_for_runtime(temp_dir, runtime_cls, run_as_openhands): """Test that the initialize_repository_for_runtime function works.""" runtime, config = _load_runtime(temp_dir, runtime_cls, run_as_openhands) - repository_dir = initialize_repository_for_runtime( - runtime, 'https://github.com/All-Hands-AI/OpenHands' + mock_repo = Repository( + id=1232, + full_name='All-Hands-AI/OpenHands', + git_provider=ProviderType.GITHUB, + is_public=True, ) + + with patch( + 'openhands.runtime.base.ProviderHandler.verify_repo_provider', + return_value=mock_repo, + ): + repository_dir = initialize_repository_for_runtime( + runtime, selected_repository='All-Hands-AI/OpenHands' + ) + assert repository_dir is not None assert repository_dir == 'OpenHands' diff --git a/tests/unit/test_conversation.py b/tests/unit/test_conversation.py index c6d101ff78..e09113d3d9 100644 --- a/tests/unit/test_conversation.py +++ b/tests/unit/test_conversation.py @@ -1,14 +1,15 @@ import json from contextlib import contextmanager from datetime import datetime, timezone +from types import MappingProxyType from unittest.mock import AsyncMock, MagicMock, patch import pytest from fastapi.responses import JSONResponse from openhands.integrations.service_types import ( + AuthenticationError, ProviderType, - Repository, SuggestedTask, TaskType, ) @@ -25,6 +26,7 @@ from openhands.server.routes.manage_conversations import ( update_conversation, ) from openhands.server.types import LLMAuthenticationError, MissingSettingsError +from openhands.server.user_auth.user_auth import AuthType from openhands.storage.data_models.conversation_metadata import ( ConversationMetadata, ConversationTrigger, @@ -62,6 +64,28 @@ def _patch_store(): yield +def create_new_test_conversation( + test_request: InitSessionRequest, auth_type: AuthType | None = None +): + return new_conversation( + data=test_request, + user_id='test_user', + provider_tokens=MappingProxyType({'github': 'token123'}), + auth_type=auth_type, + ) + + +@pytest.fixture +def provider_handler_mock(): + with patch( + 'openhands.server.routes.manage_conversations.ProviderHandler' + ) as mock_cls: + mock_instance = MagicMock() + mock_instance.verify_repo_provider = AsyncMock(return_value=ProviderType.GITHUB) + mock_cls.return_value = mock_instance + yield mock_instance + + @pytest.mark.asyncio async def test_search_conversations(): with _patch_store(): @@ -232,7 +256,7 @@ async def test_update_conversation(): @pytest.mark.asyncio -async def test_new_conversation_success(): +async def test_new_conversation_success(provider_handler_mock): """Test successful creation of a new conversation.""" with _patch_store(): # Mock the _create_new_conversation function directly @@ -242,26 +266,16 @@ async def test_new_conversation_success(): # Set up the mock to return a conversation ID mock_create_conversation.return_value = 'test_conversation_id' - # Create test data - test_repo = Repository( - id=12345, - full_name='test/repo', - git_provider=ProviderType.GITHUB, - is_public=True, - ) - test_request = InitSessionRequest( conversation_trigger=ConversationTrigger.GUI, - selected_repository=test_repo, + repository='test/repo', selected_branch='main', initial_user_msg='Hello, agent!', image_urls=['https://example.com/image.jpg'], ) # Call new_conversation - response = await new_conversation( - data=test_request, user_id='test_user', provider_tokens={} - ) + response = await create_new_test_conversation(test_request) # Verify the response assert isinstance(response, JSONResponse) @@ -275,7 +289,7 @@ async def test_new_conversation_success(): mock_create_conversation.assert_called_once() call_args = mock_create_conversation.call_args[1] assert call_args['user_id'] == 'test_user' - assert call_args['selected_repository'] == test_repo + assert call_args['selected_repository'] == 'test/repo' assert call_args['selected_branch'] == 'main' assert call_args['initial_user_msg'] == 'Hello, agent!' assert call_args['image_urls'] == ['https://example.com/image.jpg'] @@ -283,7 +297,7 @@ async def test_new_conversation_success(): @pytest.mark.asyncio -async def test_new_conversation_with_suggested_task(): +async def test_new_conversation_with_suggested_task(provider_handler_mock): """Test creating a new conversation with a suggested task.""" with _patch_store(): # Mock the _create_new_conversation function directly @@ -301,14 +315,6 @@ async def test_new_conversation_with_suggested_task(): 'Please fix the failing checks in PR #123' ) - # Create test data - test_repo = Repository( - id=12345, - full_name='test/repo', - git_provider=ProviderType.GITHUB, - is_public=True, - ) - test_task = SuggestedTask( git_provider=ProviderType.GITHUB, task_type=TaskType.FAILING_CHECKS, @@ -319,15 +325,13 @@ async def test_new_conversation_with_suggested_task(): test_request = InitSessionRequest( conversation_trigger=ConversationTrigger.SUGGESTED_TASK, - selected_repository=test_repo, + repository='test/repo', selected_branch='main', suggested_task=test_task, ) # Call new_conversation - response = await new_conversation( - data=test_request, user_id='test_user', provider_tokens={} - ) + response = await create_new_test_conversation(test_request) # Verify the response assert isinstance(response, JSONResponse) @@ -341,7 +345,7 @@ async def test_new_conversation_with_suggested_task(): mock_create_conversation.assert_called_once() call_args = mock_create_conversation.call_args[1] assert call_args['user_id'] == 'test_user' - assert call_args['selected_repository'] == test_repo + assert call_args['selected_repository'] == 'test/repo' assert call_args['selected_branch'] == 'main' assert ( call_args['initial_user_msg'] @@ -357,7 +361,7 @@ async def test_new_conversation_with_suggested_task(): @pytest.mark.asyncio -async def test_new_conversation_missing_settings(): +async def test_new_conversation_missing_settings(provider_handler_mock): """Test creating a new conversation when settings are missing.""" with _patch_store(): # Mock the _create_new_conversation function to raise MissingSettingsError @@ -369,25 +373,15 @@ async def test_new_conversation_missing_settings(): 'Settings not found' ) - # Create test data - test_repo = Repository( - id=12345, - full_name='test/repo', - git_provider=ProviderType.GITHUB, - is_public=True, - ) - test_request = InitSessionRequest( conversation_trigger=ConversationTrigger.GUI, - selected_repository=test_repo, + repository='test/repo', selected_branch='main', initial_user_msg='Hello, agent!', ) # Call new_conversation - response = await new_conversation( - data=test_request, user_id='test_user', provider_tokens={} - ) + response = await create_new_test_conversation(test_request) # Verify the response assert isinstance(response, JSONResponse) @@ -409,17 +403,9 @@ async def test_new_conversation_invalid_api_key(): 'Error authenticating with the LLM provider. Please check your API key' ) - # Create test data - test_repo = Repository( - id=12345, - full_name='test/repo', - git_provider=ProviderType.GITHUB, - is_public=True, - ) - test_request = InitSessionRequest( conversation_trigger=ConversationTrigger.GUI, - selected_repository=test_repo, + repo='test/repo', selected_branch='main', initial_user_msg='Hello, agent!', ) @@ -496,3 +482,120 @@ async def test_delete_conversation(): mock_runtime_cls.delete.assert_called_once_with( 'some_conversation_id' ) + + +@pytest.mark.asyncio +async def test_new_conversation_with_bearer_auth(provider_handler_mock): + """Test creating a new conversation with bearer authentication.""" + with _patch_store(): + # Mock the _create_new_conversation function + with patch( + 'openhands.server.routes.manage_conversations._create_new_conversation' + ) as mock_create_conversation: + # Set up the mock to return a conversation ID + mock_create_conversation.return_value = 'test_conversation_id' + + # Create the request object + test_request = InitSessionRequest( + conversation_trigger=ConversationTrigger.GUI, # This should be overridden + repository='test/repo', + selected_branch='main', + initial_user_msg='Hello, agent!', + ) + + # Call new_conversation with auth_type=BEARER + response = await create_new_test_conversation(test_request, AuthType.BEARER) + + # Verify the response + assert isinstance(response, JSONResponse) + assert response.status_code == 200 + + # Verify that _create_new_conversation was called with REMOTE_API_KEY trigger + mock_create_conversation.assert_called_once() + call_args = mock_create_conversation.call_args[1] + assert ( + call_args['conversation_trigger'] == ConversationTrigger.REMOTE_API_KEY + ) + + +@pytest.mark.asyncio +async def test_new_conversation_with_null_repository(): + """Test creating a new conversation with null repository.""" + with _patch_store(): + # Mock the _create_new_conversation function + with patch( + 'openhands.server.routes.manage_conversations._create_new_conversation' + ) as mock_create_conversation: + # Set up the mock to return a conversation ID + mock_create_conversation.return_value = 'test_conversation_id' + + # Create the request object with null repository + test_request = InitSessionRequest( + conversation_trigger=ConversationTrigger.GUI, + repository=None, # Explicitly set to None + selected_branch=None, + initial_user_msg='Hello, agent!', + ) + + # Call new_conversation + response = await create_new_test_conversation(test_request) + + # Verify the response + assert isinstance(response, JSONResponse) + assert response.status_code == 200 + + # Verify that _create_new_conversation was called with None repository + mock_create_conversation.assert_called_once() + call_args = mock_create_conversation.call_args[1] + assert call_args['selected_repository'] is None + + +@pytest.mark.asyncio +async def test_new_conversation_with_provider_authentication_error( + provider_handler_mock, +): + provider_handler_mock.verify_repo_provider = AsyncMock( + side_effect=AuthenticationError('auth error') + ) + + """Test creating a new conversation when provider authentication fails.""" + with _patch_store(): + # Mock the _create_new_conversation function + with patch( + 'openhands.server.routes.manage_conversations._create_new_conversation' + ) as mock_create_conversation: + # Set up the mock to return a conversation ID + mock_create_conversation.return_value = 'test_conversation_id' + + # Create the request object + test_request = InitSessionRequest( + conversation_trigger=ConversationTrigger.GUI, + repository='test/repo', + selected_branch='main', + initial_user_msg='Hello, agent!', + ) + + # Call new_conversation + response = await new_conversation( + data=test_request, + user_id='test_user', + provider_tokens={'github': 'token123'}, + auth_type=None, + ) + + # Verify the response + assert isinstance(response, JSONResponse) + assert response.status_code == 400 + assert json.loads(response.body.decode('utf-8')) == { + 'status': 'error', + 'message': 'auth error', + 'msg_id': 'STATUS$GIT_PROVIDER_AUTHENTICATION_ERROR', + } + + # Verify that verify_repo_provider was called with the repository + provider_handler_mock.verify_repo_provider.assert_called_once_with( + 'test/repo', None + ) + + # Verify that _create_new_conversation was not called + mock_create_conversation.assert_not_called() diff --git a/tests/unit/test_runtime_git_tokens.py b/tests/unit/test_runtime_git_tokens.py index 9d07ace929..690d05b699 100644 --- a/tests/unit/test_runtime_git_tokens.py +++ b/tests/unit/test_runtime_git_tokens.py @@ -1,4 +1,5 @@ from types import MappingProxyType +from unittest.mock import MagicMock, patch import pytest from pydantic import SecretStr @@ -8,7 +9,8 @@ from openhands.events.action import Action from openhands.events.action.commands import CmdRunAction from openhands.events.observation import NullObservation, Observation from openhands.events.stream import EventStream -from openhands.integrations.provider import ProviderToken, ProviderType +from openhands.integrations.provider import ProviderHandler, ProviderToken, ProviderType +from openhands.integrations.service_types import AuthenticationError, Repository from openhands.runtime.base import Runtime from openhands.storage import get_file_store @@ -16,6 +18,13 @@ from openhands.storage import get_file_store class TestRuntime(Runtime): """A concrete implementation of Runtime for testing""" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.run_action_calls = [] + self._execute_shell_fn_git_handler = MagicMock( + return_value=MagicMock(exit_code=0, stdout='', stderr='') + ) + async def connect(self): pass @@ -23,22 +32,22 @@ class TestRuntime(Runtime): pass def browse(self, action): - return NullObservation() + return NullObservation(content='') def browse_interactive(self, action): - return NullObservation() + return NullObservation(content='') def run(self, action): - return NullObservation() + return NullObservation(content='') def run_ipython(self, action): - return NullObservation() + return NullObservation(content='') def read(self, action): - return NullObservation() + return NullObservation(content='') def write(self, action): - return NullObservation() + return NullObservation(content='') def copy_from(self, path): return '' @@ -50,10 +59,11 @@ class TestRuntime(Runtime): return [] def run_action(self, action: Action) -> Observation: - return NullObservation() + self.run_action_calls.append(action) + return NullObservation(content='') def call_tool_mcp(self, action): - return NullObservation() + return NullObservation(content='') @pytest.fixture @@ -80,6 +90,20 @@ def runtime(temp_dir): return runtime +def mock_repo_and_patch(monkeypatch, provider=ProviderType.GITHUB, is_public=True): + repo = Repository( + id=123, full_name='owner/repo', git_provider=provider, is_public=is_public + ) + + async def mock_verify_repo_provider(*_args, **_kwargs): + return repo + + monkeypatch.setattr( + ProviderHandler, 'verify_repo_provider', mock_verify_repo_provider + ) + return repo + + @pytest.mark.asyncio async def test_export_latest_git_provider_tokens_no_user_id(temp_dir): """Test that no token export happens when user_id is not set""" @@ -183,3 +207,200 @@ async def test_export_latest_git_provider_tokens_token_update(runtime): # Verify that the new token was exported assert runtime.event_stream.secrets == {'github_token': new_token} + + +@pytest.mark.asyncio +async def test_clone_or_init_repo_no_repo_with_user_id(temp_dir): + """Test that git init is run when no repository is selected and user_id is set""" + config = AppConfig() + file_store = get_file_store('local', temp_dir) + event_stream = EventStream('abc', file_store) + runtime = TestRuntime( + config=config, event_stream=event_stream, sid='test', user_id='test_user' + ) + + # Call the function with no repository + result = await runtime.clone_or_init_repo(None, None, None) + + # Verify that git init was called + assert len(runtime.run_action_calls) == 1 + assert isinstance(runtime.run_action_calls[0], CmdRunAction) + assert runtime.run_action_calls[0].command == 'git init' + assert result == '' + + +@pytest.mark.asyncio +async def test_clone_or_init_repo_no_repo_no_user_id_no_workspace_base(temp_dir): + """Test that git init is run when no repository is selected, no user_id, and no workspace_base""" + config = AppConfig() + config.workspace_base = None # Ensure workspace_base is not set + file_store = get_file_store('local', temp_dir) + event_stream = EventStream('abc', file_store) + runtime = TestRuntime( + config=config, event_stream=event_stream, sid='test', user_id=None + ) + + # Call the function with no repository + result = await runtime.clone_or_init_repo(None, None, None) + + # Verify that git init was called + assert len(runtime.run_action_calls) == 1 + assert isinstance(runtime.run_action_calls[0], CmdRunAction) + assert runtime.run_action_calls[0].command == 'git init' + assert result == '' + + +@pytest.mark.asyncio +async def test_clone_or_init_repo_no_repo_no_user_id_with_workspace_base(temp_dir): + """Test that git init is not run when no repository is selected, no user_id, but workspace_base is set""" + config = AppConfig() + config.workspace_base = '/some/path' # Set workspace_base + file_store = get_file_store('local', temp_dir) + event_stream = EventStream('abc', file_store) + runtime = TestRuntime( + config=config, event_stream=event_stream, sid='test', user_id=None + ) + + # Call the function with no repository + result = await runtime.clone_or_init_repo(None, None, None) + + # Verify that git init was not called + assert len(runtime.run_action_calls) == 0 + assert result == '' + + +@pytest.mark.asyncio +async def test_clone_or_init_repo_auth_error(temp_dir): + """Test that RuntimeError is raised when authentication fails""" + config = AppConfig() + file_store = get_file_store('local', temp_dir) + event_stream = EventStream('abc', file_store) + runtime = TestRuntime( + config=config, event_stream=event_stream, sid='test', user_id='test_user' + ) + + # Mock the verify_repo_provider method to raise AuthenticationError + with patch.object( + ProviderHandler, + 'verify_repo_provider', + side_effect=AuthenticationError('Auth failed'), + ): + # Call the function with a repository + with pytest.raises(RuntimeError) as excinfo: + await runtime.clone_or_init_repo(None, 'owner/repo', None) + + # Verify the error message + assert 'Git provider authentication issue when cloning repo' in str( + excinfo.value + ) + + +@pytest.mark.asyncio +async def test_clone_or_init_repo_github_with_token(temp_dir, monkeypatch): + config = AppConfig() + file_store = get_file_store('local', temp_dir) + event_stream = EventStream('abc', file_store) + + github_token = 'github_test_token' + git_provider_tokens = MappingProxyType( + {ProviderType.GITHUB: ProviderToken(token=SecretStr(github_token))} + ) + + runtime = TestRuntime( + config=config, + event_stream=event_stream, + sid='test', + user_id='test_user', + git_provider_tokens=git_provider_tokens, + ) + + mock_repo_and_patch(monkeypatch, provider=ProviderType.GITHUB) + + result = await runtime.clone_or_init_repo(git_provider_tokens, 'owner/repo', None) + + cmd = runtime.run_action_calls[0].command + assert f'git clone https://{github_token}@github.com/owner/repo.git repo' in cmd + assert result == 'repo' + + +@pytest.mark.asyncio +async def test_clone_or_init_repo_github_no_token(temp_dir, monkeypatch): + """Test cloning a GitHub repository without a token""" + config = AppConfig() + file_store = get_file_store('local', temp_dir) + event_stream = EventStream('abc', file_store) + + runtime = TestRuntime( + config=config, event_stream=event_stream, sid='test', user_id='test_user' + ) + + mock_repo_and_patch(monkeypatch, provider=ProviderType.GITHUB) + result = await runtime.clone_or_init_repo(None, 'owner/repo', None) + + # Verify that git clone was called with the public URL + assert len(runtime.run_action_calls) == 1 + assert isinstance(runtime.run_action_calls[0], CmdRunAction) + + # Check that the command contains the correct URL format without token + cmd = runtime.run_action_calls[0].command + assert 'git clone https://github.com/owner/repo.git repo' in cmd + assert 'cd repo' in cmd + assert 'git checkout -b openhands-workspace-' in cmd + assert result == 'repo' + + +@pytest.mark.asyncio +async def test_clone_or_init_repo_gitlab_with_token(temp_dir, monkeypatch): + config = AppConfig() + file_store = get_file_store('local', temp_dir) + event_stream = EventStream('abc', file_store) + + gitlab_token = 'gitlab_test_token' + git_provider_tokens = MappingProxyType( + {ProviderType.GITLAB: ProviderToken(token=SecretStr(gitlab_token))} + ) + + runtime = TestRuntime( + config=config, + event_stream=event_stream, + sid='test', + user_id='test_user', + git_provider_tokens=git_provider_tokens, + ) + + mock_repo_and_patch(monkeypatch, provider=ProviderType.GITLAB) + + result = await runtime.clone_or_init_repo(git_provider_tokens, 'owner/repo', None) + + cmd = runtime.run_action_calls[0].command + assert ( + f'git clone https://oauth2:{gitlab_token}@gitlab.com/owner/repo.git repo' in cmd + ) + assert result == 'repo' + + +@pytest.mark.asyncio +async def test_clone_or_init_repo_with_branch(temp_dir, monkeypatch): + """Test cloning a repository with a specified branch""" + config = AppConfig() + file_store = get_file_store('local', temp_dir) + event_stream = EventStream('abc', file_store) + + runtime = TestRuntime( + config=config, event_stream=event_stream, sid='test', user_id='test_user' + ) + + mock_repo_and_patch(monkeypatch, provider=ProviderType.GITHUB) + result = await runtime.clone_or_init_repo(None, 'owner/repo', 'feature-branch') + + # Verify that git clone was called with the correct branch checkout + assert len(runtime.run_action_calls) == 1 + assert isinstance(runtime.run_action_calls[0], CmdRunAction) + + # Check that the command contains the correct branch checkout + cmd = runtime.run_action_calls[0].command + assert 'git clone https://github.com/owner/repo.git repo' in cmd + assert 'cd repo' in cmd + assert 'git checkout feature-branch' in cmd + assert 'git checkout -b' not in cmd # Should not create a new branch + assert result == 'repo'