mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
[Fix]: Dedup token verification logic in resolver (#7967)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
parent
300a59853b
commit
1e509a70d4
@ -98,9 +98,9 @@ class SandboxConfig(BaseModel):
|
||||
raise ValueError(f'Invalid sandbox configuration: {e}')
|
||||
|
||||
return sandbox_mapping
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_default_base_image(self) -> "SandboxConfig":
|
||||
|
||||
@model_validator(mode='after')
|
||||
def set_default_base_image(self) -> 'SandboxConfig':
|
||||
if self.base_container_image is None:
|
||||
self.base_container_image = 'nikolaik/python-nodejs:python3.12-nodejs22'
|
||||
return self
|
||||
|
||||
@ -32,6 +32,7 @@ class GitHubService(GitService):
|
||||
external_auth_token: SecretStr | None = None,
|
||||
token: SecretStr | None = None,
|
||||
external_token_manager: bool = False,
|
||||
base_domain: str | None = None,
|
||||
):
|
||||
self.user_id = user_id
|
||||
self.external_token_manager = external_token_manager
|
||||
@ -39,6 +40,9 @@ class GitHubService(GitService):
|
||||
if token:
|
||||
self.token = token
|
||||
|
||||
if base_domain:
|
||||
self.BASE_URL = f'https://{base_domain}/api/v3'
|
||||
|
||||
async def _get_github_headers(self) -> dict:
|
||||
"""Retrieve the GH Token from settings store to construct the headers."""
|
||||
if self.user_id and not self.token:
|
||||
|
||||
@ -30,6 +30,7 @@ class GitLabService(GitService):
|
||||
external_auth_token: SecretStr | None = None,
|
||||
token: SecretStr | None = None,
|
||||
external_token_manager: bool = False,
|
||||
base_domain: str | None = None,
|
||||
):
|
||||
self.user_id = user_id
|
||||
self.external_token_manager = external_token_manager
|
||||
@ -37,6 +38,10 @@ class GitLabService(GitService):
|
||||
if token:
|
||||
self.token = token
|
||||
|
||||
if base_domain:
|
||||
self.BASE_URL = f'https://{base_domain}/api/v4'
|
||||
self.GRAPHQL_URL = f'https://{base_domain}/api/graphql'
|
||||
|
||||
async def _get_gitlab_headers(self) -> dict[str, Any]:
|
||||
"""
|
||||
Retrieve the GitLab Token to construct the headers
|
||||
|
||||
@ -5,7 +5,9 @@ from openhands.integrations.gitlab.gitlab_service import GitLabService
|
||||
from openhands.integrations.provider import ProviderType
|
||||
|
||||
|
||||
async def validate_provider_token(token: SecretStr) -> ProviderType | None:
|
||||
async def validate_provider_token(
|
||||
token: SecretStr, base_domain: str | None = None
|
||||
) -> ProviderType | None:
|
||||
"""
|
||||
Determine whether a token is for GitHub or GitLab by attempting to get user info
|
||||
from both services.
|
||||
@ -20,7 +22,7 @@ async def validate_provider_token(token: SecretStr) -> ProviderType | None:
|
||||
"""
|
||||
# Try GitHub first
|
||||
try:
|
||||
github_service = GitHubService(token=token)
|
||||
github_service = GitHubService(token=token, base_domain=base_domain)
|
||||
await github_service.get_user()
|
||||
return ProviderType.GITHUB
|
||||
except Exception:
|
||||
@ -28,7 +30,7 @@ async def validate_provider_token(token: SecretStr) -> ProviderType | None:
|
||||
|
||||
# Try GitLab next
|
||||
try:
|
||||
gitlab_service = GitLabService(token=token)
|
||||
gitlab_service = GitLabService(token=token, base_domain=base_domain)
|
||||
await gitlab_service.get_user()
|
||||
return ProviderType.GITLAB
|
||||
except Exception:
|
||||
|
||||
@ -42,6 +42,7 @@ from openhands.resolver.utils import (
|
||||
reset_logger_for_multiprocessing,
|
||||
)
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.utils.async_utils import GENERAL_TIMEOUT, call_async_from_sync
|
||||
|
||||
# Don't make this confgurable for now, unless we have other competitive agents
|
||||
AGENT_CLASS = 'CodeActAgent'
|
||||
@ -688,7 +689,12 @@ def main() -> None:
|
||||
if not token:
|
||||
raise ValueError('Token is required.')
|
||||
|
||||
platform = identify_token(token, my_args.selected_repo, my_args.base_domain)
|
||||
platform = call_async_from_sync(
|
||||
identify_token,
|
||||
GENERAL_TIMEOUT,
|
||||
token,
|
||||
my_args.base_domain,
|
||||
)
|
||||
|
||||
api_key = my_args.llm_api_key or os.environ['LLM_API_KEY']
|
||||
model = my_args.llm_model or os.environ['LLM_MODEL']
|
||||
|
||||
@ -22,6 +22,7 @@ from openhands.resolver.io_utils import (
|
||||
from openhands.resolver.patching import apply_diff, parse_patch
|
||||
from openhands.resolver.resolver_output import ResolverOutput
|
||||
from openhands.resolver.utils import identify_token
|
||||
from openhands.utils.async_utils import GENERAL_TIMEOUT, call_async_from_sync
|
||||
|
||||
|
||||
def apply_patch(repo_dir: str, patch: str) -> None:
|
||||
@ -685,7 +686,12 @@ def main() -> None:
|
||||
)
|
||||
username = my_args.username if my_args.username else os.getenv('GIT_USERNAME')
|
||||
|
||||
platform = identify_token(token, my_args.selected_repo, my_args.base_domain)
|
||||
platform = call_async_from_sync(
|
||||
identify_token,
|
||||
GENERAL_TIMEOUT,
|
||||
token,
|
||||
my_args.base_domain,
|
||||
)
|
||||
|
||||
api_key = my_args.llm_api_key or os.environ['LLM_API_KEY']
|
||||
llm_config = LLMConfig(
|
||||
|
||||
@ -4,7 +4,7 @@ import os
|
||||
import re
|
||||
from typing import Callable
|
||||
|
||||
import httpx
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.logger import get_console_handler
|
||||
@ -12,69 +12,21 @@ from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action import Action
|
||||
from openhands.events.action.message import MessageAction
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.integrations.utils import validate_provider_token
|
||||
|
||||
|
||||
def identify_token(
|
||||
token: str, selected_repo: str | None = None, base_domain: str | None = 'github.com'
|
||||
) -> ProviderType:
|
||||
async def identify_token(token: str, base_domain: str | None) -> ProviderType:
|
||||
"""
|
||||
Identifies whether a token belongs to GitHub or GitLab.
|
||||
|
||||
Parameters:
|
||||
token (str): The personal access token to check.
|
||||
selected_repo (str): Repository in format "owner/repo" for GitHub Actions token validation.
|
||||
base_domain (str): The base domain for GitHub Enterprise (default: "github.com").
|
||||
|
||||
Returns:
|
||||
ProviderType: "GitHub" if the token is valid for GitHub,
|
||||
"GitLab" if the token is valid for GitLab,
|
||||
"Invalid" if the token is not recognized by either.
|
||||
base_domain (str): Custom base domain for provider (e.g GitHub Enterprise)
|
||||
"""
|
||||
# Determine GitHub API base URL based on domain
|
||||
if base_domain is None or base_domain == 'github.com':
|
||||
github_api_base = 'https://api.github.com'
|
||||
else:
|
||||
github_api_base = f'https://{base_domain}/api/v3'
|
||||
provider = await validate_provider_token(SecretStr(token), base_domain)
|
||||
if not provider:
|
||||
raise ValueError('Token is invalid.')
|
||||
|
||||
# Try GitHub Actions token format (Bearer) with repo endpoint if repo is provided
|
||||
if selected_repo:
|
||||
github_repo_url = f'{github_api_base}/repos/{selected_repo}'
|
||||
github_bearer_headers = {
|
||||
'Authorization': f'Bearer {token}',
|
||||
'Accept': 'application/vnd.github+json',
|
||||
}
|
||||
|
||||
try:
|
||||
github_repo_response = httpx.get(
|
||||
github_repo_url, headers=github_bearer_headers, timeout=5
|
||||
)
|
||||
if github_repo_response.status_code == 200:
|
||||
return ProviderType.GITHUB
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f'Error connecting to GitHub API (selected_repo check): {e}')
|
||||
|
||||
# Try GitHub PAT format (token)
|
||||
github_url = f'{github_api_base}/user'
|
||||
github_headers = {'Authorization': f'token {token}'}
|
||||
|
||||
try:
|
||||
github_response = httpx.get(github_url, headers=github_headers, timeout=5)
|
||||
if github_response.status_code == 200:
|
||||
return ProviderType.GITHUB
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f'Error connecting to GitHub API: {e}')
|
||||
|
||||
gitlab_url = 'https://gitlab.com/api/v4/user'
|
||||
gitlab_headers = {'Authorization': f'Bearer {token}'}
|
||||
|
||||
try:
|
||||
gitlab_response = httpx.get(gitlab_url, headers=gitlab_headers, timeout=5)
|
||||
if gitlab_response.status_code == 200:
|
||||
return ProviderType.GITLAB
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f'Error connecting to GitLab API: {e}')
|
||||
|
||||
raise ValueError('Token is invalid.')
|
||||
return provider
|
||||
|
||||
|
||||
def codeact_user_response(
|
||||
|
||||
@ -1269,7 +1269,7 @@ def test_main(
|
||||
# Run main function
|
||||
main()
|
||||
|
||||
mock_identify_token.assert_called_with('mock_token', None, ANY)
|
||||
mock_identify_token.assert_called_with('mock_token', mock_args.base_domain)
|
||||
|
||||
llm_config = LLMConfig(
|
||||
model=mock_args.llm_model,
|
||||
|
||||
@ -1171,7 +1171,7 @@ def test_main(
|
||||
# Run main function
|
||||
main()
|
||||
|
||||
mock_identify_token.assert_called_with('mock_token', None, ANY)
|
||||
mock_identify_token.assert_called_with('mock_token', mock_args.base_domain)
|
||||
|
||||
llm_config = LLMConfig(
|
||||
model=mock_args.llm_model,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user