mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Allow to merge to a specific target branch instead of main (#5109)
This commit is contained in:
parent
ca64c69b4a
commit
2c580387c5
@ -203,6 +203,7 @@ def send_pull_request(
|
||||
pr_type: str,
|
||||
fork_owner: str | None = None,
|
||||
additional_message: str | None = None,
|
||||
target_branch: str | None = None,
|
||||
) -> str:
|
||||
if pr_type not in ['branch', 'draft', 'ready']:
|
||||
raise ValueError(f'Invalid pr_type: {pr_type}')
|
||||
@ -224,12 +225,19 @@ def send_pull_request(
|
||||
attempt += 1
|
||||
branch_name = f'{base_branch_name}-try{attempt}'
|
||||
|
||||
# Get the default branch
|
||||
print('Getting default branch...')
|
||||
response = requests.get(f'{base_url}', headers=headers)
|
||||
response.raise_for_status()
|
||||
default_branch = response.json()['default_branch']
|
||||
print(f'Default branch: {default_branch}')
|
||||
# Get the default branch or use specified target branch
|
||||
print('Getting base branch...')
|
||||
if target_branch:
|
||||
base_branch = target_branch
|
||||
# Verify the target branch exists
|
||||
response = requests.get(f'{base_url}/branches/{target_branch}', headers=headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f'Target branch {target_branch} does not exist')
|
||||
else:
|
||||
response = requests.get(f'{base_url}', headers=headers)
|
||||
response.raise_for_status()
|
||||
base_branch = response.json()['default_branch']
|
||||
print(f'Base branch: {base_branch}')
|
||||
|
||||
# Create and checkout the new branch
|
||||
print('Creating new branch...')
|
||||
@ -279,7 +287,7 @@ def send_pull_request(
|
||||
'title': pr_title, # No need to escape title for GitHub API
|
||||
'body': pr_body,
|
||||
'head': branch_name,
|
||||
'base': default_branch,
|
||||
'base': base_branch,
|
||||
'draft': pr_type == 'draft',
|
||||
}
|
||||
response = requests.post(f'{base_url}/pulls', headers=headers, json=data)
|
||||
@ -435,6 +443,7 @@ def process_single_issue(
|
||||
llm_config: LLMConfig,
|
||||
fork_owner: str | None,
|
||||
send_on_failure: bool,
|
||||
target_branch: str | None = None,
|
||||
) -> None:
|
||||
if not resolver_output.success and not send_on_failure:
|
||||
print(
|
||||
@ -484,6 +493,7 @@ def process_single_issue(
|
||||
llm_config=llm_config,
|
||||
fork_owner=fork_owner,
|
||||
additional_message=resolver_output.success_explanation,
|
||||
target_branch=target_branch,
|
||||
)
|
||||
|
||||
|
||||
@ -508,6 +518,7 @@ def process_all_successful_issues(
|
||||
llm_config,
|
||||
fork_owner,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
@ -573,6 +584,12 @@ def main():
|
||||
default=None,
|
||||
help='Base URL for the LLM model.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--target-branch',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Target branch to create the pull request against (defaults to repository default branch)',
|
||||
)
|
||||
my_args = parser.parse_args()
|
||||
|
||||
github_token = (
|
||||
@ -625,6 +642,7 @@ def main():
|
||||
llm_config,
|
||||
my_args.fork_owner,
|
||||
my_args.send_on_failure,
|
||||
my_args.target_branch,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -322,7 +322,17 @@ def test_update_existing_pull_request(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('pr_type', ['branch', 'draft', 'ready'])
|
||||
@pytest.mark.parametrize(
|
||||
'pr_type,target_branch',
|
||||
[
|
||||
('branch', None),
|
||||
('draft', None),
|
||||
('ready', None),
|
||||
('branch', 'feature'),
|
||||
('draft', 'develop'),
|
||||
('ready', 'staging'),
|
||||
],
|
||||
)
|
||||
@patch('subprocess.run')
|
||||
@patch('requests.post')
|
||||
@patch('requests.get')
|
||||
@ -334,14 +344,22 @@ def test_send_pull_request(
|
||||
mock_output_dir,
|
||||
mock_llm_config,
|
||||
pr_type,
|
||||
target_branch,
|
||||
):
|
||||
repo_path = os.path.join(mock_output_dir, 'repo')
|
||||
|
||||
# Mock API responses
|
||||
mock_get.side_effect = [
|
||||
MagicMock(status_code=404), # Branch doesn't exist
|
||||
MagicMock(json=lambda: {'default_branch': 'main'}),
|
||||
]
|
||||
# Mock API responses based on whether target_branch is specified
|
||||
if target_branch:
|
||||
mock_get.side_effect = [
|
||||
MagicMock(status_code=404), # Branch doesn't exist
|
||||
MagicMock(status_code=200), # Target branch exists
|
||||
]
|
||||
else:
|
||||
mock_get.side_effect = [
|
||||
MagicMock(status_code=404), # Branch doesn't exist
|
||||
MagicMock(json=lambda: {'default_branch': 'main'}), # Get default branch
|
||||
]
|
||||
|
||||
mock_post.return_value.json.return_value = {
|
||||
'html_url': 'https://github.com/test-owner/test-repo/pull/1'
|
||||
}
|
||||
@ -360,10 +378,12 @@ def test_send_pull_request(
|
||||
patch_dir=repo_path,
|
||||
pr_type=pr_type,
|
||||
llm_config=mock_llm_config,
|
||||
target_branch=target_branch,
|
||||
)
|
||||
|
||||
# Assert API calls
|
||||
assert mock_get.call_count == 2
|
||||
expected_get_calls = 2
|
||||
assert mock_get.call_count == expected_get_calls
|
||||
|
||||
# Check branch creation and push
|
||||
assert mock_run.call_count == 2
|
||||
@ -401,10 +421,41 @@ def test_send_pull_request(
|
||||
assert post_data['title'] == 'Fix issue #42: Test Issue'
|
||||
assert post_data['body'].startswith('This pull request fixes #42.')
|
||||
assert post_data['head'] == 'openhands-fix-issue-42'
|
||||
assert post_data['base'] == 'main'
|
||||
assert post_data['base'] == (target_branch if target_branch else 'main')
|
||||
assert post_data['draft'] == (pr_type == 'draft')
|
||||
|
||||
|
||||
@patch('requests.get')
|
||||
def test_send_pull_request_invalid_target_branch(
|
||||
mock_get, mock_github_issue, mock_output_dir, mock_llm_config
|
||||
):
|
||||
"""Test that an error is raised when specifying a non-existent target branch"""
|
||||
repo_path = os.path.join(mock_output_dir, 'repo')
|
||||
|
||||
# Mock API response for non-existent branch
|
||||
mock_get.side_effect = [
|
||||
MagicMock(status_code=404), # Branch doesn't exist
|
||||
MagicMock(status_code=404), # Target branch doesn't exist
|
||||
]
|
||||
|
||||
# Test that ValueError is raised when target branch doesn't exist
|
||||
with pytest.raises(
|
||||
ValueError, match='Target branch nonexistent-branch does not exist'
|
||||
):
|
||||
send_pull_request(
|
||||
github_issue=mock_github_issue,
|
||||
github_token='test-token',
|
||||
github_username='test-user',
|
||||
patch_dir=repo_path,
|
||||
pr_type='ready',
|
||||
llm_config=mock_llm_config,
|
||||
target_branch='nonexistent-branch',
|
||||
)
|
||||
|
||||
# Verify API calls
|
||||
assert mock_get.call_count == 2
|
||||
|
||||
|
||||
@patch('subprocess.run')
|
||||
@patch('requests.post')
|
||||
@patch('requests.get')
|
||||
@ -616,6 +667,7 @@ def test_process_single_pr_update(
|
||||
mock_llm_config,
|
||||
None,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
|
||||
mock_initialize_repo.assert_called_once_with(mock_output_dir, 1, 'pr', 'branch 1')
|
||||
@ -688,6 +740,7 @@ def test_process_single_issue(
|
||||
mock_llm_config,
|
||||
None,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
|
||||
# Assert that the mocked functions were called with correct arguments
|
||||
@ -704,9 +757,10 @@ def test_process_single_issue(
|
||||
github_username=github_username,
|
||||
patch_dir=f'{mock_output_dir}/patches/issue_1',
|
||||
pr_type=pr_type,
|
||||
llm_config=mock_llm_config,
|
||||
fork_owner=None,
|
||||
additional_message=resolver_output.success_explanation,
|
||||
llm_config=mock_llm_config,
|
||||
target_branch=None,
|
||||
)
|
||||
|
||||
|
||||
@ -757,6 +811,7 @@ def test_process_single_issue_unsuccessful(
|
||||
mock_llm_config,
|
||||
None,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
|
||||
# Assert that none of the mocked functions were called
|
||||
@ -863,6 +918,7 @@ def test_process_all_successful_issues(
|
||||
mock_llm_config,
|
||||
None,
|
||||
False,
|
||||
None,
|
||||
),
|
||||
call(
|
||||
'output_dir',
|
||||
@ -873,6 +929,7 @@ def test_process_all_successful_issues(
|
||||
mock_llm_config,
|
||||
None,
|
||||
False,
|
||||
None,
|
||||
),
|
||||
]
|
||||
)
|
||||
@ -971,6 +1028,7 @@ def test_main(
|
||||
mock_args.llm_model = 'mock_model'
|
||||
mock_args.llm_base_url = 'mock_url'
|
||||
mock_args.llm_api_key = 'mock_key'
|
||||
mock_args.target_branch = None
|
||||
mock_parser.return_value.parse_args.return_value = mock_args
|
||||
|
||||
# Setup environment variables
|
||||
@ -994,12 +1052,8 @@ def test_main(
|
||||
api_key=mock_args.llm_api_key,
|
||||
)
|
||||
|
||||
# Assert function calls
|
||||
mock_parser.assert_called_once()
|
||||
mock_getenv.assert_any_call('GITHUB_TOKEN')
|
||||
mock_path_exists.assert_called_with('/mock/output')
|
||||
mock_load_single_resolver_output.assert_called_with('/mock/output/output.jsonl', 42)
|
||||
mock_process_single_issue.assert_called_with(
|
||||
# Use any_call instead of assert_called_with for more flexible matching
|
||||
assert mock_process_single_issue.call_args == call(
|
||||
'/mock/output',
|
||||
mock_resolver_output,
|
||||
'mock_token',
|
||||
@ -1008,8 +1062,15 @@ def test_main(
|
||||
llm_config,
|
||||
None,
|
||||
False,
|
||||
mock_args.target_branch,
|
||||
)
|
||||
|
||||
# Other assertions
|
||||
mock_parser.assert_called_once()
|
||||
mock_getenv.assert_any_call('GITHUB_TOKEN')
|
||||
mock_path_exists.assert_called_with('/mock/output')
|
||||
mock_load_single_resolver_output.assert_called_with('/mock/output/output.jsonl', 42)
|
||||
|
||||
# Test for 'all_successful' issue number
|
||||
mock_args.issue_number = 'all_successful'
|
||||
main()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user