mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 13:47:19 +08:00
Change names to prepare for moving to keycloak User ID (#7178)
This commit is contained in:
@@ -19,43 +19,37 @@ from openhands.utils.import_utils import get_impl
|
||||
|
||||
class GitHubService:
|
||||
BASE_URL = 'https://api.github.com'
|
||||
token: SecretStr = SecretStr('')
|
||||
github_token: SecretStr = SecretStr('')
|
||||
refresh = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str | None = None,
|
||||
idp_token: SecretStr | None = None,
|
||||
token: SecretStr | None = None,
|
||||
external_auth_token: SecretStr | None = None,
|
||||
github_token: SecretStr | None = None,
|
||||
external_token_manager: bool = False,
|
||||
):
|
||||
self.user_id = user_id
|
||||
self.external_token_manager = external_token_manager
|
||||
|
||||
if token:
|
||||
self.token = token
|
||||
if github_token:
|
||||
self.github_token = github_token
|
||||
|
||||
async def _get_github_headers(self) -> dict:
|
||||
"""
|
||||
Retrieve the GH Token from settings store to construct the headers
|
||||
"""
|
||||
|
||||
if self.user_id and not self.token:
|
||||
self.token = await self.get_latest_token()
|
||||
"""Retrieve the GH Token from settings store to construct the headers."""
|
||||
if self.user_id and not self.github_token:
|
||||
self.github_token = await self.get_latest_token()
|
||||
|
||||
return {
|
||||
'Authorization': f'Bearer {self.token.get_secret_value()}',
|
||||
'Authorization': f'Bearer {self.github_token.get_secret_value() if self.github_token else ""}',
|
||||
'Accept': 'application/vnd.github.v3+json',
|
||||
}
|
||||
|
||||
def _has_token_expired(self, status_code: int) -> bool:
|
||||
return status_code == 401
|
||||
|
||||
async def get_latest_token(self) -> SecretStr:
|
||||
return self.token
|
||||
|
||||
async def get_latest_provider_token(self) -> SecretStr:
|
||||
return self.token
|
||||
async def get_latest_token(self) -> SecretStr | None:
|
||||
return self.github_token
|
||||
|
||||
async def _fetch_data(
|
||||
self, url: str, params: dict | None = None
|
||||
@@ -187,11 +181,11 @@ class GitHubService:
|
||||
raise GHUnknownException('Unknown error')
|
||||
|
||||
async def get_suggested_tasks(self) -> list[SuggestedTask]:
|
||||
"""
|
||||
Get suggested tasks for the authenticated user across all repositories.
|
||||
"""Get suggested tasks for the authenticated user across all repositories.
|
||||
|
||||
Returns:
|
||||
- PRs authored by the user
|
||||
- Issues assigned to the user
|
||||
- PRs authored by the user.
|
||||
- Issues assigned to the user.
|
||||
"""
|
||||
# Get user info to use in queries
|
||||
user = await self.get_user()
|
||||
|
||||
@@ -224,7 +224,7 @@ class Runtime(FileEditRuntimeMixin):
|
||||
gh_client = GithubServiceImpl(
|
||||
user_id=self.github_user_id, external_token_manager=True
|
||||
)
|
||||
token = await gh_client.get_latest_provider_token()
|
||||
token = await gh_client.get_latest_token()
|
||||
if token:
|
||||
export_cmd = CmdRunAction(
|
||||
f"export GITHUB_TOKEN='{token.get_secret_value()}'"
|
||||
|
||||
@@ -2,13 +2,17 @@ from fastapi import Request
|
||||
from pydantic import SecretStr
|
||||
|
||||
|
||||
def get_access_token(request: Request) -> SecretStr | None:
|
||||
return getattr(request.state, 'access_token', None)
|
||||
|
||||
|
||||
def get_user_id(request: Request) -> str | None:
|
||||
return getattr(request.state, 'user_id', None)
|
||||
|
||||
|
||||
def get_github_token(request: Request) -> SecretStr | None:
|
||||
return getattr(request.state, 'github_token', None)
|
||||
|
||||
|
||||
def get_user_id(request: Request) -> str | None:
|
||||
def get_github_user_id(request: Request) -> str | None:
|
||||
return getattr(request.state, 'github_user_id', None)
|
||||
|
||||
|
||||
def get_idp_token(request: Request) -> SecretStr | None:
|
||||
return getattr(request.state, 'idp_token', None)
|
||||
|
||||
@@ -12,7 +12,7 @@ from starlette.requests import Request as StarletteRequest
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
from openhands.server import shared
|
||||
from openhands.server.auth import get_user_id
|
||||
from openhands.server.auth import get_github_user_id
|
||||
from openhands.server.types import SessionMiddlewareInterface
|
||||
|
||||
|
||||
@@ -189,7 +189,7 @@ class GitHubTokenMiddleware(SessionMiddlewareInterface):
|
||||
|
||||
async def __call__(self, request: Request, call_next: Callable):
|
||||
settings_store = await shared.SettingsStoreImpl.get_instance(
|
||||
shared.config, get_user_id(request)
|
||||
shared.config, get_github_user_id(request)
|
||||
)
|
||||
settings = await settings_store.load()
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from openhands.integrations.github.github_types import (
|
||||
GitHubUser,
|
||||
SuggestedTask,
|
||||
)
|
||||
from openhands.server.auth import get_github_token, get_idp_token, get_user_id
|
||||
from openhands.server.auth import get_access_token, get_github_token, get_github_user_id
|
||||
|
||||
app = APIRouter(prefix='/api/github')
|
||||
|
||||
@@ -21,12 +21,14 @@ async def get_github_repositories(
|
||||
per_page: int = 10,
|
||||
sort: str = 'pushed',
|
||||
installation_id: int | None = None,
|
||||
github_user_id: str | None = Depends(get_user_id),
|
||||
github_user_id: str | None = Depends(get_github_user_id),
|
||||
github_user_token: SecretStr | None = Depends(get_github_token),
|
||||
idp_token: SecretStr | None = Depends(get_idp_token),
|
||||
access_token: SecretStr | None = Depends(get_access_token),
|
||||
):
|
||||
client = GithubServiceImpl(
|
||||
user_id=github_user_id, idp_token=idp_token, token=github_user_token
|
||||
user_id=github_user_id,
|
||||
external_auth_token=access_token,
|
||||
github_token=github_user_token,
|
||||
)
|
||||
try:
|
||||
repos: list[GitHubRepository] = await client.get_repositories(
|
||||
@@ -49,12 +51,14 @@ async def get_github_repositories(
|
||||
|
||||
@app.get('/user', response_model=GitHubUser)
|
||||
async def get_github_user(
|
||||
github_user_id: str | None = Depends(get_user_id),
|
||||
github_user_id: str | None = Depends(get_github_user_id),
|
||||
github_user_token: SecretStr | None = Depends(get_github_token),
|
||||
idp_token: SecretStr | None = Depends(get_idp_token),
|
||||
access_token: SecretStr | None = Depends(get_access_token),
|
||||
):
|
||||
client = GithubServiceImpl(
|
||||
user_id=github_user_id, idp_token=idp_token, token=github_user_token
|
||||
user_id=github_user_id,
|
||||
external_auth_token=access_token,
|
||||
github_token=github_user_token,
|
||||
)
|
||||
try:
|
||||
user: GitHubUser = await client.get_user()
|
||||
@@ -75,12 +79,14 @@ async def get_github_user(
|
||||
|
||||
@app.get('/installations', response_model=list[int])
|
||||
async def get_github_installation_ids(
|
||||
github_user_id: str | None = Depends(get_user_id),
|
||||
github_user_id: str | None = Depends(get_github_user_id),
|
||||
github_user_token: SecretStr | None = Depends(get_github_token),
|
||||
idp_token: SecretStr | None = Depends(get_idp_token),
|
||||
access_token: SecretStr | None = Depends(get_access_token),
|
||||
):
|
||||
client = GithubServiceImpl(
|
||||
user_id=github_user_id, idp_token=idp_token, token=github_user_token
|
||||
user_id=github_user_id,
|
||||
external_auth_token=access_token,
|
||||
github_token=github_user_token,
|
||||
)
|
||||
try:
|
||||
installations_ids: list[int] = await client.get_installation_ids()
|
||||
@@ -105,12 +111,14 @@ async def search_github_repositories(
|
||||
per_page: int = 5,
|
||||
sort: str = 'stars',
|
||||
order: str = 'desc',
|
||||
github_user_id: str | None = Depends(get_user_id),
|
||||
github_user_id: str | None = Depends(get_github_user_id),
|
||||
github_user_token: SecretStr | None = Depends(get_github_token),
|
||||
idp_token: SecretStr | None = Depends(get_idp_token),
|
||||
access_token: SecretStr | None = Depends(get_access_token),
|
||||
):
|
||||
client = GithubServiceImpl(
|
||||
user_id=github_user_id, idp_token=idp_token, token=github_user_token
|
||||
user_id=github_user_id,
|
||||
external_auth_token=access_token,
|
||||
github_token=github_user_token,
|
||||
)
|
||||
try:
|
||||
repos: list[GitHubRepository] = await client.search_repositories(
|
||||
@@ -133,9 +141,9 @@ async def search_github_repositories(
|
||||
|
||||
@app.get('/suggested-tasks', response_model=list[SuggestedTask])
|
||||
async def get_suggested_tasks(
|
||||
github_user_id: str | None = Depends(get_user_id),
|
||||
github_user_id: str | None = Depends(get_github_user_id),
|
||||
github_user_token: SecretStr | None = Depends(get_github_token),
|
||||
idp_token: SecretStr | None = Depends(get_idp_token),
|
||||
access_token: SecretStr | None = Depends(get_access_token),
|
||||
):
|
||||
"""Get suggested tasks for the authenticated user across their most recently pushed repositories.
|
||||
|
||||
@@ -144,7 +152,9 @@ async def get_suggested_tasks(
|
||||
- Issues assigned to the user.
|
||||
"""
|
||||
client = GithubServiceImpl(
|
||||
user_id=github_user_id, idp_token=idp_token, token=github_user_token
|
||||
user_id=github_user_id,
|
||||
external_auth_token=access_token,
|
||||
github_token=github_user_token,
|
||||
)
|
||||
try:
|
||||
tasks: list[SuggestedTask] = await client.get_suggested_tasks()
|
||||
|
||||
@@ -9,7 +9,7 @@ from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action.message import MessageAction
|
||||
from openhands.integrations.github.github_service import GithubServiceImpl
|
||||
from openhands.runtime import get_runtime_cls
|
||||
from openhands.server.auth import get_github_token, get_idp_token, get_user_id
|
||||
from openhands.server.auth import get_access_token, get_github_token, get_github_user_id
|
||||
from openhands.server.data_models.conversation_info import ConversationInfo
|
||||
from openhands.server.data_models.conversation_info_result_set import (
|
||||
ConversationInfoResultSet,
|
||||
@@ -136,11 +136,11 @@ async def new_conversation(request: Request, data: InitSessionRequest):
|
||||
using the returned conversation ID.
|
||||
"""
|
||||
logger.info('Initializing new conversation')
|
||||
user_id = get_user_id(request)
|
||||
user_id = get_github_user_id(request)
|
||||
gh_client = GithubServiceImpl(
|
||||
user_id=user_id,
|
||||
idp_token=get_idp_token(request),
|
||||
token=get_github_token(request),
|
||||
external_auth_token=get_access_token(request),
|
||||
github_token=get_github_token(request),
|
||||
)
|
||||
github_token = await gh_client.get_latest_token()
|
||||
|
||||
@@ -191,7 +191,7 @@ async def search_conversations(
|
||||
limit: int = 20,
|
||||
) -> ConversationInfoResultSet:
|
||||
conversation_store = await ConversationStoreImpl.get_instance(
|
||||
config, get_user_id(request)
|
||||
config, get_github_user_id(request)
|
||||
)
|
||||
conversation_metadata_result_set = await conversation_store.search(page_id, limit)
|
||||
|
||||
@@ -210,7 +210,7 @@ async def search_conversations(
|
||||
conversation.conversation_id for conversation in filtered_results
|
||||
)
|
||||
running_conversations = await conversation_manager.get_running_agent_loops(
|
||||
get_user_id(request), set(conversation_ids)
|
||||
get_github_user_id(request), set(conversation_ids)
|
||||
)
|
||||
result = ConversationInfoResultSet(
|
||||
results=await wait_all(
|
||||
@@ -230,7 +230,7 @@ async def get_conversation(
|
||||
conversation_id: str, request: Request
|
||||
) -> ConversationInfo | None:
|
||||
conversation_store = await ConversationStoreImpl.get_instance(
|
||||
config, get_user_id(request)
|
||||
config, get_github_user_id(request)
|
||||
)
|
||||
try:
|
||||
metadata = await conversation_store.get_metadata(conversation_id)
|
||||
@@ -246,7 +246,7 @@ async def update_conversation(
|
||||
request: Request, conversation_id: str, title: str = Body(embed=True)
|
||||
) -> bool:
|
||||
conversation_store = await ConversationStoreImpl.get_instance(
|
||||
config, get_user_id(request)
|
||||
config, get_github_user_id(request)
|
||||
)
|
||||
metadata = await conversation_store.get_metadata(conversation_id)
|
||||
if not metadata:
|
||||
@@ -262,7 +262,7 @@ async def delete_conversation(
|
||||
request: Request,
|
||||
) -> bool:
|
||||
conversation_store = await ConversationStoreImpl.get_instance(
|
||||
config, get_user_id(request)
|
||||
config, get_github_user_id(request)
|
||||
)
|
||||
try:
|
||||
await conversation_store.get_metadata(conversation_id)
|
||||
|
||||
@@ -4,7 +4,7 @@ from pydantic import SecretStr
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.github.github_service import GithubServiceImpl
|
||||
from openhands.server.auth import get_github_token, get_user_id
|
||||
from openhands.server.auth import get_github_token, get_github_user_id
|
||||
from openhands.server.settings import GETSettingsModel, POSTSettingsModel, Settings
|
||||
from openhands.server.shared import SettingsStoreImpl, config
|
||||
|
||||
@@ -14,7 +14,7 @@ app = APIRouter(prefix='/api')
|
||||
@app.get('/settings', response_model=GETSettingsModel)
|
||||
async def load_settings(request: Request) -> GETSettingsModel | JSONResponse:
|
||||
try:
|
||||
user_id = get_user_id(request)
|
||||
user_id = get_github_user_id(request)
|
||||
settings_store = await SettingsStoreImpl.get_instance(config, user_id)
|
||||
settings = await settings_store.load()
|
||||
if not settings:
|
||||
@@ -51,7 +51,9 @@ async def store_settings(
|
||||
# We check if the token is valid by getting the user
|
||||
# If the token is invalid, this will raise an exception
|
||||
github = GithubServiceImpl(
|
||||
user_id=None, idp_token=None, token=SecretStr(settings.github_token)
|
||||
user_id=None,
|
||||
external_auth_token=None,
|
||||
github_token=SecretStr(settings.github_token),
|
||||
)
|
||||
await github.get_user()
|
||||
|
||||
@@ -66,7 +68,7 @@ async def store_settings(
|
||||
|
||||
try:
|
||||
settings_store = await SettingsStoreImpl.get_instance(
|
||||
config, get_user_id(request)
|
||||
config, get_github_user_id(request)
|
||||
)
|
||||
existing_settings = await settings_store.load()
|
||||
|
||||
|
||||
@@ -12,9 +12,9 @@ from openhands.integrations.github.github_types import GhAuthenticationError
|
||||
async def test_github_service_token_handling():
|
||||
# Test initialization with SecretStr token
|
||||
token = SecretStr('test-token')
|
||||
service = GitHubService(user_id=None, token=token)
|
||||
assert service.token == token
|
||||
assert service.token.get_secret_value() == 'test-token'
|
||||
service = GitHubService(user_id=None, github_token=token)
|
||||
assert service.github_token == token
|
||||
assert service.github_token.get_secret_value() == 'test-token'
|
||||
|
||||
# Test headers contain the token correctly
|
||||
headers = await service._get_github_headers()
|
||||
@@ -23,14 +23,14 @@ async def test_github_service_token_handling():
|
||||
|
||||
# Test initialization without token
|
||||
service = GitHubService(user_id='test-user')
|
||||
assert service.token == SecretStr('')
|
||||
assert service.github_token == SecretStr('')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_github_service_token_refresh():
|
||||
# Test that token refresh is only attempted when refresh=True
|
||||
token = SecretStr('test-token')
|
||||
service = GitHubService(user_id=None, token=token)
|
||||
service = GitHubService(user_id=None, github_token=token)
|
||||
assert not service.refresh
|
||||
|
||||
# Test token expiry detection
|
||||
@@ -58,7 +58,7 @@ async def test_github_service_fetch_data():
|
||||
mock_client.__aexit__.return_value = None
|
||||
|
||||
with patch('httpx.AsyncClient', return_value=mock_client):
|
||||
service = GitHubService(user_id=None, token=SecretStr('test-token'))
|
||||
service = GitHubService(user_id=None, github_token=SecretStr('test-token'))
|
||||
_ = await service._fetch_data('https://api.github.com/user')
|
||||
|
||||
# Verify the request was made with correct headers
|
||||
|
||||
Reference in New Issue
Block a user