refactor: introduce HTTPClient protocol for git service integrations (#10731)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Rohit Malhotra 2025-09-01 01:44:31 -04:00 committed by GitHub
parent 21f3ef540f
commit 3e87c08631
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 431 additions and 68 deletions

View File

@ -4,6 +4,7 @@ from typing import Any
import httpx
from pydantic import SecretStr
from openhands.integrations.protocols.http_client import HTTPClient
from openhands.integrations.service_types import (
BaseGitService,
OwnerType,
@ -15,14 +16,12 @@ from openhands.integrations.service_types import (
)
class BitBucketMixinBase(BaseGitService):
class BitBucketMixinBase(BaseGitService, HTTPClient):
"""
Base mixin for BitBucket service containing common functionality
"""
BASE_URL = 'https://api.bitbucket.org/2.0'
token: SecretStr = SecretStr('')
refresh = False
def _extract_owner_and_repo(self, repository: str) -> tuple[str, str]:
"""Extract owner and repo from repository string.
@ -49,7 +48,7 @@ class BitBucketMixinBase(BaseGitService):
def _has_token_expired(self, status_code: int) -> bool:
return status_code == 401
async def _get_bitbucket_headers(self) -> dict[str, str]:
async def _get_headers(self) -> dict[str, str]:
"""Get headers for Bitbucket API requests."""
token_value = self.token.get_secret_value()
@ -85,13 +84,13 @@ class BitBucketMixinBase(BaseGitService):
"""
try:
async with httpx.AsyncClient() as client:
bitbucket_headers = await self._get_bitbucket_headers()
bitbucket_headers = await self._get_headers()
response = await self.execute_request(
client, url, bitbucket_headers, params, method
)
if self.refresh and self._has_token_expired(response.status_code):
await self.get_latest_token()
bitbucket_headers = await self._get_bitbucket_headers()
bitbucket_headers = await self._get_headers()
response = await self.execute_request(
client=client,
url=url,

View File

@ -43,8 +43,6 @@ class GitHubService(
BASE_URL = 'https://api.github.com'
GRAPHQL_URL = 'https://api.github.com/graphql'
token: SecretStr = SecretStr('')
refresh = False
def __init__(
self,

View File

@ -1,5 +1,6 @@
# openhands/integrations/github/service/__init__.py
from .base import GitHubMixinBase
from .branches_prs import GitHubBranchesMixin
from .features import GitHubFeaturesMixin
from .prs import GitHubPRsMixin
@ -7,6 +8,7 @@ from .repos import GitHubReposMixin
from .resolver import GitHubResolverMixin
__all__ = [
'GitHubMixinBase',
'GitHubBranchesMixin',
'GitHubFeaturesMixin',
'GitHubPRsMixin',

View File

@ -4,6 +4,7 @@ from typing import Any, cast
import httpx
from pydantic import SecretStr
from openhands.integrations.protocols.http_client import HTTPClient
from openhands.integrations.service_types import (
BaseGitService,
RequestMethod,
@ -12,19 +13,15 @@ from openhands.integrations.service_types import (
)
class GitHubMixinBase(BaseGitService):
class GitHubMixinBase(BaseGitService, HTTPClient):
"""
Declares common attributes and method signatures used across mixins.
"""
BASE_URL: str
GRAPHQL_URL: str
token: SecretStr
refresh: bool
external_auth_id: str | None
base_domain: str | None
async def _get_github_headers(self) -> dict:
async def _get_headers(self) -> dict:
"""Retrieve the GH Token from settings store to construct the headers."""
if not self.token:
latest_token = await self.get_latest_token()
@ -47,7 +44,7 @@ class GitHubMixinBase(BaseGitService):
) -> tuple[Any, dict]: # type: ignore[override]
try:
async with httpx.AsyncClient() as client:
github_headers = await self._get_github_headers()
github_headers = await self._get_headers()
# Make initial request
response = await self.execute_request(
@ -61,7 +58,7 @@ class GitHubMixinBase(BaseGitService):
# Handle token refresh if needed
if self.refresh and self._has_token_expired(response.status_code):
await self.get_latest_token()
github_headers = await self._get_github_headers()
github_headers = await self._get_headers()
response = await self.execute_request(
client=client,
url=url,
@ -87,7 +84,7 @@ class GitHubMixinBase(BaseGitService):
) -> dict[str, Any]:
try:
async with httpx.AsyncClient() as client:
github_headers = await self._get_github_headers()
github_headers = await self._get_headers()
response = await client.post(
self.GRAPHQL_URL,

View File

@ -41,8 +41,6 @@ class GitLabService(
BASE_URL = 'https://gitlab.com/api/v4'
GRAPHQL_URL = 'https://gitlab.com/api/graphql'
token: SecretStr = SecretStr('')
refresh = False
def __init__(
self,

View File

@ -1,5 +1,6 @@
# openhands/integrations/gitlab/service/__init__.py
from .base import GitLabMixinBase
from .branches import GitLabBranchesMixin
from .features import GitLabFeaturesMixin
from .prs import GitLabPRsMixin
@ -7,6 +8,7 @@ from .repos import GitLabReposMixin
from .resolver import GitLabResolverMixin
__all__ = [
'GitLabMixinBase',
'GitLabBranchesMixin',
'GitLabFeaturesMixin',
'GitLabPRsMixin',

View File

@ -3,6 +3,7 @@ from typing import Any
import httpx
from pydantic import SecretStr
from openhands.integrations.protocols.http_client import HTTPClient
from openhands.integrations.service_types import (
BaseGitService,
RequestMethod,
@ -11,19 +12,15 @@ from openhands.integrations.service_types import (
)
class GitLabMixinBase(BaseGitService):
class GitLabMixinBase(BaseGitService, HTTPClient):
"""
Declares common attributes and method signatures used across mixins.
"""
BASE_URL: str
GRAPHQL_URL: str
token: SecretStr
refresh: bool
external_auth_id: str | None
base_domain: str | None
async def _get_gitlab_headers(self) -> dict[str, Any]:
async def _get_headers(self) -> dict[str, Any]:
"""Retrieve the GitLab Token to construct the headers"""
if not self.token:
latest_token = await self.get_latest_token()
@ -45,7 +42,7 @@ class GitLabMixinBase(BaseGitService):
) -> tuple[Any, dict]: # type: ignore[override]
try:
async with httpx.AsyncClient() as client:
gitlab_headers = await self._get_gitlab_headers()
gitlab_headers = await self._get_headers()
# Make initial request
response = await self.execute_request(
@ -59,7 +56,7 @@ class GitLabMixinBase(BaseGitService):
# Handle token refresh if needed
if self.refresh and self._has_token_expired(response.status_code):
await self.get_latest_token()
gitlab_headers = await self._get_gitlab_headers()
gitlab_headers = await self._get_headers()
response = await self.execute_request(
client=client,
url=url,
@ -103,7 +100,7 @@ class GitLabMixinBase(BaseGitService):
variables = {}
try:
async with httpx.AsyncClient() as client:
gitlab_headers = await self._get_gitlab_headers()
gitlab_headers = await self._get_headers()
# Add content type header for GraphQL
gitlab_headers['Content-Type'] = 'application/json'
@ -118,7 +115,7 @@ class GitLabMixinBase(BaseGitService):
if self.refresh and self._has_token_expired(response.status_code):
await self.get_latest_token()
gitlab_headers = await self._get_gitlab_headers()
gitlab_headers = await self._get_headers()
gitlab_headers['Content-Type'] = 'application/json'
response = await client.post(
self.GRAPHQL_URL, headers=gitlab_headers, json=payload

View File

@ -0,0 +1,99 @@
"""HTTP Client Protocol for Git Service Integrations."""
from abc import ABC, abstractmethod
from typing import Any
from httpx import AsyncClient, HTTPError, HTTPStatusError
from pydantic import SecretStr
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.service_types import (
AuthenticationError,
RateLimitError,
RequestMethod,
ResourceNotFoundError,
UnknownException,
)
class HTTPClient(ABC):
"""Abstract base class defining the HTTP client interface for Git service integrations.
This class abstracts the common HTTP client functionality needed by all
Git service providers (GitHub, GitLab, BitBucket) while keeping inheritance in place.
"""
# Default attributes (subclasses may override)
token: SecretStr = SecretStr('')
refresh: bool = False
external_auth_id: str | None = None
external_auth_token: SecretStr | None = None
external_token_manager: bool = False
base_domain: str | None = None
# Provider identification must be implemented by subclasses
@property
@abstractmethod
def provider(self) -> str: ...
# Abstract methods that concrete classes must implement
@abstractmethod
async def get_latest_token(self) -> SecretStr | None:
"""Get the latest working token for the service."""
...
@abstractmethod
async def _get_headers(self) -> dict[str, Any]:
"""Get HTTP headers for API requests."""
...
@abstractmethod
async def _make_request(
self,
url: str,
params: dict | None = None,
method: RequestMethod = RequestMethod.GET,
) -> tuple[Any, dict]:
"""Make an HTTP request to the Git service API."""
...
def _has_token_expired(self, status_code: int) -> bool:
"""Check if the token has expired based on HTTP status code."""
return status_code == 401
async def execute_request(
self,
client: AsyncClient,
url: str,
headers: dict,
params: dict | None,
method: RequestMethod = RequestMethod.GET,
):
"""Execute an HTTP request using the provided client."""
if method == RequestMethod.POST:
return await client.post(url, headers=headers, json=params)
return await client.get(url, headers=headers, params=params)
def handle_http_status_error(
self, e: HTTPStatusError
) -> (
AuthenticationError | RateLimitError | ResourceNotFoundError | UnknownException
):
"""Handle HTTP status errors and convert them to appropriate exceptions."""
if e.response.status_code == 401:
return AuthenticationError(f'Invalid {self.provider} token')
elif e.response.status_code == 404:
return ResourceNotFoundError(
f'Resource not found on {self.provider} API: {e}'
)
elif e.response.status_code == 429:
logger.warning(f'Rate limit exceeded on {self.provider} API: {e}')
return RateLimitError(f'{self.provider} API rate limit exceeded')
logger.warning(f'Status error on {self.provider} API: {e}')
return UnknownException(f'Unknown error: {e}')
def handle_http_error(self, e: HTTPError) -> UnknownException:
"""Handle general HTTP errors."""
logger.warning(f'HTTP error on {self.provider} API: {type(e).__name__} : {e}')
return UnknownException(f'HTTP error {type(e).__name__} : {e}')

View File

@ -4,7 +4,6 @@ from enum import Enum
from pathlib import Path
from typing import Any, Protocol
from httpx import AsyncClient, HTTPError, HTTPStatusError
from jinja2 import Environment, FileSystemLoader
from pydantic import BaseModel, SecretStr
@ -242,40 +241,6 @@ class BaseGitService(ABC):
"""Extract file path from directory item."""
...
async def execute_request(
self,
client: AsyncClient,
url: str,
headers: dict,
params: dict | None,
method: RequestMethod = RequestMethod.GET,
):
if method == RequestMethod.POST:
return await client.post(url, headers=headers, json=params)
return await client.get(url, headers=headers, params=params)
def handle_http_status_error(
self, e: HTTPStatusError
) -> (
AuthenticationError | RateLimitError | ResourceNotFoundError | UnknownException
):
if e.response.status_code == 401:
return AuthenticationError(f'Invalid {self.provider} token')
elif e.response.status_code == 404:
return ResourceNotFoundError(
f'Resource not found on {self.provider} API: {e}'
)
elif e.response.status_code == 429:
logger.warning(f'Rate limit exceeded on {self.provider} API: {e}')
return RateLimitError('GitHub API rate limit exceeded')
logger.warning(f'Status error on {self.provider} API: {e}')
return UnknownException(f'Unknown error: {e}')
def handle_http_error(self, e: HTTPError) -> UnknownException:
logger.warning(f'HTTP error on {self.provider} API: {type(e).__name__} : {e}')
return UnknownException(f'HTTP error {type(e).__name__} : {e}')
def _determine_microagents_path(self, repository_name: str) -> str:
"""Determine the microagents directory path based on repository name."""
actual_repo_name = repository_name.split('/')[-1]
@ -462,9 +427,6 @@ class BaseGitService(ABC):
return comment_body[:max_comment_length] + '...'
return comment_body
def _has_token_expired(self, status_code: int) -> bool:
return status_code == 401
class InstallationsService(Protocol):
async def get_installations(self) -> list[str]:

View File

@ -24,7 +24,7 @@ async def test_github_service_token_handling():
assert service.token.get_secret_value() == 'test-token'
# Test headers contain the token correctly
headers = await service._get_github_headers()
headers = await service._get_headers()
assert headers['Authorization'] == 'Bearer test-token'
assert headers['Accept'] == 'application/vnd.github.v3+json'

View File

@ -0,0 +1,309 @@
"""Unit tests for HTTPClient abstract base class (ABC)."""
from typing import Any
from unittest.mock import AsyncMock, Mock
import httpx
import pytest
from pydantic import SecretStr
from openhands.integrations.protocols.http_client import HTTPClient
from openhands.integrations.service_types import (
AuthenticationError,
RateLimitError,
RequestMethod,
ResourceNotFoundError,
UnknownException,
)
class TestableHTTPClient(HTTPClient):
"""Testable concrete implementation of HTTPClient for unit testing."""
def __init__(self, provider_name: str = 'test-provider'):
self.token = SecretStr('test-token')
self.refresh = False
self.external_auth_id = None
self.external_auth_token = None
self.external_token_manager = False
self.base_domain = None
self._provider_name = provider_name
@property
def provider(self) -> str:
return self._provider_name
@provider.setter
def provider(self, value: str) -> None:
self._provider_name = value
async def get_latest_token(self) -> SecretStr | None:
return self.token
async def _get_headers(self) -> dict[str, Any]:
return {'Authorization': f'Bearer {self.token.get_secret_value()}'}
async def _make_request(
self,
url: str,
params: dict | None = None,
method: RequestMethod = RequestMethod.GET,
):
# Mock implementation for testing
return {'test': 'data'}, {}
@pytest.mark.asyncio
class TestHTTPClient:
"""Test cases for HTTPClient ABC."""
def setup_method(self):
"""Set up test fixtures."""
self.client = TestableHTTPClient()
def test_default_attributes(self):
"""Test default attribute values."""
assert isinstance(self.client.token, SecretStr)
assert self.client.refresh is False
assert self.client.external_auth_id is None
assert self.client.external_auth_token is None
assert self.client.external_token_manager is False
assert self.client.base_domain is None
def test_provider_property(self):
"""Test provider property."""
assert self.client.provider == 'test-provider'
def test_has_token_expired_default_implementation(self):
"""Test default _has_token_expired implementation."""
# The TestableHTTPClient inherits the default implementation from the protocol
client = TestableHTTPClient()
assert client._has_token_expired(401) is True
assert client._has_token_expired(200) is False
assert client._has_token_expired(404) is False
assert client._has_token_expired(500) is False
async def test_execute_request_get(self):
"""Test execute_request with GET method."""
client = TestableHTTPClient()
mock_client = AsyncMock()
mock_response = AsyncMock()
mock_client.get.return_value = mock_response
url = 'https://api.example.com/user'
headers = {'Authorization': 'Bearer token'}
params = {'per_page': 10}
result = await client.execute_request(
mock_client, url, headers, params, RequestMethod.GET
)
assert result == mock_response
mock_client.get.assert_called_once_with(url, headers=headers, params=params)
async def test_execute_request_post(self):
"""Test execute_request with POST method."""
client = TestableHTTPClient()
mock_client = AsyncMock()
mock_response = AsyncMock()
mock_client.post.return_value = mock_response
url = 'https://api.example.com/issues'
headers = {'Authorization': 'Bearer token'}
params = {'title': 'Test Issue'}
result = await client.execute_request(
mock_client, url, headers, params, RequestMethod.POST
)
assert result == mock_response
mock_client.post.assert_called_once_with(url, headers=headers, json=params)
def test_handle_http_status_error_401(self):
"""Test handling of 401 HTTP status error."""
client = TestableHTTPClient('github')
mock_response = Mock()
mock_response.status_code = 401
error = httpx.HTTPStatusError(
message='401 Unauthorized', request=Mock(), response=mock_response
)
result = client.handle_http_status_error(error)
assert isinstance(result, AuthenticationError)
assert 'Invalid github token' in str(result)
def test_handle_http_status_error_404(self):
"""Test handling of 404 HTTP status error."""
client = TestableHTTPClient()
client.provider = 'gitlab'
mock_response = Mock()
mock_response.status_code = 404
error = httpx.HTTPStatusError(
message='404 Not Found', request=Mock(), response=mock_response
)
result = client.handle_http_status_error(error)
assert isinstance(result, ResourceNotFoundError)
assert 'Resource not found on gitlab API' in str(result)
def test_handle_http_status_error_429(self):
"""Test handling of 429 HTTP status error."""
client = TestableHTTPClient()
client.provider = 'bitbucket'
mock_response = Mock()
mock_response.status_code = 429
error = httpx.HTTPStatusError(
message='429 Too Many Requests', request=Mock(), response=mock_response
)
result = client.handle_http_status_error(error)
assert isinstance(result, RateLimitError)
assert 'bitbucket API rate limit exceeded' in str(result)
def test_handle_http_status_error_other(self):
"""Test handling of other HTTP status errors."""
client = TestableHTTPClient()
client.provider = 'test-provider'
mock_response = Mock()
mock_response.status_code = 500
error = httpx.HTTPStatusError(
message='500 Internal Server Error', request=Mock(), response=mock_response
)
result = client.handle_http_status_error(error)
assert isinstance(result, UnknownException)
assert 'Unknown error' in str(result)
def test_handle_http_error(self):
"""Test handling of general HTTP errors."""
client = TestableHTTPClient()
client.provider = 'test-provider'
error = httpx.ConnectError('Connection failed')
result = client.handle_http_error(error)
assert isinstance(result, UnknownException)
assert 'HTTP error ConnectError' in str(result)
def test_handle_http_error_with_different_error_types(self):
"""Test handling of different HTTP error types."""
client = TestableHTTPClient()
client.provider = 'test-provider'
# Test with different error types
errors = [
httpx.ConnectError('Connection failed'),
httpx.TimeoutException('Request timed out'),
httpx.ReadTimeout('Read timeout'),
httpx.WriteTimeout('Write timeout'),
]
for error in errors:
result = client.handle_http_error(error)
assert isinstance(result, UnknownException)
assert f'HTTP error {type(error).__name__}' in str(result)
def test_runtime_checkable(self):
"""Test that HTTPClient is runtime checkable."""
from openhands.integrations.protocols.http_client import HTTPClient
# Test that our testable client implements the protocol
assert isinstance(self.client, HTTPClient)
# Test that a class without the required methods doesn't implement the protocol
class IncompleteClient:
pass
incomplete = IncompleteClient()
assert not isinstance(incomplete, HTTPClient)
def test_protocol_attributes_exist(self):
"""Test that protocol defines expected attributes."""
client = TestableHTTPClient()
# Test default attribute values from protocol
assert hasattr(client, 'token')
assert hasattr(client, 'refresh')
assert hasattr(client, 'external_auth_id')
assert hasattr(client, 'external_auth_token')
assert hasattr(client, 'external_token_manager')
assert hasattr(client, 'base_domain')
# Test TestableHTTPClient values
assert client.token == SecretStr('test-token')
assert client.refresh is False
assert client.external_auth_id is None
assert client.external_auth_token is None
assert client.external_token_manager is False
assert client.base_domain is None
def test_protocol_methods_exist(self):
"""Test that protocol defines expected methods."""
client = TestableHTTPClient()
# Test that methods exist
assert hasattr(client, 'get_latest_token')
assert hasattr(client, '_get_headers')
assert hasattr(client, '_make_request')
assert hasattr(client, '_has_token_expired')
assert hasattr(client, 'execute_request')
assert hasattr(client, 'handle_http_status_error')
assert hasattr(client, 'handle_http_error')
assert hasattr(client, 'provider')
def test_protocol_concrete_methods_work(self):
"""Test that concrete protocol methods work correctly."""
client = TestableHTTPClient()
# These methods should work since TestableHTTPClient implements them
assert client.provider == 'test-provider'
# Test that the default implementations from the protocol are available
assert hasattr(client, '_has_token_expired')
assert hasattr(client, 'execute_request')
assert hasattr(client, 'handle_http_status_error')
assert hasattr(client, 'handle_http_error')
def test_provider_specific_error_messages(self):
"""Test that error messages are provider-specific."""
providers = ['github', 'gitlab', 'bitbucket']
for provider in providers:
client = TestableHTTPClient()
client.provider = provider
# Test 401 error
mock_response = Mock()
mock_response.status_code = 401
error = httpx.HTTPStatusError(
message='401 Unauthorized', request=Mock(), response=mock_response
)
result = client.handle_http_status_error(error)
assert f'Invalid {provider} token' in str(result)
# Test 404 error
mock_response.status_code = 404
error = httpx.HTTPStatusError(
message='404 Not Found', request=Mock(), response=mock_response
)
result = client.handle_http_status_error(error)
assert f'Resource not found on {provider} API' in str(result)
# Test 429 error
mock_response.status_code = 429
error = httpx.HTTPStatusError(
message='429 Too Many Requests', request=Mock(), response=mock_response
)
result = client.handle_http_status_error(error)
assert f'{provider} API rate limit exceeded' in str(result)