diff --git a/openhands/integrations/bitbucket/__init__.py b/openhands/integrations/bitbucket/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/openhands/integrations/bitbucket/service/base.py b/openhands/integrations/bitbucket/service/base.py index be135c5ca5..d7c9b4adf7 100644 --- a/openhands/integrations/bitbucket/service/base.py +++ b/openhands/integrations/bitbucket/service/base.py @@ -4,6 +4,7 @@ from typing import Any import httpx from pydantic import SecretStr +from openhands.integrations.protocols.http_client import HTTPClient from openhands.integrations.service_types import ( BaseGitService, OwnerType, @@ -15,14 +16,12 @@ from openhands.integrations.service_types import ( ) -class BitBucketMixinBase(BaseGitService): +class BitBucketMixinBase(BaseGitService, HTTPClient): """ Base mixin for BitBucket service containing common functionality """ BASE_URL = 'https://api.bitbucket.org/2.0' - token: SecretStr = SecretStr('') - refresh = False def _extract_owner_and_repo(self, repository: str) -> tuple[str, str]: """Extract owner and repo from repository string. @@ -49,7 +48,7 @@ class BitBucketMixinBase(BaseGitService): def _has_token_expired(self, status_code: int) -> bool: return status_code == 401 - async def _get_bitbucket_headers(self) -> dict[str, str]: + async def _get_headers(self) -> dict[str, str]: """Get headers for Bitbucket API requests.""" token_value = self.token.get_secret_value() @@ -85,13 +84,13 @@ class BitBucketMixinBase(BaseGitService): """ try: async with httpx.AsyncClient() as client: - bitbucket_headers = await self._get_bitbucket_headers() + bitbucket_headers = await self._get_headers() response = await self.execute_request( client, url, bitbucket_headers, params, method ) if self.refresh and self._has_token_expired(response.status_code): await self.get_latest_token() - bitbucket_headers = await self._get_bitbucket_headers() + bitbucket_headers = await self._get_headers() response = await self.execute_request( client=client, url=url, diff --git a/openhands/integrations/github/github_service.py b/openhands/integrations/github/github_service.py index 01c28b0fe2..4d0eb6080f 100644 --- a/openhands/integrations/github/github_service.py +++ b/openhands/integrations/github/github_service.py @@ -43,8 +43,6 @@ class GitHubService( BASE_URL = 'https://api.github.com' GRAPHQL_URL = 'https://api.github.com/graphql' - token: SecretStr = SecretStr('') - refresh = False def __init__( self, diff --git a/openhands/integrations/github/service/__init__.py b/openhands/integrations/github/service/__init__.py index ea0bef7d64..3509f4056b 100644 --- a/openhands/integrations/github/service/__init__.py +++ b/openhands/integrations/github/service/__init__.py @@ -1,5 +1,6 @@ # openhands/integrations/github/service/__init__.py +from .base import GitHubMixinBase from .branches_prs import GitHubBranchesMixin from .features import GitHubFeaturesMixin from .prs import GitHubPRsMixin @@ -7,6 +8,7 @@ from .repos import GitHubReposMixin from .resolver import GitHubResolverMixin __all__ = [ + 'GitHubMixinBase', 'GitHubBranchesMixin', 'GitHubFeaturesMixin', 'GitHubPRsMixin', diff --git a/openhands/integrations/github/service/base.py b/openhands/integrations/github/service/base.py index 16d4f0dcce..556c647390 100644 --- a/openhands/integrations/github/service/base.py +++ b/openhands/integrations/github/service/base.py @@ -4,6 +4,7 @@ from typing import Any, cast import httpx from pydantic import SecretStr +from openhands.integrations.protocols.http_client import HTTPClient from openhands.integrations.service_types import ( BaseGitService, RequestMethod, @@ -12,19 +13,15 @@ from openhands.integrations.service_types import ( ) -class GitHubMixinBase(BaseGitService): +class GitHubMixinBase(BaseGitService, HTTPClient): """ Declares common attributes and method signatures used across mixins. """ BASE_URL: str GRAPHQL_URL: str - token: SecretStr - refresh: bool - external_auth_id: str | None - base_domain: str | None - async def _get_github_headers(self) -> dict: + async def _get_headers(self) -> dict: """Retrieve the GH Token from settings store to construct the headers.""" if not self.token: latest_token = await self.get_latest_token() @@ -47,7 +44,7 @@ class GitHubMixinBase(BaseGitService): ) -> tuple[Any, dict]: # type: ignore[override] try: async with httpx.AsyncClient() as client: - github_headers = await self._get_github_headers() + github_headers = await self._get_headers() # Make initial request response = await self.execute_request( @@ -61,7 +58,7 @@ class GitHubMixinBase(BaseGitService): # Handle token refresh if needed if self.refresh and self._has_token_expired(response.status_code): await self.get_latest_token() - github_headers = await self._get_github_headers() + github_headers = await self._get_headers() response = await self.execute_request( client=client, url=url, @@ -87,7 +84,7 @@ class GitHubMixinBase(BaseGitService): ) -> dict[str, Any]: try: async with httpx.AsyncClient() as client: - github_headers = await self._get_github_headers() + github_headers = await self._get_headers() response = await client.post( self.GRAPHQL_URL, diff --git a/openhands/integrations/gitlab/gitlab_service.py b/openhands/integrations/gitlab/gitlab_service.py index 24bb32139b..7c672b7f19 100644 --- a/openhands/integrations/gitlab/gitlab_service.py +++ b/openhands/integrations/gitlab/gitlab_service.py @@ -41,8 +41,6 @@ class GitLabService( BASE_URL = 'https://gitlab.com/api/v4' GRAPHQL_URL = 'https://gitlab.com/api/graphql' - token: SecretStr = SecretStr('') - refresh = False def __init__( self, diff --git a/openhands/integrations/gitlab/service/__init__.py b/openhands/integrations/gitlab/service/__init__.py index 94b01c2c32..b90c66ed40 100644 --- a/openhands/integrations/gitlab/service/__init__.py +++ b/openhands/integrations/gitlab/service/__init__.py @@ -1,5 +1,6 @@ # openhands/integrations/gitlab/service/__init__.py +from .base import GitLabMixinBase from .branches import GitLabBranchesMixin from .features import GitLabFeaturesMixin from .prs import GitLabPRsMixin @@ -7,6 +8,7 @@ from .repos import GitLabReposMixin from .resolver import GitLabResolverMixin __all__ = [ + 'GitLabMixinBase', 'GitLabBranchesMixin', 'GitLabFeaturesMixin', 'GitLabPRsMixin', diff --git a/openhands/integrations/gitlab/service/base.py b/openhands/integrations/gitlab/service/base.py index edbb05baae..239d972720 100644 --- a/openhands/integrations/gitlab/service/base.py +++ b/openhands/integrations/gitlab/service/base.py @@ -3,6 +3,7 @@ from typing import Any import httpx from pydantic import SecretStr +from openhands.integrations.protocols.http_client import HTTPClient from openhands.integrations.service_types import ( BaseGitService, RequestMethod, @@ -11,19 +12,15 @@ from openhands.integrations.service_types import ( ) -class GitLabMixinBase(BaseGitService): +class GitLabMixinBase(BaseGitService, HTTPClient): """ Declares common attributes and method signatures used across mixins. """ BASE_URL: str GRAPHQL_URL: str - token: SecretStr - refresh: bool - external_auth_id: str | None - base_domain: str | None - async def _get_gitlab_headers(self) -> dict[str, Any]: + async def _get_headers(self) -> dict[str, Any]: """Retrieve the GitLab Token to construct the headers""" if not self.token: latest_token = await self.get_latest_token() @@ -45,7 +42,7 @@ class GitLabMixinBase(BaseGitService): ) -> tuple[Any, dict]: # type: ignore[override] try: async with httpx.AsyncClient() as client: - gitlab_headers = await self._get_gitlab_headers() + gitlab_headers = await self._get_headers() # Make initial request response = await self.execute_request( @@ -59,7 +56,7 @@ class GitLabMixinBase(BaseGitService): # Handle token refresh if needed if self.refresh and self._has_token_expired(response.status_code): await self.get_latest_token() - gitlab_headers = await self._get_gitlab_headers() + gitlab_headers = await self._get_headers() response = await self.execute_request( client=client, url=url, @@ -103,7 +100,7 @@ class GitLabMixinBase(BaseGitService): variables = {} try: async with httpx.AsyncClient() as client: - gitlab_headers = await self._get_gitlab_headers() + gitlab_headers = await self._get_headers() # Add content type header for GraphQL gitlab_headers['Content-Type'] = 'application/json' @@ -118,7 +115,7 @@ class GitLabMixinBase(BaseGitService): if self.refresh and self._has_token_expired(response.status_code): await self.get_latest_token() - gitlab_headers = await self._get_gitlab_headers() + gitlab_headers = await self._get_headers() gitlab_headers['Content-Type'] = 'application/json' response = await client.post( self.GRAPHQL_URL, headers=gitlab_headers, json=payload diff --git a/openhands/integrations/protocols/http_client.py b/openhands/integrations/protocols/http_client.py new file mode 100644 index 0000000000..21ec1857e0 --- /dev/null +++ b/openhands/integrations/protocols/http_client.py @@ -0,0 +1,99 @@ +"""HTTP Client Protocol for Git Service Integrations.""" + +from abc import ABC, abstractmethod +from typing import Any + +from httpx import AsyncClient, HTTPError, HTTPStatusError +from pydantic import SecretStr + +from openhands.core.logger import openhands_logger as logger +from openhands.integrations.service_types import ( + AuthenticationError, + RateLimitError, + RequestMethod, + ResourceNotFoundError, + UnknownException, +) + + +class HTTPClient(ABC): + """Abstract base class defining the HTTP client interface for Git service integrations. + + This class abstracts the common HTTP client functionality needed by all + Git service providers (GitHub, GitLab, BitBucket) while keeping inheritance in place. + """ + + # Default attributes (subclasses may override) + token: SecretStr = SecretStr('') + refresh: bool = False + external_auth_id: str | None = None + external_auth_token: SecretStr | None = None + external_token_manager: bool = False + base_domain: str | None = None + + # Provider identification must be implemented by subclasses + @property + @abstractmethod + def provider(self) -> str: ... + + # Abstract methods that concrete classes must implement + @abstractmethod + async def get_latest_token(self) -> SecretStr | None: + """Get the latest working token for the service.""" + ... + + @abstractmethod + async def _get_headers(self) -> dict[str, Any]: + """Get HTTP headers for API requests.""" + ... + + @abstractmethod + async def _make_request( + self, + url: str, + params: dict | None = None, + method: RequestMethod = RequestMethod.GET, + ) -> tuple[Any, dict]: + """Make an HTTP request to the Git service API.""" + ... + + def _has_token_expired(self, status_code: int) -> bool: + """Check if the token has expired based on HTTP status code.""" + return status_code == 401 + + async def execute_request( + self, + client: AsyncClient, + url: str, + headers: dict, + params: dict | None, + method: RequestMethod = RequestMethod.GET, + ): + """Execute an HTTP request using the provided client.""" + if method == RequestMethod.POST: + return await client.post(url, headers=headers, json=params) + return await client.get(url, headers=headers, params=params) + + def handle_http_status_error( + self, e: HTTPStatusError + ) -> ( + AuthenticationError | RateLimitError | ResourceNotFoundError | UnknownException + ): + """Handle HTTP status errors and convert them to appropriate exceptions.""" + if e.response.status_code == 401: + return AuthenticationError(f'Invalid {self.provider} token') + elif e.response.status_code == 404: + return ResourceNotFoundError( + f'Resource not found on {self.provider} API: {e}' + ) + elif e.response.status_code == 429: + logger.warning(f'Rate limit exceeded on {self.provider} API: {e}') + return RateLimitError(f'{self.provider} API rate limit exceeded') + + logger.warning(f'Status error on {self.provider} API: {e}') + return UnknownException(f'Unknown error: {e}') + + def handle_http_error(self, e: HTTPError) -> UnknownException: + """Handle general HTTP errors.""" + logger.warning(f'HTTP error on {self.provider} API: {type(e).__name__} : {e}') + return UnknownException(f'HTTP error {type(e).__name__} : {e}') diff --git a/openhands/integrations/service_types.py b/openhands/integrations/service_types.py index 613ee69270..9c22c660a8 100644 --- a/openhands/integrations/service_types.py +++ b/openhands/integrations/service_types.py @@ -4,7 +4,6 @@ from enum import Enum from pathlib import Path from typing import Any, Protocol -from httpx import AsyncClient, HTTPError, HTTPStatusError from jinja2 import Environment, FileSystemLoader from pydantic import BaseModel, SecretStr @@ -242,40 +241,6 @@ class BaseGitService(ABC): """Extract file path from directory item.""" ... - async def execute_request( - self, - client: AsyncClient, - url: str, - headers: dict, - params: dict | None, - method: RequestMethod = RequestMethod.GET, - ): - if method == RequestMethod.POST: - return await client.post(url, headers=headers, json=params) - return await client.get(url, headers=headers, params=params) - - def handle_http_status_error( - self, e: HTTPStatusError - ) -> ( - AuthenticationError | RateLimitError | ResourceNotFoundError | UnknownException - ): - if e.response.status_code == 401: - return AuthenticationError(f'Invalid {self.provider} token') - elif e.response.status_code == 404: - return ResourceNotFoundError( - f'Resource not found on {self.provider} API: {e}' - ) - elif e.response.status_code == 429: - logger.warning(f'Rate limit exceeded on {self.provider} API: {e}') - return RateLimitError('GitHub API rate limit exceeded') - - logger.warning(f'Status error on {self.provider} API: {e}') - return UnknownException(f'Unknown error: {e}') - - def handle_http_error(self, e: HTTPError) -> UnknownException: - logger.warning(f'HTTP error on {self.provider} API: {type(e).__name__} : {e}') - return UnknownException(f'HTTP error {type(e).__name__} : {e}') - def _determine_microagents_path(self, repository_name: str) -> str: """Determine the microagents directory path based on repository name.""" actual_repo_name = repository_name.split('/')[-1] @@ -462,9 +427,6 @@ class BaseGitService(ABC): return comment_body[:max_comment_length] + '...' return comment_body - def _has_token_expired(self, status_code: int) -> bool: - return status_code == 401 - class InstallationsService(Protocol): async def get_installations(self) -> list[str]: diff --git a/tests/unit/integrations/github/test_github_service.py b/tests/unit/integrations/github/test_github_service.py index 0248e5f396..9e059e5e2c 100644 --- a/tests/unit/integrations/github/test_github_service.py +++ b/tests/unit/integrations/github/test_github_service.py @@ -24,7 +24,7 @@ async def test_github_service_token_handling(): assert service.token.get_secret_value() == 'test-token' # Test headers contain the token correctly - headers = await service._get_github_headers() + headers = await service._get_headers() assert headers['Authorization'] == 'Bearer test-token' assert headers['Accept'] == 'application/vnd.github.v3+json' diff --git a/tests/unit/integrations/protocols/test_http_client.py b/tests/unit/integrations/protocols/test_http_client.py new file mode 100644 index 0000000000..1210d34344 --- /dev/null +++ b/tests/unit/integrations/protocols/test_http_client.py @@ -0,0 +1,309 @@ +"""Unit tests for HTTPClient abstract base class (ABC).""" + +from typing import Any +from unittest.mock import AsyncMock, Mock + +import httpx +import pytest +from pydantic import SecretStr + +from openhands.integrations.protocols.http_client import HTTPClient +from openhands.integrations.service_types import ( + AuthenticationError, + RateLimitError, + RequestMethod, + ResourceNotFoundError, + UnknownException, +) + + +class TestableHTTPClient(HTTPClient): + """Testable concrete implementation of HTTPClient for unit testing.""" + + def __init__(self, provider_name: str = 'test-provider'): + self.token = SecretStr('test-token') + self.refresh = False + self.external_auth_id = None + self.external_auth_token = None + self.external_token_manager = False + self.base_domain = None + self._provider_name = provider_name + + @property + def provider(self) -> str: + return self._provider_name + + @provider.setter + def provider(self, value: str) -> None: + self._provider_name = value + + async def get_latest_token(self) -> SecretStr | None: + return self.token + + async def _get_headers(self) -> dict[str, Any]: + return {'Authorization': f'Bearer {self.token.get_secret_value()}'} + + async def _make_request( + self, + url: str, + params: dict | None = None, + method: RequestMethod = RequestMethod.GET, + ): + # Mock implementation for testing + return {'test': 'data'}, {} + + +@pytest.mark.asyncio +class TestHTTPClient: + """Test cases for HTTPClient ABC.""" + + def setup_method(self): + """Set up test fixtures.""" + self.client = TestableHTTPClient() + + def test_default_attributes(self): + """Test default attribute values.""" + assert isinstance(self.client.token, SecretStr) + assert self.client.refresh is False + assert self.client.external_auth_id is None + assert self.client.external_auth_token is None + assert self.client.external_token_manager is False + assert self.client.base_domain is None + + def test_provider_property(self): + """Test provider property.""" + assert self.client.provider == 'test-provider' + + def test_has_token_expired_default_implementation(self): + """Test default _has_token_expired implementation.""" + # The TestableHTTPClient inherits the default implementation from the protocol + client = TestableHTTPClient() + + assert client._has_token_expired(401) is True + assert client._has_token_expired(200) is False + assert client._has_token_expired(404) is False + assert client._has_token_expired(500) is False + + async def test_execute_request_get(self): + """Test execute_request with GET method.""" + client = TestableHTTPClient() + + mock_client = AsyncMock() + mock_response = AsyncMock() + mock_client.get.return_value = mock_response + + url = 'https://api.example.com/user' + headers = {'Authorization': 'Bearer token'} + params = {'per_page': 10} + + result = await client.execute_request( + mock_client, url, headers, params, RequestMethod.GET + ) + + assert result == mock_response + mock_client.get.assert_called_once_with(url, headers=headers, params=params) + + async def test_execute_request_post(self): + """Test execute_request with POST method.""" + client = TestableHTTPClient() + + mock_client = AsyncMock() + mock_response = AsyncMock() + mock_client.post.return_value = mock_response + + url = 'https://api.example.com/issues' + headers = {'Authorization': 'Bearer token'} + params = {'title': 'Test Issue'} + + result = await client.execute_request( + mock_client, url, headers, params, RequestMethod.POST + ) + + assert result == mock_response + mock_client.post.assert_called_once_with(url, headers=headers, json=params) + + def test_handle_http_status_error_401(self): + """Test handling of 401 HTTP status error.""" + client = TestableHTTPClient('github') + + mock_response = Mock() + mock_response.status_code = 401 + + error = httpx.HTTPStatusError( + message='401 Unauthorized', request=Mock(), response=mock_response + ) + + result = client.handle_http_status_error(error) + assert isinstance(result, AuthenticationError) + assert 'Invalid github token' in str(result) + + def test_handle_http_status_error_404(self): + """Test handling of 404 HTTP status error.""" + client = TestableHTTPClient() + client.provider = 'gitlab' + + mock_response = Mock() + mock_response.status_code = 404 + + error = httpx.HTTPStatusError( + message='404 Not Found', request=Mock(), response=mock_response + ) + + result = client.handle_http_status_error(error) + assert isinstance(result, ResourceNotFoundError) + assert 'Resource not found on gitlab API' in str(result) + + def test_handle_http_status_error_429(self): + """Test handling of 429 HTTP status error.""" + client = TestableHTTPClient() + client.provider = 'bitbucket' + + mock_response = Mock() + mock_response.status_code = 429 + + error = httpx.HTTPStatusError( + message='429 Too Many Requests', request=Mock(), response=mock_response + ) + + result = client.handle_http_status_error(error) + assert isinstance(result, RateLimitError) + assert 'bitbucket API rate limit exceeded' in str(result) + + def test_handle_http_status_error_other(self): + """Test handling of other HTTP status errors.""" + client = TestableHTTPClient() + client.provider = 'test-provider' + + mock_response = Mock() + mock_response.status_code = 500 + + error = httpx.HTTPStatusError( + message='500 Internal Server Error', request=Mock(), response=mock_response + ) + + result = client.handle_http_status_error(error) + assert isinstance(result, UnknownException) + assert 'Unknown error' in str(result) + + def test_handle_http_error(self): + """Test handling of general HTTP errors.""" + client = TestableHTTPClient() + client.provider = 'test-provider' + + error = httpx.ConnectError('Connection failed') + + result = client.handle_http_error(error) + assert isinstance(result, UnknownException) + assert 'HTTP error ConnectError' in str(result) + + def test_handle_http_error_with_different_error_types(self): + """Test handling of different HTTP error types.""" + client = TestableHTTPClient() + client.provider = 'test-provider' + + # Test with different error types + errors = [ + httpx.ConnectError('Connection failed'), + httpx.TimeoutException('Request timed out'), + httpx.ReadTimeout('Read timeout'), + httpx.WriteTimeout('Write timeout'), + ] + + for error in errors: + result = client.handle_http_error(error) + assert isinstance(result, UnknownException) + assert f'HTTP error {type(error).__name__}' in str(result) + + def test_runtime_checkable(self): + """Test that HTTPClient is runtime checkable.""" + from openhands.integrations.protocols.http_client import HTTPClient + + # Test that our testable client implements the protocol + assert isinstance(self.client, HTTPClient) + + # Test that a class without the required methods doesn't implement the protocol + class IncompleteClient: + pass + + incomplete = IncompleteClient() + assert not isinstance(incomplete, HTTPClient) + + def test_protocol_attributes_exist(self): + """Test that protocol defines expected attributes.""" + client = TestableHTTPClient() + + # Test default attribute values from protocol + assert hasattr(client, 'token') + assert hasattr(client, 'refresh') + assert hasattr(client, 'external_auth_id') + assert hasattr(client, 'external_auth_token') + assert hasattr(client, 'external_token_manager') + assert hasattr(client, 'base_domain') + + # Test TestableHTTPClient values + assert client.token == SecretStr('test-token') + assert client.refresh is False + assert client.external_auth_id is None + assert client.external_auth_token is None + assert client.external_token_manager is False + assert client.base_domain is None + + def test_protocol_methods_exist(self): + """Test that protocol defines expected methods.""" + client = TestableHTTPClient() + + # Test that methods exist + assert hasattr(client, 'get_latest_token') + assert hasattr(client, '_get_headers') + assert hasattr(client, '_make_request') + assert hasattr(client, '_has_token_expired') + assert hasattr(client, 'execute_request') + assert hasattr(client, 'handle_http_status_error') + assert hasattr(client, 'handle_http_error') + assert hasattr(client, 'provider') + + def test_protocol_concrete_methods_work(self): + """Test that concrete protocol methods work correctly.""" + client = TestableHTTPClient() + + # These methods should work since TestableHTTPClient implements them + assert client.provider == 'test-provider' + + # Test that the default implementations from the protocol are available + assert hasattr(client, '_has_token_expired') + assert hasattr(client, 'execute_request') + assert hasattr(client, 'handle_http_status_error') + assert hasattr(client, 'handle_http_error') + + def test_provider_specific_error_messages(self): + """Test that error messages are provider-specific.""" + providers = ['github', 'gitlab', 'bitbucket'] + + for provider in providers: + client = TestableHTTPClient() + client.provider = provider + + # Test 401 error + mock_response = Mock() + mock_response.status_code = 401 + error = httpx.HTTPStatusError( + message='401 Unauthorized', request=Mock(), response=mock_response + ) + result = client.handle_http_status_error(error) + assert f'Invalid {provider} token' in str(result) + + # Test 404 error + mock_response.status_code = 404 + error = httpx.HTTPStatusError( + message='404 Not Found', request=Mock(), response=mock_response + ) + result = client.handle_http_status_error(error) + assert f'Resource not found on {provider} API' in str(result) + + # Test 429 error + mock_response.status_code = 429 + error = httpx.HTTPStatusError( + message='429 Too Many Requests', request=Mock(), response=mock_response + ) + result = client.handle_http_status_error(error) + assert f'{provider} API rate limit exceeded' in str(result)