[Fix]: Use str in place of Repository for repository param when creating new conversation (#8159)

Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
This commit is contained in:
Rohit Malhotra 2025-05-02 11:17:04 -04:00 committed by GitHub
parent 7244e5df9f
commit 767d092f8f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 579 additions and 135 deletions

View File

@ -858,14 +858,15 @@
"schema": {
"type": "object",
"properties": {
"selected_repository": {
"type": "object",
"repository": {
"type": "string",
"nullable": true,
"properties": {
"full_name": {
"type": "string"
}
}
"description": "Full name of the repository (e.g., owner/repo)"
},
"git_provider": {
"type": "string",
"nullable": true,
"description": "The Git provider (e.g., github or gitlab). If omitted, all configured providers are checked for the repository."
},
"selected_branch": {
"type": "string",

View File

@ -49,6 +49,7 @@ describe("HomeHeader", () => {
"gui",
undefined,
undefined,
undefined,
[],
undefined,
undefined,

View File

@ -165,12 +165,8 @@ describe("RepoConnector", () => {
expect(createConversationSpy).toHaveBeenCalledExactlyOnceWith(
"gui",
{
full_name: "rbren/polaris",
git_provider: "github",
id: 1,
is_public: true,
},
"rbren/polaris",
"github",
undefined,
[],
undefined,

View File

@ -89,7 +89,8 @@ describe("TaskCard", () => {
expect(createConversationSpy).toHaveBeenCalledWith(
"suggested_task",
MOCK_RESPOSITORIES[0],
MOCK_RESPOSITORIES[0].full_name,
MOCK_RESPOSITORIES[0].git_provider,
undefined,
[],
undefined,

View File

@ -13,7 +13,7 @@ import {
ConversationTrigger,
} from "./open-hands.types";
import { openHands } from "./open-hands-axios";
import { ApiSettings, PostApiSettings } from "#/types/settings";
import { ApiSettings, PostApiSettings, Provider } from "#/types/settings";
import { GitUser, GitRepository } from "#/types/git";
import { SuggestedTask } from "#/components/features/home/tasks/task.types";
@ -152,7 +152,8 @@ class OpenHands {
static async createConversation(
conversation_trigger: ConversationTrigger = "gui",
selectedRepository?: GitRepository,
selectedRepository?: string,
git_provider?: Provider,
initialUserMsg?: string,
imageUrls?: string[],
replayJson?: string,
@ -160,7 +161,8 @@ class OpenHands {
): Promise<Conversation> {
const body = {
conversation_trigger,
selected_repository: selectedRepository,
repository: selectedRepository,
git_provider,
selected_branch: undefined,
initial_user_msg: initialUserMsg,
image_urls: imageUrls,

View File

@ -24,13 +24,19 @@ export const useCreateConversation = () => {
conversation_trigger: ConversationTrigger;
q?: string;
selectedRepository?: GitRepository | null;
suggested_task?: SuggestedTask;
}) => {
if (variables.q) dispatch(setInitialPrompt(variables.q));
return OpenHands.createConversation(
variables.conversation_trigger,
variables.selectedRepository || undefined,
variables.selectedRepository
? variables.selectedRepository.full_name
: undefined,
variables.selectedRepository
? variables.selectedRepository.git_provider
: undefined,
variables.q,
files,
replayJson || undefined,

View File

@ -85,40 +85,41 @@ def create_runtime(
def initialize_repository_for_runtime(
runtime: Runtime,
selected_repository: str | None = None,
github_token: SecretStr | None = None,
runtime: Runtime, selected_repository: str | None = None
) -> str | None:
"""Initialize the repository for the runtime.
Args:
runtime: The runtime to initialize the repository for.
selected_repository: (optional) The GitHub repository to use.
github_token: (optional) The GitHub token to use.
Returns:
The repository directory path if a repository was cloned, None otherwise.
"""
# clone selected repository if provided
if github_token is None and 'GITHUB_TOKEN' in os.environ:
provider_tokens = {}
if 'GITHUB_TOKEN' in os.environ:
github_token = SecretStr(os.environ['GITHUB_TOKEN'])
provider_tokens[ProviderType.GITHUB] = ProviderToken(
token=SecretStr(github_token)
)
if 'GITLAB_TOKEN' in os.environ:
gitlab_token = SecretStr(os.environ['GITLAB_TOKEN'])
provider_tokens[ProviderType.GITLAB] = ProviderToken(
token=SecretStr(gitlab_token)
)
secret_store = (
SecretStore(
provider_tokens={
ProviderType.GITHUB: ProviderToken(token=SecretStr(github_token))
}
)
if github_token
else None
SecretStore(provider_tokens=provider_tokens) if provider_tokens else None
)
provider_tokens = secret_store.provider_tokens if secret_store else None
immutable_provider_tokens = secret_store.provider_tokens if secret_store else None
logger.debug(f'Selected repository {selected_repository}.')
repo_directory = call_async_from_sync(
runtime.clone_or_init_repo,
GENERAL_TIMEOUT,
provider_tokens,
immutable_provider_tokens,
selected_repository,
None,
)

View File

@ -390,6 +390,18 @@ class GitHubService(BaseGitService, GitService):
except Exception:
return []
async def get_repository_details_from_repo_name(self, repository: str) -> Repository:
url = f'{self.BASE_URL}/repos/{repository}'
repo, _ = await self._make_request(url)
return Repository(
id=repo.get('id'),
full_name=repo.get('full_name'),
stargazers_count=repo.get('stargazers_count'),
git_provider=ProviderType.GITHUB,
is_public=not repo.get('private', True),
)
github_service_cls = os.environ.get(
'OPENHANDS_GITHUB_SERVICE_CLS',

View File

@ -383,6 +383,23 @@ class GitLabService(BaseGitService, GitService):
return []
async def get_repository_details_from_repo_name(self, repository: str) -> Repository:
encoded_name = repository.replace("/", "%2F")
url = f'{self.BASE_URL}/projects/{encoded_name}'
repo, _ = await self._make_request(url)
return Repository(
id=repo.get('id'),
full_name=repo.get('path_with_namespace'),
stargazers_count=repo.get('star_count'),
git_provider=ProviderType.GITLAB,
is_public=repo.get('visibility') == 'public',
)
gitlab_service_cls = os.environ.get(
'OPENHANDS_GITLAB_SERVICE_CLS',
'openhands.integrations.gitlab.gitlab_service.GitLabService',

View File

@ -397,3 +397,22 @@ class ProviderHandler:
Map ProviderType value to the environment variable name in the runtime
"""
return f'{provider.value}_token'.lower()
async def verify_repo_provider(
self, repository: str, specified_provider: ProviderType | None = None
):
if specified_provider:
try:
service = self._get_service(specified_provider)
return await service.get_repository_details_from_repo_name(repository)
except Exception:
pass
for provider in self.provider_tokens:
try:
service = self._get_service(provider)
return await service.get_repository_details_from_repo_name(repository)
except Exception:
pass
raise AuthenticationError(f'Unable to access repo {repository}')

View File

@ -206,3 +206,8 @@ class GitService(Protocol):
async def get_suggested_tasks(self) -> list[SuggestedTask]:
"""Get suggested tasks for the authenticated user across all repositories"""
...
async def get_repository_details_from_repo_name(
self, repository: str
) -> Repository:
"""Gets all repository details from repository name"""

View File

@ -47,7 +47,7 @@ from openhands.integrations.provider import (
ProviderHandler,
ProviderType,
)
from openhands.integrations.service_types import Repository
from openhands.integrations.service_types import AuthenticationError
from openhands.microagent import (
BaseMicroagent,
load_microagents_from_dir,
@ -311,10 +311,23 @@ class Runtime(FileEditRuntimeMixin):
async def clone_or_init_repo(
self,
git_provider_tokens: PROVIDER_TOKEN_TYPE | None,
selected_repository: str | Repository | None,
selected_repository: str | None,
selected_branch: str | None,
repository_provider: ProviderType = ProviderType.GITHUB,
) -> str:
repository = None
if selected_repository: # Determine provider from repo name
try:
provider_handler = ProviderHandler(
git_provider_tokens or MappingProxyType({})
)
repository = await provider_handler.verify_repo_provider(
selected_repository
)
except AuthenticationError:
raise RuntimeError(
'Git provider authentication issue when cloning repo'
)
if not selected_repository:
# In SaaS mode (indicated by user_id being set), always run git init
# In OSS mode, only run git init if workspace_base is not set
@ -332,36 +345,30 @@ class Runtime(FileEditRuntimeMixin):
)
return ''
# This satisfies mypy because param is optional, but `verify_repo_provider` guarentees this gets populated
if not repository:
return ''
provider = repository.git_provider
provider_domains = {
ProviderType.GITHUB: 'github.com',
ProviderType.GITLAB: 'gitlab.com',
}
chosen_provider = (
repository_provider
if isinstance(selected_repository, str)
else selected_repository.git_provider
)
domain = provider_domains[chosen_provider]
repository = (
selected_repository
if isinstance(selected_repository, str)
else selected_repository.full_name
)
domain = provider_domains[provider]
# Try to use token if available, otherwise use public URL
if git_provider_tokens and chosen_provider in git_provider_tokens:
git_token = git_provider_tokens[chosen_provider].token
if git_provider_tokens and provider in git_provider_tokens:
git_token = git_provider_tokens[provider].token
if git_token:
if chosen_provider == ProviderType.GITLAB:
remote_repo_url = f'https://oauth2:{git_token.get_secret_value()}@{domain}/{repository}.git'
if provider == ProviderType.GITLAB:
remote_repo_url = f'https://oauth2:{git_token.get_secret_value()}@{domain}/{selected_repository}.git'
else:
remote_repo_url = f'https://{git_token.get_secret_value()}@{domain}/{repository}.git'
remote_repo_url = f'https://{git_token.get_secret_value()}@{domain}/{selected_repository}.git'
else:
remote_repo_url = f'https://{domain}/{repository}.git'
remote_repo_url = f'https://{domain}/{selected_repository}.git'
else:
remote_repo_url = f'https://{domain}/{repository}.git'
remote_repo_url = f'https://{domain}/{selected_repository}.git'
if not remote_repo_url:
raise ValueError('Missing either Git token or valid repository')
@ -371,7 +378,7 @@ class Runtime(FileEditRuntimeMixin):
'info', 'STATUS$SETTING_UP_WORKSPACE', 'Setting up workspace...'
)
dir_name = repository.split('/')[-1]
dir_name = selected_repository.split('/')[-1]
# Generate a random branch name to avoid conflicts
random_str = ''.join(

View File

@ -12,8 +12,9 @@ from openhands.events.event import EventSource
from openhands.events.stream import EventStream
from openhands.integrations.provider import (
PROVIDER_TOKEN_TYPE,
ProviderHandler,
)
from openhands.integrations.service_types import Repository, SuggestedTask
from openhands.integrations.service_types import AuthenticationError, ProviderType, Repository, SuggestedTask
from openhands.runtime import get_runtime_cls
from openhands.server.data_models.conversation_info import ConversationInfo
from openhands.server.data_models.conversation_info_result_set import (
@ -29,9 +30,11 @@ from openhands.server.shared import (
)
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
from openhands.server.user_auth import (
get_auth_type,
get_provider_tokens,
get_user_id,
)
from openhands.server.user_auth.user_auth import AuthType
from openhands.server.utils import get_conversation_store
from openhands.storage.conversation.conversation_store import ConversationStore
from openhands.storage.data_models.conversation_metadata import (
@ -48,7 +51,8 @@ app = APIRouter(prefix='/api')
class InitSessionRequest(BaseModel):
conversation_trigger: ConversationTrigger = ConversationTrigger.GUI
selected_repository: Repository | None = None
repository: str | None = None
git_provider: ProviderType | None = None
selected_branch: str | None = None
initial_user_msg: str | None = None
image_urls: list[str] | None = None
@ -59,7 +63,7 @@ class InitSessionRequest(BaseModel):
async def _create_new_conversation(
user_id: str | None,
git_provider_tokens: PROVIDER_TOKEN_TYPE | None,
selected_repository: Repository | None,
selected_repository: str | None,
selected_branch: str | None,
initial_user_msg: str | None,
image_urls: list[str] | None,
@ -67,7 +71,7 @@ async def _create_new_conversation(
conversation_trigger: ConversationTrigger = ConversationTrigger.GUI,
attach_convo_id: bool = False,
):
print("trigger", conversation_trigger)
logger.info(
'Creating conversation',
extra={'signal': 'create_conversation', 'user_id': user_id, 'trigger': conversation_trigger.value},
@ -122,9 +126,7 @@ async def _create_new_conversation(
title=conversation_title,
user_id=user_id,
github_user_id=None,
selected_repository=selected_repository.full_name
if selected_repository
else selected_repository,
selected_repository=selected_repository,
selected_branch=selected_branch,
)
)
@ -161,6 +163,7 @@ async def new_conversation(
data: InitSessionRequest,
user_id: str = Depends(get_user_id),
provider_tokens: PROVIDER_TOKEN_TYPE = Depends(get_provider_tokens),
auth_type: AuthType | None = Depends(get_auth_type)
):
"""Initialize a new session or join an existing one.
@ -168,24 +171,33 @@ async def new_conversation(
using the returned conversation ID.
"""
logger.info('Initializing new conversation')
selected_repository = data.selected_repository
repository = data.repository
selected_branch = data.selected_branch
initial_user_msg = data.initial_user_msg
image_urls = data.image_urls or []
replay_json = data.replay_json
suggested_task = data.suggested_task
conversation_trigger = data.conversation_trigger
git_provider = data.git_provider
if suggested_task:
initial_user_msg = suggested_task.get_prompt_for_task()
conversation_trigger = ConversationTrigger.SUGGESTED_TASK
if auth_type == AuthType.BEARER:
conversation_trigger = ConversationTrigger.REMOTE_API_KEY
try:
if repository:
provider_handler = ProviderHandler(provider_tokens)
# Check against git_provider, otherwise check all provider apis
await provider_handler.verify_repo_provider(repository, git_provider)
# Create conversation with initial message
conversation_id = await _create_new_conversation(
user_id=user_id,
git_provider_tokens=provider_tokens,
selected_repository=selected_repository,
selected_repository=repository,
selected_branch=selected_branch,
initial_user_msg=initial_user_msg,
image_urls=image_urls,
@ -215,6 +227,16 @@ async def new_conversation(
},
status_code=status.HTTP_400_BAD_REQUEST,
)
except AuthenticationError as e:
return JSONResponse(
content={
'status': 'error',
'message': str(e),
'msg_id': 'STATUS$GIT_PROVIDER_AUTHENTICATION_ERROR'
},
status_code=status.HTTP_400_BAD_REQUEST,
)
@app.get('/conversations')

View File

@ -86,7 +86,7 @@ class AgentSession:
max_budget_per_task: float | None = None,
agent_to_llm_config: dict[str, LLMConfig] | None = None,
agent_configs: dict[str, AgentConfig] | None = None,
selected_repository: Repository | None = None,
selected_repository: str | None = None,
selected_branch: str | None = None,
initial_message: MessageAction | None = None,
replay_json: str | None = None,
@ -153,7 +153,7 @@ class AgentSession:
repo_directory = None
if self.runtime and runtime_connected and selected_repository:
repo_directory = selected_repository.full_name.split('/')[-1]
repo_directory = selected_repository.split('/')[-1]
self.memory = await self._create_memory(
selected_repository=selected_repository,
@ -265,7 +265,7 @@ class AgentSession:
config: AppConfig,
agent: Agent,
git_provider_tokens: PROVIDER_TOKEN_TYPE | None = None,
selected_repository: Repository | None = None,
selected_repository: str | None = None,
selected_branch: str | None = None,
) -> bool:
"""Creates a runtime instance
@ -400,7 +400,7 @@ class AgentSession:
return controller
async def _create_memory(
self, selected_repository: Repository | None, repo_directory: str | None
self, selected_repository: str | None, repo_directory: str | None
) -> Memory:
memory = Memory(
event_stream=self.event_stream,
@ -415,13 +415,13 @@ class AgentSession:
# loads microagents from repo/.openhands/microagents
microagents: list[BaseMicroagent] = await call_sync_from_async(
self.runtime.get_microagents_from_selected_repo,
selected_repository.full_name if selected_repository else None,
selected_repository or None,
)
memory.load_user_workspace_microagents(microagents)
if selected_repository and repo_directory:
memory.set_repository_info(
selected_repository.full_name, repo_directory
selected_repository, repo_directory
)
return memory

View File

@ -1,7 +1,6 @@
from pydantic import Field
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
from openhands.integrations.service_types import Repository
from openhands.storage.data_models.settings import Settings
@ -11,7 +10,7 @@ class ConversationInitData(Settings):
"""
git_provider_tokens: PROVIDER_TOKEN_TYPE | None = Field(default=None, frozen=True)
selected_repository: Repository | None = Field(default=None)
selected_repository: str | None = Field(default=None)
replay_json: str | None = Field(default=None)
selected_branch: str | None = Field(default=None)

View File

@ -4,7 +4,7 @@ from pydantic import SecretStr
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
from openhands.integrations.service_types import ProviderType
from openhands.server.settings import Settings
from openhands.server.user_auth.user_auth import get_user_auth
from openhands.server.user_auth.user_auth import AuthType, get_user_auth
from openhands.storage.settings.settings_store import SettingsStore
@ -46,3 +46,8 @@ async def get_user_settings_store(request: Request) -> SettingsStore | None:
user_auth = await get_user_auth(request)
user_settings_store = await user_auth.get_user_settings_store()
return user_settings_store
async def get_auth_type(request: Request) -> AuthType | None:
user_auth = await get_user_auth(request)
return user_auth.get_auth_type()

View File

@ -51,6 +51,7 @@ class DefaultUserAuth(UserAuth):
provider_tokens = getattr(secrets_store, 'provider_tokens', None)
return provider_tokens
@classmethod
async def get_instance(cls, request: Request) -> UserAuth:
user_auth = DefaultUserAuth()

View File

@ -1,6 +1,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from enum import Enum
from fastapi import Request
from pydantic import SecretStr
@ -12,6 +13,11 @@ from openhands.storage.settings.settings_store import SettingsStore
from openhands.utils.import_utils import get_impl
class AuthType(Enum):
COOKIE = "cookie"
BEARER = "bearer"
class UserAuth(ABC):
"""Extensible class encapsulating user Authentication"""
@ -45,6 +51,9 @@ class UserAuth(ABC):
self._settings = settings
return settings
def get_auth_type(self) -> AuthType | None:
return None
@classmethod
@abstractmethod
async def get_instance(cls, request: Request) -> UserAuth:

View File

@ -7,6 +7,7 @@ class ConversationTrigger(Enum):
RESOLVER = 'resolver'
GUI = 'gui'
SUGGESTED_TASK = 'suggested_task'
REMOTE_API_KEY = 'openhands_api'
@dataclass

View File

@ -1,5 +1,7 @@
"""Tests for the setup script."""
from unittest.mock import patch
from conftest import (
_load_runtime,
)
@ -7,14 +9,27 @@ from conftest import (
from openhands.core.setup import initialize_repository_for_runtime
from openhands.events.action import FileReadAction, FileWriteAction
from openhands.events.observation import FileReadObservation, FileWriteObservation
from openhands.integrations.service_types import ProviderType, Repository
def test_initialize_repository_for_runtime(temp_dir, runtime_cls, run_as_openhands):
"""Test that the initialize_repository_for_runtime function works."""
runtime, config = _load_runtime(temp_dir, runtime_cls, run_as_openhands)
repository_dir = initialize_repository_for_runtime(
runtime, 'https://github.com/All-Hands-AI/OpenHands'
mock_repo = Repository(
id=1232,
full_name='All-Hands-AI/OpenHands',
git_provider=ProviderType.GITHUB,
is_public=True,
)
with patch(
'openhands.runtime.base.ProviderHandler.verify_repo_provider',
return_value=mock_repo,
):
repository_dir = initialize_repository_for_runtime(
runtime, selected_repository='All-Hands-AI/OpenHands'
)
assert repository_dir is not None
assert repository_dir == 'OpenHands'

View File

@ -1,14 +1,15 @@
import json
from contextlib import contextmanager
from datetime import datetime, timezone
from types import MappingProxyType
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi.responses import JSONResponse
from openhands.integrations.service_types import (
AuthenticationError,
ProviderType,
Repository,
SuggestedTask,
TaskType,
)
@ -25,6 +26,7 @@ from openhands.server.routes.manage_conversations import (
update_conversation,
)
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
from openhands.server.user_auth.user_auth import AuthType
from openhands.storage.data_models.conversation_metadata import (
ConversationMetadata,
ConversationTrigger,
@ -62,6 +64,28 @@ def _patch_store():
yield
def create_new_test_conversation(
test_request: InitSessionRequest, auth_type: AuthType | None = None
):
return new_conversation(
data=test_request,
user_id='test_user',
provider_tokens=MappingProxyType({'github': 'token123'}),
auth_type=auth_type,
)
@pytest.fixture
def provider_handler_mock():
with patch(
'openhands.server.routes.manage_conversations.ProviderHandler'
) as mock_cls:
mock_instance = MagicMock()
mock_instance.verify_repo_provider = AsyncMock(return_value=ProviderType.GITHUB)
mock_cls.return_value = mock_instance
yield mock_instance
@pytest.mark.asyncio
async def test_search_conversations():
with _patch_store():
@ -232,7 +256,7 @@ async def test_update_conversation():
@pytest.mark.asyncio
async def test_new_conversation_success():
async def test_new_conversation_success(provider_handler_mock):
"""Test successful creation of a new conversation."""
with _patch_store():
# Mock the _create_new_conversation function directly
@ -242,26 +266,16 @@ async def test_new_conversation_success():
# Set up the mock to return a conversation ID
mock_create_conversation.return_value = 'test_conversation_id'
# Create test data
test_repo = Repository(
id=12345,
full_name='test/repo',
git_provider=ProviderType.GITHUB,
is_public=True,
)
test_request = InitSessionRequest(
conversation_trigger=ConversationTrigger.GUI,
selected_repository=test_repo,
repository='test/repo',
selected_branch='main',
initial_user_msg='Hello, agent!',
image_urls=['https://example.com/image.jpg'],
)
# Call new_conversation
response = await new_conversation(
data=test_request, user_id='test_user', provider_tokens={}
)
response = await create_new_test_conversation(test_request)
# Verify the response
assert isinstance(response, JSONResponse)
@ -275,7 +289,7 @@ async def test_new_conversation_success():
mock_create_conversation.assert_called_once()
call_args = mock_create_conversation.call_args[1]
assert call_args['user_id'] == 'test_user'
assert call_args['selected_repository'] == test_repo
assert call_args['selected_repository'] == 'test/repo'
assert call_args['selected_branch'] == 'main'
assert call_args['initial_user_msg'] == 'Hello, agent!'
assert call_args['image_urls'] == ['https://example.com/image.jpg']
@ -283,7 +297,7 @@ async def test_new_conversation_success():
@pytest.mark.asyncio
async def test_new_conversation_with_suggested_task():
async def test_new_conversation_with_suggested_task(provider_handler_mock):
"""Test creating a new conversation with a suggested task."""
with _patch_store():
# Mock the _create_new_conversation function directly
@ -301,14 +315,6 @@ async def test_new_conversation_with_suggested_task():
'Please fix the failing checks in PR #123'
)
# Create test data
test_repo = Repository(
id=12345,
full_name='test/repo',
git_provider=ProviderType.GITHUB,
is_public=True,
)
test_task = SuggestedTask(
git_provider=ProviderType.GITHUB,
task_type=TaskType.FAILING_CHECKS,
@ -319,15 +325,13 @@ async def test_new_conversation_with_suggested_task():
test_request = InitSessionRequest(
conversation_trigger=ConversationTrigger.SUGGESTED_TASK,
selected_repository=test_repo,
repository='test/repo',
selected_branch='main',
suggested_task=test_task,
)
# Call new_conversation
response = await new_conversation(
data=test_request, user_id='test_user', provider_tokens={}
)
response = await create_new_test_conversation(test_request)
# Verify the response
assert isinstance(response, JSONResponse)
@ -341,7 +345,7 @@ async def test_new_conversation_with_suggested_task():
mock_create_conversation.assert_called_once()
call_args = mock_create_conversation.call_args[1]
assert call_args['user_id'] == 'test_user'
assert call_args['selected_repository'] == test_repo
assert call_args['selected_repository'] == 'test/repo'
assert call_args['selected_branch'] == 'main'
assert (
call_args['initial_user_msg']
@ -357,7 +361,7 @@ async def test_new_conversation_with_suggested_task():
@pytest.mark.asyncio
async def test_new_conversation_missing_settings():
async def test_new_conversation_missing_settings(provider_handler_mock):
"""Test creating a new conversation when settings are missing."""
with _patch_store():
# Mock the _create_new_conversation function to raise MissingSettingsError
@ -369,25 +373,15 @@ async def test_new_conversation_missing_settings():
'Settings not found'
)
# Create test data
test_repo = Repository(
id=12345,
full_name='test/repo',
git_provider=ProviderType.GITHUB,
is_public=True,
)
test_request = InitSessionRequest(
conversation_trigger=ConversationTrigger.GUI,
selected_repository=test_repo,
repository='test/repo',
selected_branch='main',
initial_user_msg='Hello, agent!',
)
# Call new_conversation
response = await new_conversation(
data=test_request, user_id='test_user', provider_tokens={}
)
response = await create_new_test_conversation(test_request)
# Verify the response
assert isinstance(response, JSONResponse)
@ -409,17 +403,9 @@ async def test_new_conversation_invalid_api_key():
'Error authenticating with the LLM provider. Please check your API key'
)
# Create test data
test_repo = Repository(
id=12345,
full_name='test/repo',
git_provider=ProviderType.GITHUB,
is_public=True,
)
test_request = InitSessionRequest(
conversation_trigger=ConversationTrigger.GUI,
selected_repository=test_repo,
repo='test/repo',
selected_branch='main',
initial_user_msg='Hello, agent!',
)
@ -496,3 +482,120 @@ async def test_delete_conversation():
mock_runtime_cls.delete.assert_called_once_with(
'some_conversation_id'
)
@pytest.mark.asyncio
async def test_new_conversation_with_bearer_auth(provider_handler_mock):
"""Test creating a new conversation with bearer authentication."""
with _patch_store():
# Mock the _create_new_conversation function
with patch(
'openhands.server.routes.manage_conversations._create_new_conversation'
) as mock_create_conversation:
# Set up the mock to return a conversation ID
mock_create_conversation.return_value = 'test_conversation_id'
# Create the request object
test_request = InitSessionRequest(
conversation_trigger=ConversationTrigger.GUI, # This should be overridden
repository='test/repo',
selected_branch='main',
initial_user_msg='Hello, agent!',
)
# Call new_conversation with auth_type=BEARER
response = await create_new_test_conversation(test_request, AuthType.BEARER)
# Verify the response
assert isinstance(response, JSONResponse)
assert response.status_code == 200
# Verify that _create_new_conversation was called with REMOTE_API_KEY trigger
mock_create_conversation.assert_called_once()
call_args = mock_create_conversation.call_args[1]
assert (
call_args['conversation_trigger'] == ConversationTrigger.REMOTE_API_KEY
)
@pytest.mark.asyncio
async def test_new_conversation_with_null_repository():
"""Test creating a new conversation with null repository."""
with _patch_store():
# Mock the _create_new_conversation function
with patch(
'openhands.server.routes.manage_conversations._create_new_conversation'
) as mock_create_conversation:
# Set up the mock to return a conversation ID
mock_create_conversation.return_value = 'test_conversation_id'
# Create the request object with null repository
test_request = InitSessionRequest(
conversation_trigger=ConversationTrigger.GUI,
repository=None, # Explicitly set to None
selected_branch=None,
initial_user_msg='Hello, agent!',
)
# Call new_conversation
response = await create_new_test_conversation(test_request)
# Verify the response
assert isinstance(response, JSONResponse)
assert response.status_code == 200
# Verify that _create_new_conversation was called with None repository
mock_create_conversation.assert_called_once()
call_args = mock_create_conversation.call_args[1]
assert call_args['selected_repository'] is None
@pytest.mark.asyncio
async def test_new_conversation_with_provider_authentication_error(
provider_handler_mock,
):
provider_handler_mock.verify_repo_provider = AsyncMock(
side_effect=AuthenticationError('auth error')
)
"""Test creating a new conversation when provider authentication fails."""
with _patch_store():
# Mock the _create_new_conversation function
with patch(
'openhands.server.routes.manage_conversations._create_new_conversation'
) as mock_create_conversation:
# Set up the mock to return a conversation ID
mock_create_conversation.return_value = 'test_conversation_id'
# Create the request object
test_request = InitSessionRequest(
conversation_trigger=ConversationTrigger.GUI,
repository='test/repo',
selected_branch='main',
initial_user_msg='Hello, agent!',
)
# Call new_conversation
response = await new_conversation(
data=test_request,
user_id='test_user',
provider_tokens={'github': 'token123'},
auth_type=None,
)
# Verify the response
assert isinstance(response, JSONResponse)
assert response.status_code == 400
assert json.loads(response.body.decode('utf-8')) == {
'status': 'error',
'message': 'auth error',
'msg_id': 'STATUS$GIT_PROVIDER_AUTHENTICATION_ERROR',
}
# Verify that verify_repo_provider was called with the repository
provider_handler_mock.verify_repo_provider.assert_called_once_with(
'test/repo', None
)
# Verify that _create_new_conversation was not called
mock_create_conversation.assert_not_called()

View File

@ -1,4 +1,5 @@
from types import MappingProxyType
from unittest.mock import MagicMock, patch
import pytest
from pydantic import SecretStr
@ -8,7 +9,8 @@ from openhands.events.action import Action
from openhands.events.action.commands import CmdRunAction
from openhands.events.observation import NullObservation, Observation
from openhands.events.stream import EventStream
from openhands.integrations.provider import ProviderToken, ProviderType
from openhands.integrations.provider import ProviderHandler, ProviderToken, ProviderType
from openhands.integrations.service_types import AuthenticationError, Repository
from openhands.runtime.base import Runtime
from openhands.storage import get_file_store
@ -16,6 +18,13 @@ from openhands.storage import get_file_store
class TestRuntime(Runtime):
"""A concrete implementation of Runtime for testing"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.run_action_calls = []
self._execute_shell_fn_git_handler = MagicMock(
return_value=MagicMock(exit_code=0, stdout='', stderr='')
)
async def connect(self):
pass
@ -23,22 +32,22 @@ class TestRuntime(Runtime):
pass
def browse(self, action):
return NullObservation()
return NullObservation(content='')
def browse_interactive(self, action):
return NullObservation()
return NullObservation(content='')
def run(self, action):
return NullObservation()
return NullObservation(content='')
def run_ipython(self, action):
return NullObservation()
return NullObservation(content='')
def read(self, action):
return NullObservation()
return NullObservation(content='')
def write(self, action):
return NullObservation()
return NullObservation(content='')
def copy_from(self, path):
return ''
@ -50,10 +59,11 @@ class TestRuntime(Runtime):
return []
def run_action(self, action: Action) -> Observation:
return NullObservation()
self.run_action_calls.append(action)
return NullObservation(content='')
def call_tool_mcp(self, action):
return NullObservation()
return NullObservation(content='')
@pytest.fixture
@ -80,6 +90,20 @@ def runtime(temp_dir):
return runtime
def mock_repo_and_patch(monkeypatch, provider=ProviderType.GITHUB, is_public=True):
repo = Repository(
id=123, full_name='owner/repo', git_provider=provider, is_public=is_public
)
async def mock_verify_repo_provider(*_args, **_kwargs):
return repo
monkeypatch.setattr(
ProviderHandler, 'verify_repo_provider', mock_verify_repo_provider
)
return repo
@pytest.mark.asyncio
async def test_export_latest_git_provider_tokens_no_user_id(temp_dir):
"""Test that no token export happens when user_id is not set"""
@ -183,3 +207,200 @@ async def test_export_latest_git_provider_tokens_token_update(runtime):
# Verify that the new token was exported
assert runtime.event_stream.secrets == {'github_token': new_token}
@pytest.mark.asyncio
async def test_clone_or_init_repo_no_repo_with_user_id(temp_dir):
"""Test that git init is run when no repository is selected and user_id is set"""
config = AppConfig()
file_store = get_file_store('local', temp_dir)
event_stream = EventStream('abc', file_store)
runtime = TestRuntime(
config=config, event_stream=event_stream, sid='test', user_id='test_user'
)
# Call the function with no repository
result = await runtime.clone_or_init_repo(None, None, None)
# Verify that git init was called
assert len(runtime.run_action_calls) == 1
assert isinstance(runtime.run_action_calls[0], CmdRunAction)
assert runtime.run_action_calls[0].command == 'git init'
assert result == ''
@pytest.mark.asyncio
async def test_clone_or_init_repo_no_repo_no_user_id_no_workspace_base(temp_dir):
"""Test that git init is run when no repository is selected, no user_id, and no workspace_base"""
config = AppConfig()
config.workspace_base = None # Ensure workspace_base is not set
file_store = get_file_store('local', temp_dir)
event_stream = EventStream('abc', file_store)
runtime = TestRuntime(
config=config, event_stream=event_stream, sid='test', user_id=None
)
# Call the function with no repository
result = await runtime.clone_or_init_repo(None, None, None)
# Verify that git init was called
assert len(runtime.run_action_calls) == 1
assert isinstance(runtime.run_action_calls[0], CmdRunAction)
assert runtime.run_action_calls[0].command == 'git init'
assert result == ''
@pytest.mark.asyncio
async def test_clone_or_init_repo_no_repo_no_user_id_with_workspace_base(temp_dir):
"""Test that git init is not run when no repository is selected, no user_id, but workspace_base is set"""
config = AppConfig()
config.workspace_base = '/some/path' # Set workspace_base
file_store = get_file_store('local', temp_dir)
event_stream = EventStream('abc', file_store)
runtime = TestRuntime(
config=config, event_stream=event_stream, sid='test', user_id=None
)
# Call the function with no repository
result = await runtime.clone_or_init_repo(None, None, None)
# Verify that git init was not called
assert len(runtime.run_action_calls) == 0
assert result == ''
@pytest.mark.asyncio
async def test_clone_or_init_repo_auth_error(temp_dir):
"""Test that RuntimeError is raised when authentication fails"""
config = AppConfig()
file_store = get_file_store('local', temp_dir)
event_stream = EventStream('abc', file_store)
runtime = TestRuntime(
config=config, event_stream=event_stream, sid='test', user_id='test_user'
)
# Mock the verify_repo_provider method to raise AuthenticationError
with patch.object(
ProviderHandler,
'verify_repo_provider',
side_effect=AuthenticationError('Auth failed'),
):
# Call the function with a repository
with pytest.raises(RuntimeError) as excinfo:
await runtime.clone_or_init_repo(None, 'owner/repo', None)
# Verify the error message
assert 'Git provider authentication issue when cloning repo' in str(
excinfo.value
)
@pytest.mark.asyncio
async def test_clone_or_init_repo_github_with_token(temp_dir, monkeypatch):
config = AppConfig()
file_store = get_file_store('local', temp_dir)
event_stream = EventStream('abc', file_store)
github_token = 'github_test_token'
git_provider_tokens = MappingProxyType(
{ProviderType.GITHUB: ProviderToken(token=SecretStr(github_token))}
)
runtime = TestRuntime(
config=config,
event_stream=event_stream,
sid='test',
user_id='test_user',
git_provider_tokens=git_provider_tokens,
)
mock_repo_and_patch(monkeypatch, provider=ProviderType.GITHUB)
result = await runtime.clone_or_init_repo(git_provider_tokens, 'owner/repo', None)
cmd = runtime.run_action_calls[0].command
assert f'git clone https://{github_token}@github.com/owner/repo.git repo' in cmd
assert result == 'repo'
@pytest.mark.asyncio
async def test_clone_or_init_repo_github_no_token(temp_dir, monkeypatch):
"""Test cloning a GitHub repository without a token"""
config = AppConfig()
file_store = get_file_store('local', temp_dir)
event_stream = EventStream('abc', file_store)
runtime = TestRuntime(
config=config, event_stream=event_stream, sid='test', user_id='test_user'
)
mock_repo_and_patch(monkeypatch, provider=ProviderType.GITHUB)
result = await runtime.clone_or_init_repo(None, 'owner/repo', None)
# Verify that git clone was called with the public URL
assert len(runtime.run_action_calls) == 1
assert isinstance(runtime.run_action_calls[0], CmdRunAction)
# Check that the command contains the correct URL format without token
cmd = runtime.run_action_calls[0].command
assert 'git clone https://github.com/owner/repo.git repo' in cmd
assert 'cd repo' in cmd
assert 'git checkout -b openhands-workspace-' in cmd
assert result == 'repo'
@pytest.mark.asyncio
async def test_clone_or_init_repo_gitlab_with_token(temp_dir, monkeypatch):
config = AppConfig()
file_store = get_file_store('local', temp_dir)
event_stream = EventStream('abc', file_store)
gitlab_token = 'gitlab_test_token'
git_provider_tokens = MappingProxyType(
{ProviderType.GITLAB: ProviderToken(token=SecretStr(gitlab_token))}
)
runtime = TestRuntime(
config=config,
event_stream=event_stream,
sid='test',
user_id='test_user',
git_provider_tokens=git_provider_tokens,
)
mock_repo_and_patch(monkeypatch, provider=ProviderType.GITLAB)
result = await runtime.clone_or_init_repo(git_provider_tokens, 'owner/repo', None)
cmd = runtime.run_action_calls[0].command
assert (
f'git clone https://oauth2:{gitlab_token}@gitlab.com/owner/repo.git repo' in cmd
)
assert result == 'repo'
@pytest.mark.asyncio
async def test_clone_or_init_repo_with_branch(temp_dir, monkeypatch):
"""Test cloning a repository with a specified branch"""
config = AppConfig()
file_store = get_file_store('local', temp_dir)
event_stream = EventStream('abc', file_store)
runtime = TestRuntime(
config=config, event_stream=event_stream, sid='test', user_id='test_user'
)
mock_repo_and_patch(monkeypatch, provider=ProviderType.GITHUB)
result = await runtime.clone_or_init_repo(None, 'owner/repo', 'feature-branch')
# Verify that git clone was called with the correct branch checkout
assert len(runtime.run_action_calls) == 1
assert isinstance(runtime.run_action_calls[0], CmdRunAction)
# Check that the command contains the correct branch checkout
cmd = runtime.run_action_calls[0].command
assert 'git clone https://github.com/owner/repo.git repo' in cmd
assert 'cd repo' in cmd
assert 'git checkout feature-branch' in cmd
assert 'git checkout -b' not in cmd # Should not create a new branch
assert result == 'repo'