Add base_domain parameter for GitHub Enterprise support (#7754)

Co-authored-by: Tom Deckers <tdeckers@cisco.com>
Co-authored-by: Robert Brennan <accounts@rbren.io>
Co-authored-by: Rohit Malhotra <rohitvinodmalhotra@gmail.com>
This commit is contained in:
Tom Deckers 2025-04-16 02:00:32 +02:00 committed by GitHub
parent d7e8f843ad
commit 7e14a512e0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 203 additions and 49 deletions

View File

@ -12,11 +12,21 @@ from openhands.resolver.utils import extract_issue_references
class GithubIssueHandler(IssueHandlerInterface):
def __init__(self, owner: str, repo: str, token: str, username: str | None = None):
def __init__(self, owner: str, repo: str, token: str, username: str | None = None, base_domain: str = "github.com"):
"""Initialize a GitHub issue handler.
Args:
owner: The owner of the repository
repo: The name of the repository
token: The GitHub personal access token
username: Optional GitHub username
base_domain: The domain for GitHub Enterprise (default: "github.com")
"""
self.owner = owner
self.repo = repo
self.token = token
self.username = username
self.base_domain = base_domain
self.base_url = self.get_base_url()
self.download_url = self.get_download_url()
self.clone_url = self.get_clone_url()
@ -32,10 +42,13 @@ class GithubIssueHandler(IssueHandlerInterface):
}
def get_base_url(self) -> str:
return f'https://api.github.com/repos/{self.owner}/{self.repo}'
if self.base_domain == "github.com":
return f'https://api.github.com/repos/{self.owner}/{self.repo}'
else:
return f'https://{self.base_domain}/api/v3/repos/{self.owner}/{self.repo}'
def get_authorize_url(self) -> str:
return f'https://{self.username}:{self.token}@github.com/'
return f'https://{self.username}:{self.token}@{self.base_domain}/'
def get_branch_url(self, branch_name: str) -> str:
return self.get_base_url() + f'/branches/{branch_name}'
@ -49,13 +62,16 @@ class GithubIssueHandler(IssueHandlerInterface):
if self.username
else f'x-auth-token:{self.token}'
)
return f'https://{username_and_token}@github.com/{self.owner}/{self.repo}.git'
return f'https://{username_and_token}@{self.base_domain}/{self.owner}/{self.repo}.git'
def get_graphql_url(self) -> str:
return 'https://api.github.com/graphql'
if self.base_domain == "github.com":
return 'https://api.github.com/graphql'
else:
return f'https://{self.base_domain}/api/v3/graphql'
def get_compare_url(self, branch_name: str) -> str:
return f'https://github.com/{self.owner}/{self.repo}/compare/{branch_name}?expand=1'
return f'https://{self.base_domain}/{self.owner}/{self.repo}/compare/{branch_name}?expand=1'
def get_converted_issues(
self, issue_numbers: list[int] | None = None, comment_id: int | None = None
@ -220,7 +236,7 @@ class GithubIssueHandler(IssueHandlerInterface):
response.raise_for_status()
def get_pull_url(self, pr_number: int) -> str:
return f'https://github.com/{self.owner}/{self.repo}/pull/{pr_number}'
return f'https://{self.base_domain}/{self.owner}/{self.repo}/pull/{pr_number}'
def get_default_branch_name(self) -> str:
response = httpx.get(f'{self.base_url}', headers=self.headers)
@ -286,11 +302,21 @@ class GithubIssueHandler(IssueHandlerInterface):
class GithubPRHandler(GithubIssueHandler):
def __init__(self, owner: str, repo: str, token: str, username: str | None = None):
super().__init__(owner, repo, token, username)
self.download_url = (
f'https://api.github.com/repos/{self.owner}/{self.repo}/pulls'
)
def __init__(self, owner: str, repo: str, token: str, username: str | None = None, base_domain: str = "github.com"):
"""Initialize a GitHub PR handler.
Args:
owner: The owner of the repository
repo: The name of the repository
token: The GitHub personal access token
username: Optional GitHub username
base_domain: The domain for GitHub Enterprise (default: "github.com")
"""
super().__init__(owner, repo, token, username, base_domain)
if self.base_domain == "github.com":
self.download_url = f'https://api.github.com/repos/{self.owner}/{self.repo}/pulls'
else:
self.download_url = f'https://{self.base_domain}/api/v3/repos/{self.owner}/{self.repo}/pulls'
def download_pr_metadata(
self, pull_number: int, comment_id: int | None = None
@ -356,7 +382,7 @@ class GithubPRHandler(GithubIssueHandler):
variables = {'owner': self.owner, 'repo': self.repo, 'pr': pull_number}
url = 'https://api.github.com/graphql'
url = self.get_graphql_url()
headers = {
'Authorization': f'Bearer {self.token}',
'Content-Type': 'application/json',
@ -444,7 +470,10 @@ class GithubPRHandler(GithubIssueHandler):
self, pr_number: int, comment_id: int | None = None
) -> list[str] | None:
"""Download comments for a specific pull request from Github."""
url = f'https://api.github.com/repos/{self.owner}/{self.repo}/issues/{pr_number}/comments'
if self.base_domain == "github.com":
url = f'https://api.github.com/repos/{self.owner}/{self.repo}/issues/{pr_number}/comments'
else:
url = f'https://{self.base_domain}/api/v3/repos/{self.owner}/{self.repo}/issues/{pr_number}/comments'
headers = {
'Authorization': f'token {self.token}',
'Accept': 'application/vnd.github.v3+json',
@ -513,7 +542,10 @@ class GithubPRHandler(GithubIssueHandler):
for issue_number in unique_issue_references:
try:
url = f'https://api.github.com/repos/{self.owner}/{self.repo}/issues/{issue_number}'
if self.base_domain == "github.com":
url = f'https://api.github.com/repos/{self.owner}/{self.repo}/issues/{issue_number}'
else:
url = f'https://{self.base_domain}/api/v3/repos/{self.owner}/{self.repo}/issues/{issue_number}'
headers = {
'Authorization': f'Bearer {self.token}',
'Accept': 'application/vnd.github.v3+json',

View File

@ -13,11 +13,28 @@ from openhands.resolver.utils import extract_issue_references
class GitlabIssueHandler(IssueHandlerInterface):
def __init__(self, owner: str, repo: str, token: str, username: str | None = None):
def __init__(
self,
owner: str,
repo: str,
token: str,
username: str | None = None,
base_domain: str = 'gitlab.com',
):
"""Initialize a GitLab issue handler.
Args:
owner: The owner of the repository
repo: The name of the repository
token: The GitLab personal access token
username: Optional GitLab username
base_domain: The domain for GitLab Enterprise (default: "gitlab.com")
"""
self.owner = owner
self.repo = repo
self.token = token
self.username = username
self.base_domain = base_domain
self.base_url = self.get_base_url()
self.download_url = self.get_download_url()
self.clone_url = self.get_clone_url()
@ -34,10 +51,10 @@ class GitlabIssueHandler(IssueHandlerInterface):
def get_base_url(self) -> str:
project_path = quote(f'{self.owner}/{self.repo}', safe='')
return f'https://gitlab.com/api/v4/projects/{project_path}'
return f'https://{self.base_domain}/api/v4/projects/{project_path}'
def get_authorize_url(self) -> str:
return f'https://{self.username}:{self.token}@gitlab.com/'
return f'https://{self.username}:{self.token}@{self.base_domain}/'
def get_branch_url(self, branch_name: str) -> str:
return self.get_base_url() + f'/repository/branches/{branch_name}'
@ -49,13 +66,13 @@ class GitlabIssueHandler(IssueHandlerInterface):
username_and_token = self.token
if self.username:
username_and_token = f'{self.username}:{self.token}'
return f'https://{username_and_token}@gitlab.com/{self.owner}/{self.repo}.git'
return f'https://{username_and_token}@{self.base_domain}/{self.owner}/{self.repo}.git'
def get_graphql_url(self) -> str:
return 'https://gitlab.com/api/graphql'
return f'https://{self.base_domain}/api/graphql'
def get_compare_url(self, branch_name: str) -> str:
return f'https://gitlab.com/{self.owner}/{self.repo}/-/compare/{self.get_default_branch_name()}...{branch_name}'
return f'https://{self.base_domain}/{self.owner}/{self.repo}/-/compare/{self.get_default_branch_name()}...{branch_name}'
def get_converted_issues(
self, issue_numbers: list[int] | None = None, comment_id: int | None = None
@ -215,9 +232,7 @@ class GitlabIssueHandler(IssueHandlerInterface):
response.raise_for_status()
def get_pull_url(self, pr_number: int) -> str:
return (
f'https://gitlab.com/{self.owner}/{self.repo}/-/merge_requests/{pr_number}'
)
return f'https://{self.base_domain}/{self.owner}/{self.repo}/-/merge_requests/{pr_number}'
def get_default_branch_name(self) -> str:
response = httpx.get(f'{self.base_url}', headers=self.headers)
@ -248,7 +263,7 @@ class GitlabIssueHandler(IssueHandlerInterface):
def request_reviewers(self, reviewer: str, pr_number: int) -> None:
response = httpx.get(
f'https://gitlab.com/api/v4/users?username={reviewer}',
f'https://{self.base_domain}/api/v4/users?username={reviewer}',
headers=self.headers,
)
response.raise_for_status()
@ -298,8 +313,24 @@ class GitlabIssueHandler(IssueHandlerInterface):
class GitlabPRHandler(GitlabIssueHandler):
def __init__(self, owner: str, repo: str, token: str, username: str | None = None):
super().__init__(owner, repo, token, username)
def __init__(
self,
owner: str,
repo: str,
token: str,
username: str | None = None,
base_domain: str = 'gitlab.com',
):
"""Initialize a GitLab PR handler.
Args:
owner: The owner of the repository
repo: The name of the repository
token: The GitLab personal access token
username: Optional GitLab username
base_domain: The domain for GitLab Enterprise (default: "gitlab.com")
"""
super().__init__(owner, repo, token, username, base_domain)
self.download_url = f'{self.base_url}/merge_requests'
def download_pr_metadata(

View File

@ -66,6 +66,7 @@ async def resolve_issues(
issue_type: str,
repo_instruction: str | None,
issue_numbers: list[int] | None,
base_domain: str = 'github.com',
) -> None:
"""Resolve multiple github or gitlab issues.
@ -86,7 +87,7 @@ async def resolve_issues(
issue_numbers: List of issue numbers to resolve.
"""
issue_handler = issue_handler_factory(
issue_type, owner, repo, token, llm_config, platform
issue_type, owner, repo, token, llm_config, platform, username, base_domain
)
# Load dataset
@ -323,6 +324,12 @@ def main() -> None:
choices=['issue', 'pr'],
help='Type of issue to resolve, either open issue or pr comments.',
)
parser.add_argument(
'--base-domain',
type=str,
default='github.com',
help='Base domain for GitHub Enterprise (default: github.com)',
)
my_args = parser.parse_args()
@ -339,7 +346,7 @@ def main() -> None:
if not token:
raise ValueError('Token is required.')
platform = identify_token(token, my_args.selected_repo)
platform = identify_token(token, my_args.selected_repo, my_args.base_domain)
if platform == Platform.INVALID:
raise ValueError('Token is invalid.')
@ -394,6 +401,7 @@ def main() -> None:
issue_type=issue_type,
repo_instruction=repo_instruction,
issue_numbers=issue_numbers,
base_domain=my_args.base_domain,
)
)

View File

@ -322,24 +322,30 @@ def issue_handler_factory(
llm_config: LLMConfig,
platform: Platform,
username: str | None = None,
base_domain: str | None = None,
) -> ServiceContextIssue | ServiceContextPR:
# Determine default base_domain based on platform
if base_domain is None:
base_domain = 'github.com' if platform == Platform.GITHUB else 'gitlab.com'
if issue_type == 'issue':
if platform == Platform.GITHUB:
return ServiceContextIssue(
GithubIssueHandler(owner, repo, token, username), llm_config
GithubIssueHandler(owner, repo, token, username, base_domain),
llm_config,
)
else: # platform == Platform.GITLAB
return ServiceContextIssue(
GitlabIssueHandler(owner, repo, token, username), llm_config
GitlabIssueHandler(owner, repo, token, username, base_domain),
llm_config,
)
elif issue_type == 'pr':
if platform == Platform.GITHUB:
return ServiceContextPR(
GithubPRHandler(owner, repo, token, username), llm_config
GithubPRHandler(owner, repo, token, username, base_domain), llm_config
)
else: # platform == Platform.GITLAB
return ServiceContextPR(
GitlabPRHandler(owner, repo, token, username), llm_config
GitlabPRHandler(owner, repo, token, username, base_domain), llm_config
)
else:
raise ValueError(f'Invalid issue type: {issue_type}')
@ -361,6 +367,7 @@ async def resolve_issue(
issue_number: int,
comment_id: int | None,
reset_logger: bool = False,
base_domain: str | None = None,
) -> None:
"""Resolve a single issue.
@ -379,11 +386,15 @@ async def resolve_issue(
repo_instruction: Repository instruction to use.
issue_number: Issue number to resolve.
comment_id: Optional ID of a specific comment to focus on.
reset_logger: Whether to reset the logger for multiprocessing.
base_domain: The base domain for the git server (defaults to "github.com" for GitHub and "gitlab.com" for GitLab)
"""
# Determine default base_domain based on platform
if base_domain is None:
base_domain = 'github.com' if platform == Platform.GITHUB else 'gitlab.com'
issue_handler = issue_handler_factory(
issue_type, owner, repo, token, llm_config, platform, username
issue_type, owner, repo, token, llm_config, platform, username, base_domain
)
# Load dataset
@ -629,6 +640,12 @@ def main() -> None:
type=lambda x: x.lower() == 'true',
help='Whether to run in experimental mode.',
)
parser.add_argument(
'--base-domain',
type=str,
default=None,
help='Base domain for the git server (defaults to "github.com" for GitHub and "gitlab.com" for GitLab)',
)
my_args = parser.parse_args()
@ -651,7 +668,7 @@ def main() -> None:
if not token:
raise ValueError('Token is required.')
platform = identify_token(token, my_args.selected_repo)
platform = identify_token(token, my_args.selected_repo, my_args.base_domain)
if platform == Platform.INVALID:
raise ValueError('Token is invalid.')
@ -708,6 +725,7 @@ def main() -> None:
repo_instruction=repo_instruction,
issue_number=my_args.issue_number,
comment_id=my_args.comment_id,
base_domain=my_args.base_domain,
)
)

