mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 13:47:19 +08:00
Add HTTP method option to Git service fetch_data functions (#7996)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
@@ -8,9 +8,11 @@ from pydantic import SecretStr
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.service_types import (
|
||||
AuthenticationError,
|
||||
BaseGitService,
|
||||
GitService,
|
||||
ProviderType,
|
||||
Repository,
|
||||
RequestMethod,
|
||||
SuggestedTask,
|
||||
TaskType,
|
||||
UnknownException,
|
||||
@@ -20,7 +22,7 @@ from openhands.server.types import AppMode
|
||||
from openhands.utils.import_utils import get_impl
|
||||
|
||||
|
||||
class GitHubService(GitService):
|
||||
class GitHubService(BaseGitService, GitService):
|
||||
BASE_URL = 'https://api.github.com'
|
||||
token: SecretStr = SecretStr('')
|
||||
refresh = False
|
||||
@@ -59,18 +61,35 @@ class GitHubService(GitService):
|
||||
async def get_latest_token(self) -> SecretStr | None:
|
||||
return self.token
|
||||
|
||||
async def _fetch_data(
|
||||
self, url: str, params: dict | None = None
|
||||
async def _make_request(
|
||||
self,
|
||||
url: str,
|
||||
params: dict | None = None,
|
||||
method: RequestMethod = RequestMethod.GET,
|
||||
) -> tuple[Any, dict]:
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
github_headers = await self._get_github_headers()
|
||||
response = await client.get(url, headers=github_headers, params=params)
|
||||
|
||||
# Make initial request
|
||||
response = await self.execute_request(
|
||||
client=client,
|
||||
url=url,
|
||||
headers=github_headers,
|
||||
params=params,
|
||||
method=method,
|
||||
)
|
||||
|
||||
# Handle token refresh if needed
|
||||
if self.refresh and self._has_token_expired(response.status_code):
|
||||
await self.get_latest_token()
|
||||
github_headers = await self._get_github_headers()
|
||||
response = await client.get(
|
||||
url, headers=github_headers, params=params
|
||||
response = await self.execute_request(
|
||||
client=client,
|
||||
url=url,
|
||||
headers=github_headers,
|
||||
params=params,
|
||||
method=method,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
@@ -93,7 +112,7 @@ class GitHubService(GitService):
|
||||
|
||||
async def get_user(self) -> User:
|
||||
url = f'{self.BASE_URL}/user'
|
||||
response, _ = await self._fetch_data(url)
|
||||
response, _ = await self._make_request(url)
|
||||
|
||||
return User(
|
||||
id=response.get('id'),
|
||||
@@ -124,7 +143,7 @@ class GitHubService(GitService):
|
||||
|
||||
while len(repos) < max_repos:
|
||||
page_params = {**params, 'page': str(page)}
|
||||
response, headers = await self._fetch_data(url, page_params)
|
||||
response, headers = await self._make_request(url, page_params)
|
||||
|
||||
# Extract repositories from response
|
||||
page_repos = response.get(extract_key, []) if extract_key else response
|
||||
@@ -190,7 +209,7 @@ class GitHubService(GitService):
|
||||
|
||||
async def get_installation_ids(self) -> list[int]:
|
||||
url = f'{self.BASE_URL}/user/installations'
|
||||
response, _ = await self._fetch_data(url)
|
||||
response, _ = await self._make_request(url)
|
||||
installations = response.get('installations', [])
|
||||
return [i['id'] for i in installations]
|
||||
|
||||
@@ -207,7 +226,7 @@ class GitHubService(GitService):
|
||||
'order': order,
|
||||
}
|
||||
|
||||
response, _ = await self._fetch_data(url, params)
|
||||
response, _ = await self._make_request(url, params)
|
||||
repo_items = response.get('items', [])
|
||||
|
||||
repos = [
|
||||
|
||||
@@ -7,9 +7,11 @@ from pydantic import SecretStr
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.service_types import (
|
||||
AuthenticationError,
|
||||
BaseGitService,
|
||||
GitService,
|
||||
ProviderType,
|
||||
Repository,
|
||||
RequestMethod,
|
||||
UnknownException,
|
||||
User,
|
||||
)
|
||||
@@ -17,7 +19,7 @@ from openhands.server.types import AppMode
|
||||
from openhands.utils.import_utils import get_impl
|
||||
|
||||
|
||||
class GitLabService(GitService):
|
||||
class GitLabService(BaseGitService, GitService):
|
||||
BASE_URL = 'https://gitlab.com/api/v4'
|
||||
GRAPHQL_URL = 'https://gitlab.com/api/graphql'
|
||||
token: SecretStr = SecretStr('')
|
||||
@@ -59,18 +61,35 @@ class GitLabService(GitService):
|
||||
async def get_latest_token(self) -> SecretStr | None:
|
||||
return self.token
|
||||
|
||||
async def _fetch_data(
|
||||
self, url: str, params: dict | None = None
|
||||
async def _make_request(
|
||||
self,
|
||||
url: str,
|
||||
params: dict | None = None,
|
||||
method: RequestMethod = RequestMethod.GET,
|
||||
) -> tuple[Any, dict]:
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
gitlab_headers = await self._get_gitlab_headers()
|
||||
response = await client.get(url, headers=gitlab_headers, params=params)
|
||||
|
||||
# Make initial request
|
||||
response = await self.execute_request(
|
||||
client=client,
|
||||
url=url,
|
||||
headers=gitlab_headers,
|
||||
params=params,
|
||||
method=method,
|
||||
)
|
||||
|
||||
# Handle token refresh if needed
|
||||
if self.refresh and self._has_token_expired(response.status_code):
|
||||
await self.get_latest_token()
|
||||
gitlab_headers = await self._get_gitlab_headers()
|
||||
response = await client.get(
|
||||
url, headers=gitlab_headers, params=params
|
||||
response = await self.execute_request(
|
||||
client=client,
|
||||
url=url,
|
||||
headers=gitlab_headers,
|
||||
params=params,
|
||||
method=method,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
@@ -149,7 +168,7 @@ class GitLabService(GitService):
|
||||
|
||||
async def get_user(self) -> User:
|
||||
url = f'{self.BASE_URL}/user'
|
||||
response, _ = await self._fetch_data(url)
|
||||
response, _ = await self._make_request(url)
|
||||
|
||||
return User(
|
||||
id=response.get('id'),
|
||||
@@ -173,7 +192,7 @@ class GitLabService(GitService):
|
||||
'visibility': 'public',
|
||||
}
|
||||
|
||||
response, _ = await self._fetch_data(url, params)
|
||||
response, _ = await self._make_request(url, params)
|
||||
repos = [
|
||||
Repository(
|
||||
id=repo.get('id'),
|
||||
@@ -209,7 +228,7 @@ class GitLabService(GitService):
|
||||
'sort': 'desc', # GitLab uses sort for direction (asc/desc)
|
||||
'membership': 1, # Use 1 instead of True
|
||||
}
|
||||
response, headers = await self._fetch_data(url, params)
|
||||
response, headers = await self._make_request(url, params)
|
||||
|
||||
if not response: # No more repositories
|
||||
break
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from enum import Enum
|
||||
from typing import Protocol
|
||||
|
||||
from httpx import AsyncClient
|
||||
from pydantic import BaseModel, SecretStr
|
||||
|
||||
from openhands.server.types import AppMode
|
||||
@@ -57,6 +58,25 @@ class UnknownException(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class RequestMethod(Enum):
|
||||
POST = 'post'
|
||||
GET = 'get'
|
||||
|
||||
|
||||
class BaseGitService:
|
||||
async def execute_request(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
url: str,
|
||||
headers: dict,
|
||||
params: dict | None,
|
||||
method: RequestMethod = RequestMethod.GET,
|
||||
):
|
||||
if method == RequestMethod.POST:
|
||||
return await client.post(url, headers=headers, json=params)
|
||||
return await client.get(url, headers=headers, params=params)
|
||||
|
||||
|
||||
class GitService(Protocol):
|
||||
"""Protocol defining the interface for Git service providers"""
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ async def test_github_service_fetch_data():
|
||||
|
||||
with patch('httpx.AsyncClient', return_value=mock_client):
|
||||
service = GitHubService(user_id=None, token=SecretStr('test-token'))
|
||||
_ = await service._fetch_data('https://api.github.com/user')
|
||||
_ = await service._make_request('https://api.github.com/user')
|
||||
|
||||
# Verify the request was made with correct headers
|
||||
mock_client.get.assert_called_once()
|
||||
@@ -78,4 +78,4 @@ async def test_github_service_fetch_data():
|
||||
mock_client.get.return_value = mock_response
|
||||
|
||||
with pytest.raises(AuthenticationError):
|
||||
_ = await service._fetch_data('https://api.github.com/user')
|
||||
_ = await service._make_request('https://api.github.com/user')
|
||||
|
||||
Reference in New Issue
Block a user