fix: Enhance GitHub repository search to include user organizations (#10324)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
sp.wack 2025-08-19 19:56:15 +04:00 committed by GitHub
parent 0297b3da18
commit aa6b454772
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 258 additions and 57 deletions

View File

@ -232,13 +232,16 @@ describe("RepositorySelectionForm", () => {
renderForm();
const dropdown = await screen.findByTestId("repo-dropdown");
const input = dropdown.querySelector('input[type="text"]') as HTMLInputElement;
const input = dropdown.querySelector(
'input[type="text"]',
) as HTMLInputElement;
expect(input).toBeInTheDocument();
await userEvent.type(input, "https://github.com/kubernetes/kubernetes");
expect(searchGitReposSpy).toHaveBeenLastCalledWith(
"kubernetes/kubernetes",
3,
"github",
);
});
@ -268,13 +271,16 @@ describe("RepositorySelectionForm", () => {
renderForm();
const dropdown = await screen.findByTestId("repo-dropdown");
const input = dropdown.querySelector('input[type="text"]') as HTMLInputElement;
const input = dropdown.querySelector(
'input[type="text"]',
) as HTMLInputElement;
expect(input).toBeInTheDocument();
await userEvent.type(input, "https://github.com/kubernetes/kubernetes");
expect(searchGitReposSpy).toHaveBeenLastCalledWith(
"kubernetes/kubernetes",
3,
"github",
);
});
});

View File