View File

@ -235,6 +235,7 @@ def send_pull_request(
target_branch: str | None = None,
reviewer: str | None = None,
pr_title: str | None = None,
base_domain: str | None = None,
) -> str:
"""Send a pull request to a GitHub or Gitlab repository.
@ -250,18 +251,25 @@ def send_pull_request(
target_branch: The target branch to create the pull request against (defaults to repository default branch)
reviewer: The GitHub or Gitlab username of the reviewer to assign
pr_title: Custom title for the pull request (optional)
base_domain: The base domain for the git server (defaults to "github.com" for GitHub and "gitlab.com" for GitLab)
"""
if pr_type not in ['branch', 'draft', 'ready']:
raise ValueError(f'Invalid pr_type: {pr_type}')
# Determine default base_domain based on platform
if base_domain is None:
base_domain = 'github.com' if platform == Platform.GITHUB else 'gitlab.com'
handler = None
if platform == Platform.GITHUB:
handler = ServiceContextIssue(
GithubIssueHandler(issue.owner, issue.repo, token, username), None
GithubIssueHandler(issue.owner, issue.repo, token, username, base_domain),
None,
)
else: # platform == Platform.GITLAB
handler = ServiceContextIssue(
GitlabIssueHandler(issue.owner, issue.repo, token, username), None
GitlabIssueHandler(issue.owner, issue.repo, token, username, base_domain),
None,
)
# Create a new branch with a unique name
@ -363,6 +371,7 @@ def update_existing_pull_request(
llm_config: LLMConfig,
comment_message: str | None = None,
additional_message: str | None = None,
base_domain: str | None = None,
) -> str:
"""Update an existing pull request with the new patches.
@ -375,17 +384,24 @@ def update_existing_pull_request(
llm_config: The LLM configuration to use for summarizing changes.
comment_message: The main message to post as a comment on the PR.
additional_message: The additional messages to post as a comment on the PR in json list format.
base_domain: The base domain for the git server (defaults to "github.com" for GitHub and "gitlab.com" for GitLab)
"""
# Set up headers and base URL for GitHub or GitLab API
# Determine default base_domain based on platform
if base_domain is None:
base_domain = 'github.com' if platform == Platform.GITHUB else 'gitlab.com'
handler = None
if platform == Platform.GITHUB:
handler = ServiceContextIssue(
GithubIssueHandler(issue.owner, issue.repo, token, username), llm_config
GithubIssueHandler(issue.owner, issue.repo, token, username, base_domain),
llm_config,
)
else: # platform == Platform.GITLAB
handler = ServiceContextIssue(
GitlabIssueHandler(issue.owner, issue.repo, token, username), llm_config
GitlabIssueHandler(issue.owner, issue.repo, token, username, base_domain),
llm_config,
)
branch_name = issue.head_branch
@ -468,7 +484,11 @@ def process_single_issue(
target_branch: str | None = None,
reviewer: str | None = None,
pr_title: str | None = None,
base_domain: str | None = None,
) -> None:
# Determine default base_domain based on platform
if base_domain is None:
base_domain = 'github.com' if platform == Platform.GITHUB else 'gitlab.com'
if not resolver_output.success and not send_on_failure:
logger.info(
f'Issue {resolver_output.issue.number} was not successfully resolved. Skipping PR creation.'
@ -507,6 +527,7 @@ def process_single_issue(
patch_dir=patched_repo_dir,
additional_message=resolver_output.result_explanation,
llm_config=llm_config,
base_domain=base_domain,
)
else:
send_pull_request(
@ -521,6 +542,7 @@ def process_single_issue(
target_branch=target_branch,
reviewer=reviewer,
pr_title=pr_title,
base_domain=base_domain,
)
@ -532,7 +554,11 @@ def process_all_successful_issues(
pr_type: str,
llm_config: LLMConfig,
fork_owner: str | None,
base_domain: str | None = None,
) -> None:
# Determine default base_domain based on platform
if base_domain is None:
base_domain = 'github.com' if platform == Platform.GITHUB else 'gitlab.com'
output_path = os.path.join(output_dir, 'output.jsonl')
for resolver_output in load_all_resolver_outputs(output_path):
if resolver_output.success:
@ -548,6 +574,9 @@ def process_all_successful_issues(
fork_owner,
False,
None,
None,
None,
base_domain,
)
@ -633,6 +662,12 @@ def main() -> None:
help='Custom title for the pull request',
default=None,
)
parser.add_argument(
'--base-domain',
type=str,
default=None,
help='Base domain for the git server (defaults to "github.com" for GitHub and "gitlab.com" for GitLab)',
)
my_args = parser.parse_args()
token = my_args.token or os.getenv('GITHUB_TOKEN') or os.getenv('GITLAB_TOKEN')
@ -642,7 +677,7 @@ def main() -> None:
)
username = my_args.username if my_args.username else os.getenv('GIT_USERNAME')
platform = identify_token(token)
platform = identify_token(token, None, my_args.base_domain)
if platform == Platform.INVALID:
raise ValueError('Token is invalid.')
@ -667,6 +702,7 @@ def main() -> None:
my_args.pr_type,
llm_config,
my_args.fork_owner,
my_args.base_domain,
)
else:
if not my_args.issue_number.isdigit():
@ -689,6 +725,7 @@ def main() -> None:
my_args.target_branch,
my_args.reviewer,
my_args.pr_title,
my_args.base_domain,
)

