mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
522 lines
18 KiB
Python
522 lines
18 KiB
Python
import json
|
|
import os
|
|
from datetime import datetime
|
|
from typing import Any
|
|
|
|
import httpx
|
|
from pydantic import SecretStr
|
|
|
|
from openhands.core.logger import openhands_logger as logger
|
|
from openhands.integrations.github.queries import (
|
|
suggested_task_issue_graphql_query,
|
|
suggested_task_pr_graphql_query,
|
|
)
|
|
from openhands.integrations.service_types import (
|
|
BaseGitService,
|
|
Branch,
|
|
GitService,
|
|
ProviderType,
|
|
Repository,
|
|
RequestMethod,
|
|
SuggestedTask,
|
|
TaskType,
|
|
UnknownException,
|
|
User,
|
|
)
|
|
from openhands.server.types import AppMode
|
|
from openhands.utils.import_utils import get_impl
|
|
|
|
|
|
class GitHubService(BaseGitService, GitService):
|
|
"""Default implementation of GitService for GitHub integration.
|
|
|
|
TODO: This doesn't seem a good candidate for the get_impl() pattern. What are the abstract methods we should actually separate and implement here?
|
|
This is an extension point in OpenHands that allows applications to customize GitHub
|
|
integration behavior. Applications can substitute their own implementation by:
|
|
1. Creating a class that inherits from GitService
|
|
2. Implementing all required methods
|
|
3. Setting server_config.github_service_class to the fully qualified name of the class
|
|
|
|
The class is instantiated via get_impl() in openhands.server.shared.py.
|
|
"""
|
|
BASE_URL = 'https://api.github.com'
|
|
token: SecretStr = SecretStr('')
|
|
refresh = False
|
|
|
|
def __init__(
|
|
self,
|
|
user_id: str | None = None,
|
|
external_auth_id: str | None = None,
|
|
external_auth_token: SecretStr | None = None,
|
|
token: SecretStr | None = None,
|
|
external_token_manager: bool = False,
|
|
base_domain: str | None = None,
|
|
):
|
|
self.user_id = user_id
|
|
self.external_token_manager = external_token_manager
|
|
|
|
if token:
|
|
self.token = token
|
|
|
|
if base_domain and base_domain != 'github.com':
|
|
self.BASE_URL = f'https://{base_domain}/api/v3'
|
|
|
|
self.external_auth_id = external_auth_id
|
|
self.external_auth_token = external_auth_token
|
|
|
|
@property
|
|
def provider(self) -> str:
|
|
return ProviderType.GITHUB.value
|
|
|
|
async def _get_github_headers(self) -> dict:
|
|
"""Retrieve the GH Token from settings store to construct the headers."""
|
|
if not self.token:
|
|
self.token = await self.get_latest_token()
|
|
|
|
return {
|
|
'Authorization': f'Bearer {self.token.get_secret_value() if self.token else ""}',
|
|
'Accept': 'application/vnd.github.v3+json',
|
|
}
|
|
|
|
def _has_token_expired(self, status_code: int) -> bool:
|
|
return status_code == 401
|
|
|
|
async def get_latest_token(self) -> SecretStr | None:
|
|
return self.token
|
|
|
|
async def _make_request(
|
|
self,
|
|
url: str,
|
|
params: dict | None = None,
|
|
method: RequestMethod = RequestMethod.GET,
|
|
) -> tuple[Any, dict]:
|
|
try:
|
|
async with httpx.AsyncClient() as client:
|
|
github_headers = await self._get_github_headers()
|
|
|
|
# Make initial request
|
|
response = await self.execute_request(
|
|
client=client,
|
|
url=url,
|
|
headers=github_headers,
|
|
params=params,
|
|
method=method,
|
|
)
|
|
|
|
# Handle token refresh if needed
|
|
if self.refresh and self._has_token_expired(response.status_code):
|
|
await self.get_latest_token()
|
|
github_headers = await self._get_github_headers()
|
|
response = await self.execute_request(
|
|
client=client,
|
|
url=url,
|
|
headers=github_headers,
|
|
params=params,
|
|
method=method,
|
|
)
|
|
|
|
response.raise_for_status()
|
|
headers = {}
|
|
if 'Link' in response.headers:
|
|
headers['Link'] = response.headers['Link']
|
|
|
|
return response.json(), headers
|
|
|
|
except httpx.HTTPStatusError as e:
|
|
raise self.handle_http_status_error(e)
|
|
except httpx.HTTPError as e:
|
|
raise self.handle_http_error(e)
|
|
|
|
async def get_user(self) -> User:
|
|
url = f'{self.BASE_URL}/user'
|
|
response, _ = await self._make_request(url)
|
|
|
|
return User(
|
|
id=response.get('id'),
|
|
login=response.get('login'),
|
|
avatar_url=response.get('avatar_url'),
|
|
company=response.get('company'),
|
|
name=response.get('name'),
|
|
email=response.get('email'),
|
|
)
|
|
|
|
async def verify_access(self) -> bool:
|
|
"""Verify if the token is valid by making a simple request."""
|
|
url = f'{self.BASE_URL}'
|
|
await self._make_request(url)
|
|
return True
|
|
|
|
async def _fetch_paginated_repos(
|
|
self, url: str, params: dict, max_repos: int, extract_key: str | None = None
|
|
) -> list[dict]:
|
|
"""
|
|
Fetch repositories with pagination support.
|
|
|
|
Args:
|
|
url: The API endpoint URL
|
|
params: Query parameters for the request
|
|
max_repos: Maximum number of repositories to fetch
|
|
extract_key: If provided, extract repositories from this key in the response
|
|
|
|
Returns:
|
|
List of repository dictionaries
|
|
"""
|
|
repos: list[dict] = []
|
|
page = 1
|
|
|
|
while len(repos) < max_repos:
|
|
page_params = {**params, 'page': str(page)}
|
|
response, headers = await self._make_request(url, page_params)
|
|
|
|
# Extract repositories from response
|
|
page_repos = response.get(extract_key, []) if extract_key else response
|
|
|
|
if not page_repos: # No more repositories
|
|
break
|
|
|
|
repos.extend(page_repos)
|
|
page += 1
|
|
|
|
# Check if we've reached the last page
|
|
link_header = headers.get('Link', '')
|
|
if 'rel="next"' not in link_header:
|
|
break
|
|
|
|
return repos[:max_repos] # Trim to max_repos if needed
|
|
|
|
def parse_pushed_at_date(self, repo):
|
|
ts = repo.get('pushed_at')
|
|
return datetime.strptime(ts, '%Y-%m-%dT%H:%M:%SZ') if ts else datetime.min
|
|
|
|
async def get_repositories(self, sort: str, app_mode: AppMode) -> list[Repository]:
|
|
MAX_REPOS = 1000
|
|
PER_PAGE = 100 # Maximum allowed by GitHub API
|
|
all_repos: list[dict] = []
|
|
|
|
if app_mode == AppMode.SAAS:
|
|
# Get all installation IDs and fetch repos for each one
|
|
installation_ids = await self.get_installation_ids()
|
|
|
|
# Iterate through each installation ID
|
|
for installation_id in installation_ids:
|
|
params = {'per_page': str(PER_PAGE)}
|
|
url = (
|
|
f'{self.BASE_URL}/user/installations/{installation_id}/repositories'
|
|
)
|
|
|
|
# Fetch repositories for this installation
|
|
installation_repos = await self._fetch_paginated_repos(
|
|
url, params, MAX_REPOS - len(all_repos), extract_key='repositories'
|
|
)
|
|
|
|
all_repos.extend(installation_repos)
|
|
|
|
# If we've already reached MAX_REPOS, no need to check other installations
|
|
if len(all_repos) >= MAX_REPOS:
|
|
break
|
|
|
|
if sort == 'pushed':
|
|
all_repos.sort(key=self.parse_pushed_at_date, reverse=True)
|
|
else:
|
|
# Original behavior for non-SaaS mode
|
|
params = {'per_page': str(PER_PAGE), 'sort': sort}
|
|
url = f'{self.BASE_URL}/user/repos'
|
|
|
|
# Fetch user repositories
|
|
all_repos = await self._fetch_paginated_repos(url, params, MAX_REPOS)
|
|
|
|
# Convert to Repository objects
|
|
return [
|
|
Repository(
|
|
id=repo.get('id'),
|
|
full_name=repo.get('full_name'),
|
|
stargazers_count=repo.get('stargazers_count'),
|
|
git_provider=ProviderType.GITHUB,
|
|
is_public=not repo.get('private', True),
|
|
)
|
|
for repo in all_repos
|
|
]
|
|
|
|
async def get_installation_ids(self) -> list[int]:
|
|
url = f'{self.BASE_URL}/user/installations'
|
|
response, _ = await self._make_request(url)
|
|
installations = response.get('installations', [])
|
|
return [i['id'] for i in installations]
|
|
|
|
async def search_repositories(
|
|
self, query: str, per_page: int, sort: str, order: str
|
|
) -> list[Repository]:
|
|
url = f'{self.BASE_URL}/search/repositories'
|
|
# Add is:public to the query to ensure we only search for public repositories
|
|
query_with_visibility = f'{query} is:public'
|
|
params = {
|
|
'q': query_with_visibility,
|
|
'per_page': per_page,
|
|
'sort': sort,
|
|
'order': order,
|
|
}
|
|
|
|
response, _ = await self._make_request(url, params)
|
|
repo_items = response.get('items', [])
|
|
|
|
repos = [
|
|
Repository(
|
|
id=repo.get('id'),
|
|
full_name=repo.get('full_name'),
|
|
stargazers_count=repo.get('stargazers_count'),
|
|
git_provider=ProviderType.GITHUB,
|
|
is_public=True,
|
|
)
|
|
for repo in repo_items
|
|
]
|
|
|
|
return repos
|
|
|
|
async def execute_graphql_query(
|
|
self, query: str, variables: dict[str, Any]
|
|
) -> dict[str, Any]:
|
|
"""Execute a GraphQL query against the GitHub API."""
|
|
try:
|
|
async with httpx.AsyncClient() as client:
|
|
github_headers = await self._get_github_headers()
|
|
response = await client.post(
|
|
f'{self.BASE_URL}/graphql',
|
|
headers=github_headers,
|
|
json={'query': query, 'variables': variables},
|
|
)
|
|
response.raise_for_status()
|
|
|
|
result = response.json()
|
|
if 'errors' in result:
|
|
raise UnknownException(
|
|
f'GraphQL query error: {json.dumps(result["errors"])}'
|
|
)
|
|
|
|
return dict(result)
|
|
|
|
except httpx.HTTPStatusError as e:
|
|
raise self.handle_http_status_error(e)
|
|
except httpx.HTTPError as e:
|
|
raise self.handle_http_error(e)
|
|
|
|
async def get_suggested_tasks(self) -> list[SuggestedTask]:
|
|
"""Get suggested tasks for the authenticated user across all repositories.
|
|
|
|
Returns:
|
|
- PRs authored by the user.
|
|
- Issues assigned to the user.
|
|
|
|
Note: Queries are split to avoid timeout issues.
|
|
"""
|
|
# Get user info to use in queries
|
|
user = await self.get_user()
|
|
login = user.login
|
|
tasks: list[SuggestedTask] = []
|
|
variables = {'login': login}
|
|
|
|
try:
|
|
pr_response = await self.execute_graphql_query(
|
|
suggested_task_pr_graphql_query, variables
|
|
)
|
|
pr_data = pr_response['data']['user']
|
|
|
|
# Process pull requests
|
|
for pr in pr_data['pullRequests']['nodes']:
|
|
repo_name = pr['repository']['nameWithOwner']
|
|
|
|
# Start with default task type
|
|
task_type = TaskType.OPEN_PR
|
|
|
|
# Check for specific states
|
|
if pr['mergeable'] == 'CONFLICTING':
|
|
task_type = TaskType.MERGE_CONFLICTS
|
|
elif (
|
|
pr['commits']['nodes']
|
|
and pr['commits']['nodes'][0]['commit']['statusCheckRollup']
|
|
and pr['commits']['nodes'][0]['commit']['statusCheckRollup'][
|
|
'state'
|
|
]
|
|
== 'FAILURE'
|
|
):
|
|
task_type = TaskType.FAILING_CHECKS
|
|
elif any(
|
|
review['state'] in ['CHANGES_REQUESTED', 'COMMENTED']
|
|
for review in pr['reviews']['nodes']
|
|
):
|
|
task_type = TaskType.UNRESOLVED_COMMENTS
|
|
|
|
# Only add the task if it's not OPEN_PR
|
|
if task_type != TaskType.OPEN_PR:
|
|
tasks.append(
|
|
SuggestedTask(
|
|
git_provider=ProviderType.GITHUB,
|
|
task_type=task_type,
|
|
repo=repo_name,
|
|
issue_number=pr['number'],
|
|
title=pr['title'],
|
|
)
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.info(
|
|
f'Error fetching suggested task for PRs: {e}',
|
|
extra={
|
|
'signal': 'github_suggested_tasks',
|
|
'user_id': self.external_auth_id,
|
|
},
|
|
)
|
|
|
|
try:
|
|
# Execute issue query
|
|
issue_response = await self.execute_graphql_query(
|
|
suggested_task_issue_graphql_query, variables
|
|
)
|
|
issue_data = issue_response['data']['user']
|
|
|
|
# Process issues
|
|
for issue in issue_data['issues']['nodes']:
|
|
repo_name = issue['repository']['nameWithOwner']
|
|
tasks.append(
|
|
SuggestedTask(
|
|
git_provider=ProviderType.GITHUB,
|
|
task_type=TaskType.OPEN_ISSUE,
|
|
repo=repo_name,
|
|
issue_number=issue['number'],
|
|
title=issue['title'],
|
|
)
|
|
)
|
|
|
|
return tasks
|
|
|
|
except Exception as e:
|
|
logger.info(
|
|
f'Error fetching suggested task for issues: {e}',
|
|
extra={
|
|
'signal': 'github_suggested_tasks',
|
|
'user_id': self.external_auth_id,
|
|
},
|
|
)
|
|
|
|
return tasks
|
|
|
|
async def get_repository_details_from_repo_name(
|
|
self, repository: str
|
|
) -> Repository:
|
|
url = f'{self.BASE_URL}/repos/{repository}'
|
|
repo, _ = await self._make_request(url)
|
|
|
|
return Repository(
|
|
id=repo.get('id'),
|
|
full_name=repo.get('full_name'),
|
|
stargazers_count=repo.get('stargazers_count'),
|
|
git_provider=ProviderType.GITHUB,
|
|
is_public=not repo.get('private', True),
|
|
)
|
|
|
|
async def get_branches(self, repository: str) -> list[Branch]:
|
|
"""Get branches for a repository"""
|
|
url = f'{self.BASE_URL}/repos/{repository}/branches'
|
|
|
|
# Set maximum branches to fetch (10 pages with 100 per page)
|
|
MAX_BRANCHES = 1000
|
|
PER_PAGE = 100
|
|
|
|
all_branches: list[Branch] = []
|
|
page = 1
|
|
|
|
# Fetch up to 10 pages of branches
|
|
while page <= 10 and len(all_branches) < MAX_BRANCHES:
|
|
params = {'per_page': str(PER_PAGE), 'page': str(page)}
|
|
response, headers = await self._make_request(url, params)
|
|
|
|
if not response: # No more branches
|
|
break
|
|
|
|
for branch_data in response:
|
|
# Extract the last commit date if available
|
|
last_push_date = None
|
|
if branch_data.get('commit') and branch_data['commit'].get('commit'):
|
|
commit_info = branch_data['commit']['commit']
|
|
if commit_info.get('committer') and commit_info['committer'].get(
|
|
'date'
|
|
):
|
|
last_push_date = commit_info['committer']['date']
|
|
|
|
branch = Branch(
|
|
name=branch_data.get('name'),
|
|
commit_sha=branch_data.get('commit', {}).get('sha', ''),
|
|
protected=branch_data.get('protected', False),
|
|
last_push_date=last_push_date,
|
|
)
|
|
all_branches.append(branch)
|
|
|
|
page += 1
|
|
|
|
# Check if we've reached the last page
|
|
link_header = headers.get('Link', '')
|
|
if 'rel="next"' not in link_header:
|
|
break
|
|
|
|
return all_branches
|
|
|
|
async def create_pr(
|
|
self,
|
|
repo_name: str,
|
|
source_branch: str,
|
|
target_branch: str,
|
|
title: str,
|
|
body: str | None = None,
|
|
draft: bool = True,
|
|
) -> str:
|
|
"""
|
|
Creates a PR using user credentials
|
|
|
|
Args:
|
|
repo_name: The full name of the repository (owner/repo)
|
|
source_branch: The name of the branch where your changes are implemented
|
|
target_branch: The name of the branch you want the changes pulled into
|
|
title: The title of the pull request (optional, defaults to a generic title)
|
|
body: The body/description of the pull request (optional)
|
|
draft: Whether to create the PR as a draft (optional, defaults to False)
|
|
|
|
Returns:
|
|
- PR URL when successful
|
|
- Error message when unsuccessful
|
|
"""
|
|
try:
|
|
url = f'{self.BASE_URL}/repos/{repo_name}/pulls'
|
|
|
|
# Set default body if none provided
|
|
if not body:
|
|
body = f'Merging changes from {source_branch} into {target_branch}'
|
|
|
|
# Prepare the request payload
|
|
payload = {
|
|
'title': title,
|
|
'head': source_branch,
|
|
'base': target_branch,
|
|
'body': body,
|
|
'draft': draft,
|
|
}
|
|
|
|
# Make the POST request to create the PR
|
|
response, _ = await self._make_request(
|
|
url=url, params=payload, method=RequestMethod.POST
|
|
)
|
|
|
|
# Return the HTML URL of the created PR
|
|
if 'html_url' in response:
|
|
return response['html_url']
|
|
else:
|
|
return f'PR created but URL not found in response: {response}'
|
|
|
|
except Exception as e:
|
|
return f'Error creating pull request: {str(e)}'
|
|
|
|
|
|
github_service_cls = os.environ.get(
|
|
'OPENHANDS_GITHUB_SERVICE_CLS',
|
|
'openhands.integrations.github.github_service.GitHubService',
|
|
)
|
|
GithubServiceImpl = get_impl(GitHubService, github_service_cls)
|