@ -1,7 +1,9 @@
import { useCallback, useMemo, useRef } from "react";
import { useCallback, useMemo, useState } from "react";
import { useTranslation } from "react-i18next";
import { Provider } from "../../types/settings";
import { useGitRepositories } from "../../hooks/query/use-git-repositories";
import { useSearchRepositories } from "../../hooks/query/use-search-repositories";
import { useDebounce } from "../../hooks/use-debounce";
import OpenHands from "../../api/open-hands";
import { GitRepository } from "../../types/git";
import {
@ -19,10 +21,6 @@ export interface GitRepositoryDropdownProps {
onChange?: (repository?: GitRepository) => void;
}
interface SearchCache {
[key: string]: GitRepository[];
}
export function GitRepositoryDropdown({
provider,
value,
@ -33,6 +31,20 @@ export function GitRepositoryDropdown({
onChange,
}: GitRepositoryDropdownProps) {
const { t } = useTranslation();
const [searchInput, setSearchInput] = useState("");
const debouncedSearchInput = useDebounce(searchInput, 300);
// Process search input to handle URLs
const processedSearchInput = useMemo(() => {
if (debouncedSearchInput.startsWith("https://")) {
const match = debouncedSearchInput.match(
/https:\/\/[^/]+\/([^/]+\/[^/]+)/,
);
return match ? match[1] : debouncedSearchInput;
}
return debouncedSearchInput;
}, [debouncedSearchInput]);
const {
data,
fetchNextPage,
@ -45,6 +57,10 @@ export function GitRepositoryDropdown({
enabled: !disabled,
});
// Search query for processed input (handles URLs)
const { data: searchData, isLoading: isSearchLoading } =
useSearchRepositories(processedSearchInput, provider);
const allOptions: AsyncSelectOption[] = useMemo(
() =>
data?.pages
@ -58,75 +74,83 @@ export function GitRepositoryDropdown({
[data],
);
// Keep track of search results
const searchCache = useRef<SearchCache>({});
const searchOptions: AsyncSelectOption[] = useMemo(
() =>
searchData
? searchData.map((repo) => ({
value: repo.id,
label: repo.full_name,
}))
: [],
[searchData],
);
const selectedOption = useMemo(() => {
// First check in loaded pages
const option = allOptions.find((opt) => opt.value === value);
if (option) return option;
// If not found, check in search cache
const repo = Object.values(searchCache.current)
.flat()
.find((r) => r.id === value);
if (repo) {
return {
value: repo.id,
label: repo.full_name,
};
}
// If not found, check in search results
const searchOption = searchOptions.find((opt) => opt.value === value);
if (searchOption) return searchOption;
return null;
}, [allOptions, value]);
}, [allOptions, searchOptions, value]);
const loadOptions = useCallback(
async (inputValue: string): Promise<AsyncSelectOption[]> => {
// Update search input to trigger debounced search
setSearchInput(inputValue);
// If empty input, show all loaded options
if (!inputValue.trim()) {
return allOptions;
}
// If it looks like a URL, extract the repo name and search
// For very short inputs, do local filtering
if (inputValue.length < 2) {
return allOptions.filter((option) =>
option.label.toLowerCase().includes(inputValue.toLowerCase()),
);
}
// Handle URL inputs by performing direct search
if (inputValue.startsWith("https://")) {
const match = inputValue.match(/https:\/\/[^/]+\/([^/]+\/[^/]+)/);
if (match) {
const repoName = match[1];
const searchResults = await OpenHands.searchGitRepositories(
repoName,
3,
);
// Cache the search results
searchCache.current[repoName] = searchResults;
return searchResults.map((repo) => ({
value: repo.id,
label: repo.full_name,
}));
try {
// Perform direct search for URL-based inputs
const repositories = await OpenHands.searchGitRepositories(
repoName,
3,
provider,
);
return repositories.map((repo) => ({
value: repo.full_name,
label: repo.full_name,
data: repo,
}));
} catch (error) {
// Fall back to local filtering if search fails
return allOptions.filter((option) =>
option.label.toLowerCase().includes(repoName.toLowerCase()),
);
}
}
}
// For any other input, search via API
if (inputValue.length >= 2) {
// Only search if at least 2 characters
const searchResults = await OpenHands.searchGitRepositories(
inputValue,
10,
);
// Cache the search results
searchCache.current[inputValue] = searchResults;
return searchResults.map((repo) => ({
value: repo.id,
label: repo.full_name,
}));
// For regular text inputs, use hook-based search results if available
if (searchOptions.length > 0 && processedSearchInput === inputValue) {
return searchOptions;
}
// For very short inputs, do local filtering
// Fallback to local filtering while search is loading
return allOptions.filter((option) =>
option.label.toLowerCase().includes(inputValue.toLowerCase()),
);
},
[allOptions],
[allOptions, searchOptions, processedSearchInput, provider],
);
const handleChange = (option: AsyncSelectOption | null) => {
@ -142,9 +166,7 @@ export function GitRepositoryDropdown({
// If not found, check in search results
if (!repo) {
repo = Object.values(searchCache.current)
.flat()
.find((r) => r.id === option.value);
repo = searchData?.find((r) => r.id === option.value);
}
onChange?.(repo);
@ -167,7 +189,7 @@ export function GitRepositoryDropdown({
errorMessage={errorMessage}
disabled={disabled}
isClearable={false}
isLoading={isLoading || isLoading || isFetchingNextPage}
isLoading={isLoading || isFetchingNextPage || isSearchLoading}
cacheOptions
defaultOptions={allOptions}
onChange={handleChange}

View File

@ -321,6 +321,36 @@ class GitHubService(BaseGitService, GitService, InstallationsService):
installations = response.get('installations', [])
return [str(i['id']) for i in installations]
async def get_user_organizations(self) -> list[str]:
"""Get list of organization logins that the user is a member of."""
url = f'{self.BASE_URL}/user/orgs'
try:
response, _ = await self._make_request(url)
orgs = [org['login'] for org in response]
return orgs
except Exception as e:
logger.warning(f'Failed to get user organizations: {e}')
return []
def _fuzzy_match_org_name(self, query: str, org_name: str) -> bool:
"""Check if query fuzzy matches organization name."""
query_lower = query.lower().replace('-', '').replace('_', '').replace(' ', '')
org_lower = org_name.lower().replace('-', '').replace('_', '').replace(' ', '')
# Exact match after normalization
if query_lower == org_lower:
return True
# Query is a substring of org name
if query_lower in org_lower:
return True
# Org name is a substring of query (less common but possible)
if org_lower in query_lower:
return True
return False
async def search_repositories(
self, query: str, per_page: int, sort: str, order: str, public: bool
) -> list[Repository]:
@ -341,21 +371,68 @@ class GitHubService(BaseGitService, GitService, InstallationsService):
# Add is:public to the query to ensure we only search for public repositories
params['q'] = f'in:name {org}/{repo_name} is:public'
# Perhaps we should go through all orgs and the search for repos under every org
# Currently it will only search user repos, and org repos when '/' is in the name
# Handle private repository searches
if not public and '/' in query:
org, repo_query = query.split('/', 1)
query_with_user = f'org:{org} in:name {repo_query}'
params['q'] = query_with_user
elif not public:
# Expand search scope to include user's repositories and organizations they're a member of
user = await self.get_user()
params['q'] = f'in:name {query} user:{user.login}'
user_orgs = await self.get_user_organizations()
# Search in user repos and org repos separately
all_repos = []
# Search in user repositories
user_query = f'{query} user:{user.login}'
user_params = params.copy()
user_params['q'] = user_query
try:
user_response, _ = await self._make_request(url, user_params)
user_items = user_response.get('items', [])
all_repos.extend(user_items)
except Exception as e:
logger.warning(f'User search failed: {e}')
# Search for repos named "query" in each organization
for org in user_orgs:
org_query = f'{query} org:{org}'
org_params = params.copy()
org_params['q'] = org_query
try:
org_response, _ = await self._make_request(url, org_params)
org_items = org_response.get('items', [])
all_repos.extend(org_items)
except Exception as e:
logger.warning(f'Org {org} search failed: {e}')
# Also search for top repos from orgs that match the query name
for org in user_orgs:
if self._fuzzy_match_org_name(query, org):
org_repos_query = f'org:{org}'
org_repos_params = params.copy()
org_repos_params['q'] = org_repos_query
org_repos_params['sort'] = 'stars'
org_repos_params['per_page'] = 2 # Limit to first 2 repos
try:
org_repos_response, _ = await self._make_request(
url, org_repos_params
)
org_repo_items = org_repos_response.get('items', [])
all_repos.extend(org_repo_items)
except Exception as e:
logger.warning(f'Org repos search for {org} failed: {e}')
return [self._parse_repository(repo) for repo in all_repos]
# Default case (public search or slash query)
response, _ = await self._make_request(url, params)
repo_items = response.get('items', [])
repos = [self._parse_repository(repo) for repo in repo_items]
return repos
return [self._parse_repository(repo) for repo in repo_items]
async def execute_graphql_query(
self, query: str, variables: dict[str, Any]

View File

@ -10,6 +10,7 @@ from openhands.integrations.service_types import (
OwnerType,
ProviderType,
Repository,
User,
)
from openhands.server.types import AppMode
@ -244,3 +245,98 @@ async def test_github_get_repositories_owner_type_fallback():
# Verify all repositories default to USER owner_type
for repo in repositories:
assert repo.owner_type == OwnerType.USER
@pytest.mark.asyncio
async def test_github_search_repositories_with_organizations():
"""Test that search_repositories includes user organizations in the search scope."""
service = GitHubService(user_id='test-user', token=SecretStr('test-token'))
# Mock user data
mock_user = User(
id='123', login='testuser', avatar_url='https://example.com/avatar.jpg'
)
# Mock search response
mock_search_response = {
'items': [
{
'id': 1,
'name': 'OpenHands',
'full_name': 'All-Hands-AI/OpenHands',
'private': False,
'html_url': 'https://github.com/All-Hands-AI/OpenHands',
'clone_url': 'https://github.com/All-Hands-AI/OpenHands.git',
'pushed_at': '2023-01-01T00:00:00Z',
'owner': {'login': 'All-Hands-AI', 'type': 'Organization'},
}
]
}
with (
patch.object(service, 'get_user', return_value=mock_user),
patch.object(
service,
'get_user_organizations',
return_value=['All-Hands-AI', 'example-org'],
),
patch.object(
service, '_make_request', return_value=(mock_search_response, {})
) as mock_request,
):
repositories = await service.search_repositories(
query='openhands', per_page=10, sort='stars', order='desc', public=False
)
# Verify that separate requests were made for user and each organization
assert mock_request.call_count == 3
# Check the calls made
calls = mock_request.call_args_list
# First call should be for user repositories
user_call = calls[0]
user_params = user_call[0][1] # Second argument is params
assert user_params['q'] == 'openhands user:testuser'
# Second call should be for first organization
org1_call = calls[1]
org1_params = org1_call[0][1]
assert org1_params['q'] == 'openhands org:All-Hands-AI'
# Third call should be for second organization
org2_call = calls[2]
org2_params = org2_call[0][1]
assert org2_params['q'] == 'openhands org:example-org'
# Verify repositories are returned (3 copies since each call returns the same mock response)
assert len(repositories) == 3
assert all(repo.full_name == 'All-Hands-AI/OpenHands' for repo in repositories)
@pytest.mark.asyncio
async def test_github_get_user_organizations():
"""Test that get_user_organizations fetches user's organizations."""
service = GitHubService(user_id='test-user', token=SecretStr('test-token'))
mock_orgs_response = [
{'login': 'All-Hands-AI', 'id': 1},
{'login': 'example-org', 'id': 2},
]
with patch.object(service, '_make_request', return_value=(mock_orgs_response, {})):
orgs = await service.get_user_organizations()
assert orgs == ['All-Hands-AI', 'example-org']
@pytest.mark.asyncio
async def test_github_get_user_organizations_error_handling():
"""Test that get_user_organizations handles errors gracefully."""
service = GitHubService(user_id='test-user', token=SecretStr('test-token'))
with patch.object(service, '_make_request', side_effect=Exception('API Error')):
orgs = await service.get_user_organizations()
# Should return empty list on error
assert orgs == []