mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
[Fix]: Use str in place of Repository for repository param when creating new conversation (#8159)
Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
This commit is contained in:
parent
7244e5df9f
commit
767d092f8f
15
docs/static/openapi.json
vendored
15
docs/static/openapi.json
vendored
@ -858,14 +858,15 @@
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"selected_repository": {
|
||||
"type": "object",
|
||||
"repository": {
|
||||
"type": "string",
|
||||
"nullable": true,
|
||||
"properties": {
|
||||
"full_name": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
"description": "Full name of the repository (e.g., owner/repo)"
|
||||
},
|
||||
"git_provider": {
|
||||
"type": "string",
|
||||
"nullable": true,
|
||||
"description": "The Git provider (e.g., github or gitlab). If omitted, all configured providers are checked for the repository."
|
||||
},
|
||||
"selected_branch": {
|
||||
"type": "string",
|
||||
|
||||
@ -49,6 +49,7 @@ describe("HomeHeader", () => {
|
||||
"gui",
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
[],
|
||||
undefined,
|
||||
undefined,
|
||||
|
||||
@ -165,12 +165,8 @@ describe("RepoConnector", () => {
|
||||
|
||||
expect(createConversationSpy).toHaveBeenCalledExactlyOnceWith(
|
||||
"gui",
|
||||
{
|
||||
full_name: "rbren/polaris",
|
||||
git_provider: "github",
|
||||
id: 1,
|
||||
is_public: true,
|
||||
},
|
||||
"rbren/polaris",
|
||||
"github",
|
||||
undefined,
|
||||
[],
|
||||
undefined,
|
||||
|
||||
@ -89,7 +89,8 @@ describe("TaskCard", () => {
|
||||
|
||||
expect(createConversationSpy).toHaveBeenCalledWith(
|
||||
"suggested_task",
|
||||
MOCK_RESPOSITORIES[0],
|
||||
MOCK_RESPOSITORIES[0].full_name,
|
||||
MOCK_RESPOSITORIES[0].git_provider,
|
||||
undefined,
|
||||
[],
|
||||
undefined,
|
||||
|
||||
@ -13,7 +13,7 @@ import {
|
||||
ConversationTrigger,
|
||||
} from "./open-hands.types";
|
||||
import { openHands } from "./open-hands-axios";
|
||||
import { ApiSettings, PostApiSettings } from "#/types/settings";
|
||||
import { ApiSettings, PostApiSettings, Provider } from "#/types/settings";
|
||||
import { GitUser, GitRepository } from "#/types/git";
|
||||
import { SuggestedTask } from "#/components/features/home/tasks/task.types";
|
||||
|
||||
@ -152,7 +152,8 @@ class OpenHands {
|
||||
|
||||
static async createConversation(
|
||||
conversation_trigger: ConversationTrigger = "gui",
|
||||
selectedRepository?: GitRepository,
|
||||
selectedRepository?: string,
|
||||
git_provider?: Provider,
|
||||
initialUserMsg?: string,
|
||||
imageUrls?: string[],
|
||||
replayJson?: string,
|
||||
@ -160,7 +161,8 @@ class OpenHands {
|
||||
): Promise<Conversation> {
|
||||
const body = {
|
||||
conversation_trigger,
|
||||
selected_repository: selectedRepository,
|
||||
repository: selectedRepository,
|
||||
git_provider,
|
||||
selected_branch: undefined,
|
||||
initial_user_msg: initialUserMsg,
|
||||
image_urls: imageUrls,
|
||||
|
||||
@ -24,13 +24,19 @@ export const useCreateConversation = () => {
|
||||
conversation_trigger: ConversationTrigger;
|
||||
q?: string;
|
||||
selectedRepository?: GitRepository | null;
|
||||
|
||||
suggested_task?: SuggestedTask;
|
||||
}) => {
|
||||
if (variables.q) dispatch(setInitialPrompt(variables.q));
|
||||
|
||||
return OpenHands.createConversation(
|
||||
variables.conversation_trigger,
|
||||
variables.selectedRepository || undefined,
|
||||
variables.selectedRepository
|
||||
? variables.selectedRepository.full_name
|
||||
: undefined,
|
||||
variables.selectedRepository
|
||||
? variables.selectedRepository.git_provider
|
||||
: undefined,
|
||||
variables.q,
|
||||
files,
|
||||
replayJson || undefined,
|
||||
|
||||
@ -85,40 +85,41 @@ def create_runtime(
|
||||
|
||||
|
||||
def initialize_repository_for_runtime(
|
||||
runtime: Runtime,
|
||||
selected_repository: str | None = None,
|
||||
github_token: SecretStr | None = None,
|
||||
runtime: Runtime, selected_repository: str | None = None
|
||||
) -> str | None:
|
||||
"""Initialize the repository for the runtime.
|
||||
|
||||
Args:
|
||||
runtime: The runtime to initialize the repository for.
|
||||
selected_repository: (optional) The GitHub repository to use.
|
||||
github_token: (optional) The GitHub token to use.
|
||||
|
||||
Returns:
|
||||
The repository directory path if a repository was cloned, None otherwise.
|
||||
"""
|
||||
# clone selected repository if provided
|
||||
if github_token is None and 'GITHUB_TOKEN' in os.environ:
|
||||
provider_tokens = {}
|
||||
if 'GITHUB_TOKEN' in os.environ:
|
||||
github_token = SecretStr(os.environ['GITHUB_TOKEN'])
|
||||
provider_tokens[ProviderType.GITHUB] = ProviderToken(
|
||||
token=SecretStr(github_token)
|
||||
)
|
||||
|
||||
if 'GITLAB_TOKEN' in os.environ:
|
||||
gitlab_token = SecretStr(os.environ['GITLAB_TOKEN'])
|
||||
provider_tokens[ProviderType.GITLAB] = ProviderToken(
|
||||
token=SecretStr(gitlab_token)
|
||||
)
|
||||
|
||||
secret_store = (
|
||||
SecretStore(
|
||||
provider_tokens={
|
||||
ProviderType.GITHUB: ProviderToken(token=SecretStr(github_token))
|
||||
}
|
||||
)
|
||||
if github_token
|
||||
else None
|
||||
SecretStore(provider_tokens=provider_tokens) if provider_tokens else None
|
||||
)
|
||||
provider_tokens = secret_store.provider_tokens if secret_store else None
|
||||
immutable_provider_tokens = secret_store.provider_tokens if secret_store else None
|
||||
|
||||
logger.debug(f'Selected repository {selected_repository}.')
|
||||
repo_directory = call_async_from_sync(
|
||||
runtime.clone_or_init_repo,
|
||||
GENERAL_TIMEOUT,
|
||||
provider_tokens,
|
||||
immutable_provider_tokens,
|
||||
selected_repository,
|
||||
None,
|
||||
)
|
||||
|
||||
@ -390,6 +390,18 @@ class GitHubService(BaseGitService, GitService):
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
async def get_repository_details_from_repo_name(self, repository: str) -> Repository:
|
||||
url = f'{self.BASE_URL}/repos/{repository}'
|
||||
repo, _ = await self._make_request(url)
|
||||
|
||||
return Repository(
|
||||
id=repo.get('id'),
|
||||
full_name=repo.get('full_name'),
|
||||
stargazers_count=repo.get('stargazers_count'),
|
||||
git_provider=ProviderType.GITHUB,
|
||||
is_public=not repo.get('private', True),
|
||||
)
|
||||
|
||||
|
||||
github_service_cls = os.environ.get(
|
||||
'OPENHANDS_GITHUB_SERVICE_CLS',
|
||||
|
||||
@ -383,6 +383,23 @@ class GitLabService(BaseGitService, GitService):
|
||||
return []
|
||||
|
||||
|
||||
async def get_repository_details_from_repo_name(self, repository: str) -> Repository:
|
||||
encoded_name = repository.replace("/", "%2F")
|
||||
|
||||
url = f'{self.BASE_URL}/projects/{encoded_name}'
|
||||
repo, _ = await self._make_request(url)
|
||||
|
||||
return Repository(
|
||||
id=repo.get('id'),
|
||||
full_name=repo.get('path_with_namespace'),
|
||||
stargazers_count=repo.get('star_count'),
|
||||
git_provider=ProviderType.GITLAB,
|
||||
is_public=repo.get('visibility') == 'public',
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
gitlab_service_cls = os.environ.get(
|
||||
'OPENHANDS_GITLAB_SERVICE_CLS',
|
||||
'openhands.integrations.gitlab.gitlab_service.GitLabService',
|
||||
|
||||
@ -397,3 +397,22 @@ class ProviderHandler:
|
||||
Map ProviderType value to the environment variable name in the runtime
|
||||
"""
|
||||
return f'{provider.value}_token'.lower()
|
||||
|
||||
async def verify_repo_provider(
|
||||
self, repository: str, specified_provider: ProviderType | None = None
|
||||
):
|
||||
if specified_provider:
|
||||
try:
|
||||
service = self._get_service(specified_provider)
|
||||
return await service.get_repository_details_from_repo_name(repository)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
for provider in self.provider_tokens:
|
||||
try:
|
||||
service = self._get_service(provider)
|
||||
return await service.get_repository_details_from_repo_name(repository)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
raise AuthenticationError(f'Unable to access repo {repository}')
|
||||
|
||||
@ -206,3 +206,8 @@ class GitService(Protocol):
|
||||
async def get_suggested_tasks(self) -> list[SuggestedTask]:
|
||||
"""Get suggested tasks for the authenticated user across all repositories"""
|
||||
...
|
||||
|
||||
async def get_repository_details_from_repo_name(
|
||||
self, repository: str
|
||||
) -> Repository:
|
||||
"""Gets all repository details from repository name"""
|
||||
|
||||
@ -47,7 +47,7 @@ from openhands.integrations.provider import (
|
||||
ProviderHandler,
|
||||
ProviderType,
|
||||
)
|
||||
from openhands.integrations.service_types import Repository
|
||||
from openhands.integrations.service_types import AuthenticationError
|
||||
from openhands.microagent import (
|
||||
BaseMicroagent,
|
||||
load_microagents_from_dir,
|
||||
@ -311,10 +311,23 @@ class Runtime(FileEditRuntimeMixin):
|
||||
async def clone_or_init_repo(
|
||||
self,
|
||||
git_provider_tokens: PROVIDER_TOKEN_TYPE | None,
|
||||
selected_repository: str | Repository | None,
|
||||
selected_repository: str | None,
|
||||
selected_branch: str | None,
|
||||
repository_provider: ProviderType = ProviderType.GITHUB,
|
||||
) -> str:
|
||||
repository = None
|
||||
if selected_repository: # Determine provider from repo name
|
||||
try:
|
||||
provider_handler = ProviderHandler(
|
||||
git_provider_tokens or MappingProxyType({})
|
||||
)
|
||||
repository = await provider_handler.verify_repo_provider(
|
||||
selected_repository
|
||||
)
|
||||
except AuthenticationError:
|
||||
raise RuntimeError(
|
||||
'Git provider authentication issue when cloning repo'
|
||||
)
|
||||
|
||||
if not selected_repository:
|
||||
# In SaaS mode (indicated by user_id being set), always run git init
|
||||
# In OSS mode, only run git init if workspace_base is not set
|
||||
@ -332,36 +345,30 @@ class Runtime(FileEditRuntimeMixin):
|
||||
)
|
||||
return ''
|
||||
|
||||
# This satisfies mypy because param is optional, but `verify_repo_provider` guarentees this gets populated
|
||||
if not repository:
|
||||
return ''
|
||||
|
||||
provider = repository.git_provider
|
||||
provider_domains = {
|
||||
ProviderType.GITHUB: 'github.com',
|
||||
ProviderType.GITLAB: 'gitlab.com',
|
||||
}
|
||||
|
||||
chosen_provider = (
|
||||
repository_provider
|
||||
if isinstance(selected_repository, str)
|
||||
else selected_repository.git_provider
|
||||
)
|
||||
|
||||
domain = provider_domains[chosen_provider]
|
||||
repository = (
|
||||
selected_repository
|
||||
if isinstance(selected_repository, str)
|
||||
else selected_repository.full_name
|
||||
)
|
||||
domain = provider_domains[provider]
|
||||
|
||||
# Try to use token if available, otherwise use public URL
|
||||
if git_provider_tokens and chosen_provider in git_provider_tokens:
|
||||
git_token = git_provider_tokens[chosen_provider].token
|
||||
if git_provider_tokens and provider in git_provider_tokens:
|
||||
git_token = git_provider_tokens[provider].token
|
||||
if git_token:
|
||||
if chosen_provider == ProviderType.GITLAB:
|
||||
remote_repo_url = f'https://oauth2:{git_token.get_secret_value()}@{domain}/{repository}.git'
|
||||
if provider == ProviderType.GITLAB:
|
||||
remote_repo_url = f'https://oauth2:{git_token.get_secret_value()}@{domain}/{selected_repository}.git'
|
||||
else:
|
||||
remote_repo_url = f'https://{git_token.get_secret_value()}@{domain}/{repository}.git'
|
||||
remote_repo_url = f'https://{git_token.get_secret_value()}@{domain}/{selected_repository}.git'
|
||||
else:
|
||||
remote_repo_url = f'https://{domain}/{repository}.git'
|
||||
remote_repo_url = f'https://{domain}/{selected_repository}.git'
|
||||
else:
|
||||
remote_repo_url = f'https://{domain}/{repository}.git'
|
||||
remote_repo_url = f'https://{domain}/{selected_repository}.git'
|
||||
|
||||
if not remote_repo_url:
|
||||
raise ValueError('Missing either Git token or valid repository')
|
||||
@ -371,7 +378,7 @@ class Runtime(FileEditRuntimeMixin):
|
||||
'info', 'STATUS$SETTING_UP_WORKSPACE', 'Setting up workspace...'
|
||||
)
|
||||
|
||||
dir_name = repository.split('/')[-1]
|
||||
dir_name = selected_repository.split('/')[-1]
|
||||
|
||||
# Generate a random branch name to avoid conflicts
|
||||
random_str = ''.join(
|
||||
|
||||
@ -12,8 +12,9 @@ from openhands.events.event import EventSource
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.integrations.provider import (
|
||||
PROVIDER_TOKEN_TYPE,
|
||||
ProviderHandler,
|
||||
)
|
||||
from openhands.integrations.service_types import Repository, SuggestedTask
|
||||
from openhands.integrations.service_types import AuthenticationError, ProviderType, Repository, SuggestedTask
|
||||
from openhands.runtime import get_runtime_cls
|
||||
from openhands.server.data_models.conversation_info import ConversationInfo
|
||||
from openhands.server.data_models.conversation_info_result_set import (
|
||||
@ -29,9 +30,11 @@ from openhands.server.shared import (
|
||||
)
|
||||
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
|
||||
from openhands.server.user_auth import (
|
||||
get_auth_type,
|
||||
get_provider_tokens,
|
||||
get_user_id,
|
||||
)
|
||||
from openhands.server.user_auth.user_auth import AuthType
|
||||
from openhands.server.utils import get_conversation_store
|
||||
from openhands.storage.conversation.conversation_store import ConversationStore
|
||||
from openhands.storage.data_models.conversation_metadata import (
|
||||
@ -48,7 +51,8 @@ app = APIRouter(prefix='/api')
|
||||
|
||||
class InitSessionRequest(BaseModel):
|
||||
conversation_trigger: ConversationTrigger = ConversationTrigger.GUI
|
||||
selected_repository: Repository | None = None
|
||||
repository: str | None = None
|
||||
git_provider: ProviderType | None = None
|
||||
selected_branch: str | None = None
|
||||
initial_user_msg: str | None = None
|
||||
image_urls: list[str] | None = None
|
||||
@ -59,7 +63,7 @@ class InitSessionRequest(BaseModel):
|
||||
async def _create_new_conversation(
|
||||
user_id: str | None,
|
||||
git_provider_tokens: PROVIDER_TOKEN_TYPE | None,
|
||||
selected_repository: Repository | None,
|
||||
selected_repository: str | None,
|
||||
selected_branch: str | None,
|
||||
initial_user_msg: str | None,
|
||||
image_urls: list[str] | None,
|
||||
@ -67,7 +71,7 @@ async def _create_new_conversation(
|
||||
conversation_trigger: ConversationTrigger = ConversationTrigger.GUI,
|
||||
attach_convo_id: bool = False,
|
||||
):
|
||||
print("trigger", conversation_trigger)
|
||||
|
||||
logger.info(
|
||||
'Creating conversation',
|
||||
extra={'signal': 'create_conversation', 'user_id': user_id, 'trigger': conversation_trigger.value},
|
||||
@ -122,9 +126,7 @@ async def _create_new_conversation(
|
||||
title=conversation_title,
|
||||
user_id=user_id,
|
||||
github_user_id=None,
|
||||
selected_repository=selected_repository.full_name
|
||||
if selected_repository
|
||||
else selected_repository,
|
||||
selected_repository=selected_repository,
|
||||
selected_branch=selected_branch,
|
||||
)
|
||||
)
|
||||
@ -161,6 +163,7 @@ async def new_conversation(
|
||||
data: InitSessionRequest,
|
||||
user_id: str = Depends(get_user_id),
|
||||
provider_tokens: PROVIDER_TOKEN_TYPE = Depends(get_provider_tokens),
|
||||
auth_type: AuthType | None = Depends(get_auth_type)
|
||||
):
|
||||
"""Initialize a new session or join an existing one.
|
||||
|
||||
@ -168,24 +171,33 @@ async def new_conversation(
|
||||
using the returned conversation ID.
|
||||
"""
|
||||
logger.info('Initializing new conversation')
|
||||
selected_repository = data.selected_repository
|
||||
repository = data.repository
|
||||
selected_branch = data.selected_branch
|
||||
initial_user_msg = data.initial_user_msg
|
||||
image_urls = data.image_urls or []
|
||||
replay_json = data.replay_json
|
||||
suggested_task = data.suggested_task
|
||||
conversation_trigger = data.conversation_trigger
|
||||
git_provider = data.git_provider
|
||||
|
||||
if suggested_task:
|
||||
initial_user_msg = suggested_task.get_prompt_for_task()
|
||||
conversation_trigger = ConversationTrigger.SUGGESTED_TASK
|
||||
|
||||
if auth_type == AuthType.BEARER:
|
||||
conversation_trigger = ConversationTrigger.REMOTE_API_KEY
|
||||
|
||||
try:
|
||||
if repository:
|
||||
provider_handler = ProviderHandler(provider_tokens)
|
||||
# Check against git_provider, otherwise check all provider apis
|
||||
await provider_handler.verify_repo_provider(repository, git_provider)
|
||||
|
||||
# Create conversation with initial message
|
||||
conversation_id = await _create_new_conversation(
|
||||
user_id=user_id,
|
||||
git_provider_tokens=provider_tokens,
|
||||
selected_repository=selected_repository,
|
||||
selected_repository=repository,
|
||||
selected_branch=selected_branch,
|
||||
initial_user_msg=initial_user_msg,
|
||||
image_urls=image_urls,
|
||||
@ -215,6 +227,16 @@ async def new_conversation(
|
||||
},
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
except AuthenticationError as e:
|
||||
return JSONResponse(
|
||||
content={
|
||||
'status': 'error',
|
||||
'message': str(e),
|
||||
'msg_id': 'STATUS$GIT_PROVIDER_AUTHENTICATION_ERROR'
|
||||
},
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
|
||||
@app.get('/conversations')
|
||||
|
||||
@ -86,7 +86,7 @@ class AgentSession:
|
||||
max_budget_per_task: float | None = None,
|
||||
agent_to_llm_config: dict[str, LLMConfig] | None = None,
|
||||
agent_configs: dict[str, AgentConfig] | None = None,
|
||||
selected_repository: Repository | None = None,
|
||||
selected_repository: str | None = None,
|
||||
selected_branch: str | None = None,
|
||||
initial_message: MessageAction | None = None,
|
||||
replay_json: str | None = None,
|
||||
@ -153,7 +153,7 @@ class AgentSession:
|
||||
|
||||
repo_directory = None
|
||||
if self.runtime and runtime_connected and selected_repository:
|
||||
repo_directory = selected_repository.full_name.split('/')[-1]
|
||||
repo_directory = selected_repository.split('/')[-1]
|
||||
|
||||
self.memory = await self._create_memory(
|
||||
selected_repository=selected_repository,
|
||||
@ -265,7 +265,7 @@ class AgentSession:
|
||||
config: AppConfig,
|
||||
agent: Agent,
|
||||
git_provider_tokens: PROVIDER_TOKEN_TYPE | None = None,
|
||||
selected_repository: Repository | None = None,
|
||||
selected_repository: str | None = None,
|
||||
selected_branch: str | None = None,
|
||||
) -> bool:
|
||||
"""Creates a runtime instance
|
||||
@ -400,7 +400,7 @@ class AgentSession:
|
||||
return controller
|
||||
|
||||
async def _create_memory(
|
||||
self, selected_repository: Repository | None, repo_directory: str | None
|
||||
self, selected_repository: str | None, repo_directory: str | None
|
||||
) -> Memory:
|
||||
memory = Memory(
|
||||
event_stream=self.event_stream,
|
||||
@ -415,13 +415,13 @@ class AgentSession:
|
||||
# loads microagents from repo/.openhands/microagents
|
||||
microagents: list[BaseMicroagent] = await call_sync_from_async(
|
||||
self.runtime.get_microagents_from_selected_repo,
|
||||
selected_repository.full_name if selected_repository else None,
|
||||
selected_repository or None,
|
||||
)
|
||||
memory.load_user_workspace_microagents(microagents)
|
||||
|
||||
if selected_repository and repo_directory:
|
||||
memory.set_repository_info(
|
||||
selected_repository.full_name, repo_directory
|
||||
selected_repository, repo_directory
|
||||
)
|
||||
return memory
|
||||
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
from pydantic import Field
|
||||
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
||||
from openhands.integrations.service_types import Repository
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
|
||||
|
||||
@ -11,7 +10,7 @@ class ConversationInitData(Settings):
|
||||
"""
|
||||
|
||||
git_provider_tokens: PROVIDER_TOKEN_TYPE | None = Field(default=None, frozen=True)
|
||||
selected_repository: Repository | None = Field(default=None)
|
||||
selected_repository: str | None = Field(default=None)
|
||||
replay_json: str | None = Field(default=None)
|
||||
selected_branch: str | None = Field(default=None)
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ from pydantic import SecretStr
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.server.settings import Settings
|
||||
from openhands.server.user_auth.user_auth import get_user_auth
|
||||
from openhands.server.user_auth.user_auth import AuthType, get_user_auth
|
||||
from openhands.storage.settings.settings_store import SettingsStore
|
||||
|
||||
|
||||
@ -46,3 +46,8 @@ async def get_user_settings_store(request: Request) -> SettingsStore | None:
|
||||
user_auth = await get_user_auth(request)
|
||||
user_settings_store = await user_auth.get_user_settings_store()
|
||||
return user_settings_store
|
||||
|
||||
|
||||
async def get_auth_type(request: Request) -> AuthType | None:
|
||||
user_auth = await get_user_auth(request)
|
||||
return user_auth.get_auth_type()
|
||||
@ -51,6 +51,7 @@ class DefaultUserAuth(UserAuth):
|
||||
provider_tokens = getattr(secrets_store, 'provider_tokens', None)
|
||||
return provider_tokens
|
||||
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls, request: Request) -> UserAuth:
|
||||
user_auth = DefaultUserAuth()
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
|
||||
from fastapi import Request
|
||||
from pydantic import SecretStr
|
||||
@ -12,6 +13,11 @@ from openhands.storage.settings.settings_store import SettingsStore
|
||||
from openhands.utils.import_utils import get_impl
|
||||
|
||||
|
||||
class AuthType(Enum):
|
||||
COOKIE = "cookie"
|
||||
BEARER = "bearer"
|
||||
|
||||
|
||||
class UserAuth(ABC):
|
||||
"""Extensible class encapsulating user Authentication"""
|
||||
|
||||
@ -45,6 +51,9 @@ class UserAuth(ABC):
|
||||
self._settings = settings
|
||||
return settings
|
||||
|
||||
def get_auth_type(self) -> AuthType | None:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
async def get_instance(cls, request: Request) -> UserAuth:
|
||||
|
||||
@ -7,6 +7,7 @@ class ConversationTrigger(Enum):
|
||||
RESOLVER = 'resolver'
|
||||
GUI = 'gui'
|
||||
SUGGESTED_TASK = 'suggested_task'
|
||||
REMOTE_API_KEY = 'openhands_api'
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
"""Tests for the setup script."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from conftest import (
|
||||
_load_runtime,
|
||||
)
|
||||
@ -7,14 +9,27 @@ from conftest import (
|
||||
from openhands.core.setup import initialize_repository_for_runtime
|
||||
from openhands.events.action import FileReadAction, FileWriteAction
|
||||
from openhands.events.observation import FileReadObservation, FileWriteObservation
|
||||
from openhands.integrations.service_types import ProviderType, Repository
|
||||
|
||||
|
||||
def test_initialize_repository_for_runtime(temp_dir, runtime_cls, run_as_openhands):
|
||||
"""Test that the initialize_repository_for_runtime function works."""
|
||||
runtime, config = _load_runtime(temp_dir, runtime_cls, run_as_openhands)
|
||||
repository_dir = initialize_repository_for_runtime(
|
||||
runtime, 'https://github.com/All-Hands-AI/OpenHands'
|
||||
mock_repo = Repository(
|
||||
id=1232,
|
||||
full_name='All-Hands-AI/OpenHands',
|
||||
git_provider=ProviderType.GITHUB,
|
||||
is_public=True,
|
||||
)
|
||||
|
||||
with patch(
|
||||
'openhands.runtime.base.ProviderHandler.verify_repo_provider',
|
||||
return_value=mock_repo,
|
||||
):
|
||||
repository_dir = initialize_repository_for_runtime(
|
||||
runtime, selected_repository='All-Hands-AI/OpenHands'
|
||||
)
|
||||
|
||||
assert repository_dir is not None
|
||||
assert repository_dir == 'OpenHands'
|
||||
|
||||
|
||||
@ -1,14 +1,15 @@
|
||||
import json
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timezone
|
||||
from types import MappingProxyType
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from openhands.integrations.service_types import (
|
||||
AuthenticationError,
|
||||
ProviderType,
|
||||
Repository,
|
||||
SuggestedTask,
|
||||
TaskType,
|
||||
)
|
||||
@ -25,6 +26,7 @@ from openhands.server.routes.manage_conversations import (
|
||||
update_conversation,
|
||||
)
|
||||
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
|
||||
from openhands.server.user_auth.user_auth import AuthType
|
||||
from openhands.storage.data_models.conversation_metadata import (
|
||||
ConversationMetadata,
|
||||
ConversationTrigger,
|
||||
@ -62,6 +64,28 @@ def _patch_store():
|
||||
yield
|
||||
|
||||
|
||||
def create_new_test_conversation(
|
||||
test_request: InitSessionRequest, auth_type: AuthType | None = None
|
||||
):
|
||||
return new_conversation(
|
||||
data=test_request,
|
||||
user_id='test_user',
|
||||
provider_tokens=MappingProxyType({'github': 'token123'}),
|
||||
auth_type=auth_type,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider_handler_mock():
|
||||
with patch(
|
||||
'openhands.server.routes.manage_conversations.ProviderHandler'
|
||||
) as mock_cls:
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.verify_repo_provider = AsyncMock(return_value=ProviderType.GITHUB)
|
||||
mock_cls.return_value = mock_instance
|
||||
yield mock_instance
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_conversations():
|
||||
with _patch_store():
|
||||
@ -232,7 +256,7 @@ async def test_update_conversation():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_conversation_success():
|
||||
async def test_new_conversation_success(provider_handler_mock):
|
||||
"""Test successful creation of a new conversation."""
|
||||
with _patch_store():
|
||||
# Mock the _create_new_conversation function directly
|
||||
@ -242,26 +266,16 @@ async def test_new_conversation_success():
|
||||
# Set up the mock to return a conversation ID
|
||||
mock_create_conversation.return_value = 'test_conversation_id'
|
||||
|
||||
# Create test data
|
||||
test_repo = Repository(
|
||||
id=12345,
|
||||
full_name='test/repo',
|
||||
git_provider=ProviderType.GITHUB,
|
||||
is_public=True,
|
||||
)
|
||||
|
||||
test_request = InitSessionRequest(
|
||||
conversation_trigger=ConversationTrigger.GUI,
|
||||
selected_repository=test_repo,
|
||||
repository='test/repo',
|
||||
selected_branch='main',
|
||||
initial_user_msg='Hello, agent!',
|
||||
image_urls=['https://example.com/image.jpg'],
|
||||
)
|
||||
|
||||
# Call new_conversation
|
||||
response = await new_conversation(
|
||||
data=test_request, user_id='test_user', provider_tokens={}
|
||||
)
|
||||
response = await create_new_test_conversation(test_request)
|
||||
|
||||
# Verify the response
|
||||
assert isinstance(response, JSONResponse)
|
||||
@ -275,7 +289,7 @@ async def test_new_conversation_success():
|
||||
mock_create_conversation.assert_called_once()
|
||||
call_args = mock_create_conversation.call_args[1]
|
||||
assert call_args['user_id'] == 'test_user'
|
||||
assert call_args['selected_repository'] == test_repo
|
||||
assert call_args['selected_repository'] == 'test/repo'
|
||||
assert call_args['selected_branch'] == 'main'
|
||||
assert call_args['initial_user_msg'] == 'Hello, agent!'
|
||||
assert call_args['image_urls'] == ['https://example.com/image.jpg']
|
||||
@ -283,7 +297,7 @@ async def test_new_conversation_success():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_conversation_with_suggested_task():
|
||||
async def test_new_conversation_with_suggested_task(provider_handler_mock):
|
||||
"""Test creating a new conversation with a suggested task."""
|
||||
with _patch_store():
|
||||
# Mock the _create_new_conversation function directly
|
||||
@ -301,14 +315,6 @@ async def test_new_conversation_with_suggested_task():
|
||||
'Please fix the failing checks in PR #123'
|
||||
)
|
||||
|
||||
# Create test data
|
||||
test_repo = Repository(
|
||||
id=12345,
|
||||
full_name='test/repo',
|
||||
git_provider=ProviderType.GITHUB,
|
||||
is_public=True,
|
||||
)
|
||||
|
||||
test_task = SuggestedTask(
|
||||
git_provider=ProviderType.GITHUB,
|
||||
task_type=TaskType.FAILING_CHECKS,
|
||||
@ -319,15 +325,13 @@ async def test_new_conversation_with_suggested_task():
|
||||
|
||||
test_request = InitSessionRequest(
|
||||
conversation_trigger=ConversationTrigger.SUGGESTED_TASK,
|
||||
selected_repository=test_repo,
|
||||
repository='test/repo',
|
||||
selected_branch='main',
|
||||
suggested_task=test_task,
|
||||
)
|
||||
|
||||
# Call new_conversation
|
||||
response = await new_conversation(
|
||||
data=test_request, user_id='test_user', provider_tokens={}
|
||||
)
|
||||
response = await create_new_test_conversation(test_request)
|
||||
|
||||
# Verify the response
|
||||
assert isinstance(response, JSONResponse)
|
||||
@ -341,7 +345,7 @@ async def test_new_conversation_with_suggested_task():
|
||||
mock_create_conversation.assert_called_once()
|
||||
call_args = mock_create_conversation.call_args[1]
|
||||
assert call_args['user_id'] == 'test_user'
|
||||
assert call_args['selected_repository'] == test_repo
|
||||
assert call_args['selected_repository'] == 'test/repo'
|
||||
assert call_args['selected_branch'] == 'main'
|
||||
assert (
|
||||
call_args['initial_user_msg']
|
||||
@ -357,7 +361,7 @@ async def test_new_conversation_with_suggested_task():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_conversation_missing_settings():
|
||||
async def test_new_conversation_missing_settings(provider_handler_mock):
|
||||
"""Test creating a new conversation when settings are missing."""
|
||||
with _patch_store():
|
||||
# Mock the _create_new_conversation function to raise MissingSettingsError
|
||||
@ -369,25 +373,15 @@ async def test_new_conversation_missing_settings():
|
||||
'Settings not found'
|
||||
)
|
||||
|
||||
# Create test data
|
||||
test_repo = Repository(
|
||||
id=12345,
|
||||
full_name='test/repo',
|
||||
git_provider=ProviderType.GITHUB,
|
||||
is_public=True,
|
||||
)
|
||||
|
||||
test_request = InitSessionRequest(
|
||||
conversation_trigger=ConversationTrigger.GUI,
|
||||
selected_repository=test_repo,
|
||||
repository='test/repo',
|
||||
selected_branch='main',
|
||||
initial_user_msg='Hello, agent!',
|
||||
)
|
||||
|
||||
# Call new_conversation
|
||||
response = await new_conversation(
|
||||
data=test_request, user_id='test_user', provider_tokens={}
|
||||
)
|
||||
response = await create_new_test_conversation(test_request)
|
||||
|
||||
# Verify the response
|
||||
assert isinstance(response, JSONResponse)
|
||||
@ -409,17 +403,9 @@ async def test_new_conversation_invalid_api_key():
|
||||
'Error authenticating with the LLM provider. Please check your API key'
|
||||
)
|
||||
|
||||
# Create test data
|
||||
test_repo = Repository(
|
||||
id=12345,
|
||||
full_name='test/repo',
|
||||
git_provider=ProviderType.GITHUB,
|
||||
is_public=True,
|
||||
)
|
||||
|
||||
test_request = InitSessionRequest(
|
||||
conversation_trigger=ConversationTrigger.GUI,
|
||||
selected_repository=test_repo,
|
||||
repo='test/repo',
|
||||
selected_branch='main',
|
||||
initial_user_msg='Hello, agent!',
|
||||
)
|
||||
@ -496,3 +482,120 @@ async def test_delete_conversation():
|
||||
mock_runtime_cls.delete.assert_called_once_with(
|
||||
'some_conversation_id'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_conversation_with_bearer_auth(provider_handler_mock):
|
||||
"""Test creating a new conversation with bearer authentication."""
|
||||
with _patch_store():
|
||||
# Mock the _create_new_conversation function
|
||||
with patch(
|
||||
'openhands.server.routes.manage_conversations._create_new_conversation'
|
||||
) as mock_create_conversation:
|
||||
# Set up the mock to return a conversation ID
|
||||
mock_create_conversation.return_value = 'test_conversation_id'
|
||||
|
||||
# Create the request object
|
||||
test_request = InitSessionRequest(
|
||||
conversation_trigger=ConversationTrigger.GUI, # This should be overridden
|
||||
repository='test/repo',
|
||||
selected_branch='main',
|
||||
initial_user_msg='Hello, agent!',
|
||||
)
|
||||
|
||||
# Call new_conversation with auth_type=BEARER
|
||||
response = await create_new_test_conversation(test_request, AuthType.BEARER)
|
||||
|
||||
# Verify the response
|
||||
assert isinstance(response, JSONResponse)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify that _create_new_conversation was called with REMOTE_API_KEY trigger
|
||||
mock_create_conversation.assert_called_once()
|
||||
call_args = mock_create_conversation.call_args[1]
|
||||
assert (
|
||||
call_args['conversation_trigger'] == ConversationTrigger.REMOTE_API_KEY
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_conversation_with_null_repository():
|
||||
"""Test creating a new conversation with null repository."""
|
||||
with _patch_store():
|
||||
# Mock the _create_new_conversation function
|
||||
with patch(
|
||||
'openhands.server.routes.manage_conversations._create_new_conversation'
|
||||
) as mock_create_conversation:
|
||||
# Set up the mock to return a conversation ID
|
||||
mock_create_conversation.return_value = 'test_conversation_id'
|
||||
|
||||
# Create the request object with null repository
|
||||
test_request = InitSessionRequest(
|
||||
conversation_trigger=ConversationTrigger.GUI,
|
||||
repository=None, # Explicitly set to None
|
||||
selected_branch=None,
|
||||
initial_user_msg='Hello, agent!',
|
||||
)
|
||||
|
||||
# Call new_conversation
|
||||
response = await create_new_test_conversation(test_request)
|
||||
|
||||
# Verify the response
|
||||
assert isinstance(response, JSONResponse)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify that _create_new_conversation was called with None repository
|
||||
mock_create_conversation.assert_called_once()
|
||||
call_args = mock_create_conversation.call_args[1]
|
||||
assert call_args['selected_repository'] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_conversation_with_provider_authentication_error(
|
||||
provider_handler_mock,
|
||||
):
|
||||
provider_handler_mock.verify_repo_provider = AsyncMock(
|
||||
side_effect=AuthenticationError('auth error')
|
||||
)
|
||||
|
||||
"""Test creating a new conversation when provider authentication fails."""
|
||||
with _patch_store():
|
||||
# Mock the _create_new_conversation function
|
||||
with patch(
|
||||
'openhands.server.routes.manage_conversations._create_new_conversation'
|
||||
) as mock_create_conversation:
|
||||
# Set up the mock to return a conversation ID
|
||||
mock_create_conversation.return_value = 'test_conversation_id'
|
||||
|
||||
# Create the request object
|
||||
test_request = InitSessionRequest(
|
||||
conversation_trigger=ConversationTrigger.GUI,
|
||||
repository='test/repo',
|
||||
selected_branch='main',
|
||||
initial_user_msg='Hello, agent!',
|
||||
)
|
||||
|
||||
# Call new_conversation
|
||||
response = await new_conversation(
|
||||
data=test_request,
|
||||
user_id='test_user',
|
||||
provider_tokens={'github': 'token123'},
|
||||
auth_type=None,
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
assert isinstance(response, JSONResponse)
|
||||
assert response.status_code == 400
|
||||
assert json.loads(response.body.decode('utf-8')) == {
|
||||
'status': 'error',
|
||||
'message': 'auth error',
|
||||
'msg_id': 'STATUS$GIT_PROVIDER_AUTHENTICATION_ERROR',
|
||||
}
|
||||
|
||||
# Verify that verify_repo_provider was called with the repository
|
||||
provider_handler_mock.verify_repo_provider.assert_called_once_with(
|
||||
'test/repo', None
|
||||
)
|
||||
|
||||
# Verify that _create_new_conversation was not called
|
||||
mock_create_conversation.assert_not_called()
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from types import MappingProxyType
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
@ -8,7 +9,8 @@ from openhands.events.action import Action
|
||||
from openhands.events.action.commands import CmdRunAction
|
||||
from openhands.events.observation import NullObservation, Observation
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.integrations.provider import ProviderToken, ProviderType
|
||||
from openhands.integrations.provider import ProviderHandler, ProviderToken, ProviderType
|
||||
from openhands.integrations.service_types import AuthenticationError, Repository
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.storage import get_file_store
|
||||
|
||||
@ -16,6 +18,13 @@ from openhands.storage import get_file_store
|
||||
class TestRuntime(Runtime):
|
||||
"""A concrete implementation of Runtime for testing"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.run_action_calls = []
|
||||
self._execute_shell_fn_git_handler = MagicMock(
|
||||
return_value=MagicMock(exit_code=0, stdout='', stderr='')
|
||||
)
|
||||
|
||||
async def connect(self):
|
||||
pass
|
||||
|
||||
@ -23,22 +32,22 @@ class TestRuntime(Runtime):
|
||||
pass
|
||||
|
||||
def browse(self, action):
|
||||
return NullObservation()
|
||||
return NullObservation(content='')
|
||||
|
||||
def browse_interactive(self, action):
|
||||
return NullObservation()
|
||||
return NullObservation(content='')
|
||||
|
||||
def run(self, action):
|
||||
return NullObservation()
|
||||
return NullObservation(content='')
|
||||
|
||||
def run_ipython(self, action):
|
||||
return NullObservation()
|
||||
return NullObservation(content='')
|
||||
|
||||
def read(self, action):
|
||||
return NullObservation()
|
||||
return NullObservation(content='')
|
||||
|
||||
def write(self, action):
|
||||
return NullObservation()
|
||||
return NullObservation(content='')
|
||||
|
||||
def copy_from(self, path):
|
||||
return ''
|
||||
@ -50,10 +59,11 @@ class TestRuntime(Runtime):
|
||||
return []
|
||||
|
||||
def run_action(self, action: Action) -> Observation:
|
||||
return NullObservation()
|
||||
self.run_action_calls.append(action)
|
||||
return NullObservation(content='')
|
||||
|
||||
def call_tool_mcp(self, action):
|
||||
return NullObservation()
|
||||
return NullObservation(content='')
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -80,6 +90,20 @@ def runtime(temp_dir):
|
||||
return runtime
|
||||
|
||||
|
||||
def mock_repo_and_patch(monkeypatch, provider=ProviderType.GITHUB, is_public=True):
|
||||
repo = Repository(
|
||||
id=123, full_name='owner/repo', git_provider=provider, is_public=is_public
|
||||
)
|
||||
|
||||
async def mock_verify_repo_provider(*_args, **_kwargs):
|
||||
return repo
|
||||
|
||||
monkeypatch.setattr(
|
||||
ProviderHandler, 'verify_repo_provider', mock_verify_repo_provider
|
||||
)
|
||||
return repo
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_export_latest_git_provider_tokens_no_user_id(temp_dir):
|
||||
"""Test that no token export happens when user_id is not set"""
|
||||
@ -183,3 +207,200 @@ async def test_export_latest_git_provider_tokens_token_update(runtime):
|
||||
|
||||
# Verify that the new token was exported
|
||||
assert runtime.event_stream.secrets == {'github_token': new_token}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clone_or_init_repo_no_repo_with_user_id(temp_dir):
|
||||
"""Test that git init is run when no repository is selected and user_id is set"""
|
||||
config = AppConfig()
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('abc', file_store)
|
||||
runtime = TestRuntime(
|
||||
config=config, event_stream=event_stream, sid='test', user_id='test_user'
|
||||
)
|
||||
|
||||
# Call the function with no repository
|
||||
result = await runtime.clone_or_init_repo(None, None, None)
|
||||
|
||||
# Verify that git init was called
|
||||
assert len(runtime.run_action_calls) == 1
|
||||
assert isinstance(runtime.run_action_calls[0], CmdRunAction)
|
||||
assert runtime.run_action_calls[0].command == 'git init'
|
||||
assert result == ''
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clone_or_init_repo_no_repo_no_user_id_no_workspace_base(temp_dir):
|
||||
"""Test that git init is run when no repository is selected, no user_id, and no workspace_base"""
|
||||
config = AppConfig()
|
||||
config.workspace_base = None # Ensure workspace_base is not set
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('abc', file_store)
|
||||
runtime = TestRuntime(
|
||||
config=config, event_stream=event_stream, sid='test', user_id=None
|
||||
)
|
||||
|
||||
# Call the function with no repository
|
||||
result = await runtime.clone_or_init_repo(None, None, None)
|
||||
|
||||
# Verify that git init was called
|
||||
assert len(runtime.run_action_calls) == 1
|
||||
assert isinstance(runtime.run_action_calls[0], CmdRunAction)
|
||||
assert runtime.run_action_calls[0].command == 'git init'
|
||||
assert result == ''
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clone_or_init_repo_no_repo_no_user_id_with_workspace_base(temp_dir):
|
||||
"""Test that git init is not run when no repository is selected, no user_id, but workspace_base is set"""
|
||||
config = AppConfig()
|
||||
config.workspace_base = '/some/path' # Set workspace_base
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('abc', file_store)
|
||||
runtime = TestRuntime(
|
||||
config=config, event_stream=event_stream, sid='test', user_id=None
|
||||
)
|
||||
|
||||
# Call the function with no repository
|
||||
result = await runtime.clone_or_init_repo(None, None, None)
|
||||
|
||||
# Verify that git init was not called
|
||||
assert len(runtime.run_action_calls) == 0
|
||||
assert result == ''
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clone_or_init_repo_auth_error(temp_dir):
|
||||
"""Test that RuntimeError is raised when authentication fails"""
|
||||
config = AppConfig()
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('abc', file_store)
|
||||
runtime = TestRuntime(
|
||||
config=config, event_stream=event_stream, sid='test', user_id='test_user'
|
||||
)
|
||||
|
||||
# Mock the verify_repo_provider method to raise AuthenticationError
|
||||
with patch.object(
|
||||
ProviderHandler,
|
||||
'verify_repo_provider',
|
||||
side_effect=AuthenticationError('Auth failed'),
|
||||
):
|
||||
# Call the function with a repository
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
await runtime.clone_or_init_repo(None, 'owner/repo', None)
|
||||
|
||||
# Verify the error message
|
||||
assert 'Git provider authentication issue when cloning repo' in str(
|
||||
excinfo.value
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clone_or_init_repo_github_with_token(temp_dir, monkeypatch):
|
||||
config = AppConfig()
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('abc', file_store)
|
||||
|
||||
github_token = 'github_test_token'
|
||||
git_provider_tokens = MappingProxyType(
|
||||
{ProviderType.GITHUB: ProviderToken(token=SecretStr(github_token))}
|
||||
)
|
||||
|
||||
runtime = TestRuntime(
|
||||
config=config,
|
||||
event_stream=event_stream,
|
||||
sid='test',
|
||||
user_id='test_user',
|
||||
git_provider_tokens=git_provider_tokens,
|
||||
)
|
||||
|
||||
mock_repo_and_patch(monkeypatch, provider=ProviderType.GITHUB)
|
||||
|
||||
result = await runtime.clone_or_init_repo(git_provider_tokens, 'owner/repo', None)
|
||||
|
||||
cmd = runtime.run_action_calls[0].command
|
||||
assert f'git clone https://{github_token}@github.com/owner/repo.git repo' in cmd
|
||||
assert result == 'repo'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clone_or_init_repo_github_no_token(temp_dir, monkeypatch):
|
||||
"""Test cloning a GitHub repository without a token"""
|
||||
config = AppConfig()
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('abc', file_store)
|
||||
|
||||
runtime = TestRuntime(
|
||||
config=config, event_stream=event_stream, sid='test', user_id='test_user'
|
||||
)
|
||||
|
||||
mock_repo_and_patch(monkeypatch, provider=ProviderType.GITHUB)
|
||||
result = await runtime.clone_or_init_repo(None, 'owner/repo', None)
|
||||
|
||||
# Verify that git clone was called with the public URL
|
||||
assert len(runtime.run_action_calls) == 1
|
||||
assert isinstance(runtime.run_action_calls[0], CmdRunAction)
|
||||
|
||||
# Check that the command contains the correct URL format without token
|
||||
cmd = runtime.run_action_calls[0].command
|
||||
assert 'git clone https://github.com/owner/repo.git repo' in cmd
|
||||
assert 'cd repo' in cmd
|
||||
assert 'git checkout -b openhands-workspace-' in cmd
|
||||
assert result == 'repo'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clone_or_init_repo_gitlab_with_token(temp_dir, monkeypatch):
|
||||
config = AppConfig()
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('abc', file_store)
|
||||
|
||||
gitlab_token = 'gitlab_test_token'
|
||||
git_provider_tokens = MappingProxyType(
|
||||
{ProviderType.GITLAB: ProviderToken(token=SecretStr(gitlab_token))}
|
||||
)
|
||||
|
||||
runtime = TestRuntime(
|
||||
config=config,
|
||||
event_stream=event_stream,
|
||||
sid='test',
|
||||
user_id='test_user',
|
||||
git_provider_tokens=git_provider_tokens,
|
||||
)
|
||||
|
||||
mock_repo_and_patch(monkeypatch, provider=ProviderType.GITLAB)
|
||||
|
||||
result = await runtime.clone_or_init_repo(git_provider_tokens, 'owner/repo', None)
|
||||
|
||||
cmd = runtime.run_action_calls[0].command
|
||||
assert (
|
||||
f'git clone https://oauth2:{gitlab_token}@gitlab.com/owner/repo.git repo' in cmd
|
||||
)
|
||||
assert result == 'repo'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clone_or_init_repo_with_branch(temp_dir, monkeypatch):
|
||||
"""Test cloning a repository with a specified branch"""
|
||||
config = AppConfig()
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('abc', file_store)
|
||||
|
||||
runtime = TestRuntime(
|
||||
config=config, event_stream=event_stream, sid='test', user_id='test_user'
|
||||
)
|
||||
|
||||
mock_repo_and_patch(monkeypatch, provider=ProviderType.GITHUB)
|
||||
result = await runtime.clone_or_init_repo(None, 'owner/repo', 'feature-branch')
|
||||
|
||||
# Verify that git clone was called with the correct branch checkout
|
||||
assert len(runtime.run_action_calls) == 1
|
||||
assert isinstance(runtime.run_action_calls[0], CmdRunAction)
|
||||
|
||||
# Check that the command contains the correct branch checkout
|
||||
cmd = runtime.run_action_calls[0].command
|
||||
assert 'git clone https://github.com/owner/repo.git repo' in cmd
|
||||
assert 'cd repo' in cmd
|
||||
assert 'git checkout feature-branch' in cmd
|
||||
assert 'git checkout -b' not in cmd # Should not create a new branch
|
||||
assert result == 'repo'
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user