mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
refactor: introduce HTTPClient protocol for git service integrations (#10731)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
parent
21f3ef540f
commit
3e87c08631
@ -4,6 +4,7 @@ from typing import Any
|
||||
import httpx
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.integrations.protocols.http_client import HTTPClient
|
||||
from openhands.integrations.service_types import (
|
||||
BaseGitService,
|
||||
OwnerType,
|
||||
@ -15,14 +16,12 @@ from openhands.integrations.service_types import (
|
||||
)
|
||||
|
||||
|
||||
class BitBucketMixinBase(BaseGitService):
|
||||
class BitBucketMixinBase(BaseGitService, HTTPClient):
|
||||
"""
|
||||
Base mixin for BitBucket service containing common functionality
|
||||
"""
|
||||
|
||||
BASE_URL = 'https://api.bitbucket.org/2.0'
|
||||
token: SecretStr = SecretStr('')
|
||||
refresh = False
|
||||
|
||||
def _extract_owner_and_repo(self, repository: str) -> tuple[str, str]:
|
||||
"""Extract owner and repo from repository string.
|
||||
@ -49,7 +48,7 @@ class BitBucketMixinBase(BaseGitService):
|
||||
def _has_token_expired(self, status_code: int) -> bool:
|
||||
return status_code == 401
|
||||
|
||||
async def _get_bitbucket_headers(self) -> dict[str, str]:
|
||||
async def _get_headers(self) -> dict[str, str]:
|
||||
"""Get headers for Bitbucket API requests."""
|
||||
token_value = self.token.get_secret_value()
|
||||
|
||||
@ -85,13 +84,13 @@ class BitBucketMixinBase(BaseGitService):
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
bitbucket_headers = await self._get_bitbucket_headers()
|
||||
bitbucket_headers = await self._get_headers()
|
||||
response = await self.execute_request(
|
||||
client, url, bitbucket_headers, params, method
|
||||
)
|
||||
if self.refresh and self._has_token_expired(response.status_code):
|
||||
await self.get_latest_token()
|
||||
bitbucket_headers = await self._get_bitbucket_headers()
|
||||
bitbucket_headers = await self._get_headers()
|
||||
response = await self.execute_request(
|
||||
client=client,
|
||||
url=url,
|
||||
|
||||
@ -43,8 +43,6 @@ class GitHubService(
|
||||
|
||||
BASE_URL = 'https://api.github.com'
|
||||
GRAPHQL_URL = 'https://api.github.com/graphql'
|
||||
token: SecretStr = SecretStr('')
|
||||
refresh = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
# openhands/integrations/github/service/__init__.py
|
||||
|
||||
from .base import GitHubMixinBase
|
||||
from .branches_prs import GitHubBranchesMixin
|
||||
from .features import GitHubFeaturesMixin
|
||||
from .prs import GitHubPRsMixin
|
||||
@ -7,6 +8,7 @@ from .repos import GitHubReposMixin
|
||||
from .resolver import GitHubResolverMixin
|
||||
|
||||
__all__ = [
|
||||
'GitHubMixinBase',
|
||||
'GitHubBranchesMixin',
|
||||
'GitHubFeaturesMixin',
|
||||
'GitHubPRsMixin',
|
||||
|
||||
@ -4,6 +4,7 @@ from typing import Any, cast
|
||||
import httpx
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.integrations.protocols.http_client import HTTPClient
|
||||
from openhands.integrations.service_types import (
|
||||
BaseGitService,
|
||||
RequestMethod,
|
||||
@ -12,19 +13,15 @@ from openhands.integrations.service_types import (
|
||||
)
|
||||
|
||||
|
||||
class GitHubMixinBase(BaseGitService):
|
||||
class GitHubMixinBase(BaseGitService, HTTPClient):
|
||||
"""
|
||||
Declares common attributes and method signatures used across mixins.
|
||||
"""
|
||||
|
||||
BASE_URL: str
|
||||
GRAPHQL_URL: str
|
||||
token: SecretStr
|
||||
refresh: bool
|
||||
external_auth_id: str | None
|
||||
base_domain: str | None
|
||||
|
||||
async def _get_github_headers(self) -> dict:
|
||||
async def _get_headers(self) -> dict:
|
||||
"""Retrieve the GH Token from settings store to construct the headers."""
|
||||
if not self.token:
|
||||
latest_token = await self.get_latest_token()
|
||||
@ -47,7 +44,7 @@ class GitHubMixinBase(BaseGitService):
|
||||
) -> tuple[Any, dict]: # type: ignore[override]
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
github_headers = await self._get_github_headers()
|
||||
github_headers = await self._get_headers()
|
||||
|
||||
# Make initial request
|
||||
response = await self.execute_request(
|
||||
@ -61,7 +58,7 @@ class GitHubMixinBase(BaseGitService):
|
||||
# Handle token refresh if needed
|
||||
if self.refresh and self._has_token_expired(response.status_code):
|
||||
await self.get_latest_token()
|
||||
github_headers = await self._get_github_headers()
|
||||
github_headers = await self._get_headers()
|
||||
response = await self.execute_request(
|
||||
client=client,
|
||||
url=url,
|
||||
@ -87,7 +84,7 @@ class GitHubMixinBase(BaseGitService):
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
github_headers = await self._get_github_headers()
|
||||
github_headers = await self._get_headers()
|
||||
|
||||
response = await client.post(
|
||||
self.GRAPHQL_URL,
|
||||
|
||||
@ -41,8 +41,6 @@ class GitLabService(
|
||||
|
||||
BASE_URL = 'https://gitlab.com/api/v4'
|
||||
GRAPHQL_URL = 'https://gitlab.com/api/graphql'
|
||||
token: SecretStr = SecretStr('')
|
||||
refresh = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
# openhands/integrations/gitlab/service/__init__.py
|
||||
|
||||
from .base import GitLabMixinBase
|
||||
from .branches import GitLabBranchesMixin
|
||||
from .features import GitLabFeaturesMixin
|
||||
from .prs import GitLabPRsMixin
|
||||
@ -7,6 +8,7 @@ from .repos import GitLabReposMixin
|
||||
from .resolver import GitLabResolverMixin
|
||||
|
||||
__all__ = [
|
||||
'GitLabMixinBase',
|
||||
'GitLabBranchesMixin',
|
||||
'GitLabFeaturesMixin',
|
||||
'GitLabPRsMixin',
|
||||
|
||||
@ -3,6 +3,7 @@ from typing import Any
|
||||
import httpx
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.integrations.protocols.http_client import HTTPClient
|
||||
from openhands.integrations.service_types import (
|
||||
BaseGitService,
|
||||
RequestMethod,
|
||||
@ -11,19 +12,15 @@ from openhands.integrations.service_types import (
|
||||
)
|
||||
|
||||
|
||||
class GitLabMixinBase(BaseGitService):
|
||||
class GitLabMixinBase(BaseGitService, HTTPClient):
|
||||
"""
|
||||
Declares common attributes and method signatures used across mixins.
|
||||
"""
|
||||
|
||||
BASE_URL: str
|
||||
GRAPHQL_URL: str
|
||||
token: SecretStr
|
||||
refresh: bool
|
||||
external_auth_id: str | None
|
||||
base_domain: str | None
|
||||
|
||||
async def _get_gitlab_headers(self) -> dict[str, Any]:
|
||||
async def _get_headers(self) -> dict[str, Any]:
|
||||
"""Retrieve the GitLab Token to construct the headers"""
|
||||
if not self.token:
|
||||
latest_token = await self.get_latest_token()
|
||||
@ -45,7 +42,7 @@ class GitLabMixinBase(BaseGitService):
|
||||
) -> tuple[Any, dict]: # type: ignore[override]
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
gitlab_headers = await self._get_gitlab_headers()
|
||||
gitlab_headers = await self._get_headers()
|
||||
|
||||
# Make initial request
|
||||
response = await self.execute_request(
|
||||
@ -59,7 +56,7 @@ class GitLabMixinBase(BaseGitService):
|
||||
# Handle token refresh if needed
|
||||
if self.refresh and self._has_token_expired(response.status_code):
|
||||
await self.get_latest_token()
|
||||
gitlab_headers = await self._get_gitlab_headers()
|
||||
gitlab_headers = await self._get_headers()
|
||||
response = await self.execute_request(
|
||||
client=client,
|
||||
url=url,
|
||||
@ -103,7 +100,7 @@ class GitLabMixinBase(BaseGitService):
|
||||
variables = {}
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
gitlab_headers = await self._get_gitlab_headers()
|
||||
gitlab_headers = await self._get_headers()
|
||||
# Add content type header for GraphQL
|
||||
gitlab_headers['Content-Type'] = 'application/json'
|
||||
|
||||
@ -118,7 +115,7 @@ class GitLabMixinBase(BaseGitService):
|
||||
|
||||
if self.refresh and self._has_token_expired(response.status_code):
|
||||
await self.get_latest_token()
|
||||
gitlab_headers = await self._get_gitlab_headers()
|
||||
gitlab_headers = await self._get_headers()
|
||||
gitlab_headers['Content-Type'] = 'application/json'
|
||||
response = await client.post(
|
||||
self.GRAPHQL_URL, headers=gitlab_headers, json=payload
|
||||
|
||||
99
openhands/integrations/protocols/http_client.py
Normal file
99
openhands/integrations/protocols/http_client.py
Normal file
@ -0,0 +1,99 @@
|
||||
"""HTTP Client Protocol for Git Service Integrations."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from httpx import AsyncClient, HTTPError, HTTPStatusError
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.service_types import (
|
||||
AuthenticationError,
|
||||
RateLimitError,
|
||||
RequestMethod,
|
||||
ResourceNotFoundError,
|
||||
UnknownException,
|
||||
)
|
||||
|
||||
|
||||
class HTTPClient(ABC):
|
||||
"""Abstract base class defining the HTTP client interface for Git service integrations.
|
||||
|
||||
This class abstracts the common HTTP client functionality needed by all
|
||||
Git service providers (GitHub, GitLab, BitBucket) while keeping inheritance in place.
|
||||
"""
|
||||
|
||||
# Default attributes (subclasses may override)
|
||||
token: SecretStr = SecretStr('')
|
||||
refresh: bool = False
|
||||
external_auth_id: str | None = None
|
||||
external_auth_token: SecretStr | None = None
|
||||
external_token_manager: bool = False
|
||||
base_domain: str | None = None
|
||||
|
||||
# Provider identification must be implemented by subclasses
|
||||
@property
|
||||
@abstractmethod
|
||||
def provider(self) -> str: ...
|
||||
|
||||
# Abstract methods that concrete classes must implement
|
||||
@abstractmethod
|
||||
async def get_latest_token(self) -> SecretStr | None:
|
||||
"""Get the latest working token for the service."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def _get_headers(self) -> dict[str, Any]:
|
||||
"""Get HTTP headers for API requests."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def _make_request(
|
||||
self,
|
||||
url: str,
|
||||
params: dict | None = None,
|
||||
method: RequestMethod = RequestMethod.GET,
|
||||
) -> tuple[Any, dict]:
|
||||
"""Make an HTTP request to the Git service API."""
|
||||
...
|
||||
|
||||
def _has_token_expired(self, status_code: int) -> bool:
|
||||
"""Check if the token has expired based on HTTP status code."""
|
||||
return status_code == 401
|
||||
|
||||
async def execute_request(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
url: str,
|
||||
headers: dict,
|
||||
params: dict | None,
|
||||
method: RequestMethod = RequestMethod.GET,
|
||||
):
|
||||
"""Execute an HTTP request using the provided client."""
|
||||
if method == RequestMethod.POST:
|
||||
return await client.post(url, headers=headers, json=params)
|
||||
return await client.get(url, headers=headers, params=params)
|
||||
|
||||
def handle_http_status_error(
|
||||
self, e: HTTPStatusError
|
||||
) -> (
|
||||
AuthenticationError | RateLimitError | ResourceNotFoundError | UnknownException
|
||||
):
|
||||
"""Handle HTTP status errors and convert them to appropriate exceptions."""
|
||||
if e.response.status_code == 401:
|
||||
return AuthenticationError(f'Invalid {self.provider} token')
|
||||
elif e.response.status_code == 404:
|
||||
return ResourceNotFoundError(
|
||||
f'Resource not found on {self.provider} API: {e}'
|
||||
)
|
||||
elif e.response.status_code == 429:
|
||||
logger.warning(f'Rate limit exceeded on {self.provider} API: {e}')
|
||||
return RateLimitError(f'{self.provider} API rate limit exceeded')
|
||||
|
||||
logger.warning(f'Status error on {self.provider} API: {e}')
|
||||
return UnknownException(f'Unknown error: {e}')
|
||||
|
||||
def handle_http_error(self, e: HTTPError) -> UnknownException:
|
||||
"""Handle general HTTP errors."""
|
||||
logger.warning(f'HTTP error on {self.provider} API: {type(e).__name__} : {e}')
|
||||
return UnknownException(f'HTTP error {type(e).__name__} : {e}')
|
||||
@ -4,7 +4,6 @@ from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Protocol
|
||||
|
||||
from httpx import AsyncClient, HTTPError, HTTPStatusError
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from pydantic import BaseModel, SecretStr
|
||||
|
||||
@ -242,40 +241,6 @@ class BaseGitService(ABC):
|
||||
"""Extract file path from directory item."""
|
||||
...
|
||||
|
||||
async def execute_request(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
url: str,
|
||||
headers: dict,
|
||||
params: dict | None,
|
||||
method: RequestMethod = RequestMethod.GET,
|
||||
):
|
||||
if method == RequestMethod.POST:
|
||||
return await client.post(url, headers=headers, json=params)
|
||||
return await client.get(url, headers=headers, params=params)
|
||||
|
||||
def handle_http_status_error(
|
||||
self, e: HTTPStatusError
|
||||
) -> (
|
||||
AuthenticationError | RateLimitError | ResourceNotFoundError | UnknownException
|
||||
):
|
||||
if e.response.status_code == 401:
|
||||
return AuthenticationError(f'Invalid {self.provider} token')
|
||||
elif e.response.status_code == 404:
|
||||
return ResourceNotFoundError(
|
||||
f'Resource not found on {self.provider} API: {e}'
|
||||
)
|
||||
elif e.response.status_code == 429:
|
||||
logger.warning(f'Rate limit exceeded on {self.provider} API: {e}')
|
||||
return RateLimitError('GitHub API rate limit exceeded')
|
||||
|
||||
logger.warning(f'Status error on {self.provider} API: {e}')
|
||||
return UnknownException(f'Unknown error: {e}')
|
||||
|
||||
def handle_http_error(self, e: HTTPError) -> UnknownException:
|
||||
logger.warning(f'HTTP error on {self.provider} API: {type(e).__name__} : {e}')
|
||||
return UnknownException(f'HTTP error {type(e).__name__} : {e}')
|
||||
|
||||
def _determine_microagents_path(self, repository_name: str) -> str:
|
||||
"""Determine the microagents directory path based on repository name."""
|
||||
actual_repo_name = repository_name.split('/')[-1]
|
||||
@ -462,9 +427,6 @@ class BaseGitService(ABC):
|
||||
return comment_body[:max_comment_length] + '...'
|
||||
return comment_body
|
||||
|
||||
def _has_token_expired(self, status_code: int) -> bool:
|
||||
return status_code == 401
|
||||
|
||||
|
||||
class InstallationsService(Protocol):
|
||||
async def get_installations(self) -> list[str]:
|
||||
|
||||
@ -24,7 +24,7 @@ async def test_github_service_token_handling():
|
||||
assert service.token.get_secret_value() == 'test-token'
|
||||
|
||||
# Test headers contain the token correctly
|
||||
headers = await service._get_github_headers()
|
||||
headers = await service._get_headers()
|
||||
assert headers['Authorization'] == 'Bearer test-token'
|
||||
assert headers['Accept'] == 'application/vnd.github.v3+json'
|
||||
|
||||
|
||||
309
tests/unit/integrations/protocols/test_http_client.py
Normal file
309
tests/unit/integrations/protocols/test_http_client.py
Normal file
@ -0,0 +1,309 @@
|
||||
"""Unit tests for HTTPClient abstract base class (ABC)."""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.integrations.protocols.http_client import HTTPClient
|
||||
from openhands.integrations.service_types import (
|
||||
AuthenticationError,
|
||||
RateLimitError,
|
||||
RequestMethod,
|
||||
ResourceNotFoundError,
|
||||
UnknownException,
|
||||
)
|
||||
|
||||
|
||||
class TestableHTTPClient(HTTPClient):
|
||||
"""Testable concrete implementation of HTTPClient for unit testing."""
|
||||
|
||||
def __init__(self, provider_name: str = 'test-provider'):
|
||||
self.token = SecretStr('test-token')
|
||||
self.refresh = False
|
||||
self.external_auth_id = None
|
||||
self.external_auth_token = None
|
||||
self.external_token_manager = False
|
||||
self.base_domain = None
|
||||
self._provider_name = provider_name
|
||||
|
||||
@property
|
||||
def provider(self) -> str:
|
||||
return self._provider_name
|
||||
|
||||
@provider.setter
|
||||
def provider(self, value: str) -> None:
|
||||
self._provider_name = value
|
||||
|
||||
async def get_latest_token(self) -> SecretStr | None:
|
||||
return self.token
|
||||
|
||||
async def _get_headers(self) -> dict[str, Any]:
|
||||
return {'Authorization': f'Bearer {self.token.get_secret_value()}'}
|
||||
|
||||
async def _make_request(
|
||||
self,
|
||||
url: str,
|
||||
params: dict | None = None,
|
||||
method: RequestMethod = RequestMethod.GET,
|
||||
):
|
||||
# Mock implementation for testing
|
||||
return {'test': 'data'}, {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestHTTPClient:
|
||||
"""Test cases for HTTPClient ABC."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
self.client = TestableHTTPClient()
|
||||
|
||||
def test_default_attributes(self):
|
||||
"""Test default attribute values."""
|
||||
assert isinstance(self.client.token, SecretStr)
|
||||
assert self.client.refresh is False
|
||||
assert self.client.external_auth_id is None
|
||||
assert self.client.external_auth_token is None
|
||||
assert self.client.external_token_manager is False
|
||||
assert self.client.base_domain is None
|
||||
|
||||
def test_provider_property(self):
|
||||
"""Test provider property."""
|
||||
assert self.client.provider == 'test-provider'
|
||||
|
||||
def test_has_token_expired_default_implementation(self):
|
||||
"""Test default _has_token_expired implementation."""
|
||||
# The TestableHTTPClient inherits the default implementation from the protocol
|
||||
client = TestableHTTPClient()
|
||||
|
||||
assert client._has_token_expired(401) is True
|
||||
assert client._has_token_expired(200) is False
|
||||
assert client._has_token_expired(404) is False
|
||||
assert client._has_token_expired(500) is False
|
||||
|
||||
async def test_execute_request_get(self):
|
||||
"""Test execute_request with GET method."""
|
||||
client = TestableHTTPClient()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_response = AsyncMock()
|
||||
mock_client.get.return_value = mock_response
|
||||
|
||||
url = 'https://api.example.com/user'
|
||||
headers = {'Authorization': 'Bearer token'}
|
||||
params = {'per_page': 10}
|
||||
|
||||
result = await client.execute_request(
|
||||
mock_client, url, headers, params, RequestMethod.GET
|
||||
)
|
||||
|
||||
assert result == mock_response
|
||||
mock_client.get.assert_called_once_with(url, headers=headers, params=params)
|
||||
|
||||
async def test_execute_request_post(self):
|
||||
"""Test execute_request with POST method."""
|
||||
client = TestableHTTPClient()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_response = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
url = 'https://api.example.com/issues'
|
||||
headers = {'Authorization': 'Bearer token'}
|
||||
params = {'title': 'Test Issue'}
|
||||
|
||||
result = await client.execute_request(
|
||||
mock_client, url, headers, params, RequestMethod.POST
|
||||
)
|
||||
|
||||
assert result == mock_response
|
||||
mock_client.post.assert_called_once_with(url, headers=headers, json=params)
|
||||
|
||||
def test_handle_http_status_error_401(self):
|
||||
"""Test handling of 401 HTTP status error."""
|
||||
client = TestableHTTPClient('github')
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 401
|
||||
|
||||
error = httpx.HTTPStatusError(
|
||||
message='401 Unauthorized', request=Mock(), response=mock_response
|
||||
)
|
||||
|
||||
result = client.handle_http_status_error(error)
|
||||
assert isinstance(result, AuthenticationError)
|
||||
assert 'Invalid github token' in str(result)
|
||||
|
||||
def test_handle_http_status_error_404(self):
|
||||
"""Test handling of 404 HTTP status error."""
|
||||
client = TestableHTTPClient()
|
||||
client.provider = 'gitlab'
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 404
|
||||
|
||||
error = httpx.HTTPStatusError(
|
||||
message='404 Not Found', request=Mock(), response=mock_response
|
||||
)
|
||||
|
||||
result = client.handle_http_status_error(error)
|
||||
assert isinstance(result, ResourceNotFoundError)
|
||||
assert 'Resource not found on gitlab API' in str(result)
|
||||
|
||||
def test_handle_http_status_error_429(self):
|
||||
"""Test handling of 429 HTTP status error."""
|
||||
client = TestableHTTPClient()
|
||||
client.provider = 'bitbucket'
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 429
|
||||
|
||||
error = httpx.HTTPStatusError(
|
||||
message='429 Too Many Requests', request=Mock(), response=mock_response
|
||||
)
|
||||
|
||||
result = client.handle_http_status_error(error)
|
||||
assert isinstance(result, RateLimitError)
|
||||
assert 'bitbucket API rate limit exceeded' in str(result)
|
||||
|
||||
def test_handle_http_status_error_other(self):
|
||||
"""Test handling of other HTTP status errors."""
|
||||
client = TestableHTTPClient()
|
||||
client.provider = 'test-provider'
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 500
|
||||
|
||||
error = httpx.HTTPStatusError(
|
||||
message='500 Internal Server Error', request=Mock(), response=mock_response
|
||||
)
|
||||
|
||||
result = client.handle_http_status_error(error)
|
||||
assert isinstance(result, UnknownException)
|
||||
assert 'Unknown error' in str(result)
|
||||
|
||||
def test_handle_http_error(self):
|
||||
"""Test handling of general HTTP errors."""
|
||||
client = TestableHTTPClient()
|
||||
client.provider = 'test-provider'
|
||||
|
||||
error = httpx.ConnectError('Connection failed')
|
||||
|
||||
result = client.handle_http_error(error)
|
||||
assert isinstance(result, UnknownException)
|
||||
assert 'HTTP error ConnectError' in str(result)
|
||||
|
||||
def test_handle_http_error_with_different_error_types(self):
|
||||
"""Test handling of different HTTP error types."""
|
||||
client = TestableHTTPClient()
|
||||
client.provider = 'test-provider'
|
||||
|
||||
# Test with different error types
|
||||
errors = [
|
||||
httpx.ConnectError('Connection failed'),
|
||||
httpx.TimeoutException('Request timed out'),
|
||||
httpx.ReadTimeout('Read timeout'),
|
||||
httpx.WriteTimeout('Write timeout'),
|
||||
]
|
||||
|
||||
for error in errors:
|
||||
result = client.handle_http_error(error)
|
||||
assert isinstance(result, UnknownException)
|
||||
assert f'HTTP error {type(error).__name__}' in str(result)
|
||||
|
||||
def test_runtime_checkable(self):
|
||||
"""Test that HTTPClient is runtime checkable."""
|
||||
from openhands.integrations.protocols.http_client import HTTPClient
|
||||
|
||||
# Test that our testable client implements the protocol
|
||||
assert isinstance(self.client, HTTPClient)
|
||||
|
||||
# Test that a class without the required methods doesn't implement the protocol
|
||||
class IncompleteClient:
|
||||
pass
|
||||
|
||||
incomplete = IncompleteClient()
|
||||
assert not isinstance(incomplete, HTTPClient)
|
||||
|
||||
def test_protocol_attributes_exist(self):
|
||||
"""Test that protocol defines expected attributes."""
|
||||
client = TestableHTTPClient()
|
||||
|
||||
# Test default attribute values from protocol
|
||||
assert hasattr(client, 'token')
|
||||
assert hasattr(client, 'refresh')
|
||||
assert hasattr(client, 'external_auth_id')
|
||||
assert hasattr(client, 'external_auth_token')
|
||||
assert hasattr(client, 'external_token_manager')
|
||||
assert hasattr(client, 'base_domain')
|
||||
|
||||
# Test TestableHTTPClient values
|
||||
assert client.token == SecretStr('test-token')
|
||||
assert client.refresh is False
|
||||
assert client.external_auth_id is None
|
||||
assert client.external_auth_token is None
|
||||
assert client.external_token_manager is False
|
||||
assert client.base_domain is None
|
||||
|
||||
def test_protocol_methods_exist(self):
|
||||
"""Test that protocol defines expected methods."""
|
||||
client = TestableHTTPClient()
|
||||
|
||||
# Test that methods exist
|
||||
assert hasattr(client, 'get_latest_token')
|
||||
assert hasattr(client, '_get_headers')
|
||||
assert hasattr(client, '_make_request')
|
||||
assert hasattr(client, '_has_token_expired')
|
||||
assert hasattr(client, 'execute_request')
|
||||
assert hasattr(client, 'handle_http_status_error')
|
||||
assert hasattr(client, 'handle_http_error')
|
||||
assert hasattr(client, 'provider')
|
||||
|
||||
def test_protocol_concrete_methods_work(self):
|
||||
"""Test that concrete protocol methods work correctly."""
|
||||
client = TestableHTTPClient()
|
||||
|
||||
# These methods should work since TestableHTTPClient implements them
|
||||
assert client.provider == 'test-provider'
|
||||
|
||||
# Test that the default implementations from the protocol are available
|
||||
assert hasattr(client, '_has_token_expired')
|
||||
assert hasattr(client, 'execute_request')
|
||||
assert hasattr(client, 'handle_http_status_error')
|
||||
assert hasattr(client, 'handle_http_error')
|
||||
|
||||
def test_provider_specific_error_messages(self):
|
||||
"""Test that error messages are provider-specific."""
|
||||
providers = ['github', 'gitlab', 'bitbucket']
|
||||
|
||||
for provider in providers:
|
||||
client = TestableHTTPClient()
|
||||
client.provider = provider
|
||||
|
||||
# Test 401 error
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 401
|
||||
error = httpx.HTTPStatusError(
|
||||
message='401 Unauthorized', request=Mock(), response=mock_response
|
||||
)
|
||||
result = client.handle_http_status_error(error)
|
||||
assert f'Invalid {provider} token' in str(result)
|
||||
|
||||
# Test 404 error
|
||||
mock_response.status_code = 404
|
||||
error = httpx.HTTPStatusError(
|
||||
message='404 Not Found', request=Mock(), response=mock_response
|
||||
)
|
||||
result = client.handle_http_status_error(error)
|
||||
assert f'Resource not found on {provider} API' in str(result)
|
||||
|
||||
# Test 429 error
|
||||
mock_response.status_code = 429
|
||||
error = httpx.HTTPStatusError(
|
||||
message='429 Too Many Requests', request=Mock(), response=mock_response
|
||||
)
|
||||
result = client.handle_http_status_error(error)
|
||||
assert f'{provider} API rate limit exceeded' in str(result)
|
||||
Loading…
x
Reference in New Issue
Block a user