diff --git a/openhands/resolver/interfaces/github.py b/openhands/resolver/interfaces/github.py index bd34bf2916..85746b0cda 100644 --- a/openhands/resolver/interfaces/github.py +++ b/openhands/resolver/interfaces/github.py @@ -12,11 +12,21 @@ from openhands.resolver.utils import extract_issue_references class GithubIssueHandler(IssueHandlerInterface): - def __init__(self, owner: str, repo: str, token: str, username: str | None = None): + def __init__(self, owner: str, repo: str, token: str, username: str | None = None, base_domain: str = "github.com"): + """Initialize a GitHub issue handler. + + Args: + owner: The owner of the repository + repo: The name of the repository + token: The GitHub personal access token + username: Optional GitHub username + base_domain: The domain for GitHub Enterprise (default: "github.com") + """ self.owner = owner self.repo = repo self.token = token self.username = username + self.base_domain = base_domain self.base_url = self.get_base_url() self.download_url = self.get_download_url() self.clone_url = self.get_clone_url() @@ -32,10 +42,13 @@ class GithubIssueHandler(IssueHandlerInterface): } def get_base_url(self) -> str: - return f'https://api.github.com/repos/{self.owner}/{self.repo}' + if self.base_domain == "github.com": + return f'https://api.github.com/repos/{self.owner}/{self.repo}' + else: + return f'https://{self.base_domain}/api/v3/repos/{self.owner}/{self.repo}' def get_authorize_url(self) -> str: - return f'https://{self.username}:{self.token}@github.com/' + return f'https://{self.username}:{self.token}@{self.base_domain}/' def get_branch_url(self, branch_name: str) -> str: return self.get_base_url() + f'/branches/{branch_name}' @@ -49,13 +62,16 @@ class GithubIssueHandler(IssueHandlerInterface): if self.username else f'x-auth-token:{self.token}' ) - return f'https://{username_and_token}@github.com/{self.owner}/{self.repo}.git' + return f'https://{username_and_token}@{self.base_domain}/{self.owner}/{self.repo}.git' def get_graphql_url(self) -> str: - return 'https://api.github.com/graphql' + if self.base_domain == "github.com": + return 'https://api.github.com/graphql' + else: + return f'https://{self.base_domain}/api/v3/graphql' def get_compare_url(self, branch_name: str) -> str: - return f'https://github.com/{self.owner}/{self.repo}/compare/{branch_name}?expand=1' + return f'https://{self.base_domain}/{self.owner}/{self.repo}/compare/{branch_name}?expand=1' def get_converted_issues( self, issue_numbers: list[int] | None = None, comment_id: int | None = None @@ -220,7 +236,7 @@ class GithubIssueHandler(IssueHandlerInterface): response.raise_for_status() def get_pull_url(self, pr_number: int) -> str: - return f'https://github.com/{self.owner}/{self.repo}/pull/{pr_number}' + return f'https://{self.base_domain}/{self.owner}/{self.repo}/pull/{pr_number}' def get_default_branch_name(self) -> str: response = httpx.get(f'{self.base_url}', headers=self.headers) @@ -286,11 +302,21 @@ class GithubIssueHandler(IssueHandlerInterface): class GithubPRHandler(GithubIssueHandler): - def __init__(self, owner: str, repo: str, token: str, username: str | None = None): - super().__init__(owner, repo, token, username) - self.download_url = ( - f'https://api.github.com/repos/{self.owner}/{self.repo}/pulls' - ) + def __init__(self, owner: str, repo: str, token: str, username: str | None = None, base_domain: str = "github.com"): + """Initialize a GitHub PR handler. + + Args: + owner: The owner of the repository + repo: The name of the repository + token: The GitHub personal access token + username: Optional GitHub username + base_domain: The domain for GitHub Enterprise (default: "github.com") + """ + super().__init__(owner, repo, token, username, base_domain) + if self.base_domain == "github.com": + self.download_url = f'https://api.github.com/repos/{self.owner}/{self.repo}/pulls' + else: + self.download_url = f'https://{self.base_domain}/api/v3/repos/{self.owner}/{self.repo}/pulls' def download_pr_metadata( self, pull_number: int, comment_id: int | None = None @@ -356,7 +382,7 @@ class GithubPRHandler(GithubIssueHandler): variables = {'owner': self.owner, 'repo': self.repo, 'pr': pull_number} - url = 'https://api.github.com/graphql' + url = self.get_graphql_url() headers = { 'Authorization': f'Bearer {self.token}', 'Content-Type': 'application/json', @@ -444,7 +470,10 @@ class GithubPRHandler(GithubIssueHandler): self, pr_number: int, comment_id: int | None = None ) -> list[str] | None: """Download comments for a specific pull request from Github.""" - url = f'https://api.github.com/repos/{self.owner}/{self.repo}/issues/{pr_number}/comments' + if self.base_domain == "github.com": + url = f'https://api.github.com/repos/{self.owner}/{self.repo}/issues/{pr_number}/comments' + else: + url = f'https://{self.base_domain}/api/v3/repos/{self.owner}/{self.repo}/issues/{pr_number}/comments' headers = { 'Authorization': f'token {self.token}', 'Accept': 'application/vnd.github.v3+json', @@ -513,7 +542,10 @@ class GithubPRHandler(GithubIssueHandler): for issue_number in unique_issue_references: try: - url = f'https://api.github.com/repos/{self.owner}/{self.repo}/issues/{issue_number}' + if self.base_domain == "github.com": + url = f'https://api.github.com/repos/{self.owner}/{self.repo}/issues/{issue_number}' + else: + url = f'https://{self.base_domain}/api/v3/repos/{self.owner}/{self.repo}/issues/{issue_number}' headers = { 'Authorization': f'Bearer {self.token}', 'Accept': 'application/vnd.github.v3+json', diff --git a/openhands/resolver/interfaces/gitlab.py b/openhands/resolver/interfaces/gitlab.py index 2545388758..22ed3c3e06 100644 --- a/openhands/resolver/interfaces/gitlab.py +++ b/openhands/resolver/interfaces/gitlab.py @@ -13,11 +13,28 @@ from openhands.resolver.utils import extract_issue_references class GitlabIssueHandler(IssueHandlerInterface): - def __init__(self, owner: str, repo: str, token: str, username: str | None = None): + def __init__( + self, + owner: str, + repo: str, + token: str, + username: str | None = None, + base_domain: str = 'gitlab.com', + ): + """Initialize a GitLab issue handler. + + Args: + owner: The owner of the repository + repo: The name of the repository + token: The GitLab personal access token + username: Optional GitLab username + base_domain: The domain for GitLab Enterprise (default: "gitlab.com") + """ self.owner = owner self.repo = repo self.token = token self.username = username + self.base_domain = base_domain self.base_url = self.get_base_url() self.download_url = self.get_download_url() self.clone_url = self.get_clone_url() @@ -34,10 +51,10 @@ class GitlabIssueHandler(IssueHandlerInterface): def get_base_url(self) -> str: project_path = quote(f'{self.owner}/{self.repo}', safe='') - return f'https://gitlab.com/api/v4/projects/{project_path}' + return f'https://{self.base_domain}/api/v4/projects/{project_path}' def get_authorize_url(self) -> str: - return f'https://{self.username}:{self.token}@gitlab.com/' + return f'https://{self.username}:{self.token}@{self.base_domain}/' def get_branch_url(self, branch_name: str) -> str: return self.get_base_url() + f'/repository/branches/{branch_name}' @@ -49,13 +66,13 @@ class GitlabIssueHandler(IssueHandlerInterface): username_and_token = self.token if self.username: username_and_token = f'{self.username}:{self.token}' - return f'https://{username_and_token}@gitlab.com/{self.owner}/{self.repo}.git' + return f'https://{username_and_token}@{self.base_domain}/{self.owner}/{self.repo}.git' def get_graphql_url(self) -> str: - return 'https://gitlab.com/api/graphql' + return f'https://{self.base_domain}/api/graphql' def get_compare_url(self, branch_name: str) -> str: - return f'https://gitlab.com/{self.owner}/{self.repo}/-/compare/{self.get_default_branch_name()}...{branch_name}' + return f'https://{self.base_domain}/{self.owner}/{self.repo}/-/compare/{self.get_default_branch_name()}...{branch_name}' def get_converted_issues( self, issue_numbers: list[int] | None = None, comment_id: int | None = None @@ -215,9 +232,7 @@ class GitlabIssueHandler(IssueHandlerInterface): response.raise_for_status() def get_pull_url(self, pr_number: int) -> str: - return ( - f'https://gitlab.com/{self.owner}/{self.repo}/-/merge_requests/{pr_number}' - ) + return f'https://{self.base_domain}/{self.owner}/{self.repo}/-/merge_requests/{pr_number}' def get_default_branch_name(self) -> str: response = httpx.get(f'{self.base_url}', headers=self.headers) @@ -248,7 +263,7 @@ class GitlabIssueHandler(IssueHandlerInterface): def request_reviewers(self, reviewer: str, pr_number: int) -> None: response = httpx.get( - f'https://gitlab.com/api/v4/users?username={reviewer}', + f'https://{self.base_domain}/api/v4/users?username={reviewer}', headers=self.headers, ) response.raise_for_status() @@ -298,8 +313,24 @@ class GitlabIssueHandler(IssueHandlerInterface): class GitlabPRHandler(GitlabIssueHandler): - def __init__(self, owner: str, repo: str, token: str, username: str | None = None): - super().__init__(owner, repo, token, username) + def __init__( + self, + owner: str, + repo: str, + token: str, + username: str | None = None, + base_domain: str = 'gitlab.com', + ): + """Initialize a GitLab PR handler. + + Args: + owner: The owner of the repository + repo: The name of the repository + token: The GitLab personal access token + username: Optional GitLab username + base_domain: The domain for GitLab Enterprise (default: "gitlab.com") + """ + super().__init__(owner, repo, token, username, base_domain) self.download_url = f'{self.base_url}/merge_requests' def download_pr_metadata( diff --git a/openhands/resolver/resolve_all_issues.py b/openhands/resolver/resolve_all_issues.py index 1a5754f9a3..e4b295dbaf 100644 --- a/openhands/resolver/resolve_all_issues.py +++ b/openhands/resolver/resolve_all_issues.py @@ -66,6 +66,7 @@ async def resolve_issues( issue_type: str, repo_instruction: str | None, issue_numbers: list[int] | None, + base_domain: str = 'github.com', ) -> None: """Resolve multiple github or gitlab issues. @@ -86,7 +87,7 @@ async def resolve_issues( issue_numbers: List of issue numbers to resolve. """ issue_handler = issue_handler_factory( - issue_type, owner, repo, token, llm_config, platform + issue_type, owner, repo, token, llm_config, platform, username, base_domain ) # Load dataset @@ -323,6 +324,12 @@ def main() -> None: choices=['issue', 'pr'], help='Type of issue to resolve, either open issue or pr comments.', ) + parser.add_argument( + '--base-domain', + type=str, + default='github.com', + help='Base domain for GitHub Enterprise (default: github.com)', + ) my_args = parser.parse_args() @@ -339,7 +346,7 @@ def main() -> None: if not token: raise ValueError('Token is required.') - platform = identify_token(token, my_args.selected_repo) + platform = identify_token(token, my_args.selected_repo, my_args.base_domain) if platform == Platform.INVALID: raise ValueError('Token is invalid.') @@ -394,6 +401,7 @@ def main() -> None: issue_type=issue_type, repo_instruction=repo_instruction, issue_numbers=issue_numbers, + base_domain=my_args.base_domain, ) ) diff --git a/openhands/resolver/resolve_issue.py b/openhands/resolver/resolve_issue.py index 2aafd7f8c5..4a6691cc93 100644 --- a/openhands/resolver/resolve_issue.py +++ b/openhands/resolver/resolve_issue.py @@ -322,24 +322,30 @@ def issue_handler_factory( llm_config: LLMConfig, platform: Platform, username: str | None = None, + base_domain: str | None = None, ) -> ServiceContextIssue | ServiceContextPR: + # Determine default base_domain based on platform + if base_domain is None: + base_domain = 'github.com' if platform == Platform.GITHUB else 'gitlab.com' if issue_type == 'issue': if platform == Platform.GITHUB: return ServiceContextIssue( - GithubIssueHandler(owner, repo, token, username), llm_config + GithubIssueHandler(owner, repo, token, username, base_domain), + llm_config, ) else: # platform == Platform.GITLAB return ServiceContextIssue( - GitlabIssueHandler(owner, repo, token, username), llm_config + GitlabIssueHandler(owner, repo, token, username, base_domain), + llm_config, ) elif issue_type == 'pr': if platform == Platform.GITHUB: return ServiceContextPR( - GithubPRHandler(owner, repo, token, username), llm_config + GithubPRHandler(owner, repo, token, username, base_domain), llm_config ) else: # platform == Platform.GITLAB return ServiceContextPR( - GitlabPRHandler(owner, repo, token, username), llm_config + GitlabPRHandler(owner, repo, token, username, base_domain), llm_config ) else: raise ValueError(f'Invalid issue type: {issue_type}') @@ -361,6 +367,7 @@ async def resolve_issue( issue_number: int, comment_id: int | None, reset_logger: bool = False, + base_domain: str | None = None, ) -> None: """Resolve a single issue. @@ -379,11 +386,15 @@ async def resolve_issue( repo_instruction: Repository instruction to use. issue_number: Issue number to resolve. comment_id: Optional ID of a specific comment to focus on. - reset_logger: Whether to reset the logger for multiprocessing. + base_domain: The base domain for the git server (defaults to "github.com" for GitHub and "gitlab.com" for GitLab) """ + # Determine default base_domain based on platform + if base_domain is None: + base_domain = 'github.com' if platform == Platform.GITHUB else 'gitlab.com' + issue_handler = issue_handler_factory( - issue_type, owner, repo, token, llm_config, platform, username + issue_type, owner, repo, token, llm_config, platform, username, base_domain ) # Load dataset @@ -629,6 +640,12 @@ def main() -> None: type=lambda x: x.lower() == 'true', help='Whether to run in experimental mode.', ) + parser.add_argument( + '--base-domain', + type=str, + default=None, + help='Base domain for the git server (defaults to "github.com" for GitHub and "gitlab.com" for GitLab)', + ) my_args = parser.parse_args() @@ -651,7 +668,7 @@ def main() -> None: if not token: raise ValueError('Token is required.') - platform = identify_token(token, my_args.selected_repo) + platform = identify_token(token, my_args.selected_repo, my_args.base_domain) if platform == Platform.INVALID: raise ValueError('Token is invalid.') @@ -708,6 +725,7 @@ def main() -> None: repo_instruction=repo_instruction, issue_number=my_args.issue_number, comment_id=my_args.comment_id, + base_domain=my_args.base_domain, ) ) diff --git a/openhands/resolver/send_pull_request.py b/openhands/resolver/send_pull_request.py index e9513d256f..384b440b70 100644 --- a/openhands/resolver/send_pull_request.py +++ b/openhands/resolver/send_pull_request.py @@ -235,6 +235,7 @@ def send_pull_request( target_branch: str | None = None, reviewer: str | None = None, pr_title: str | None = None, + base_domain: str | None = None, ) -> str: """Send a pull request to a GitHub or Gitlab repository. @@ -250,18 +251,25 @@ def send_pull_request( target_branch: The target branch to create the pull request against (defaults to repository default branch) reviewer: The GitHub or Gitlab username of the reviewer to assign pr_title: Custom title for the pull request (optional) + base_domain: The base domain for the git server (defaults to "github.com" for GitHub and "gitlab.com" for GitLab) """ if pr_type not in ['branch', 'draft', 'ready']: raise ValueError(f'Invalid pr_type: {pr_type}') + # Determine default base_domain based on platform + if base_domain is None: + base_domain = 'github.com' if platform == Platform.GITHUB else 'gitlab.com' + handler = None if platform == Platform.GITHUB: handler = ServiceContextIssue( - GithubIssueHandler(issue.owner, issue.repo, token, username), None + GithubIssueHandler(issue.owner, issue.repo, token, username, base_domain), + None, ) else: # platform == Platform.GITLAB handler = ServiceContextIssue( - GitlabIssueHandler(issue.owner, issue.repo, token, username), None + GitlabIssueHandler(issue.owner, issue.repo, token, username, base_domain), + None, ) # Create a new branch with a unique name @@ -363,6 +371,7 @@ def update_existing_pull_request( llm_config: LLMConfig, comment_message: str | None = None, additional_message: str | None = None, + base_domain: str | None = None, ) -> str: """Update an existing pull request with the new patches. @@ -375,17 +384,24 @@ def update_existing_pull_request( llm_config: The LLM configuration to use for summarizing changes. comment_message: The main message to post as a comment on the PR. additional_message: The additional messages to post as a comment on the PR in json list format. + base_domain: The base domain for the git server (defaults to "github.com" for GitHub and "gitlab.com" for GitLab) """ # Set up headers and base URL for GitHub or GitLab API + # Determine default base_domain based on platform + if base_domain is None: + base_domain = 'github.com' if platform == Platform.GITHUB else 'gitlab.com' + handler = None if platform == Platform.GITHUB: handler = ServiceContextIssue( - GithubIssueHandler(issue.owner, issue.repo, token, username), llm_config + GithubIssueHandler(issue.owner, issue.repo, token, username, base_domain), + llm_config, ) else: # platform == Platform.GITLAB handler = ServiceContextIssue( - GitlabIssueHandler(issue.owner, issue.repo, token, username), llm_config + GitlabIssueHandler(issue.owner, issue.repo, token, username, base_domain), + llm_config, ) branch_name = issue.head_branch @@ -468,7 +484,11 @@ def process_single_issue( target_branch: str | None = None, reviewer: str | None = None, pr_title: str | None = None, + base_domain: str | None = None, ) -> None: + # Determine default base_domain based on platform + if base_domain is None: + base_domain = 'github.com' if platform == Platform.GITHUB else 'gitlab.com' if not resolver_output.success and not send_on_failure: logger.info( f'Issue {resolver_output.issue.number} was not successfully resolved. Skipping PR creation.' @@ -507,6 +527,7 @@ def process_single_issue( patch_dir=patched_repo_dir, additional_message=resolver_output.result_explanation, llm_config=llm_config, + base_domain=base_domain, ) else: send_pull_request( @@ -521,6 +542,7 @@ def process_single_issue( target_branch=target_branch, reviewer=reviewer, pr_title=pr_title, + base_domain=base_domain, ) @@ -532,7 +554,11 @@ def process_all_successful_issues( pr_type: str, llm_config: LLMConfig, fork_owner: str | None, + base_domain: str | None = None, ) -> None: + # Determine default base_domain based on platform + if base_domain is None: + base_domain = 'github.com' if platform == Platform.GITHUB else 'gitlab.com' output_path = os.path.join(output_dir, 'output.jsonl') for resolver_output in load_all_resolver_outputs(output_path): if resolver_output.success: @@ -548,6 +574,9 @@ def process_all_successful_issues( fork_owner, False, None, + None, + None, + base_domain, ) @@ -633,6 +662,12 @@ def main() -> None: help='Custom title for the pull request', default=None, ) + parser.add_argument( + '--base-domain', + type=str, + default=None, + help='Base domain for the git server (defaults to "github.com" for GitHub and "gitlab.com" for GitLab)', + ) my_args = parser.parse_args() token = my_args.token or os.getenv('GITHUB_TOKEN') or os.getenv('GITLAB_TOKEN') @@ -642,7 +677,7 @@ def main() -> None: ) username = my_args.username if my_args.username else os.getenv('GIT_USERNAME') - platform = identify_token(token) + platform = identify_token(token, None, my_args.base_domain) if platform == Platform.INVALID: raise ValueError('Token is invalid.') @@ -667,6 +702,7 @@ def main() -> None: my_args.pr_type, llm_config, my_args.fork_owner, + my_args.base_domain, ) else: if not my_args.issue_number.isdigit(): @@ -689,6 +725,7 @@ def main() -> None: my_args.target_branch, my_args.reviewer, my_args.pr_title, + my_args.base_domain, ) diff --git a/openhands/resolver/utils.py b/openhands/resolver/utils.py index f9e8276a5a..4552e9e951 100644 --- a/openhands/resolver/utils.py +++ b/openhands/resolver/utils.py @@ -20,22 +20,31 @@ class Platform(Enum): GITLAB = 2 -def identify_token(token: str, selected_repo: str | None = None) -> Platform: +def identify_token( + token: str, selected_repo: str | None = None, base_domain: str = 'github.com' +) -> Platform: """ Identifies whether a token belongs to GitHub or GitLab. Parameters: token (str): The personal access token to check. selected_repo (str): Repository in format "owner/repo" for GitHub Actions token validation. + base_domain (str): The base domain for GitHub Enterprise (default: "github.com"). Returns: Platform: "GitHub" if the token is valid for GitHub, "GitLab" if the token is valid for GitLab, "Invalid" if the token is not recognized by either. """ + # Determine GitHub API base URL based on domain + if base_domain == 'github.com': + github_api_base = 'https://api.github.com' + else: + github_api_base = f'https://{base_domain}/api/v3' + # Try GitHub Actions token format (Bearer) with repo endpoint if repo is provided if selected_repo: - github_repo_url = f'https://api.github.com/repos/{selected_repo}' + github_repo_url = f'{github_api_base}/repos/{selected_repo}' github_bearer_headers = { 'Authorization': f'Bearer {token}', 'Accept': 'application/vnd.github+json', @@ -51,7 +60,7 @@ def identify_token(token: str, selected_repo: str | None = None) -> Platform: logger.error(f'Error connecting to GitHub API (selected_repo check): {e}') # Try GitHub PAT format (token) - github_url = 'https://api.github.com/user' + github_url = f'{github_api_base}/user' github_headers = {'Authorization': f'token {token}'} try: @@ -61,7 +70,6 @@ def identify_token(token: str, selected_repo: str | None = None) -> Platform: except httpx.HTTPError as e: logger.error(f'Error connecting to GitHub API: {e}') - # Try GitLab token gitlab_url = 'https://gitlab.com/api/v4/user' gitlab_headers = {'Authorization': f'Bearer {token}'} diff --git a/tests/unit/resolver/github/test_send_pull_request.py b/tests/unit/resolver/github/test_send_pull_request.py index 329002f44a..2540fd47eb 100644 --- a/tests/unit/resolver/github/test_send_pull_request.py +++ b/tests/unit/resolver/github/test_send_pull_request.py @@ -1,6 +1,6 @@ import os import tempfile -from unittest.mock import MagicMock, call, patch +from unittest.mock import ANY, MagicMock, call, patch import pytest @@ -884,6 +884,7 @@ def test_process_single_pr_update( patch_dir=f'{mock_output_dir}/patches/pr_1', additional_message='[Test success 1]', llm_config=mock_llm_config, + base_domain='github.com', ) @@ -965,6 +966,7 @@ def test_process_single_issue( target_branch=None, reviewer=None, pr_title=None, + base_domain='github.com', ) @@ -1126,6 +1128,9 @@ def test_process_all_successful_issues( None, False, None, + None, + None, + 'github.com', ), call( 'output_dir', @@ -1138,6 +1143,9 @@ def test_process_all_successful_issues( None, False, None, + None, + None, + 'github.com', ), ] ) @@ -1260,7 +1268,7 @@ def test_main( # Run main function main() - mock_identify_token.assert_called_with('mock_token') + mock_identify_token.assert_called_with('mock_token', None, ANY) llm_config = LLMConfig( model=mock_args.llm_model, @@ -1282,6 +1290,7 @@ def test_main( mock_args.target_branch, mock_args.reviewer, mock_args.pr_title, + ANY, ) # Other assertions @@ -1301,6 +1310,7 @@ def test_main( 'draft', llm_config, None, + ANY, ) # Test for invalid issue number diff --git a/tests/unit/resolver/gitlab/test_gitlab_send_pull_request.py b/tests/unit/resolver/gitlab/test_gitlab_send_pull_request.py index c49e6ab608..2f88984be2 100644 --- a/tests/unit/resolver/gitlab/test_gitlab_send_pull_request.py +++ b/tests/unit/resolver/gitlab/test_gitlab_send_pull_request.py @@ -1,6 +1,6 @@ import os import tempfile -from unittest.mock import MagicMock, call, patch +from unittest.mock import ANY, MagicMock, call, patch from urllib.parse import quote import pytest @@ -785,6 +785,7 @@ def test_process_single_pr_update( patch_dir=f'{mock_output_dir}/patches/pr_1', additional_message='[Test success 1]', llm_config=mock_llm_config, + base_domain='gitlab.com', ) @@ -866,6 +867,7 @@ def test_process_single_issue( target_branch=None, reviewer=None, pr_title=None, + base_domain='gitlab.com', ) @@ -1027,6 +1029,9 @@ def test_process_all_successful_issues( None, False, None, + None, + None, + 'gitlab.com', ), call( 'output_dir', @@ -1039,6 +1044,9 @@ def test_process_all_successful_issues( None, False, None, + None, + None, + 'gitlab.com', ), ] ) @@ -1162,7 +1170,7 @@ def test_main( # Run main function main() - mock_identify_token.assert_called_with('mock_token') + mock_identify_token.assert_called_with('mock_token', None, ANY) llm_config = LLMConfig( model=mock_args.llm_model, @@ -1184,6 +1192,7 @@ def test_main( mock_args.target_branch, mock_args.reviewer, mock_args.pr_title, + ANY, ) # Other assertions @@ -1203,6 +1212,7 @@ def test_main( 'draft', llm_config, None, + ANY, ) # Test for invalid issue number