diff --git a/openhands/resolver/send_pull_request.py b/openhands/resolver/send_pull_request.py index 384b440b70..15214d4c0d 100644 --- a/openhands/resolver/send_pull_request.py +++ b/openhands/resolver/send_pull_request.py @@ -584,6 +584,12 @@ def main() -> None: parser = argparse.ArgumentParser( description='Send a pull request to Github or Gitlab.' ) + parser.add_argument( + '--selected-repo', + type=str, + default=None, + help='repository to send pull request in form of `owner/repo`.', + ) parser.add_argument( '--token', type=str, @@ -677,7 +683,7 @@ def main() -> None: ) username = my_args.username if my_args.username else os.getenv('GIT_USERNAME') - platform = identify_token(token, None, my_args.base_domain) + platform = identify_token(token, my_args.selected_repo, my_args.base_domain) if platform == Platform.INVALID: raise ValueError('Token is invalid.') diff --git a/openhands/resolver/utils.py b/openhands/resolver/utils.py index 4552e9e951..09ff505dd6 100644 --- a/openhands/resolver/utils.py +++ b/openhands/resolver/utils.py @@ -21,7 +21,7 @@ class Platform(Enum): def identify_token( - token: str, selected_repo: str | None = None, base_domain: str = 'github.com' + token: str, selected_repo: str | None = None, base_domain: str | None = 'github.com' ) -> Platform: """ Identifies whether a token belongs to GitHub or GitLab. @@ -37,7 +37,7 @@ def identify_token( "Invalid" if the token is not recognized by either. """ # Determine GitHub API base URL based on domain - if base_domain == 'github.com': + if base_domain is None or base_domain == 'github.com': github_api_base = 'https://api.github.com' else: github_api_base = f'https://{base_domain}/api/v3' diff --git a/tests/unit/resolver/github/test_send_pull_request.py b/tests/unit/resolver/github/test_send_pull_request.py index 2540fd47eb..84010b47b8 100644 --- a/tests/unit/resolver/github/test_send_pull_request.py +++ b/tests/unit/resolver/github/test_send_pull_request.py @@ -1249,6 +1249,7 @@ def test_main( mock_args.target_branch = None mock_args.reviewer = None mock_args.pr_title = None + mock_args.selected_repo = None mock_parser.return_value.parse_args.return_value = mock_args # Setup environment variables diff --git a/tests/unit/resolver/gitlab/test_gitlab_send_pull_request.py b/tests/unit/resolver/gitlab/test_gitlab_send_pull_request.py index 2f88984be2..08d6d2bedb 100644 --- a/tests/unit/resolver/gitlab/test_gitlab_send_pull_request.py +++ b/tests/unit/resolver/gitlab/test_gitlab_send_pull_request.py @@ -1151,6 +1151,7 @@ def test_main( mock_args.target_branch = None mock_args.reviewer = None mock_args.pr_title = None + mock_args.selected_repo = None mock_parser.return_value.parse_args.return_value = mock_args # Setup environment variables