mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Add base_domain parameter for GitHub Enterprise support (#7754)
Co-authored-by: Tom Deckers <tdeckers@cisco.com> Co-authored-by: Robert Brennan <accounts@rbren.io> Co-authored-by: Rohit Malhotra <rohitvinodmalhotra@gmail.com>
This commit is contained in:
parent
d7e8f843ad
commit
7e14a512e0
@ -12,11 +12,21 @@ from openhands.resolver.utils import extract_issue_references
|
||||
|
||||
|
||||
class GithubIssueHandler(IssueHandlerInterface):
|
||||
def __init__(self, owner: str, repo: str, token: str, username: str | None = None):
|
||||
def __init__(self, owner: str, repo: str, token: str, username: str | None = None, base_domain: str = "github.com"):
|
||||
"""Initialize a GitHub issue handler.
|
||||
|
||||
Args:
|
||||
owner: The owner of the repository
|
||||
repo: The name of the repository
|
||||
token: The GitHub personal access token
|
||||
username: Optional GitHub username
|
||||
base_domain: The domain for GitHub Enterprise (default: "github.com")
|
||||
"""
|
||||
self.owner = owner
|
||||
self.repo = repo
|
||||
self.token = token
|
||||
self.username = username
|
||||
self.base_domain = base_domain
|
||||
self.base_url = self.get_base_url()
|
||||
self.download_url = self.get_download_url()
|
||||
self.clone_url = self.get_clone_url()
|
||||
@ -32,10 +42,13 @@ class GithubIssueHandler(IssueHandlerInterface):
|
||||
}
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return f'https://api.github.com/repos/{self.owner}/{self.repo}'
|
||||
if self.base_domain == "github.com":
|
||||
return f'https://api.github.com/repos/{self.owner}/{self.repo}'
|
||||
else:
|
||||
return f'https://{self.base_domain}/api/v3/repos/{self.owner}/{self.repo}'
|
||||
|
||||
def get_authorize_url(self) -> str:
|
||||
return f'https://{self.username}:{self.token}@github.com/'
|
||||
return f'https://{self.username}:{self.token}@{self.base_domain}/'
|
||||
|
||||
def get_branch_url(self, branch_name: str) -> str:
|
||||
return self.get_base_url() + f'/branches/{branch_name}'
|
||||
@ -49,13 +62,16 @@ class GithubIssueHandler(IssueHandlerInterface):
|
||||
if self.username
|
||||
else f'x-auth-token:{self.token}'
|
||||
)
|
||||
return f'https://{username_and_token}@github.com/{self.owner}/{self.repo}.git'
|
||||
return f'https://{username_and_token}@{self.base_domain}/{self.owner}/{self.repo}.git'
|
||||
|
||||
def get_graphql_url(self) -> str:
|
||||
return 'https://api.github.com/graphql'
|
||||
if self.base_domain == "github.com":
|
||||
return 'https://api.github.com/graphql'
|
||||
else:
|
||||
return f'https://{self.base_domain}/api/v3/graphql'
|
||||
|
||||
def get_compare_url(self, branch_name: str) -> str:
|
||||
return f'https://github.com/{self.owner}/{self.repo}/compare/{branch_name}?expand=1'
|
||||
return f'https://{self.base_domain}/{self.owner}/{self.repo}/compare/{branch_name}?expand=1'
|
||||
|
||||
def get_converted_issues(
|
||||
self, issue_numbers: list[int] | None = None, comment_id: int | None = None
|
||||
@ -220,7 +236,7 @@ class GithubIssueHandler(IssueHandlerInterface):
|
||||
response.raise_for_status()
|
||||
|
||||
def get_pull_url(self, pr_number: int) -> str:
|
||||
return f'https://github.com/{self.owner}/{self.repo}/pull/{pr_number}'
|
||||
return f'https://{self.base_domain}/{self.owner}/{self.repo}/pull/{pr_number}'
|
||||
|
||||
def get_default_branch_name(self) -> str:
|
||||
response = httpx.get(f'{self.base_url}', headers=self.headers)
|
||||
@ -286,11 +302,21 @@ class GithubIssueHandler(IssueHandlerInterface):
|
||||
|
||||
|
||||
class GithubPRHandler(GithubIssueHandler):
|
||||
def __init__(self, owner: str, repo: str, token: str, username: str | None = None):
|
||||
super().__init__(owner, repo, token, username)
|
||||
self.download_url = (
|
||||
f'https://api.github.com/repos/{self.owner}/{self.repo}/pulls'
|
||||
)
|
||||
def __init__(self, owner: str, repo: str, token: str, username: str | None = None, base_domain: str = "github.com"):
|
||||
"""Initialize a GitHub PR handler.
|
||||
|
||||
Args:
|
||||
owner: The owner of the repository
|
||||
repo: The name of the repository
|
||||
token: The GitHub personal access token
|
||||
username: Optional GitHub username
|
||||
base_domain: The domain for GitHub Enterprise (default: "github.com")
|
||||
"""
|
||||
super().__init__(owner, repo, token, username, base_domain)
|
||||
if self.base_domain == "github.com":
|
||||
self.download_url = f'https://api.github.com/repos/{self.owner}/{self.repo}/pulls'
|
||||
else:
|
||||
self.download_url = f'https://{self.base_domain}/api/v3/repos/{self.owner}/{self.repo}/pulls'
|
||||
|
||||
def download_pr_metadata(
|
||||
self, pull_number: int, comment_id: int | None = None
|
||||
@ -356,7 +382,7 @@ class GithubPRHandler(GithubIssueHandler):
|
||||
|
||||
variables = {'owner': self.owner, 'repo': self.repo, 'pr': pull_number}
|
||||
|
||||
url = 'https://api.github.com/graphql'
|
||||
url = self.get_graphql_url()
|
||||
headers = {
|
||||
'Authorization': f'Bearer {self.token}',
|
||||
'Content-Type': 'application/json',
|
||||
@ -444,7 +470,10 @@ class GithubPRHandler(GithubIssueHandler):
|
||||
self, pr_number: int, comment_id: int | None = None
|
||||
) -> list[str] | None:
|
||||
"""Download comments for a specific pull request from Github."""
|
||||
url = f'https://api.github.com/repos/{self.owner}/{self.repo}/issues/{pr_number}/comments'
|
||||
if self.base_domain == "github.com":
|
||||
url = f'https://api.github.com/repos/{self.owner}/{self.repo}/issues/{pr_number}/comments'
|
||||
else:
|
||||
url = f'https://{self.base_domain}/api/v3/repos/{self.owner}/{self.repo}/issues/{pr_number}/comments'
|
||||
headers = {
|
||||
'Authorization': f'token {self.token}',
|
||||
'Accept': 'application/vnd.github.v3+json',
|
||||
@ -513,7 +542,10 @@ class GithubPRHandler(GithubIssueHandler):
|
||||
|
||||
for issue_number in unique_issue_references:
|
||||
try:
|
||||
url = f'https://api.github.com/repos/{self.owner}/{self.repo}/issues/{issue_number}'
|
||||
if self.base_domain == "github.com":
|
||||
url = f'https://api.github.com/repos/{self.owner}/{self.repo}/issues/{issue_number}'
|
||||
else:
|
||||
url = f'https://{self.base_domain}/api/v3/repos/{self.owner}/{self.repo}/issues/{issue_number}'
|
||||
headers = {
|
||||
'Authorization': f'Bearer {self.token}',
|
||||
'Accept': 'application/vnd.github.v3+json',
|
||||
|
||||
@ -13,11 +13,28 @@ from openhands.resolver.utils import extract_issue_references
|
||||
|
||||
|
||||
class GitlabIssueHandler(IssueHandlerInterface):
|
||||
def __init__(self, owner: str, repo: str, token: str, username: str | None = None):
|
||||
def __init__(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
token: str,
|
||||
username: str | None = None,
|
||||
base_domain: str = 'gitlab.com',
|
||||
):
|
||||
"""Initialize a GitLab issue handler.
|
||||
|
||||
Args:
|
||||
owner: The owner of the repository
|
||||
repo: The name of the repository
|
||||
token: The GitLab personal access token
|
||||
username: Optional GitLab username
|
||||
base_domain: The domain for GitLab Enterprise (default: "gitlab.com")
|
||||
"""
|
||||
self.owner = owner
|
||||
self.repo = repo
|
||||
self.token = token
|
||||
self.username = username
|
||||
self.base_domain = base_domain
|
||||
self.base_url = self.get_base_url()
|
||||
self.download_url = self.get_download_url()
|
||||
self.clone_url = self.get_clone_url()
|
||||
@ -34,10 +51,10 @@ class GitlabIssueHandler(IssueHandlerInterface):
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
project_path = quote(f'{self.owner}/{self.repo}', safe='')
|
||||
return f'https://gitlab.com/api/v4/projects/{project_path}'
|
||||
return f'https://{self.base_domain}/api/v4/projects/{project_path}'
|
||||
|
||||
def get_authorize_url(self) -> str:
|
||||
return f'https://{self.username}:{self.token}@gitlab.com/'
|
||||
return f'https://{self.username}:{self.token}@{self.base_domain}/'
|
||||
|
||||
def get_branch_url(self, branch_name: str) -> str:
|
||||
return self.get_base_url() + f'/repository/branches/{branch_name}'
|
||||
@ -49,13 +66,13 @@ class GitlabIssueHandler(IssueHandlerInterface):
|
||||
username_and_token = self.token
|
||||
if self.username:
|
||||
username_and_token = f'{self.username}:{self.token}'
|
||||
return f'https://{username_and_token}@gitlab.com/{self.owner}/{self.repo}.git'
|
||||
return f'https://{username_and_token}@{self.base_domain}/{self.owner}/{self.repo}.git'
|
||||
|
||||
def get_graphql_url(self) -> str:
|
||||
return 'https://gitlab.com/api/graphql'
|
||||
return f'https://{self.base_domain}/api/graphql'
|
||||
|
||||
def get_compare_url(self, branch_name: str) -> str:
|
||||
return f'https://gitlab.com/{self.owner}/{self.repo}/-/compare/{self.get_default_branch_name()}...{branch_name}'
|
||||
return f'https://{self.base_domain}/{self.owner}/{self.repo}/-/compare/{self.get_default_branch_name()}...{branch_name}'
|
||||
|
||||
def get_converted_issues(
|
||||
self, issue_numbers: list[int] | None = None, comment_id: int | None = None
|
||||
@ -215,9 +232,7 @@ class GitlabIssueHandler(IssueHandlerInterface):
|
||||
response.raise_for_status()
|
||||
|
||||
def get_pull_url(self, pr_number: int) -> str:
|
||||
return (
|
||||
f'https://gitlab.com/{self.owner}/{self.repo}/-/merge_requests/{pr_number}'
|
||||
)
|
||||
return f'https://{self.base_domain}/{self.owner}/{self.repo}/-/merge_requests/{pr_number}'
|
||||
|
||||
def get_default_branch_name(self) -> str:
|
||||
response = httpx.get(f'{self.base_url}', headers=self.headers)
|
||||
@ -248,7 +263,7 @@ class GitlabIssueHandler(IssueHandlerInterface):
|
||||
|
||||
def request_reviewers(self, reviewer: str, pr_number: int) -> None:
|
||||
response = httpx.get(
|
||||
f'https://gitlab.com/api/v4/users?username={reviewer}',
|
||||
f'https://{self.base_domain}/api/v4/users?username={reviewer}',
|
||||
headers=self.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
@ -298,8 +313,24 @@ class GitlabIssueHandler(IssueHandlerInterface):
|
||||
|
||||
|
||||
class GitlabPRHandler(GitlabIssueHandler):
|
||||
def __init__(self, owner: str, repo: str, token: str, username: str | None = None):
|
||||
super().__init__(owner, repo, token, username)
|
||||
def __init__(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
token: str,
|
||||
username: str | None = None,
|
||||
base_domain: str = 'gitlab.com',
|
||||
):
|
||||
"""Initialize a GitLab PR handler.
|
||||
|
||||
Args:
|
||||
owner: The owner of the repository
|
||||
repo: The name of the repository
|
||||
token: The GitLab personal access token
|
||||
username: Optional GitLab username
|
||||
base_domain: The domain for GitLab Enterprise (default: "gitlab.com")
|
||||
"""
|
||||
super().__init__(owner, repo, token, username, base_domain)
|
||||
self.download_url = f'{self.base_url}/merge_requests'
|
||||
|
||||
def download_pr_metadata(
|
||||
|
||||
@ -66,6 +66,7 @@ async def resolve_issues(
|
||||
issue_type: str,
|
||||
repo_instruction: str | None,
|
||||
issue_numbers: list[int] | None,
|
||||
base_domain: str = 'github.com',
|
||||
) -> None:
|
||||
"""Resolve multiple github or gitlab issues.
|
||||
|
||||
@ -86,7 +87,7 @@ async def resolve_issues(
|
||||
issue_numbers: List of issue numbers to resolve.
|
||||
"""
|
||||
issue_handler = issue_handler_factory(
|
||||
issue_type, owner, repo, token, llm_config, platform
|
||||
issue_type, owner, repo, token, llm_config, platform, username, base_domain
|
||||
)
|
||||
|
||||
# Load dataset
|
||||
@ -323,6 +324,12 @@ def main() -> None:
|
||||
choices=['issue', 'pr'],
|
||||
help='Type of issue to resolve, either open issue or pr comments.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--base-domain',
|
||||
type=str,
|
||||
default='github.com',
|
||||
help='Base domain for GitHub Enterprise (default: github.com)',
|
||||
)
|
||||
|
||||
my_args = parser.parse_args()
|
||||
|
||||
@ -339,7 +346,7 @@ def main() -> None:
|
||||
if not token:
|
||||
raise ValueError('Token is required.')
|
||||
|
||||
platform = identify_token(token, my_args.selected_repo)
|
||||
platform = identify_token(token, my_args.selected_repo, my_args.base_domain)
|
||||
if platform == Platform.INVALID:
|
||||
raise ValueError('Token is invalid.')
|
||||
|
||||
@ -394,6 +401,7 @@ def main() -> None:
|
||||
issue_type=issue_type,
|
||||
repo_instruction=repo_instruction,
|
||||
issue_numbers=issue_numbers,
|
||||
base_domain=my_args.base_domain,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -322,24 +322,30 @@ def issue_handler_factory(
|
||||
llm_config: LLMConfig,
|
||||
platform: Platform,
|
||||
username: str | None = None,
|
||||
base_domain: str | None = None,
|
||||
) -> ServiceContextIssue | ServiceContextPR:
|
||||
# Determine default base_domain based on platform
|
||||
if base_domain is None:
|
||||
base_domain = 'github.com' if platform == Platform.GITHUB else 'gitlab.com'
|
||||
if issue_type == 'issue':
|
||||
if platform == Platform.GITHUB:
|
||||
return ServiceContextIssue(
|
||||
GithubIssueHandler(owner, repo, token, username), llm_config
|
||||
GithubIssueHandler(owner, repo, token, username, base_domain),
|
||||
llm_config,
|
||||
)
|
||||
else: # platform == Platform.GITLAB
|
||||
return ServiceContextIssue(
|
||||
GitlabIssueHandler(owner, repo, token, username), llm_config
|
||||
GitlabIssueHandler(owner, repo, token, username, base_domain),
|
||||
llm_config,
|
||||
)
|
||||
elif issue_type == 'pr':
|
||||
if platform == Platform.GITHUB:
|
||||
return ServiceContextPR(
|
||||
GithubPRHandler(owner, repo, token, username), llm_config
|
||||
GithubPRHandler(owner, repo, token, username, base_domain), llm_config
|
||||
)
|
||||
else: # platform == Platform.GITLAB
|
||||
return ServiceContextPR(
|
||||
GitlabPRHandler(owner, repo, token, username), llm_config
|
||||
GitlabPRHandler(owner, repo, token, username, base_domain), llm_config
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'Invalid issue type: {issue_type}')
|
||||
@ -361,6 +367,7 @@ async def resolve_issue(
|
||||
issue_number: int,
|
||||
comment_id: int | None,
|
||||
reset_logger: bool = False,
|
||||
base_domain: str | None = None,
|
||||
) -> None:
|
||||
"""Resolve a single issue.
|
||||
|
||||
@ -379,11 +386,15 @@ async def resolve_issue(
|
||||
repo_instruction: Repository instruction to use.
|
||||
issue_number: Issue number to resolve.
|
||||
comment_id: Optional ID of a specific comment to focus on.
|
||||
|
||||
reset_logger: Whether to reset the logger for multiprocessing.
|
||||
base_domain: The base domain for the git server (defaults to "github.com" for GitHub and "gitlab.com" for GitLab)
|
||||
"""
|
||||
# Determine default base_domain based on platform
|
||||
if base_domain is None:
|
||||
base_domain = 'github.com' if platform == Platform.GITHUB else 'gitlab.com'
|
||||
|
||||
issue_handler = issue_handler_factory(
|
||||
issue_type, owner, repo, token, llm_config, platform, username
|
||||
issue_type, owner, repo, token, llm_config, platform, username, base_domain
|
||||
)
|
||||
|
||||
# Load dataset
|
||||
@ -629,6 +640,12 @@ def main() -> None:
|
||||
type=lambda x: x.lower() == 'true',
|
||||
help='Whether to run in experimental mode.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--base-domain',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Base domain for the git server (defaults to "github.com" for GitHub and "gitlab.com" for GitLab)',
|
||||
)
|
||||
|
||||
my_args = parser.parse_args()
|
||||
|
||||
@ -651,7 +668,7 @@ def main() -> None:
|
||||
if not token:
|
||||
raise ValueError('Token is required.')
|
||||
|
||||
platform = identify_token(token, my_args.selected_repo)
|
||||
platform = identify_token(token, my_args.selected_repo, my_args.base_domain)
|
||||
if platform == Platform.INVALID:
|
||||
raise ValueError('Token is invalid.')
|
||||
|
||||
@ -708,6 +725,7 @@ def main() -> None:
|
||||
repo_instruction=repo_instruction,
|
||||
issue_number=my_args.issue_number,
|
||||
comment_id=my_args.comment_id,
|
||||
base_domain=my_args.base_domain,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -235,6 +235,7 @@ def send_pull_request(
|
||||
target_branch: str | None = None,
|
||||
reviewer: str | None = None,
|
||||
pr_title: str | None = None,
|
||||
base_domain: str | None = None,
|
||||
) -> str:
|
||||
"""Send a pull request to a GitHub or Gitlab repository.
|
||||
|
||||
@ -250,18 +251,25 @@ def send_pull_request(
|
||||
target_branch: The target branch to create the pull request against (defaults to repository default branch)
|
||||
reviewer: The GitHub or Gitlab username of the reviewer to assign
|
||||
pr_title: Custom title for the pull request (optional)
|
||||
base_domain: The base domain for the git server (defaults to "github.com" for GitHub and "gitlab.com" for GitLab)
|
||||
"""
|
||||
if pr_type not in ['branch', 'draft', 'ready']:
|
||||
raise ValueError(f'Invalid pr_type: {pr_type}')
|
||||
|
||||
# Determine default base_domain based on platform
|
||||
if base_domain is None:
|
||||
base_domain = 'github.com' if platform == Platform.GITHUB else 'gitlab.com'
|
||||
|
||||
handler = None
|
||||
if platform == Platform.GITHUB:
|
||||
handler = ServiceContextIssue(
|
||||
GithubIssueHandler(issue.owner, issue.repo, token, username), None
|
||||
GithubIssueHandler(issue.owner, issue.repo, token, username, base_domain),
|
||||
None,
|
||||
)
|
||||
else: # platform == Platform.GITLAB
|
||||
handler = ServiceContextIssue(
|
||||
GitlabIssueHandler(issue.owner, issue.repo, token, username), None
|
||||
GitlabIssueHandler(issue.owner, issue.repo, token, username, base_domain),
|
||||
None,
|
||||
)
|
||||
|
||||
# Create a new branch with a unique name
|
||||
@ -363,6 +371,7 @@ def update_existing_pull_request(
|
||||
llm_config: LLMConfig,
|
||||
comment_message: str | None = None,
|
||||
additional_message: str | None = None,
|
||||
base_domain: str | None = None,
|
||||
) -> str:
|
||||
"""Update an existing pull request with the new patches.
|
||||
|
||||
@ -375,17 +384,24 @@ def update_existing_pull_request(
|
||||
llm_config: The LLM configuration to use for summarizing changes.
|
||||
comment_message: The main message to post as a comment on the PR.
|
||||
additional_message: The additional messages to post as a comment on the PR in json list format.
|
||||
base_domain: The base domain for the git server (defaults to "github.com" for GitHub and "gitlab.com" for GitLab)
|
||||
"""
|
||||
# Set up headers and base URL for GitHub or GitLab API
|
||||
|
||||
# Determine default base_domain based on platform
|
||||
if base_domain is None:
|
||||
base_domain = 'github.com' if platform == Platform.GITHUB else 'gitlab.com'
|
||||
|
||||
handler = None
|
||||
if platform == Platform.GITHUB:
|
||||
handler = ServiceContextIssue(
|
||||
GithubIssueHandler(issue.owner, issue.repo, token, username), llm_config
|
||||
GithubIssueHandler(issue.owner, issue.repo, token, username, base_domain),
|
||||
llm_config,
|
||||
)
|
||||
else: # platform == Platform.GITLAB
|
||||
handler = ServiceContextIssue(
|
||||
GitlabIssueHandler(issue.owner, issue.repo, token, username), llm_config
|
||||
GitlabIssueHandler(issue.owner, issue.repo, token, username, base_domain),
|
||||
llm_config,
|
||||
)
|
||||
|
||||
branch_name = issue.head_branch
|
||||
@ -468,7 +484,11 @@ def process_single_issue(
|
||||
target_branch: str | None = None,
|
||||
reviewer: str | None = None,
|
||||
pr_title: str | None = None,
|
||||
base_domain: str | None = None,
|
||||
) -> None:
|
||||
# Determine default base_domain based on platform
|
||||
if base_domain is None:
|
||||
base_domain = 'github.com' if platform == Platform.GITHUB else 'gitlab.com'
|
||||
if not resolver_output.success and not send_on_failure:
|
||||
logger.info(
|
||||
f'Issue {resolver_output.issue.number} was not successfully resolved. Skipping PR creation.'
|
||||
@ -507,6 +527,7 @@ def process_single_issue(
|
||||
patch_dir=patched_repo_dir,
|
||||
additional_message=resolver_output.result_explanation,
|
||||
llm_config=llm_config,
|
||||
base_domain=base_domain,
|
||||
)
|
||||
else:
|
||||
send_pull_request(
|
||||
@ -521,6 +542,7 @@ def process_single_issue(
|
||||
target_branch=target_branch,
|
||||
reviewer=reviewer,
|
||||
pr_title=pr_title,
|
||||
base_domain=base_domain,
|
||||
)
|
||||
|
||||
|
||||
@ -532,7 +554,11 @@ def process_all_successful_issues(
|
||||
pr_type: str,
|
||||
llm_config: LLMConfig,
|
||||
fork_owner: str | None,
|
||||
base_domain: str | None = None,
|
||||
) -> None:
|
||||
# Determine default base_domain based on platform
|
||||
if base_domain is None:
|
||||
base_domain = 'github.com' if platform == Platform.GITHUB else 'gitlab.com'
|
||||
output_path = os.path.join(output_dir, 'output.jsonl')
|
||||
for resolver_output in load_all_resolver_outputs(output_path):
|
||||
if resolver_output.success:
|
||||
@ -548,6 +574,9 @@ def process_all_successful_issues(
|
||||
fork_owner,
|
||||
False,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
base_domain,
|
||||
)
|
||||
|
||||
|
||||
@ -633,6 +662,12 @@ def main() -> None:
|
||||
help='Custom title for the pull request',
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
'--base-domain',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Base domain for the git server (defaults to "github.com" for GitHub and "gitlab.com" for GitLab)',
|
||||
)
|
||||
my_args = parser.parse_args()
|
||||
|
||||
token = my_args.token or os.getenv('GITHUB_TOKEN') or os.getenv('GITLAB_TOKEN')
|
||||
@ -642,7 +677,7 @@ def main() -> None:
|
||||
)
|
||||
username = my_args.username if my_args.username else os.getenv('GIT_USERNAME')
|
||||
|
||||
platform = identify_token(token)
|
||||
platform = identify_token(token, None, my_args.base_domain)
|
||||
if platform == Platform.INVALID:
|
||||
raise ValueError('Token is invalid.')
|
||||
|
||||
@ -667,6 +702,7 @@ def main() -> None:
|
||||
my_args.pr_type,
|
||||
llm_config,
|
||||
my_args.fork_owner,
|
||||
my_args.base_domain,
|
||||
)
|
||||
else:
|
||||
if not my_args.issue_number.isdigit():
|
||||
@ -689,6 +725,7 @@ def main() -> None:
|
||||
my_args.target_branch,
|
||||
my_args.reviewer,
|
||||
my_args.pr_title,
|
||||
my_args.base_domain,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -20,22 +20,31 @@ class Platform(Enum):
|
||||
GITLAB = 2
|
||||
|
||||
|
||||
def identify_token(token: str, selected_repo: str | None = None) -> Platform:
|
||||
def identify_token(
|
||||
token: str, selected_repo: str | None = None, base_domain: str = 'github.com'
|
||||
) -> Platform:
|
||||
"""
|
||||
Identifies whether a token belongs to GitHub or GitLab.
|
||||
|
||||
Parameters:
|
||||
token (str): The personal access token to check.
|
||||
selected_repo (str): Repository in format "owner/repo" for GitHub Actions token validation.
|
||||
base_domain (str): The base domain for GitHub Enterprise (default: "github.com").
|
||||
|
||||
Returns:
|
||||
Platform: "GitHub" if the token is valid for GitHub,
|
||||
"GitLab" if the token is valid for GitLab,
|
||||
"Invalid" if the token is not recognized by either.
|
||||
"""
|
||||
# Determine GitHub API base URL based on domain
|
||||
if base_domain == 'github.com':
|
||||
github_api_base = 'https://api.github.com'
|
||||
else:
|
||||
github_api_base = f'https://{base_domain}/api/v3'
|
||||
|
||||
# Try GitHub Actions token format (Bearer) with repo endpoint if repo is provided
|
||||
if selected_repo:
|
||||
github_repo_url = f'https://api.github.com/repos/{selected_repo}'
|
||||
github_repo_url = f'{github_api_base}/repos/{selected_repo}'
|
||||
github_bearer_headers = {
|
||||
'Authorization': f'Bearer {token}',
|
||||
'Accept': 'application/vnd.github+json',
|
||||
@ -51,7 +60,7 @@ def identify_token(token: str, selected_repo: str | None = None) -> Platform:
|
||||
logger.error(f'Error connecting to GitHub API (selected_repo check): {e}')
|
||||
|
||||
# Try GitHub PAT format (token)
|
||||
github_url = 'https://api.github.com/user'
|
||||
github_url = f'{github_api_base}/user'
|
||||
github_headers = {'Authorization': f'token {token}'}
|
||||
|
||||
try:
|
||||
@ -61,7 +70,6 @@ def identify_token(token: str, selected_repo: str | None = None) -> Platform:
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f'Error connecting to GitHub API: {e}')
|
||||
|
||||
# Try GitLab token
|
||||
gitlab_url = 'https://gitlab.com/api/v4/user'
|
||||
gitlab_headers = {'Authorization': f'Bearer {token}'}
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import os
|
||||
import tempfile
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
from unittest.mock import ANY, MagicMock, call, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@ -884,6 +884,7 @@ def test_process_single_pr_update(
|
||||
patch_dir=f'{mock_output_dir}/patches/pr_1',
|
||||
additional_message='[Test success 1]',
|
||||
llm_config=mock_llm_config,
|
||||
base_domain='github.com',
|
||||
)
|
||||
|
||||
|
||||
@ -965,6 +966,7 @@ def test_process_single_issue(
|
||||
target_branch=None,
|
||||
reviewer=None,
|
||||
pr_title=None,
|
||||
base_domain='github.com',
|
||||
)
|
||||
|
||||
|
||||
@ -1126,6 +1128,9 @@ def test_process_all_successful_issues(
|
||||
None,
|
||||
False,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
'github.com',
|
||||
),
|
||||
call(
|
||||
'output_dir',
|
||||
@ -1138,6 +1143,9 @@ def test_process_all_successful_issues(
|
||||
None,
|
||||
False,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
'github.com',
|
||||
),
|
||||
]
|
||||
)
|
||||
@ -1260,7 +1268,7 @@ def test_main(
|
||||
# Run main function
|
||||
main()
|
||||
|
||||
mock_identify_token.assert_called_with('mock_token')
|
||||
mock_identify_token.assert_called_with('mock_token', None, ANY)
|
||||
|
||||
llm_config = LLMConfig(
|
||||
model=mock_args.llm_model,
|
||||
@ -1282,6 +1290,7 @@ def test_main(
|
||||
mock_args.target_branch,
|
||||
mock_args.reviewer,
|
||||
mock_args.pr_title,
|
||||
ANY,
|
||||
)
|
||||
|
||||
# Other assertions
|
||||
@ -1301,6 +1310,7 @@ def test_main(
|
||||
'draft',
|
||||
llm_config,
|
||||
None,
|
||||
ANY,
|
||||
)
|
||||
|
||||
# Test for invalid issue number
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import os
|
||||
import tempfile
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
from unittest.mock import ANY, MagicMock, call, patch
|
||||
from urllib.parse import quote
|
||||
|
||||
import pytest
|
||||
@ -785,6 +785,7 @@ def test_process_single_pr_update(
|
||||
patch_dir=f'{mock_output_dir}/patches/pr_1',
|
||||
additional_message='[Test success 1]',
|
||||
llm_config=mock_llm_config,
|
||||
base_domain='gitlab.com',
|
||||
)
|
||||
|
||||
|
||||
@ -866,6 +867,7 @@ def test_process_single_issue(
|
||||
target_branch=None,
|
||||
reviewer=None,
|
||||
pr_title=None,
|
||||
base_domain='gitlab.com',
|
||||
)
|
||||
|
||||
|
||||
@ -1027,6 +1029,9 @@ def test_process_all_successful_issues(
|
||||
None,
|
||||
False,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
'gitlab.com',
|
||||
),
|
||||
call(
|
||||
'output_dir',
|
||||
@ -1039,6 +1044,9 @@ def test_process_all_successful_issues(
|
||||
None,
|
||||
False,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
'gitlab.com',
|
||||
),
|
||||
]
|
||||
)
|
||||
@ -1162,7 +1170,7 @@ def test_main(
|
||||
# Run main function
|
||||
main()
|
||||
|
||||
mock_identify_token.assert_called_with('mock_token')
|
||||
mock_identify_token.assert_called_with('mock_token', None, ANY)
|
||||
|
||||
llm_config = LLMConfig(
|
||||
model=mock_args.llm_model,
|
||||
@ -1184,6 +1192,7 @@ def test_main(
|
||||
mock_args.target_branch,
|
||||
mock_args.reviewer,
|
||||
mock_args.pr_title,
|
||||
ANY,
|
||||
)
|
||||
|
||||
# Other assertions
|
||||
@ -1203,6 +1212,7 @@ def test_main(
|
||||
'draft',
|
||||
llm_config,
|
||||
None,
|
||||
ANY,
|
||||
)
|
||||
|
||||
# Test for invalid issue number
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user