diff --git a/openhands/resolver/issue_handler_factory.py b/openhands/resolver/issue_handler_factory.py new file mode 100644 index 0000000000..482e8a6ab7 --- /dev/null +++ b/openhands/resolver/issue_handler_factory.py @@ -0,0 +1,80 @@ +from openhands.core.config import LLMConfig +from openhands.integrations.provider import ProviderType +from openhands.resolver.interfaces.github import GithubIssueHandler, GithubPRHandler +from openhands.resolver.interfaces.gitlab import GitlabIssueHandler, GitlabPRHandler +from openhands.resolver.interfaces.issue_definitions import ( + ServiceContextIssue, + ServiceContextPR, +) + + +class IssueHandlerFactory: + def __init__( + self, + owner: str, + repo: str, + token: str, + username: str, + platform: ProviderType, + base_domain: str, + issue_type: str, + llm_config: LLMConfig, + ) -> None: + self.owner = owner + self.repo = repo + self.token = token + self.username = username + self.platform = platform + self.base_domain = base_domain + self.issue_type = issue_type + self.llm_config = llm_config + + def create(self) -> ServiceContextIssue | ServiceContextPR: + if self.issue_type == 'issue': + if self.platform == ProviderType.GITHUB: + return ServiceContextIssue( + GithubIssueHandler( + self.owner, + self.repo, + self.token, + self.username, + self.base_domain, + ), + self.llm_config, + ) + else: # platform == Platform.GITLAB + return ServiceContextIssue( + GitlabIssueHandler( + self.owner, + self.repo, + self.token, + self.username, + self.base_domain, + ), + self.llm_config, + ) + elif self.issue_type == 'pr': + if self.platform == ProviderType.GITHUB: + return ServiceContextPR( + GithubPRHandler( + self.owner, + self.repo, + self.token, + self.username, + self.base_domain, + ), + self.llm_config, + ) + else: # platform == Platform.GITLAB + return ServiceContextPR( + GitlabPRHandler( + self.owner, + self.repo, + self.token, + self.username, + self.base_domain, + ), + self.llm_config, + ) + else: + raise ValueError(f'Invalid issue type: {self.issue_type}') diff --git a/openhands/resolver/resolve_issue.py b/openhands/resolver/resolve_issue.py index 89725d449e..90fae23047 100644 --- a/openhands/resolver/resolve_issue.py +++ b/openhands/resolver/resolve_issue.py @@ -28,13 +28,12 @@ from openhands.events.observation import ( ) from openhands.events.stream import EventStreamSubscriber from openhands.integrations.service_types import ProviderType -from openhands.resolver.interfaces.github import GithubIssueHandler, GithubPRHandler -from openhands.resolver.interfaces.gitlab import GitlabIssueHandler, GitlabPRHandler from openhands.resolver.interfaces.issue import Issue from openhands.resolver.interfaces.issue_definitions import ( ServiceContextIssue, ServiceContextPR, ) +from openhands.resolver.issue_handler_factory import IssueHandlerFactory from openhands.resolver.resolver_output import ResolverOutput from openhands.resolver.utils import ( codeact_user_response, @@ -162,8 +161,6 @@ class IssueResolver: self.owner = owner self.repo = repo - self.token = token - self.username = username self.platform = platform self.runtime_container_image = runtime_container_image self.base_container_image = base_container_image @@ -175,9 +172,20 @@ class IssueResolver: self.repo_instruction = repo_instruction self.issue_number = args.issue_number self.comment_id = args.comment_id - self.base_domain = base_domain self.platform = platform + factory = IssueHandlerFactory( + owner=self.owner, + repo=self.repo, + token=token, + username=username, + platform=self.platform, + base_domain=base_domain, + issue_type=self.issue_type, + llm_config=self.llm_config, + ) + self.issue_handler = factory.create() + def initialize_runtime( self, runtime: Runtime, @@ -445,58 +453,6 @@ class IssueResolver: ) return output - def issue_handler_factory(self) -> ServiceContextIssue | ServiceContextPR: - # Determine default base_domain based on platform - - if self.issue_type == 'issue': - if self.platform == ProviderType.GITHUB: - return ServiceContextIssue( - GithubIssueHandler( - self.owner, - self.repo, - self.token, - self.username, - self.base_domain, - ), - self.llm_config, - ) - else: # platform == Platform.GITLAB - return ServiceContextIssue( - GitlabIssueHandler( - self.owner, - self.repo, - self.token, - self.username, - self.base_domain, - ), - self.llm_config, - ) - elif self.issue_type == 'pr': - if self.platform == ProviderType.GITHUB: - return ServiceContextPR( - GithubPRHandler( - self.owner, - self.repo, - self.token, - self.username, - self.base_domain, - ), - self.llm_config, - ) - else: # platform == Platform.GITLAB - return ServiceContextPR( - GitlabPRHandler( - self.owner, - self.repo, - self.token, - self.username, - self.base_domain, - ), - self.llm_config, - ) - else: - raise ValueError(f'Invalid issue type: {self.issue_type}') - async def resolve_issue( self, reset_logger: bool = False, @@ -507,10 +463,8 @@ class IssueResolver: reset_logger: Whether to reset the logger for multiprocessing. """ - issue_handler = self.issue_handler_factory() - # Load dataset - issues: list[Issue] = issue_handler.get_converted_issues( + issues: list[Issue] = self.issue_handler.get_converted_issues( issue_numbers=[self.issue_number], comment_id=self.comment_id ) @@ -556,7 +510,7 @@ class IssueResolver: [ 'git', 'clone', - issue_handler.get_clone_url(), + self.issue_handler.get_clone_url(), f'{self.output_dir}/repo', ] ).decode('utf-8') @@ -635,7 +589,7 @@ class IssueResolver: output = await self.process_issue( issue, base_commit, - issue_handler, + self.issue_handler, reset_logger, ) output_fp.write(output.model_dump_json() + '\n') diff --git a/tests/unit/resolver/github/test_resolve_issues.py b/tests/unit/resolver/github/test_resolve_issues.py index 42e21efd04..99bcc61adf 100644 --- a/tests/unit/resolver/github/test_resolve_issues.py +++ b/tests/unit/resolver/github/test_resolve_issues.py @@ -142,8 +142,8 @@ async def test_resolve_issue_no_issues_found(default_mock_args, mock_github_toke # Create a resolver instance with mocked token identification resolver = IssueResolver(default_mock_args) - # Mock the issue_handler_factory method - resolver.issue_handler_factory = MagicMock(return_value=mock_handler) + # Mock the issue handler + resolver.issue_handler = mock_handler # Test that the correct exception is raised with pytest.raises(ValueError) as exc_info: @@ -153,8 +153,6 @@ async def test_resolve_issue_no_issues_found(default_mock_args, mock_github_toke assert 'No issues found for issue number 5432' in str(exc_info.value) assert 'test-owner/test-repo' in str(exc_info.value) - # Verify that the handler was correctly configured and called - resolver.issue_handler_factory.assert_called_once() mock_handler.get_converted_issues.assert_called_once_with( issue_numbers=[5432], comment_id=None ) @@ -447,7 +445,8 @@ async def test_process_issue( resolver = IssueResolver(default_mock_args) resolver.prompt_template = mock_prompt_template - # Mock the handler + # Mock the handler with LLM config + llm_config = LLMConfig(model='test', api_key='test') handler_instance = MagicMock() handler_instance.guess_success.return_value = ( test_case['expected_success'], @@ -456,6 +455,7 @@ async def test_process_issue( ) handler_instance.get_instruction.return_value = ('Test instruction', []) handler_instance.issue_type = 'pr' if test_case.get('is_pr', False) else 'issue' + handler_instance.llm = LLM(llm_config) # Mock the runtime and its methods mock_runtime = MagicMock() diff --git a/tests/unit/resolver/gitlab/test_gitlab_resolve_issues.py b/tests/unit/resolver/gitlab/test_gitlab_resolve_issues.py index 82060c23a0..edfbf822a7 100644 --- a/tests/unit/resolver/gitlab/test_gitlab_resolve_issues.py +++ b/tests/unit/resolver/gitlab/test_gitlab_resolve_issues.py @@ -163,8 +163,8 @@ async def test_resolve_issue_no_issues_found(default_mock_args, mock_gitlab_toke # Create a resolver instance with mocked token identification resolver = IssueResolver(default_mock_args) - # Mock the issue_handler_factory method - resolver.issue_handler_factory = MagicMock(return_value=mock_handler) + # Mock the issue handler + resolver.issue_handler = mock_handler # Test that the correct exception is raised with pytest.raises(ValueError) as exc_info: @@ -174,8 +174,6 @@ async def test_resolve_issue_no_issues_found(default_mock_args, mock_gitlab_toke assert 'No issues found for issue number 5432' in str(exc_info.value) assert 'test-owner/test-repo' in str(exc_info.value) - # Verify that the handler was correctly configured and called - resolver.issue_handler_factory.assert_called_once() mock_handler.get_converted_issues.assert_called_once_with( issue_numbers=[5432], comment_id=None ) @@ -483,7 +481,8 @@ async def test_process_issue( resolver = IssueResolver(default_mock_args) resolver.prompt_template = mock_prompt_template - # Mock the handler + # Mock the handler with LLM config + llm_config = LLMConfig(model='test', api_key='test') handler_instance = MagicMock() handler_instance.guess_success.return_value = ( test_case['expected_success'], @@ -492,6 +491,7 @@ async def test_process_issue( ) handler_instance.get_instruction.return_value = ('Test instruction', []) handler_instance.issue_type = 'pr' if test_case.get('is_pr', False) else 'issue' + handler_instance.llm = LLM(llm_config) # Create mock runtime and mock run_controller mock_runtime = MagicMock() diff --git a/tests/unit/resolver/test_issue_handler_factory.py b/tests/unit/resolver/test_issue_handler_factory.py new file mode 100644 index 0000000000..56262ed9e5 --- /dev/null +++ b/tests/unit/resolver/test_issue_handler_factory.py @@ -0,0 +1,77 @@ +from typing import Type +from unittest.mock import MagicMock + +import pytest +from pydantic import SecretStr + +from openhands.core.config import LLMConfig +from openhands.integrations.provider import ProviderType +from openhands.resolver.interfaces.github import GithubIssueHandler, GithubPRHandler +from openhands.resolver.interfaces.gitlab import GitlabIssueHandler, GitlabPRHandler +from openhands.resolver.issue_handler_factory import IssueHandlerFactory +from openhands.resolver.interfaces.issue_definitions import ( + ServiceContextIssue, + ServiceContextPR, +) + + +@pytest.fixture +def llm_config(): + return LLMConfig( + model='test-model', + api_key=SecretStr('test-key'), + ) + + +@pytest.fixture +def factory_params(llm_config): + return { + 'owner': 'test-owner', + 'repo': 'test-repo', + 'token': 'test-token', + 'username': 'test-user', + 'base_domain': 'github.com', + 'llm_config': llm_config, + } + + +test_cases = [ + # platform, issue_type, expected_context_type, expected_handler_type + (ProviderType.GITHUB, 'issue', ServiceContextIssue, GithubIssueHandler), + (ProviderType.GITHUB, 'pr', ServiceContextPR, GithubPRHandler), + (ProviderType.GITLAB, 'issue', ServiceContextIssue, GitlabIssueHandler), + (ProviderType.GITLAB, 'pr', ServiceContextPR, GitlabPRHandler), +] + + +@pytest.mark.parametrize( + 'platform,issue_type,expected_context_type,expected_handler_type', + test_cases +) +def test_handler_creation( + factory_params, + platform: ProviderType, + issue_type: str, + expected_context_type: Type, + expected_handler_type: Type, +): + factory = IssueHandlerFactory( + **factory_params, + platform=platform, + issue_type=issue_type + ) + + handler = factory.create() + + assert isinstance(handler, expected_context_type) + assert isinstance(handler._strategy, expected_handler_type) + +def test_invalid_issue_type(factory_params): + factory = IssueHandlerFactory( + **factory_params, + platform=ProviderType.GITHUB, + issue_type='invalid' + ) + + with pytest.raises(ValueError, match='Invalid issue type: invalid'): + factory.create() \ No newline at end of file