Add branch picker to homepage (#8259)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Robert Brennan
2025-05-05 11:06:22 -04:00
committed by GitHub
parent 0aec96baec
commit 3e224faea6
24 changed files with 494 additions and 62 deletions

View File

@@ -53,6 +53,7 @@ describe("HomeHeader", () => {
[],
undefined,
undefined,
undefined,
);
// expect to be redirected to /conversations/:conversationId

View File

@@ -171,6 +171,7 @@ describe("RepoConnector", () => {
[],
undefined,
undefined,
undefined,
);
});

View File

@@ -95,6 +95,7 @@ describe("TaskCard", () => {
[],
undefined,
MOCK_TASK_1,
undefined,
);
});
});

View File

@@ -14,7 +14,7 @@ import {
} from "./open-hands.types";
import { openHands } from "./open-hands-axios";
import { ApiSettings, PostApiSettings, Provider } from "#/types/settings";
import { GitUser, GitRepository } from "#/types/git";
import { GitUser, GitRepository, Branch } from "#/types/git";
import { SuggestedTask } from "#/components/features/home/tasks/task.types";
class OpenHands {
@@ -158,12 +158,13 @@ class OpenHands {
imageUrls?: string[],
replayJson?: string,
suggested_task?: SuggestedTask,
selected_branch?: string,
): Promise<Conversation> {
const body = {
conversation_trigger,
repository: selectedRepository,
git_provider,
selected_branch: undefined,
selected_branch,
initial_user_msg: initialUserMsg,
image_urls: imageUrls,
replay_json: replayJson,
@@ -316,6 +317,14 @@ class OpenHands {
return data;
}
static async getRepositoryBranches(repository: string): Promise<Branch[]> {
const { data } = await openHands.get<Branch[]>(
`/api/user/repository/branches?repository=${encodeURIComponent(repository)}`,
);
return data;
}
}
export default OpenHands;

View File

@@ -4,6 +4,7 @@ import { RepositorySelectionForm } from "./repo-selection-form";
// Create mock functions
const mockUseUserRepositories = vi.fn();
const mockUseRepositoryBranches = vi.fn();
const mockUseCreateConversation = vi.fn();
const mockUseIsCreatingConversation = vi.fn();
const mockUseTranslation = vi.fn();
@@ -16,6 +17,12 @@ mockUseUserRepositories.mockReturnValue({
isError: false,
});
mockUseRepositoryBranches.mockReturnValue({
data: [],
isLoading: false,
isError: false,
});
mockUseCreateConversation.mockReturnValue({
mutate: vi.fn(),
isPending: false,
@@ -47,6 +54,10 @@ vi.mock("#/hooks/query/use-user-repositories", () => ({
useUserRepositories: () => mockUseUserRepositories(),
}));
vi.mock("#/hooks/query/use-repository-branches", () => ({
useRepositoryBranches: () => mockUseRepositoryBranches(),
}));
vi.mock("#/hooks/mutation/use-create-conversation", () => ({
useCreateConversation: () => mockUseCreateConversation(),
}));

View File

@@ -1,79 +1,42 @@
import React from "react";
import { useTranslation } from "react-i18next";
import { Spinner } from "@heroui/react";
import { useCreateConversation } from "#/hooks/mutation/use-create-conversation";
import { useUserRepositories } from "#/hooks/query/use-user-repositories";
import { useRepositoryBranches } from "#/hooks/query/use-repository-branches";
import { useIsCreatingConversation } from "#/hooks/use-is-creating-conversation";
import { GitRepository } from "#/types/git";
import { Branch, GitRepository } from "#/types/git";
import { BrandButton } from "../settings/brand-button";
import { SettingsDropdownInput } from "../settings/settings-dropdown-input";
import {
RepositoryDropdown,
RepositoryLoadingState,
RepositoryErrorState,
BranchDropdown,
BranchLoadingState,
BranchErrorState,
} from "./repository-selection";
interface RepositorySelectionFormProps {
onRepoSelection: (repoTitle: string | null) => void;
}
// Loading state component
function RepositoryLoadingState() {
const { t } = useTranslation();
return (
<div
data-testid="repo-dropdown-loading"
className="flex items-center gap-2 max-w-[500px] h-10 px-3 bg-tertiary border border-[#717888] rounded"
>
<Spinner size="sm" />
<span className="text-sm">{t("HOME$LOADING_REPOSITORIES")}</span>
</div>
);
}
// Error state component
function RepositoryErrorState() {
const { t } = useTranslation();
return (
<div
data-testid="repo-dropdown-error"
className="flex items-center gap-2 max-w-[500px] h-10 px-3 bg-tertiary border border-[#717888] rounded text-red-500"
>
<span className="text-sm">{t("HOME$FAILED_TO_LOAD_REPOSITORIES")}</span>
</div>
);
}
// Repository dropdown component
interface RepositoryDropdownProps {
items: { key: React.Key; label: string }[];
onSelectionChange: (key: React.Key | null) => void;
onInputChange: (value: string) => void;
}
function RepositoryDropdown({
items,
onSelectionChange,
onInputChange,
}: RepositoryDropdownProps) {
return (
<SettingsDropdownInput
testId="repo-dropdown"
name="repo-dropdown"
placeholder="Select a repo"
items={items}
wrapperClassName="max-w-[500px]"
onSelectionChange={onSelectionChange}
onInputChange={onInputChange}
/>
);
}
export function RepositorySelectionForm({
onRepoSelection,
}: RepositorySelectionFormProps) {
const [selectedRepository, setSelectedRepository] =
React.useState<GitRepository | null>(null);
const [selectedBranch, setSelectedBranch] = React.useState<Branch | null>(
null,
);
const {
data: repositories,
isLoading: isLoadingRepositories,
isError: isRepositoriesError,
} = useUserRepositories();
const {
data: branches,
isLoading: isLoadingBranches,
isError: isBranchesError,
} = useRepositoryBranches(selectedRepository?.full_name || null);
const {
mutate: createConversation,
isPending,
@@ -82,6 +45,27 @@ export function RepositorySelectionForm({
const isCreatingConversationElsewhere = useIsCreatingConversation();
const { t } = useTranslation();
// Auto-select main or master branch if it exists
React.useEffect(() => {
if (
branches &&
branches.length > 0 &&
!selectedBranch &&
!isLoadingBranches
) {
// Look for main or master branch
const mainBranch = branches.find((branch) => branch.name === "main");
const masterBranch = branches.find((branch) => branch.name === "master");
// Select main if it exists, otherwise select master if it exists
if (mainBranch) {
setSelectedBranch(mainBranch);
} else if (masterBranch) {
setSelectedBranch(masterBranch);
}
}
}, [branches, selectedBranch, isLoadingBranches]);
// We check for isSuccess because the app might require time to render
// into the new conversation screen after the conversation is created.
const isCreatingConversation =
@@ -92,6 +76,11 @@ export function RepositorySelectionForm({
label: repo.full_name,
}));
const branchesItems = branches?.map((branch) => ({
key: branch.name,
label: branch.name,
}));
const handleRepoSelection = (key: React.Key | null) => {
const selectedRepo = repositories?.find(
(repo) => repo.id.toString() === key,
@@ -99,15 +88,28 @@ export function RepositorySelectionForm({
if (selectedRepo) onRepoSelection(selectedRepo.full_name);
setSelectedRepository(selectedRepo || null);
setSelectedBranch(null); // Reset branch selection when repo changes
};
const handleInputChange = (value: string) => {
const handleBranchSelection = (key: React.Key | null) => {
const selectedBranchObj = branches?.find((branch) => branch.name === key);
setSelectedBranch(selectedBranchObj || null);
};
const handleRepoInputChange = (value: string) => {
if (value === "") {
setSelectedRepository(null);
setSelectedBranch(null);
onRepoSelection(null);
}
};
const handleBranchInputChange = (value: string) => {
if (value === "") {
setSelectedBranch(null);
}
};
// Render the appropriate UI based on the loading/error state
const renderRepositorySelector = () => {
if (isLoadingRepositories) {
@@ -122,15 +124,49 @@ export function RepositorySelectionForm({
<RepositoryDropdown
items={repositoriesItems || []}
onSelectionChange={handleRepoSelection}
onInputChange={handleInputChange}
onInputChange={handleRepoInputChange}
/>
);
};
// Render the appropriate UI for branch selector based on the loading/error state
const renderBranchSelector = () => {
if (!selectedRepository) {
return (
<BranchDropdown
items={[]}
onSelectionChange={() => {}}
onInputChange={() => {}}
isDisabled
/>
);
}
if (isLoadingBranches) {
return <BranchLoadingState />;
}
if (isBranchesError) {
return <BranchErrorState />;
}
return (
<BranchDropdown
items={branchesItems || []}
onSelectionChange={handleBranchSelection}
onInputChange={handleBranchInputChange}
isDisabled={false}
selectedKey={selectedBranch?.name}
/>
);
};
return (
<>
<div className="flex flex-col gap-4">
{renderRepositorySelector()}
{renderBranchSelector()}
<BrandButton
testId="repo-launch-button"
variant="primary"
@@ -145,12 +181,13 @@ export function RepositorySelectionForm({
createConversation({
selectedRepository,
conversation_trigger: "gui",
selected_branch: selectedBranch?.name,
})
}
>
{!isCreatingConversation && "Launch"}
{isCreatingConversation && t("HOME$LOADING")}
</BrandButton>
</>
</div>
);
}

View File

@@ -0,0 +1,32 @@
import React from "react";
import { SettingsDropdownInput } from "../../settings/settings-dropdown-input";
export interface BranchDropdownProps {
items: { key: React.Key; label: string }[];
onSelectionChange: (key: React.Key | null) => void;
onInputChange: (value: string) => void;
isDisabled: boolean;
selectedKey?: string;
}
export function BranchDropdown({
items,
onSelectionChange,
onInputChange,
isDisabled,
selectedKey,
}: BranchDropdownProps) {
return (
<SettingsDropdownInput
testId="branch-dropdown"
name="branch-dropdown"
placeholder="Select a branch"
items={items}
wrapperClassName="max-w-[500px]"
onSelectionChange={onSelectionChange}
onInputChange={onInputChange}
isDisabled={isDisabled}
selectedKey={selectedKey}
/>
);
}

View File

@@ -0,0 +1,14 @@
import React from "react";
import { useTranslation } from "react-i18next";
export function BranchErrorState() {
const { t } = useTranslation();
return (
<div
data-testid="branch-dropdown-error"
className="flex items-center gap-2 max-w-[500px] h-10 px-3 bg-tertiary border border-[#717888] rounded text-red-500"
>
<span className="text-sm">{t("HOME$FAILED_TO_LOAD_BRANCHES")}</span>
</div>
);
}

View File

@@ -0,0 +1,16 @@
import React from "react";
import { useTranslation } from "react-i18next";
import { Spinner } from "@heroui/react";
export function BranchLoadingState() {
const { t } = useTranslation();
return (
<div
data-testid="branch-dropdown-loading"
className="flex items-center gap-2 max-w-[500px] h-10 px-3 bg-tertiary border border-[#717888] rounded"
>
<Spinner size="sm" />
<span className="text-sm">{t("HOME$LOADING_BRANCHES")}</span>
</div>
);
}

View File

@@ -0,0 +1,6 @@
export { RepositoryDropdown } from "#/components/features/home/repository-selection/repository-dropdown";
export { RepositoryLoadingState } from "#/components/features/home/repository-selection/repository-loading-state";
export { RepositoryErrorState } from "#/components/features/home/repository-selection/repository-error-state";
export { BranchDropdown } from "#/components/features/home/repository-selection/branch-dropdown";
export { BranchLoadingState } from "#/components/features/home/repository-selection/branch-loading-state";
export { BranchErrorState } from "#/components/features/home/repository-selection/branch-error-state";

View File

@@ -0,0 +1,26 @@
import React from "react";
import { SettingsDropdownInput } from "../../settings/settings-dropdown-input";
export interface RepositoryDropdownProps {
items: { key: React.Key; label: string }[];
onSelectionChange: (key: React.Key | null) => void;
onInputChange: (value: string) => void;
}
export function RepositoryDropdown({
items,
onSelectionChange,
onInputChange,
}: RepositoryDropdownProps) {
return (
<SettingsDropdownInput
testId="repo-dropdown"
name="repo-dropdown"
placeholder="Select a repo"
items={items}
wrapperClassName="max-w-[500px]"
onSelectionChange={onSelectionChange}
onInputChange={onInputChange}
/>
);
}

View File

@@ -0,0 +1,14 @@
import React from "react";
import { useTranslation } from "react-i18next";
export function RepositoryErrorState() {
const { t } = useTranslation();
return (
<div
data-testid="repo-dropdown-error"
className="flex items-center gap-2 max-w-[500px] h-10 px-3 bg-tertiary border border-[#717888] rounded text-red-500"
>
<span className="text-sm">{t("HOME$FAILED_TO_LOAD_REPOSITORIES")}</span>
</div>
);
}

View File

@@ -0,0 +1,16 @@
import React from "react";
import { useTranslation } from "react-i18next";
import { Spinner } from "@heroui/react";
export function RepositoryLoadingState() {
const { t } = useTranslation();
return (
<div
data-testid="repo-dropdown-loading"
className="flex items-center gap-2 max-w-[500px] h-10 px-3 bg-tertiary border border-[#717888] rounded"
>
<Spinner size="sm" />
<span className="text-sm">{t("HOME$LOADING_REPOSITORIES")}</span>
</div>
);
}

View File

@@ -13,6 +13,7 @@ interface SettingsDropdownInputProps {
showOptionalTag?: boolean;
isDisabled?: boolean;
defaultSelectedKey?: string;
selectedKey?: string;
isClearable?: boolean;
onSelectionChange?: (key: React.Key | null) => void;
onInputChange?: (value: string) => void;
@@ -28,6 +29,7 @@ export function SettingsDropdownInput({
showOptionalTag,
isDisabled,
defaultSelectedKey,
selectedKey,
isClearable,
onSelectionChange,
onInputChange,
@@ -46,6 +48,7 @@ export function SettingsDropdownInput({
name={name}
defaultItems={items}
defaultSelectedKey={defaultSelectedKey}
selectedKey={selectedKey}
onSelectionChange={onSelectionChange}
onInputChange={onInputChange}
isClearable={isClearable}

View File

@@ -24,7 +24,7 @@ export const useCreateConversation = () => {
conversation_trigger: ConversationTrigger;
q?: string;
selectedRepository?: GitRepository | null;
selected_branch?: string;
suggested_task?: SuggestedTask;
}) => {
if (variables.q) dispatch(setInitialPrompt(variables.q));
@@ -41,6 +41,7 @@ export const useCreateConversation = () => {
files,
replayJson || undefined,
variables.suggested_task || undefined,
variables.selected_branch,
);
},
onSuccess: async ({ conversation_id: conversationId }, { q }) => {

View File

@@ -0,0 +1,14 @@
import { useQuery } from "@tanstack/react-query";
import OpenHands from "#/api/open-hands";
import { Branch } from "#/types/git";
export const useRepositoryBranches = (repository: string | null) =>
useQuery<Branch[]>({
queryKey: ["repository", repository, "branches"],
queryFn: async () => {
if (!repository) return [];
return OpenHands.getRepositoryBranches(repository);
},
enabled: !!repository,
staleTime: 1000 * 60 * 5, // 5 minutes
});

View File

@@ -8,6 +8,8 @@ export enum I18nKey {
HOME$LOADING = "HOME$LOADING",
HOME$LOADING_REPOSITORIES = "HOME$LOADING_REPOSITORIES",
HOME$FAILED_TO_LOAD_REPOSITORIES = "HOME$FAILED_TO_LOAD_REPOSITORIES",
HOME$LOADING_BRANCHES = "HOME$LOADING_BRANCHES",
HOME$FAILED_TO_LOAD_BRANCHES = "HOME$FAILED_TO_LOAD_BRANCHES",
HOME$OPEN_ISSUE = "HOME$OPEN_ISSUE",
HOME$FIX_FAILING_CHECKS = "HOME$FIX_FAILING_CHECKS",
HOME$RESOLVE_MERGE_CONFLICTS = "HOME$RESOLVE_MERGE_CONFLICTS",

View File

@@ -119,6 +119,36 @@
"tr": "Depolar yüklenemedi",
"de": "Fehler beim Laden der Repositories"
},
"HOME$LOADING_BRANCHES": {
"en": "Loading branches...",
"ja": "ブランチを読み込み中...",
"zh-CN": "正在加载分支...",
"zh-TW": "正在加載分支...",
"ko-KR": "브랜치 로딩 중...",
"no": "Laster inn branches...",
"it": "Caricamento dei branch...",
"pt": "Carregando branches...",
"es": "Cargando ramas...",
"ar": "جاري تحميل الفروع...",
"fr": "Chargement des branches...",
"tr": "Dallar yükleniyor...",
"de": "Lade Branches..."
},
"HOME$FAILED_TO_LOAD_BRANCHES": {
"en": "Failed to load branches",
"ja": "ブランチの読み込みに失敗しました",
"zh-CN": "加载分支失败",
"zh-TW": "加載分支失敗",
"ko-KR": "브랜치 로딩 실패",
"no": "Kunne ikke laste inn branches",
"it": "Impossibile caricare i branch",
"pt": "Falha ao carregar branches",
"es": "Error al cargar ramas",
"ar": "فشل في تحميل الفروع",
"fr": "Échec du chargement des branches",
"tr": "Dallar yüklenemedi",
"de": "Fehler beim Laden der Branches"
},
"HOME$OPEN_ISSUE": {
"en": "Open issue",
"ja": "オープンな課題",

View File

@@ -15,6 +15,13 @@ interface GitUser {
email: string | null;
}
interface Branch {
name: string;
commit_sha: string;
protected: boolean;
last_push_date?: string;
}
interface GitRepository {
id: number;
full_name: string;

View File

@@ -8,6 +8,7 @@ from pydantic import SecretStr
from openhands.integrations.service_types import (
BaseGitService,
Branch,
GitService,
ProviderType,
Repository,
@@ -385,6 +386,52 @@ class GitHubService(BaseGitService, GitService):
is_public=not repo.get('private', True),
)
async def get_branches(self, repository: str) -> list[Branch]:
"""Get branches for a repository"""
url = f'{self.BASE_URL}/repos/{repository}/branches'
# Set maximum branches to fetch (10 pages with 100 per page)
MAX_BRANCHES = 1000
PER_PAGE = 100
all_branches: list[Branch] = []
page = 1
# Fetch up to 10 pages of branches
while page <= 10 and len(all_branches) < MAX_BRANCHES:
params = {'per_page': str(PER_PAGE), 'page': str(page)}
response, headers = await self._make_request(url, params)
if not response: # No more branches
break
for branch_data in response:
# Extract the last commit date if available
last_push_date = None
if branch_data.get('commit') and branch_data['commit'].get('commit'):
commit_info = branch_data['commit']['commit']
if commit_info.get('committer') and commit_info['committer'].get(
'date'
):
last_push_date = commit_info['committer']['date']
branch = Branch(
name=branch_data.get('name'),
commit_sha=branch_data.get('commit', {}).get('sha', ''),
protected=branch_data.get('protected', False),
last_push_date=last_push_date,
)
all_branches.append(branch)
page += 1
# Check if we've reached the last page
link_header = headers.get('Link', '')
if 'rel="next"' not in link_header:
break
return all_branches
github_service_cls = os.environ.get(
'OPENHANDS_GITHUB_SERVICE_CLS',

View File

@@ -6,6 +6,7 @@ from pydantic import SecretStr
from openhands.integrations.service_types import (
BaseGitService,
Branch,
GitService,
ProviderType,
Repository,
@@ -398,6 +399,44 @@ class GitLabService(BaseGitService, GitService):
is_public=repo.get('visibility') == 'public',
)
async def get_branches(self, repository: str) -> list[Branch]:
"""Get branches for a repository"""
encoded_name = repository.replace('/', '%2F')
url = f'{self.BASE_URL}/projects/{encoded_name}/repository/branches'
# Set maximum branches to fetch (10 pages with 100 per page)
MAX_BRANCHES = 1000
PER_PAGE = 100
all_branches: list[Branch] = []
page = 1
# Fetch up to 10 pages of branches
while page <= 10 and len(all_branches) < MAX_BRANCHES:
params = {'per_page': str(PER_PAGE), 'page': str(page)}
response, headers = await self._make_request(url, params)
if not response: # No more branches
break
for branch_data in response:
branch = Branch(
name=branch_data.get('name'),
commit_sha=branch_data.get('commit', {}).get('id', ''),
protected=branch_data.get('protected', False),
last_push_date=branch_data.get('commit', {}).get('committed_date'),
)
all_branches.append(branch)
page += 1
# Check if we've reached the last page
link_header = headers.get('Link', '')
if 'rel="next"' not in link_header:
break
return all_branches
gitlab_service_cls = os.environ.get(
'OPENHANDS_GITLAB_SERVICE_CLS',

View File

@@ -18,6 +18,7 @@ from openhands.integrations.github.github_service import GithubServiceImpl
from openhands.integrations.gitlab.gitlab_service import GitLabServiceImpl
from openhands.integrations.service_types import (
AuthenticationError,
Branch,
GitService,
ProviderType,
Repository,
@@ -305,3 +306,56 @@ class ProviderHandler:
pass
raise AuthenticationError(f'Unable to access repo {repository}')
async def get_branches(
self, repository: str, specified_provider: ProviderType | None = None
) -> list[Branch]:
"""
Get branches for a repository
Args:
repository: The repository name
specified_provider: Optional provider type to use
Returns:
A list of branches for the repository
"""
all_branches: list[Branch] = []
if specified_provider:
try:
service = self._get_service(specified_provider)
branches = await service.get_branches(repository)
return branches
except Exception as e:
logger.warning(
f'Error fetching branches from {specified_provider}: {e}'
)
for provider in self.provider_tokens:
try:
service = self._get_service(provider)
branches = await service.get_branches(repository)
all_branches.extend(branches)
# If we found branches, no need to check other providers
if all_branches:
break
except Exception as e:
logger.warning(f'Error fetching branches from {provider}: {e}')
# Sort branches by last push date (newest first)
all_branches.sort(
key=lambda b: b.last_push_date if b.last_push_date else '', reverse=True
)
# Move main/master branch to the top if it exists
main_branches = []
other_branches = []
for branch in all_branches:
if branch.name.lower() in ['main', 'master']:
main_branches.append(branch)
else:
other_branches.append(branch)
return main_branches + other_branches

View File

@@ -91,6 +91,13 @@ class User(BaseModel):
email: str | None = None
class Branch(BaseModel):
name: str
commit_sha: str
protected: bool
last_push_date: str | None = None # ISO 8601 format date string
class Repository(BaseModel):
id: int
full_name: str
@@ -211,3 +218,6 @@ class GitService(Protocol):
self, repository: str
) -> Repository:
"""Gets all repository details from repository name"""
async def get_branches(self, repository: str) -> list[Branch]:
"""Get branches for a repository"""

View File

@@ -8,6 +8,7 @@ from openhands.integrations.provider import (
)
from openhands.integrations.service_types import (
AuthenticationError,
Branch,
Repository,
SuggestedTask,
UnknownException,
@@ -165,3 +166,43 @@ async def get_suggested_tasks(
content='No providers set.',
status_code=status.HTTP_401_UNAUTHORIZED,
)
@app.get('/repository/branches', response_model=list[Branch])
async def get_repository_branches(
repository: str,
provider_tokens: PROVIDER_TOKEN_TYPE | None = Depends(get_provider_tokens),
access_token: SecretStr | None = Depends(get_access_token),
):
"""Get branches for a repository.
Args:
repository: The repository name in the format 'owner/repo'
Returns:
A list of branches for the repository
"""
if provider_tokens:
client = ProviderHandler(
provider_tokens=provider_tokens, external_auth_token=access_token
)
try:
branches: list[Branch] = await client.get_branches(repository)
return branches
except AuthenticationError as e:
return JSONResponse(
content=str(e),
status_code=status.HTTP_401_UNAUTHORIZED,
)
except UnknownException as e:
return JSONResponse(
content=str(e),
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
return JSONResponse(
content='Git provider token required. (such as GitHub).',
status_code=status.HTTP_401_UNAUTHORIZED,
)