Change names to prepare for moving to keycloak User ID (#7178)

This commit is contained in:
chuckbutkus
2025-03-10 19:23:57 -04:00
committed by GitHub
parent d6e601ea2e
commit 2cb5b91300
8 changed files with 74 additions and 64 deletions

View File

@@ -19,43 +19,37 @@ from openhands.utils.import_utils import get_impl
class GitHubService:
BASE_URL = 'https://api.github.com'
token: SecretStr = SecretStr('')
github_token: SecretStr = SecretStr('')
refresh = False
def __init__(
self,
user_id: str | None = None,
idp_token: SecretStr | None = None,
token: SecretStr | None = None,
external_auth_token: SecretStr | None = None,
github_token: SecretStr | None = None,
external_token_manager: bool = False,
):
self.user_id = user_id
self.external_token_manager = external_token_manager
if token:
self.token = token
if github_token:
self.github_token = github_token
async def _get_github_headers(self) -> dict:
"""
Retrieve the GH Token from settings store to construct the headers
"""
if self.user_id and not self.token:
self.token = await self.get_latest_token()
"""Retrieve the GH Token from settings store to construct the headers."""
if self.user_id and not self.github_token:
self.github_token = await self.get_latest_token()
return {
'Authorization': f'Bearer {self.token.get_secret_value()}',
'Authorization': f'Bearer {self.github_token.get_secret_value() if self.github_token else ""}',
'Accept': 'application/vnd.github.v3+json',
}
def _has_token_expired(self, status_code: int) -> bool:
return status_code == 401
async def get_latest_token(self) -> SecretStr:
return self.token
async def get_latest_provider_token(self) -> SecretStr:
return self.token
async def get_latest_token(self) -> SecretStr | None:
return self.github_token
async def _fetch_data(
self, url: str, params: dict | None = None
@@ -187,11 +181,11 @@ class GitHubService:
raise GHUnknownException('Unknown error')
async def get_suggested_tasks(self) -> list[SuggestedTask]:
"""
Get suggested tasks for the authenticated user across all repositories.
"""Get suggested tasks for the authenticated user across all repositories.
Returns:
- PRs authored by the user
- Issues assigned to the user
- PRs authored by the user.
- Issues assigned to the user.
"""
# Get user info to use in queries
user = await self.get_user()

View File

@@ -224,7 +224,7 @@ class Runtime(FileEditRuntimeMixin):
gh_client = GithubServiceImpl(
user_id=self.github_user_id, external_token_manager=True
)
token = await gh_client.get_latest_provider_token()
token = await gh_client.get_latest_token()
if token:
export_cmd = CmdRunAction(
f"export GITHUB_TOKEN='{token.get_secret_value()}'"

View File

@@ -2,13 +2,17 @@ from fastapi import Request
from pydantic import SecretStr
def get_access_token(request: Request) -> SecretStr | None:
return getattr(request.state, 'access_token', None)
def get_user_id(request: Request) -> str | None:
return getattr(request.state, 'user_id', None)
def get_github_token(request: Request) -> SecretStr | None:
return getattr(request.state, 'github_token', None)
def get_user_id(request: Request) -> str | None:
def get_github_user_id(request: Request) -> str | None:
return getattr(request.state, 'github_user_id', None)
def get_idp_token(request: Request) -> SecretStr | None:
return getattr(request.state, 'idp_token', None)

View File

@@ -12,7 +12,7 @@ from starlette.requests import Request as StarletteRequest
from starlette.types import ASGIApp
from openhands.server import shared
from openhands.server.auth import get_user_id
from openhands.server.auth import get_github_user_id
from openhands.server.types import SessionMiddlewareInterface
@@ -189,7 +189,7 @@ class GitHubTokenMiddleware(SessionMiddlewareInterface):
async def __call__(self, request: Request, call_next: Callable):
settings_store = await shared.SettingsStoreImpl.get_instance(
shared.config, get_user_id(request)
shared.config, get_github_user_id(request)
)
settings = await settings_store.load()

View File

@@ -10,7 +10,7 @@ from openhands.integrations.github.github_types import (
GitHubUser,
SuggestedTask,
)
from openhands.server.auth import get_github_token, get_idp_token, get_user_id
from openhands.server.auth import get_access_token, get_github_token, get_github_user_id
app = APIRouter(prefix='/api/github')
@@ -21,12 +21,14 @@ async def get_github_repositories(
per_page: int = 10,
sort: str = 'pushed',
installation_id: int | None = None,
github_user_id: str | None = Depends(get_user_id),
github_user_id: str | None = Depends(get_github_user_id),
github_user_token: SecretStr | None = Depends(get_github_token),
idp_token: SecretStr | None = Depends(get_idp_token),
access_token: SecretStr | None = Depends(get_access_token),
):
client = GithubServiceImpl(
user_id=github_user_id, idp_token=idp_token, token=github_user_token
user_id=github_user_id,
external_auth_token=access_token,
github_token=github_user_token,
)
try:
repos: list[GitHubRepository] = await client.get_repositories(
@@ -49,12 +51,14 @@ async def get_github_repositories(
@app.get('/user', response_model=GitHubUser)
async def get_github_user(
github_user_id: str | None = Depends(get_user_id),
github_user_id: str | None = Depends(get_github_user_id),
github_user_token: SecretStr | None = Depends(get_github_token),
idp_token: SecretStr | None = Depends(get_idp_token),
access_token: SecretStr | None = Depends(get_access_token),
):
client = GithubServiceImpl(
user_id=github_user_id, idp_token=idp_token, token=github_user_token
user_id=github_user_id,
external_auth_token=access_token,
github_token=github_user_token,
)
try:
user: GitHubUser = await client.get_user()
@@ -75,12 +79,14 @@ async def get_github_user(
@app.get('/installations', response_model=list[int])
async def get_github_installation_ids(
github_user_id: str | None = Depends(get_user_id),
github_user_id: str | None = Depends(get_github_user_id),
github_user_token: SecretStr | None = Depends(get_github_token),
idp_token: SecretStr | None = Depends(get_idp_token),
access_token: SecretStr | None = Depends(get_access_token),
):
client = GithubServiceImpl(
user_id=github_user_id, idp_token=idp_token, token=github_user_token
user_id=github_user_id,
external_auth_token=access_token,
github_token=github_user_token,
)
try:
installations_ids: list[int] = await client.get_installation_ids()
@@ -105,12 +111,14 @@ async def search_github_repositories(
per_page: int = 5,
sort: str = 'stars',
order: str = 'desc',
github_user_id: str | None = Depends(get_user_id),
github_user_id: str | None = Depends(get_github_user_id),
github_user_token: SecretStr | None = Depends(get_github_token),
idp_token: SecretStr | None = Depends(get_idp_token),
access_token: SecretStr | None = Depends(get_access_token),
):
client = GithubServiceImpl(
user_id=github_user_id, idp_token=idp_token, token=github_user_token
user_id=github_user_id,
external_auth_token=access_token,
github_token=github_user_token,
)
try:
repos: list[GitHubRepository] = await client.search_repositories(
@@ -133,9 +141,9 @@ async def search_github_repositories(
@app.get('/suggested-tasks', response_model=list[SuggestedTask])
async def get_suggested_tasks(
github_user_id: str | None = Depends(get_user_id),
github_user_id: str | None = Depends(get_github_user_id),
github_user_token: SecretStr | None = Depends(get_github_token),
idp_token: SecretStr | None = Depends(get_idp_token),
access_token: SecretStr | None = Depends(get_access_token),
):
"""Get suggested tasks for the authenticated user across their most recently pushed repositories.
@@ -144,7 +152,9 @@ async def get_suggested_tasks(
- Issues assigned to the user.
"""
client = GithubServiceImpl(
user_id=github_user_id, idp_token=idp_token, token=github_user_token
user_id=github_user_id,
external_auth_token=access_token,
github_token=github_user_token,
)
try:
tasks: list[SuggestedTask] = await client.get_suggested_tasks()

View File

@@ -9,7 +9,7 @@ from openhands.core.logger import openhands_logger as logger
from openhands.events.action.message import MessageAction
from openhands.integrations.github.github_service import GithubServiceImpl
from openhands.runtime import get_runtime_cls
from openhands.server.auth import get_github_token, get_idp_token, get_user_id
from openhands.server.auth import get_access_token, get_github_token, get_github_user_id
from openhands.server.data_models.conversation_info import ConversationInfo
from openhands.server.data_models.conversation_info_result_set import (
ConversationInfoResultSet,
@@ -136,11 +136,11 @@ async def new_conversation(request: Request, data: InitSessionRequest):
using the returned conversation ID.
"""
logger.info('Initializing new conversation')
user_id = get_user_id(request)
user_id = get_github_user_id(request)
gh_client = GithubServiceImpl(
user_id=user_id,
idp_token=get_idp_token(request),
token=get_github_token(request),
external_auth_token=get_access_token(request),
github_token=get_github_token(request),
)
github_token = await gh_client.get_latest_token()
@@ -191,7 +191,7 @@ async def search_conversations(
limit: int = 20,
) -> ConversationInfoResultSet:
conversation_store = await ConversationStoreImpl.get_instance(
config, get_user_id(request)
config, get_github_user_id(request)
)
conversation_metadata_result_set = await conversation_store.search(page_id, limit)
@@ -210,7 +210,7 @@ async def search_conversations(
conversation.conversation_id for conversation in filtered_results
)
running_conversations = await conversation_manager.get_running_agent_loops(
get_user_id(request), set(conversation_ids)
get_github_user_id(request), set(conversation_ids)
)
result = ConversationInfoResultSet(
results=await wait_all(
@@ -230,7 +230,7 @@ async def get_conversation(
conversation_id: str, request: Request
) -> ConversationInfo | None:
conversation_store = await ConversationStoreImpl.get_instance(
config, get_user_id(request)
config, get_github_user_id(request)
)
try:
metadata = await conversation_store.get_metadata(conversation_id)
@@ -246,7 +246,7 @@ async def update_conversation(
request: Request, conversation_id: str, title: str = Body(embed=True)
) -> bool:
conversation_store = await ConversationStoreImpl.get_instance(
config, get_user_id(request)
config, get_github_user_id(request)
)
metadata = await conversation_store.get_metadata(conversation_id)
if not metadata:
@@ -262,7 +262,7 @@ async def delete_conversation(
request: Request,
) -> bool:
conversation_store = await ConversationStoreImpl.get_instance(
config, get_user_id(request)
config, get_github_user_id(request)
)
try:
await conversation_store.get_metadata(conversation_id)

View File

@@ -4,7 +4,7 @@ from pydantic import SecretStr
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.github.github_service import GithubServiceImpl
from openhands.server.auth import get_github_token, get_user_id
from openhands.server.auth import get_github_token, get_github_user_id
from openhands.server.settings import GETSettingsModel, POSTSettingsModel, Settings
from openhands.server.shared import SettingsStoreImpl, config
@@ -14,7 +14,7 @@ app = APIRouter(prefix='/api')
@app.get('/settings', response_model=GETSettingsModel)
async def load_settings(request: Request) -> GETSettingsModel | JSONResponse:
try:
user_id = get_user_id(request)
user_id = get_github_user_id(request)
settings_store = await SettingsStoreImpl.get_instance(config, user_id)
settings = await settings_store.load()
if not settings:
@@ -51,7 +51,9 @@ async def store_settings(
# We check if the token is valid by getting the user
# If the token is invalid, this will raise an exception
github = GithubServiceImpl(
user_id=None, idp_token=None, token=SecretStr(settings.github_token)
user_id=None,
external_auth_token=None,
github_token=SecretStr(settings.github_token),
)
await github.get_user()
@@ -66,7 +68,7 @@ async def store_settings(
try:
settings_store = await SettingsStoreImpl.get_instance(
config, get_user_id(request)
config, get_github_user_id(request)
)
existing_settings = await settings_store.load()

View File

@@ -12,9 +12,9 @@ from openhands.integrations.github.github_types import GhAuthenticationError
async def test_github_service_token_handling():
# Test initialization with SecretStr token
token = SecretStr('test-token')
service = GitHubService(user_id=None, token=token)
assert service.token == token
assert service.token.get_secret_value() == 'test-token'
service = GitHubService(user_id=None, github_token=token)
assert service.github_token == token
assert service.github_token.get_secret_value() == 'test-token'
# Test headers contain the token correctly
headers = await service._get_github_headers()
@@ -23,14 +23,14 @@ async def test_github_service_token_handling():
# Test initialization without token
service = GitHubService(user_id='test-user')
assert service.token == SecretStr('')
assert service.github_token == SecretStr('')
@pytest.mark.asyncio
async def test_github_service_token_refresh():
# Test that token refresh is only attempted when refresh=True
token = SecretStr('test-token')
service = GitHubService(user_id=None, token=token)
service = GitHubService(user_id=None, github_token=token)
assert not service.refresh
# Test token expiry detection
@@ -58,7 +58,7 @@ async def test_github_service_fetch_data():
mock_client.__aexit__.return_value = None
with patch('httpx.AsyncClient', return_value=mock_client):
service = GitHubService(user_id=None, token=SecretStr('test-token'))
service = GitHubService(user_id=None, github_token=SecretStr('test-token'))
_ = await service._fetch_data('https://api.github.com/user')
# Verify the request was made with correct headers