[Fix]: Replace duplicate enums for providers in resolver (#7954)

This commit is contained in:
Rohit Malhotra
2025-04-20 14:06:18 -04:00
committed by GitHub
parent 20bf48b693
commit 0637b5b912
10 changed files with 102 additions and 115 deletions

View File

@@ -13,16 +13,14 @@ from tqdm import tqdm
from openhands.core.config import LLMConfig
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.service_types import ProviderType
from openhands.resolver.interfaces.issue import Issue
from openhands.resolver.resolve_issue import (
issue_handler_factory,
process_issue,
)
from openhands.resolver.resolver_output import ResolverOutput
from openhands.resolver.utils import (
Platform,
identify_token,
)
from openhands.resolver.utils import identify_token
def cleanup() -> None:
@@ -55,7 +53,7 @@ async def resolve_issues(
repo: str,
token: str,
username: str,
platform: Platform,
platform: ProviderType,
max_iterations: int,
limit_issues: int | None,
num_workers: int,
@@ -347,9 +345,6 @@ def main() -> None:
raise ValueError('Token is required.')
platform = identify_token(token, my_args.selected_repo, my_args.base_domain)
if platform == Platform.INVALID:
raise ValueError('Token is invalid.')
api_key = my_args.llm_api_key or os.environ['LLM_API_KEY']
llm_config = LLMConfig(

View File

@@ -26,6 +26,7 @@ from openhands.events.observation import (
Observation,
)
from openhands.events.stream import EventStreamSubscriber
from openhands.integrations.service_types import ProviderType
from openhands.resolver.interfaces.github import GithubIssueHandler, GithubPRHandler
from openhands.resolver.interfaces.gitlab import GitlabIssueHandler, GitlabPRHandler
from openhands.resolver.interfaces.issue import Issue
@@ -35,7 +36,6 @@ from openhands.resolver.interfaces.issue_definitions import (
)
from openhands.resolver.resolver_output import ResolverOutput
from openhands.resolver.utils import (
Platform,
codeact_user_response,
get_unique_uid,
identify_token,
@@ -49,7 +49,7 @@ AGENT_CLASS = 'CodeActAgent'
def initialize_runtime(
runtime: Runtime,
platform: Platform,
platform: ProviderType,
) -> None:
"""Initialize the runtime for the agent.
@@ -68,7 +68,7 @@ def initialize_runtime(
if not isinstance(obs, CmdOutputObservation) or obs.exit_code != 0:
raise RuntimeError(f'Failed to change directory to /workspace.\n{obs}')
if platform == Platform.GITLAB and os.getenv('GITLAB_CI') == 'true':
if platform == ProviderType.GITLAB and os.getenv('GITLAB_CI') == 'true':
action = CmdRunAction(command='sudo chown -R 1001:0 /workspace/*')
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
@@ -85,7 +85,7 @@ def initialize_runtime(
async def complete_runtime(
runtime: Runtime,
base_commit: str,
platform: Platform,
platform: ProviderType,
) -> dict[str, Any]:
"""Complete the runtime for the agent.
@@ -121,7 +121,7 @@ async def complete_runtime(
if not isinstance(obs, CmdOutputObservation) or obs.exit_code != 0:
raise RuntimeError(f'Failed to set git config. Observation: {obs}')
if platform == Platform.GITLAB and os.getenv('GITLAB_CI') == 'true':
if platform == ProviderType.GITLAB and os.getenv('GITLAB_CI') == 'true':
action = CmdRunAction(command='sudo git add -A')
else:
action = CmdRunAction(command='git add -A')
@@ -162,7 +162,7 @@ async def complete_runtime(
async def process_issue(
issue: Issue,
platform: Platform,
platform: ProviderType,
base_commit: str,
max_iterations: int,
llm_config: LLMConfig,
@@ -320,15 +320,15 @@ def issue_handler_factory(
repo: str,
token: str,
llm_config: LLMConfig,
platform: Platform,
platform: ProviderType,
username: str | None = None,
base_domain: str | None = None,
) -> ServiceContextIssue | ServiceContextPR:
# Determine default base_domain based on platform
if base_domain is None:
base_domain = 'github.com' if platform == Platform.GITHUB else 'gitlab.com'
base_domain = 'github.com' if platform == ProviderType.GITHUB else 'gitlab.com'
if issue_type == 'issue':
if platform == Platform.GITHUB:
if platform == ProviderType.GITHUB:
return ServiceContextIssue(
GithubIssueHandler(owner, repo, token, username, base_domain),
llm_config,
@@ -339,7 +339,7 @@ def issue_handler_factory(
llm_config,
)
elif issue_type == 'pr':
if platform == Platform.GITHUB:
if platform == ProviderType.GITHUB:
return ServiceContextPR(
GithubPRHandler(owner, repo, token, username, base_domain), llm_config
)
@@ -356,7 +356,7 @@ async def resolve_issue(
repo: str,
token: str,
username: str,
platform: Platform,
platform: ProviderType,
max_iterations: int,
output_dir: str,
llm_config: LLMConfig,
@@ -391,7 +391,7 @@ async def resolve_issue(
"""
# Determine default base_domain based on platform
if base_domain is None:
base_domain = 'github.com' if platform == Platform.GITHUB else 'gitlab.com'
base_domain = 'github.com' if platform == ProviderType.GITHUB else 'gitlab.com'
issue_handler = issue_handler_factory(
issue_type, owner, repo, token, llm_config, platform, username, base_domain
@@ -669,8 +669,6 @@ def main() -> None:
raise ValueError('Token is required.')
platform = identify_token(token, my_args.selected_repo, my_args.base_domain)
if platform == Platform.INVALID:
raise ValueError('Token is invalid.')
api_key = my_args.llm_api_key or os.environ['LLM_API_KEY']
model = my_args.llm_model or os.environ['LLM_MODEL']

View File

@@ -9,6 +9,7 @@ from pydantic import SecretStr
from openhands.core.config import LLMConfig
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.service_types import ProviderType
from openhands.llm.llm import LLM
from openhands.resolver.interfaces.github import GithubIssueHandler
from openhands.resolver.interfaces.gitlab import GitlabIssueHandler
@@ -20,10 +21,7 @@ from openhands.resolver.io_utils import (
)
from openhands.resolver.patching import apply_diff, parse_patch
from openhands.resolver.resolver_output import ResolverOutput
from openhands.resolver.utils import (
Platform,
identify_token,
)
from openhands.resolver.utils import identify_token
def apply_patch(repo_dir: str, patch: str) -> None:
@@ -227,7 +225,7 @@ def send_pull_request(
issue: Issue,
token: str,
username: str | None,
platform: Platform,
platform: ProviderType,
patch_dir: str,
pr_type: str,
fork_owner: str | None = None,
@@ -258,10 +256,10 @@ def send_pull_request(
# Determine default base_domain based on platform
if base_domain is None:
base_domain = 'github.com' if platform == Platform.GITHUB else 'gitlab.com'
base_domain = 'github.com' if platform == ProviderType.GITHUB else 'gitlab.com'
handler = None
if platform == Platform.GITHUB:
if platform == ProviderType.GITHUB:
handler = ServiceContextIssue(
GithubIssueHandler(issue.owner, issue.repo, token, username, base_domain),
None,
@@ -329,7 +327,7 @@ def send_pull_request(
# For cross repo pull request, we need to send head parameter like fork_owner:branch as per git documentation here : https://docs.github.com/en/rest/pulls/pulls?apiVersion=2022-11-28#create-a-pull-request
# head parameter usage : The name of the branch where your changes are implemented. For cross-repository pull requests in the same network, namespace head with a user like this: username:branch.
if fork_owner and platform == Platform.GITHUB:
if fork_owner and platform == ProviderType.GITHUB:
head_branch = f'{fork_owner}:{branch_name}'
else:
head_branch = branch_name
@@ -341,9 +339,13 @@ def send_pull_request(
# Prepare the PR for the GitHub API
data = {
'title': final_pr_title,
('body' if platform == Platform.GITHUB else 'description'): pr_body,
('head' if platform == Platform.GITHUB else 'source_branch'): head_branch,
('base' if platform == Platform.GITHUB else 'target_branch'): base_branch,
('body' if platform == ProviderType.GITHUB else 'description'): pr_body,
(
'head' if platform == ProviderType.GITHUB else 'source_branch'
): head_branch,
(
'base' if platform == ProviderType.GITHUB else 'target_branch'
): base_branch,
'draft': pr_type == 'draft',
}
@@ -366,7 +368,7 @@ def update_existing_pull_request(
issue: Issue,
token: str,
username: str | None,
platform: Platform,
platform: ProviderType,
patch_dir: str,
llm_config: LLMConfig,
comment_message: str | None = None,
@@ -390,10 +392,10 @@ def update_existing_pull_request(
# Determine default base_domain based on platform
if base_domain is None:
base_domain = 'github.com' if platform == Platform.GITHUB else 'gitlab.com'
base_domain = 'github.com' if platform == ProviderType.GITHUB else 'gitlab.com'
handler = None
if platform == Platform.GITHUB:
if platform == ProviderType.GITHUB:
handler = ServiceContextIssue(
GithubIssueHandler(issue.owner, issue.repo, token, username, base_domain),
llm_config,
@@ -476,7 +478,7 @@ def process_single_issue(
resolver_output: ResolverOutput,
token: str,
username: str,
platform: Platform,
platform: ProviderType,
pr_type: str,
llm_config: LLMConfig,
fork_owner: str | None,
@@ -488,7 +490,7 @@ def process_single_issue(
) -> None:
# Determine default base_domain based on platform
if base_domain is None:
base_domain = 'github.com' if platform == Platform.GITHUB else 'gitlab.com'
base_domain = 'github.com' if platform == ProviderType.GITHUB else 'gitlab.com'
if not resolver_output.success and not send_on_failure:
logger.info(
f'Issue {resolver_output.issue.number} was not successfully resolved. Skipping PR creation.'
@@ -550,7 +552,7 @@ def process_all_successful_issues(
output_dir: str,
token: str,
username: str,
platform: Platform,
platform: ProviderType,
pr_type: str,
llm_config: LLMConfig,
fork_owner: str | None,
@@ -558,7 +560,7 @@ def process_all_successful_issues(
) -> None:
# Determine default base_domain based on platform
if base_domain is None:
base_domain = 'github.com' if platform == Platform.GITHUB else 'gitlab.com'
base_domain = 'github.com' if platform == ProviderType.GITHUB else 'gitlab.com'
output_path = os.path.join(output_dir, 'output.jsonl')
for resolver_output in load_all_resolver_outputs(output_path):
if resolver_output.success:
@@ -684,8 +686,6 @@ def main() -> None:
username = my_args.username if my_args.username else os.getenv('GIT_USERNAME')
platform = identify_token(token, my_args.selected_repo, my_args.base_domain)
if platform == Platform.INVALID:
raise ValueError('Token is invalid.')
api_key = my_args.llm_api_key or os.environ['LLM_API_KEY']
llm_config = LLMConfig(

View File

@@ -2,7 +2,6 @@ import logging
import multiprocessing as mp
import os
import re
from enum import Enum
from typing import Callable
import httpx
@@ -12,17 +11,12 @@ from openhands.core.logger import get_console_handler
from openhands.core.logger import openhands_logger as logger
from openhands.events.action import Action
from openhands.events.action.message import MessageAction
class Platform(Enum):
INVALID = 0
GITHUB = 1
GITLAB = 2
from openhands.integrations.service_types import ProviderType
def identify_token(
token: str, selected_repo: str | None = None, base_domain: str | None = 'github.com'
) -> Platform:
) -> ProviderType:
"""
Identifies whether a token belongs to GitHub or GitLab.
@@ -32,7 +26,7 @@ def identify_token(
base_domain (str): The base domain for GitHub Enterprise (default: "github.com").
Returns:
Platform: "GitHub" if the token is valid for GitHub,
ProviderType: "GitHub" if the token is valid for GitHub,
"GitLab" if the token is valid for GitLab,
"Invalid" if the token is not recognized by either.
"""
@@ -55,7 +49,7 @@ def identify_token(
github_repo_url, headers=github_bearer_headers, timeout=5
)
if github_repo_response.status_code == 200:
return Platform.GITHUB
return ProviderType.GITHUB
except httpx.HTTPError as e:
logger.error(f'Error connecting to GitHub API (selected_repo check): {e}')
@@ -66,7 +60,7 @@ def identify_token(
try:
github_response = httpx.get(github_url, headers=github_headers, timeout=5)
if github_response.status_code == 200:
return Platform.GITHUB
return ProviderType.GITHUB
except httpx.HTTPError as e:
logger.error(f'Error connecting to GitHub API: {e}')
@@ -76,10 +70,11 @@ def identify_token(
try:
gitlab_response = httpx.get(gitlab_url, headers=gitlab_headers, timeout=5)
if gitlab_response.status_code == 200:
return Platform.GITLAB
return ProviderType.GITLAB
except httpx.HTTPError as e:
logger.error(f'Error connecting to GitLab API: {e}')
return Platform.INVALID
raise ValueError('Token is invalid.')
def codeact_user_response(

View File

@@ -4,7 +4,7 @@ import tempfile
from openhands.resolver.interfaces.issue import Issue
from openhands.resolver.send_pull_request import make_commit
from openhands.resolver.utils import Platform
from openhands.integrations.service_types import ProviderType
def test_commit_message_with_quotes():
@@ -160,7 +160,7 @@ def test_pr_title_with_quotes(monkeypatch):
issue=issue,
token='dummy-token',
username='test-user',
platform=Platform.GITHUB,
platform=ProviderType.GITHUB,
patch_dir=temp_dir,
pr_type='ready',
)

View File

@@ -24,7 +24,7 @@ from openhands.resolver.resolve_issue import (
process_issue,
)
from openhands.resolver.resolver_output import ResolverOutput
from openhands.resolver.utils import Platform
from openhands.integrations.service_types import ProviderType
@pytest.fixture
@@ -81,7 +81,7 @@ def test_initialize_runtime():
),
]
initialize_runtime(mock_runtime, Platform.GITHUB)
initialize_runtime(mock_runtime, ProviderType.GITHUB)
assert mock_runtime.run_action.call_count == 2
mock_runtime.run_action.assert_any_call(CmdRunAction(command='cd /workspace'))
@@ -108,7 +108,7 @@ async def test_resolve_issue_no_issues_found():
repo='test-repo',
token='test-token',
username='test-user',
platform=Platform.GITHUB,
platform=ProviderType.GITHUB,
max_iterations=5,
output_dir='/tmp',
llm_config=LLMConfig(model='test', api_key='test'),
@@ -315,7 +315,7 @@ async def test_complete_runtime():
create_cmd_output(exit_code=0, content='git diff content', command='git apply'),
]
result = await complete_runtime(mock_runtime, 'base_commit_hash', Platform.GITHUB)
result = await complete_runtime(mock_runtime, 'base_commit_hash', ProviderType.GITHUB)
assert result == {'git_patch': 'git diff content'}
assert mock_runtime.run_action.call_count == 5
@@ -442,7 +442,7 @@ async def test_process_issue(mock_output_dir, mock_prompt_template):
# Call the function
result = await process_issue(
issue,
Platform.GITHUB,
ProviderType.GITHUB,
base_commit,
max_iterations,
llm_config,

View File

@@ -18,7 +18,7 @@ from openhands.resolver.send_pull_request import (
send_pull_request,
update_existing_pull_request,
)
from openhands.resolver.utils import Platform
from openhands.integrations.service_types import ProviderType
@pytest.fixture
@@ -289,7 +289,7 @@ def test_update_existing_pull_request(
issue,
token,
username,
Platform.GITHUB,
ProviderType.GITHUB,
patch_dir,
llm_config,
comment_message=None,
@@ -388,7 +388,7 @@ def test_send_pull_request(
issue=mock_issue,
token='test-token',
username='test-user',
platform=Platform.GITHUB,
platform=ProviderType.GITHUB,
patch_dir=repo_path,
pr_type=pr_type,
target_branch=target_branch,
@@ -478,7 +478,7 @@ def test_send_pull_request_with_reviewer(
issue=mock_issue,
token='test-token',
username='test-user',
platform=Platform.GITHUB,
platform=ProviderType.GITHUB,
patch_dir=repo_path,
pr_type='ready',
reviewer=reviewer,
@@ -536,7 +536,7 @@ def test_send_pull_request_target_branch_with_fork(
issue=mock_issue,
token='test-token',
username='test-user',
platform=Platform.GITHUB,
platform=ProviderType.GITHUB,
patch_dir=repo_path,
pr_type='ready',
fork_owner=fork_owner,
@@ -600,7 +600,7 @@ def test_send_pull_request_target_branch_with_additional_message(
issue=mock_issue,
token='test-token',
username='test-user',
platform=Platform.GITHUB,
platform=ProviderType.GITHUB,
patch_dir=repo_path,
pr_type='ready',
target_branch=target_branch,
@@ -639,7 +639,7 @@ def test_send_pull_request_invalid_target_branch(
issue=mock_issue,
token='test-token',
username='test-user',
platform=Platform.GITHUB,
platform=ProviderType.GITHUB,
patch_dir=repo_path,
pr_type='ready',
target_branch='nonexistent-branch',
@@ -674,7 +674,7 @@ def test_send_pull_request_git_push_failure(
issue=mock_issue,
token='test-token',
username='test-user',
platform=Platform.GITHUB,
platform=ProviderType.GITHUB,
patch_dir=repo_path,
pr_type='ready',
)
@@ -734,7 +734,7 @@ def test_send_pull_request_permission_error(
issue=mock_issue,
token='test-token',
username='test-user',
platform=Platform.GITHUB,
platform=ProviderType.GITHUB,
patch_dir=repo_path,
pr_type='ready',
)
@@ -861,7 +861,7 @@ def test_process_single_pr_update(
resolver_output,
token,
username,
Platform.GITHUB,
ProviderType.GITHUB,
pr_type,
mock_llm_config,
None,
@@ -880,7 +880,7 @@ def test_process_single_pr_update(
issue=resolver_output.issue,
token=token,
username=username,
platform=Platform.GITHUB,
platform=ProviderType.GITHUB,
patch_dir=f'{mock_output_dir}/patches/pr_1',
additional_message='[Test success 1]',
llm_config=mock_llm_config,
@@ -904,7 +904,7 @@ def test_process_single_issue(
token = 'test_token'
username = 'test_user'
pr_type = 'draft'
platform = Platform.GITHUB
platform = ProviderType.GITHUB
resolver_output = ResolverOutput(
issue=Issue(
@@ -1013,7 +1013,7 @@ def test_process_single_issue_unsuccessful(
resolver_output,
token,
username,
Platform.GITHUB,
ProviderType.GITHUB,
pr_type,
mock_llm_config,
None,
@@ -1105,7 +1105,7 @@ def test_process_all_successful_issues(
'output_dir',
'token',
'username',
Platform.GITHUB,
ProviderType.GITHUB,
'draft',
mock_llm_config, # llm_config
None, # fork_owner
@@ -1122,7 +1122,7 @@ def test_process_all_successful_issues(
resolver_output_1,
'token',
'username',
Platform.GITHUB,
ProviderType.GITHUB,
'draft',
mock_llm_config,
None,
@@ -1137,7 +1137,7 @@ def test_process_all_successful_issues(
resolver_output_3,
'token',
'username',
Platform.GITHUB,
ProviderType.GITHUB,
'draft',
mock_llm_config,
None,
@@ -1179,7 +1179,7 @@ def test_send_pull_request_branch_naming(
issue=mock_issue,
token='test-token',
username='test-user',
platform=Platform.GITHUB,
platform=ProviderType.GITHUB,
patch_dir=repo_path,
pr_type='branch',
)
@@ -1264,7 +1264,7 @@ def test_main(
mock_resolver_output = MagicMock()
mock_load_single_resolver_output.return_value = mock_resolver_output
mock_identify_token.return_value = Platform.GITHUB
mock_identify_token.return_value = ProviderType.GITHUB
# Run main function
main()
@@ -1283,7 +1283,7 @@ def test_main(
mock_resolver_output,
'mock_token',
'mock_username',
Platform.GITHUB,
ProviderType.GITHUB,
'draft',
llm_config,
None,
@@ -1307,7 +1307,7 @@ def test_main(
'/mock/output',
'mock_token',
'mock_username',
Platform.GITHUB,
ProviderType.GITHUB,
'draft',
llm_config,
None,
@@ -1320,8 +1320,9 @@ def test_main(
main()
# Test for invalid token
mock_identify_token.return_value = Platform.INVALID
with pytest.raises(ValueError, match='Token is invalid.'):
mock_args.issue_number = '42' # Reset to valid issue number
mock_getenv.side_effect = lambda key, default=None: None # Return None for all env vars
with pytest.raises(ValueError, match='token is not set'):
main()

View File

@@ -5,8 +5,8 @@ import tempfile
from openhands.core.logger import openhands_logger as logger
from openhands.resolver.interfaces.issue import Issue
from openhands.resolver.send_pull_request import make_commit
from openhands.resolver.utils import Platform
from openhands.integrations.service_types import ProviderType
from openhands.resolver.send_pull_request import send_pull_request
def test_commit_message_with_quotes():
# Create a temporary directory and initialize git repo
@@ -155,13 +155,12 @@ def test_pr_title_with_quotes(monkeypatch):
# Try to send a PR - this will fail if the title is incorrectly escaped
logger.info('Sending PR...')
from openhands.resolver.send_pull_request import send_pull_request
send_pull_request(
issue=issue,
token='dummy-token',
username='test-user',
platform=Platform.GITHUB,
platform=ProviderType.GITHUB,
patch_dir=temp_dir,
pr_type='ready',
)

View File

@@ -24,8 +24,7 @@ from openhands.resolver.resolve_issue import (
process_issue,
)
from openhands.resolver.resolver_output import ResolverOutput
from openhands.resolver.utils import Platform
from openhands.integrations.service_types import ProviderType
@pytest.fixture
def mock_output_dir():
@@ -93,7 +92,7 @@ def test_initialize_runtime():
),
]
initialize_runtime(mock_runtime, Platform.GITLAB)
initialize_runtime(mock_runtime, ProviderType.GITLAB)
if os.getenv('GITLAB_CI') == 'true':
assert mock_runtime.run_action.call_count == 3
@@ -128,7 +127,7 @@ async def test_resolve_issue_no_issues_found():
repo='test-repo',
token='test-token',
username='test-user',
platform=Platform.GITLAB,
platform=ProviderType.GITLAB,
max_iterations=5,
output_dir='/tmp',
llm_config=LLMConfig(model='test', api_key='test'),
@@ -355,7 +354,7 @@ async def test_complete_runtime():
create_cmd_output(exit_code=0, content='git diff content', command='git apply'),
]
result = await complete_runtime(mock_runtime, 'base_commit_hash', Platform.GITLAB)
result = await complete_runtime(mock_runtime, 'base_commit_hash', ProviderType.GITLAB)
assert result == {'git_patch': 'git diff content'}
assert mock_runtime.run_action.call_count == 5
@@ -482,7 +481,7 @@ async def test_process_issue(mock_output_dir, mock_prompt_template):
# Call the function
result = await process_issue(
issue,
Platform.GITLAB,
ProviderType.GITLAB,
base_commit,
max_iterations,
llm_config,

View File

@@ -19,8 +19,7 @@ from openhands.resolver.send_pull_request import (
send_pull_request,
update_existing_pull_request,
)
from openhands.resolver.utils import Platform
from openhands.integrations.service_types import ProviderType
@pytest.fixture
def mock_output_dir():
@@ -290,7 +289,7 @@ def test_update_existing_pull_request(
issue,
token,
username,
Platform.GITLAB,
ProviderType.GITLAB,
patch_dir,
llm_config,
comment_message=None,
@@ -392,7 +391,7 @@ def test_send_pull_request(
issue=mock_issue,
token='test-token',
username='test-user',
platform=Platform.GITLAB,
platform=ProviderType.GITLAB,
patch_dir=repo_path,
pr_type=pr_type,
target_branch=target_branch,
@@ -499,7 +498,7 @@ def test_send_pull_request_with_reviewer(
issue=mock_issue,
token='test-token',
username='test-user',
platform=Platform.GITLAB,
platform=ProviderType.GITLAB,
patch_dir=repo_path,
pr_type='ready',
reviewer=reviewer,
@@ -547,7 +546,7 @@ def test_send_pull_request_invalid_target_branch(
issue=mock_issue,
token='test-token',
username='test-user',
platform=Platform.GITLAB,
platform=ProviderType.GITLAB,
patch_dir=repo_path,
pr_type='ready',
target_branch='nonexistent-branch',
@@ -582,7 +581,7 @@ def test_send_pull_request_git_push_failure(
issue=mock_issue,
token='test-token',
username='test-user',
platform=Platform.GITLAB,
platform=ProviderType.GITLAB,
patch_dir=repo_path,
pr_type='ready',
)
@@ -642,7 +641,7 @@ def test_send_pull_request_permission_error(
issue=mock_issue,
token='test-token',
username='test-user',
platform=Platform.GITLAB,
platform=ProviderType.GITLAB,
patch_dir=repo_path,
pr_type='ready',
)
@@ -762,7 +761,7 @@ def test_process_single_pr_update(
resolver_output,
token,
username,
Platform.GITLAB,
ProviderType.GITLAB,
pr_type,
mock_llm_config,
None,
@@ -781,7 +780,7 @@ def test_process_single_pr_update(
issue=resolver_output.issue,
token=token,
username=username,
platform=Platform.GITLAB,
platform=ProviderType.GITLAB,
patch_dir=f'{mock_output_dir}/patches/pr_1',
additional_message='[Test success 1]',
llm_config=mock_llm_config,
@@ -805,7 +804,7 @@ def test_process_single_issue(
token = 'test_token'
username = 'test_user'
pr_type = 'draft'
platform = Platform.GITLAB
platform = ProviderType.GITLAB
resolver_output = ResolverOutput(
issue=Issue(
@@ -914,7 +913,7 @@ def test_process_single_issue_unsuccessful(
resolver_output,
token,
username,
Platform.GITLAB,
ProviderType.GITLAB,
pr_type,
mock_llm_config,
None,
@@ -1006,7 +1005,7 @@ def test_process_all_successful_issues(
'output_dir',
'token',
'username',
Platform.GITLAB,
ProviderType.GITLAB,
'draft',
mock_llm_config, # llm_config
None, # fork_owner
@@ -1023,7 +1022,7 @@ def test_process_all_successful_issues(
resolver_output_1,
'token',
'username',
Platform.GITLAB,
ProviderType.GITLAB,
'draft',
mock_llm_config,
None,
@@ -1038,7 +1037,7 @@ def test_process_all_successful_issues(
resolver_output_3,
'token',
'username',
Platform.GITLAB,
ProviderType.GITLAB,
'draft',
mock_llm_config,
None,
@@ -1081,7 +1080,7 @@ def test_send_pull_request_branch_naming(
issue=mock_issue,
token='test-token',
username='test-user',
platform=Platform.GITLAB,
platform=ProviderType.GITLAB,
patch_dir=repo_path,
pr_type='branch',
)
@@ -1166,7 +1165,7 @@ def test_main(
mock_resolver_output = MagicMock()
mock_load_single_resolver_output.return_value = mock_resolver_output
mock_identify_token.return_value = Platform.GITLAB
mock_identify_token.return_value = ProviderType.GITLAB
# Run main function
main()
@@ -1185,7 +1184,7 @@ def test_main(
mock_resolver_output,
'mock_token',
'mock_username',
Platform.GITLAB,
ProviderType.GITLAB,
'draft',
llm_config,
None,
@@ -1209,7 +1208,7 @@ def test_main(
'/mock/output',
'mock_token',
'mock_username',
Platform.GITLAB,
ProviderType.GITLAB,
'draft',
llm_config,
None,
@@ -1222,8 +1221,9 @@ def test_main(
main()
# Test for invalid token
mock_identify_token.return_value = Platform.INVALID
with pytest.raises(ValueError, match='Token is invalid.'):
mock_args.issue_number = '42' # Reset to valid issue number
mock_getenv.side_effect = lambda key, default=None: None # Return None for all env vars
with pytest.raises(ValueError, match='token is not set'):
main()