mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
[refactor]: Refactored the initialization of issue_handler within IssueResolver (#8417)
Co-authored-by: Rohit Malhotra <rohitvinodmalhotra@gmail.com>
This commit is contained in:
parent
da637a0dad
commit
a17c57d82e
80
openhands/resolver/issue_handler_factory.py
Normal file
80
openhands/resolver/issue_handler_factory.py
Normal file
@ -0,0 +1,80 @@
|
||||
from openhands.core.config import LLMConfig
|
||||
from openhands.integrations.provider import ProviderType
|
||||
from openhands.resolver.interfaces.github import GithubIssueHandler, GithubPRHandler
|
||||
from openhands.resolver.interfaces.gitlab import GitlabIssueHandler, GitlabPRHandler
|
||||
from openhands.resolver.interfaces.issue_definitions import (
|
||||
ServiceContextIssue,
|
||||
ServiceContextPR,
|
||||
)
|
||||
|
||||
|
||||
class IssueHandlerFactory:
|
||||
def __init__(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
token: str,
|
||||
username: str,
|
||||
platform: ProviderType,
|
||||
base_domain: str,
|
||||
issue_type: str,
|
||||
llm_config: LLMConfig,
|
||||
) -> None:
|
||||
self.owner = owner
|
||||
self.repo = repo
|
||||
self.token = token
|
||||
self.username = username
|
||||
self.platform = platform
|
||||
self.base_domain = base_domain
|
||||
self.issue_type = issue_type
|
||||
self.llm_config = llm_config
|
||||
|
||||
def create(self) -> ServiceContextIssue | ServiceContextPR:
|
||||
if self.issue_type == 'issue':
|
||||
if self.platform == ProviderType.GITHUB:
|
||||
return ServiceContextIssue(
|
||||
GithubIssueHandler(
|
||||
self.owner,
|
||||
self.repo,
|
||||
self.token,
|
||||
self.username,
|
||||
self.base_domain,
|
||||
),
|
||||
self.llm_config,
|
||||
)
|
||||
else: # platform == Platform.GITLAB
|
||||
return ServiceContextIssue(
|
||||
GitlabIssueHandler(
|
||||
self.owner,
|
||||
self.repo,
|
||||
self.token,
|
||||
self.username,
|
||||
self.base_domain,
|
||||
),
|
||||
self.llm_config,
|
||||
)
|
||||
elif self.issue_type == 'pr':
|
||||
if self.platform == ProviderType.GITHUB:
|
||||
return ServiceContextPR(
|
||||
GithubPRHandler(
|
||||
self.owner,
|
||||
self.repo,
|
||||
self.token,
|
||||
self.username,
|
||||
self.base_domain,
|
||||
),
|
||||
self.llm_config,
|
||||
)
|
||||
else: # platform == Platform.GITLAB
|
||||
return ServiceContextPR(
|
||||
GitlabPRHandler(
|
||||
self.owner,
|
||||
self.repo,
|
||||
self.token,
|
||||
self.username,
|
||||
self.base_domain,
|
||||
),
|
||||
self.llm_config,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'Invalid issue type: {self.issue_type}')
|
||||
@ -28,13 +28,12 @@ from openhands.events.observation import (
|
||||
)
|
||||
from openhands.events.stream import EventStreamSubscriber
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.resolver.interfaces.github import GithubIssueHandler, GithubPRHandler
|
||||
from openhands.resolver.interfaces.gitlab import GitlabIssueHandler, GitlabPRHandler
|
||||
from openhands.resolver.interfaces.issue import Issue
|
||||
from openhands.resolver.interfaces.issue_definitions import (
|
||||
ServiceContextIssue,
|
||||
ServiceContextPR,
|
||||
)
|
||||
from openhands.resolver.issue_handler_factory import IssueHandlerFactory
|
||||
from openhands.resolver.resolver_output import ResolverOutput
|
||||
from openhands.resolver.utils import (
|
||||
codeact_user_response,
|
||||
@ -162,8 +161,6 @@ class IssueResolver:
|
||||
|
||||
self.owner = owner
|
||||
self.repo = repo
|
||||
self.token = token
|
||||
self.username = username
|
||||
self.platform = platform
|
||||
self.runtime_container_image = runtime_container_image
|
||||
self.base_container_image = base_container_image
|
||||
@ -175,9 +172,20 @@ class IssueResolver:
|
||||
self.repo_instruction = repo_instruction
|
||||
self.issue_number = args.issue_number
|
||||
self.comment_id = args.comment_id
|
||||
self.base_domain = base_domain
|
||||
self.platform = platform
|
||||
|
||||
factory = IssueHandlerFactory(
|
||||
owner=self.owner,
|
||||
repo=self.repo,
|
||||
token=token,
|
||||
username=username,
|
||||
platform=self.platform,
|
||||
base_domain=base_domain,
|
||||
issue_type=self.issue_type,
|
||||
llm_config=self.llm_config,
|
||||
)
|
||||
self.issue_handler = factory.create()
|
||||
|
||||
def initialize_runtime(
|
||||
self,
|
||||
runtime: Runtime,
|
||||
@ -445,58 +453,6 @@ class IssueResolver:
|
||||
)
|
||||
return output
|
||||
|
||||
def issue_handler_factory(self) -> ServiceContextIssue | ServiceContextPR:
|
||||
# Determine default base_domain based on platform
|
||||
|
||||
if self.issue_type == 'issue':
|
||||
if self.platform == ProviderType.GITHUB:
|
||||
return ServiceContextIssue(
|
||||
GithubIssueHandler(
|
||||
self.owner,
|
||||
self.repo,
|
||||
self.token,
|
||||
self.username,
|
||||
self.base_domain,
|
||||
),
|
||||
self.llm_config,
|
||||
)
|
||||
else: # platform == Platform.GITLAB
|
||||
return ServiceContextIssue(
|
||||
GitlabIssueHandler(
|
||||
self.owner,
|
||||
self.repo,
|
||||
self.token,
|
||||
self.username,
|
||||
self.base_domain,
|
||||
),
|
||||
self.llm_config,
|
||||
)
|
||||
elif self.issue_type == 'pr':
|
||||
if self.platform == ProviderType.GITHUB:
|
||||
return ServiceContextPR(
|
||||
GithubPRHandler(
|
||||
self.owner,
|
||||
self.repo,
|
||||
self.token,
|
||||
self.username,
|
||||
self.base_domain,
|
||||
),
|
||||
self.llm_config,
|
||||
)
|
||||
else: # platform == Platform.GITLAB
|
||||
return ServiceContextPR(
|
||||
GitlabPRHandler(
|
||||
self.owner,
|
||||
self.repo,
|
||||
self.token,
|
||||
self.username,
|
||||
self.base_domain,
|
||||
),
|
||||
self.llm_config,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'Invalid issue type: {self.issue_type}')
|
||||
|
||||
async def resolve_issue(
|
||||
self,
|
||||
reset_logger: bool = False,
|
||||
@ -507,10 +463,8 @@ class IssueResolver:
|
||||
reset_logger: Whether to reset the logger for multiprocessing.
|
||||
"""
|
||||
|
||||
issue_handler = self.issue_handler_factory()
|
||||
|
||||
# Load dataset
|
||||
issues: list[Issue] = issue_handler.get_converted_issues(
|
||||
issues: list[Issue] = self.issue_handler.get_converted_issues(
|
||||
issue_numbers=[self.issue_number], comment_id=self.comment_id
|
||||
)
|
||||
|
||||
@ -556,7 +510,7 @@ class IssueResolver:
|
||||
[
|
||||
'git',
|
||||
'clone',
|
||||
issue_handler.get_clone_url(),
|
||||
self.issue_handler.get_clone_url(),
|
||||
f'{self.output_dir}/repo',
|
||||
]
|
||||
).decode('utf-8')
|
||||
@ -635,7 +589,7 @@ class IssueResolver:
|
||||
output = await self.process_issue(
|
||||
issue,
|
||||
base_commit,
|
||||
issue_handler,
|
||||
self.issue_handler,
|
||||
reset_logger,
|
||||
)
|
||||
output_fp.write(output.model_dump_json() + '\n')
|
||||
|
||||
@ -142,8 +142,8 @@ async def test_resolve_issue_no_issues_found(default_mock_args, mock_github_toke
|
||||
# Create a resolver instance with mocked token identification
|
||||
resolver = IssueResolver(default_mock_args)
|
||||
|
||||
# Mock the issue_handler_factory method
|
||||
resolver.issue_handler_factory = MagicMock(return_value=mock_handler)
|
||||
# Mock the issue handler
|
||||
resolver.issue_handler = mock_handler
|
||||
|
||||
# Test that the correct exception is raised
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
@ -153,8 +153,6 @@ async def test_resolve_issue_no_issues_found(default_mock_args, mock_github_toke
|
||||
assert 'No issues found for issue number 5432' in str(exc_info.value)
|
||||
assert 'test-owner/test-repo' in str(exc_info.value)
|
||||
|
||||
# Verify that the handler was correctly configured and called
|
||||
resolver.issue_handler_factory.assert_called_once()
|
||||
mock_handler.get_converted_issues.assert_called_once_with(
|
||||
issue_numbers=[5432], comment_id=None
|
||||
)
|
||||
@ -447,7 +445,8 @@ async def test_process_issue(
|
||||
resolver = IssueResolver(default_mock_args)
|
||||
resolver.prompt_template = mock_prompt_template
|
||||
|
||||
# Mock the handler
|
||||
# Mock the handler with LLM config
|
||||
llm_config = LLMConfig(model='test', api_key='test')
|
||||
handler_instance = MagicMock()
|
||||
handler_instance.guess_success.return_value = (
|
||||
test_case['expected_success'],
|
||||
@ -456,6 +455,7 @@ async def test_process_issue(
|
||||
)
|
||||
handler_instance.get_instruction.return_value = ('Test instruction', [])
|
||||
handler_instance.issue_type = 'pr' if test_case.get('is_pr', False) else 'issue'
|
||||
handler_instance.llm = LLM(llm_config)
|
||||
|
||||
# Mock the runtime and its methods
|
||||
mock_runtime = MagicMock()
|
||||
|
||||
@ -163,8 +163,8 @@ async def test_resolve_issue_no_issues_found(default_mock_args, mock_gitlab_toke
|
||||
# Create a resolver instance with mocked token identification
|
||||
resolver = IssueResolver(default_mock_args)
|
||||
|
||||
# Mock the issue_handler_factory method
|
||||
resolver.issue_handler_factory = MagicMock(return_value=mock_handler)
|
||||
# Mock the issue handler
|
||||
resolver.issue_handler = mock_handler
|
||||
|
||||
# Test that the correct exception is raised
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
@ -174,8 +174,6 @@ async def test_resolve_issue_no_issues_found(default_mock_args, mock_gitlab_toke
|
||||
assert 'No issues found for issue number 5432' in str(exc_info.value)
|
||||
assert 'test-owner/test-repo' in str(exc_info.value)
|
||||
|
||||
# Verify that the handler was correctly configured and called
|
||||
resolver.issue_handler_factory.assert_called_once()
|
||||
mock_handler.get_converted_issues.assert_called_once_with(
|
||||
issue_numbers=[5432], comment_id=None
|
||||
)
|
||||
@ -483,7 +481,8 @@ async def test_process_issue(
|
||||
resolver = IssueResolver(default_mock_args)
|
||||
resolver.prompt_template = mock_prompt_template
|
||||
|
||||
# Mock the handler
|
||||
# Mock the handler with LLM config
|
||||
llm_config = LLMConfig(model='test', api_key='test')
|
||||
handler_instance = MagicMock()
|
||||
handler_instance.guess_success.return_value = (
|
||||
test_case['expected_success'],
|
||||
@ -492,6 +491,7 @@ async def test_process_issue(
|
||||
)
|
||||
handler_instance.get_instruction.return_value = ('Test instruction', [])
|
||||
handler_instance.issue_type = 'pr' if test_case.get('is_pr', False) else 'issue'
|
||||
handler_instance.llm = LLM(llm_config)
|
||||
|
||||
# Create mock runtime and mock run_controller
|
||||
mock_runtime = MagicMock()
|
||||
|
||||
77
tests/unit/resolver/test_issue_handler_factory.py
Normal file
77
tests/unit/resolver/test_issue_handler_factory.py
Normal file
@ -0,0 +1,77 @@
|
||||
from typing import Type
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.core.config import LLMConfig
|
||||
from openhands.integrations.provider import ProviderType
|
||||
from openhands.resolver.interfaces.github import GithubIssueHandler, GithubPRHandler
|
||||
from openhands.resolver.interfaces.gitlab import GitlabIssueHandler, GitlabPRHandler
|
||||
from openhands.resolver.issue_handler_factory import IssueHandlerFactory
|
||||
from openhands.resolver.interfaces.issue_definitions import (
|
||||
ServiceContextIssue,
|
||||
ServiceContextPR,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm_config():
|
||||
return LLMConfig(
|
||||
model='test-model',
|
||||
api_key=SecretStr('test-key'),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def factory_params(llm_config):
|
||||
return {
|
||||
'owner': 'test-owner',
|
||||
'repo': 'test-repo',
|
||||
'token': 'test-token',
|
||||
'username': 'test-user',
|
||||
'base_domain': 'github.com',
|
||||
'llm_config': llm_config,
|
||||
}
|
||||
|
||||
|
||||
test_cases = [
|
||||
# platform, issue_type, expected_context_type, expected_handler_type
|
||||
(ProviderType.GITHUB, 'issue', ServiceContextIssue, GithubIssueHandler),
|
||||
(ProviderType.GITHUB, 'pr', ServiceContextPR, GithubPRHandler),
|
||||
(ProviderType.GITLAB, 'issue', ServiceContextIssue, GitlabIssueHandler),
|
||||
(ProviderType.GITLAB, 'pr', ServiceContextPR, GitlabPRHandler),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'platform,issue_type,expected_context_type,expected_handler_type',
|
||||
test_cases
|
||||
)
|
||||
def test_handler_creation(
|
||||
factory_params,
|
||||
platform: ProviderType,
|
||||
issue_type: str,
|
||||
expected_context_type: Type,
|
||||
expected_handler_type: Type,
|
||||
):
|
||||
factory = IssueHandlerFactory(
|
||||
**factory_params,
|
||||
platform=platform,
|
||||
issue_type=issue_type
|
||||
)
|
||||
|
||||
handler = factory.create()
|
||||
|
||||
assert isinstance(handler, expected_context_type)
|
||||
assert isinstance(handler._strategy, expected_handler_type)
|
||||
|
||||
def test_invalid_issue_type(factory_params):
|
||||
factory = IssueHandlerFactory(
|
||||
**factory_params,
|
||||
platform=ProviderType.GITHUB,
|
||||
issue_type='invalid'
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match='Invalid issue type: invalid'):
|
||||
factory.create()
|
||||
Loading…
x
Reference in New Issue
Block a user