View File

@ -20,22 +20,31 @@ class Platform(Enum):
GITLAB = 2
def identify_token(token: str, selected_repo: str | None = None) -> Platform:
def identify_token(
token: str, selected_repo: str | None = None, base_domain: str = 'github.com'
) -> Platform:
"""
Identifies whether a token belongs to GitHub or GitLab.
Parameters:
token (str): The personal access token to check.
selected_repo (str): Repository in format "owner/repo" for GitHub Actions token validation.
base_domain (str): The base domain for GitHub Enterprise (default: "github.com").
Returns:
Platform: "GitHub" if the token is valid for GitHub,
"GitLab" if the token is valid for GitLab,
"Invalid" if the token is not recognized by either.
"""
# Determine GitHub API base URL based on domain
if base_domain == 'github.com':
github_api_base = 'https://api.github.com'
else:
github_api_base = f'https://{base_domain}/api/v3'
# Try GitHub Actions token format (Bearer) with repo endpoint if repo is provided
if selected_repo:
github_repo_url = f'https://api.github.com/repos/{selected_repo}'
github_repo_url = f'{github_api_base}/repos/{selected_repo}'
github_bearer_headers = {
'Authorization': f'Bearer {token}',
'Accept': 'application/vnd.github+json',
@ -51,7 +60,7 @@ def identify_token(token: str, selected_repo: str | None = None) -> Platform:
logger.error(f'Error connecting to GitHub API (selected_repo check): {e}')
# Try GitHub PAT format (token)
github_url = 'https://api.github.com/user'
github_url = f'{github_api_base}/user'
github_headers = {'Authorization': f'token {token}'}
try:
@ -61,7 +70,6 @@ def identify_token(token: str, selected_repo: str | None = None) -> Platform:
except httpx.HTTPError as e:
logger.error(f'Error connecting to GitHub API: {e}')
# Try GitLab token
gitlab_url = 'https://gitlab.com/api/v4/user'
gitlab_headers = {'Authorization': f'Bearer {token}'}

View File

@ -1,6 +1,6 @@
import os
import tempfile
from unittest.mock import MagicMock, call, patch
from unittest.mock import ANY, MagicMock, call, patch
import pytest
@ -884,6 +884,7 @@ def test_process_single_pr_update(
patch_dir=f'{mock_output_dir}/patches/pr_1',
additional_message='[Test success 1]',
llm_config=mock_llm_config,
base_domain='github.com',
)
@ -965,6 +966,7 @@ def test_process_single_issue(
target_branch=None,
reviewer=None,
pr_title=None,
base_domain='github.com',
)
@ -1126,6 +1128,9 @@ def test_process_all_successful_issues(
None,
False,
None,
None,
None,
'github.com',
),
call(
'output_dir',
@ -1138,6 +1143,9 @@ def test_process_all_successful_issues(
None,
False,
None,
None,
None,
'github.com',
),
]
)
@ -1260,7 +1268,7 @@ def test_main(
# Run main function
main()
mock_identify_token.assert_called_with('mock_token')
mock_identify_token.assert_called_with('mock_token', None, ANY)
llm_config = LLMConfig(
model=mock_args.llm_model,
@ -1282,6 +1290,7 @@ def test_main(
mock_args.target_branch,
mock_args.reviewer,
mock_args.pr_title,
ANY,
)
# Other assertions
@ -1301,6 +1310,7 @@ def test_main(
'draft',
llm_config,
None,
ANY,
)
# Test for invalid issue number

View File

@ -1,6 +1,6 @@
import os
import tempfile
from unittest.mock import MagicMock, call, patch
from unittest.mock import ANY, MagicMock, call, patch
from urllib.parse import quote
import pytest
@ -785,6 +785,7 @@ def test_process_single_pr_update(
patch_dir=f'{mock_output_dir}/patches/pr_1',
additional_message='[Test success 1]',
llm_config=mock_llm_config,
base_domain='gitlab.com',
)
@ -866,6 +867,7 @@ def test_process_single_issue(
target_branch=None,
reviewer=None,
pr_title=None,
base_domain='gitlab.com',
)
@ -1027,6 +1029,9 @@ def test_process_all_successful_issues(
None,
False,
None,
None,
None,
'gitlab.com',
),
call(
'output_dir',
@ -1039,6 +1044,9 @@ def test_process_all_successful_issues(
None,
False,
None,
None,
None,
'gitlab.com',
),
]
)
@ -1162,7 +1170,7 @@ def test_main(
# Run main function
main()
mock_identify_token.assert_called_with('mock_token')
mock_identify_token.assert_called_with('mock_token', None, ANY)
llm_config = LLMConfig(
model=mock_args.llm_model,
@ -1184,6 +1192,7 @@ def test_main(
mock_args.target_branch,
mock_args.reviewer,
mock_args.pr_title,
ANY,
)
# Other assertions
@ -1203,6 +1212,7 @@ def test_main(
'draft',
llm_config,
None,
ANY,
)
# Test for invalid issue number