From 2cb5b91300effb8377c2c83f5b9f322f6d103c5d Mon Sep 17 00:00:00 2001 From: chuckbutkus Date: Mon, 10 Mar 2025 19:23:57 -0400 Subject: [PATCH] Change names to prepare for moving to keycloak User ID (#7178) --- .../integrations/github/github_service.py | 36 +++++++--------- openhands/runtime/base.py | 2 +- openhands/server/auth.py | 14 ++++--- openhands/server/middleware.py | 4 +- openhands/server/routes/github.py | 42 ++++++++++++------- .../server/routes/manage_conversations.py | 18 ++++---- openhands/server/routes/settings.py | 10 +++-- tests/unit/test_github_service.py | 12 +++--- 8 files changed, 74 insertions(+), 64 deletions(-) diff --git a/openhands/integrations/github/github_service.py b/openhands/integrations/github/github_service.py index 6cc9f9cca9..bb85cfdee5 100644 --- a/openhands/integrations/github/github_service.py +++ b/openhands/integrations/github/github_service.py @@ -19,43 +19,37 @@ from openhands.utils.import_utils import get_impl class GitHubService: BASE_URL = 'https://api.github.com' - token: SecretStr = SecretStr('') + github_token: SecretStr = SecretStr('') refresh = False def __init__( self, user_id: str | None = None, - idp_token: SecretStr | None = None, - token: SecretStr | None = None, + external_auth_token: SecretStr | None = None, + github_token: SecretStr | None = None, external_token_manager: bool = False, ): self.user_id = user_id self.external_token_manager = external_token_manager - if token: - self.token = token + if github_token: + self.github_token = github_token async def _get_github_headers(self) -> dict: - """ - Retrieve the GH Token from settings store to construct the headers - """ - - if self.user_id and not self.token: - self.token = await self.get_latest_token() + """Retrieve the GH Token from settings store to construct the headers.""" + if self.user_id and not self.github_token: + self.github_token = await self.get_latest_token() return { - 'Authorization': f'Bearer {self.token.get_secret_value()}', + 'Authorization': f'Bearer {self.github_token.get_secret_value() if self.github_token else ""}', 'Accept': 'application/vnd.github.v3+json', } def _has_token_expired(self, status_code: int) -> bool: return status_code == 401 - async def get_latest_token(self) -> SecretStr: - return self.token - - async def get_latest_provider_token(self) -> SecretStr: - return self.token + async def get_latest_token(self) -> SecretStr | None: + return self.github_token async def _fetch_data( self, url: str, params: dict | None = None @@ -187,11 +181,11 @@ class GitHubService: raise GHUnknownException('Unknown error') async def get_suggested_tasks(self) -> list[SuggestedTask]: - """ - Get suggested tasks for the authenticated user across all repositories. + """Get suggested tasks for the authenticated user across all repositories. + Returns: - - PRs authored by the user - - Issues assigned to the user + - PRs authored by the user. + - Issues assigned to the user. """ # Get user info to use in queries user = await self.get_user() diff --git a/openhands/runtime/base.py b/openhands/runtime/base.py index 291c8f3063..aee48fdba6 100644 --- a/openhands/runtime/base.py +++ b/openhands/runtime/base.py @@ -224,7 +224,7 @@ class Runtime(FileEditRuntimeMixin): gh_client = GithubServiceImpl( user_id=self.github_user_id, external_token_manager=True ) - token = await gh_client.get_latest_provider_token() + token = await gh_client.get_latest_token() if token: export_cmd = CmdRunAction( f"export GITHUB_TOKEN='{token.get_secret_value()}'" diff --git a/openhands/server/auth.py b/openhands/server/auth.py index 470834f8d0..55ded747a0 100644 --- a/openhands/server/auth.py +++ b/openhands/server/auth.py @@ -2,13 +2,17 @@ from fastapi import Request from pydantic import SecretStr +def get_access_token(request: Request) -> SecretStr | None: + return getattr(request.state, 'access_token', None) + + +def get_user_id(request: Request) -> str | None: + return getattr(request.state, 'user_id', None) + + def get_github_token(request: Request) -> SecretStr | None: return getattr(request.state, 'github_token', None) -def get_user_id(request: Request) -> str | None: +def get_github_user_id(request: Request) -> str | None: return getattr(request.state, 'github_user_id', None) - - -def get_idp_token(request: Request) -> SecretStr | None: - return getattr(request.state, 'idp_token', None) diff --git a/openhands/server/middleware.py b/openhands/server/middleware.py index 734d52004b..3efdd265ce 100644 --- a/openhands/server/middleware.py +++ b/openhands/server/middleware.py @@ -12,7 +12,7 @@ from starlette.requests import Request as StarletteRequest from starlette.types import ASGIApp from openhands.server import shared -from openhands.server.auth import get_user_id +from openhands.server.auth import get_github_user_id from openhands.server.types import SessionMiddlewareInterface @@ -189,7 +189,7 @@ class GitHubTokenMiddleware(SessionMiddlewareInterface): async def __call__(self, request: Request, call_next: Callable): settings_store = await shared.SettingsStoreImpl.get_instance( - shared.config, get_user_id(request) + shared.config, get_github_user_id(request) ) settings = await settings_store.load() diff --git a/openhands/server/routes/github.py b/openhands/server/routes/github.py index 71f10d3749..1435987cbc 100644 --- a/openhands/server/routes/github.py +++ b/openhands/server/routes/github.py @@ -10,7 +10,7 @@ from openhands.integrations.github.github_types import ( GitHubUser, SuggestedTask, ) -from openhands.server.auth import get_github_token, get_idp_token, get_user_id +from openhands.server.auth import get_access_token, get_github_token, get_github_user_id app = APIRouter(prefix='/api/github') @@ -21,12 +21,14 @@ async def get_github_repositories( per_page: int = 10, sort: str = 'pushed', installation_id: int | None = None, - github_user_id: str | None = Depends(get_user_id), + github_user_id: str | None = Depends(get_github_user_id), github_user_token: SecretStr | None = Depends(get_github_token), - idp_token: SecretStr | None = Depends(get_idp_token), + access_token: SecretStr | None = Depends(get_access_token), ): client = GithubServiceImpl( - user_id=github_user_id, idp_token=idp_token, token=github_user_token + user_id=github_user_id, + external_auth_token=access_token, + github_token=github_user_token, ) try: repos: list[GitHubRepository] = await client.get_repositories( @@ -49,12 +51,14 @@ async def get_github_repositories( @app.get('/user', response_model=GitHubUser) async def get_github_user( - github_user_id: str | None = Depends(get_user_id), + github_user_id: str | None = Depends(get_github_user_id), github_user_token: SecretStr | None = Depends(get_github_token), - idp_token: SecretStr | None = Depends(get_idp_token), + access_token: SecretStr | None = Depends(get_access_token), ): client = GithubServiceImpl( - user_id=github_user_id, idp_token=idp_token, token=github_user_token + user_id=github_user_id, + external_auth_token=access_token, + github_token=github_user_token, ) try: user: GitHubUser = await client.get_user() @@ -75,12 +79,14 @@ async def get_github_user( @app.get('/installations', response_model=list[int]) async def get_github_installation_ids( - github_user_id: str | None = Depends(get_user_id), + github_user_id: str | None = Depends(get_github_user_id), github_user_token: SecretStr | None = Depends(get_github_token), - idp_token: SecretStr | None = Depends(get_idp_token), + access_token: SecretStr | None = Depends(get_access_token), ): client = GithubServiceImpl( - user_id=github_user_id, idp_token=idp_token, token=github_user_token + user_id=github_user_id, + external_auth_token=access_token, + github_token=github_user_token, ) try: installations_ids: list[int] = await client.get_installation_ids() @@ -105,12 +111,14 @@ async def search_github_repositories( per_page: int = 5, sort: str = 'stars', order: str = 'desc', - github_user_id: str | None = Depends(get_user_id), + github_user_id: str | None = Depends(get_github_user_id), github_user_token: SecretStr | None = Depends(get_github_token), - idp_token: SecretStr | None = Depends(get_idp_token), + access_token: SecretStr | None = Depends(get_access_token), ): client = GithubServiceImpl( - user_id=github_user_id, idp_token=idp_token, token=github_user_token + user_id=github_user_id, + external_auth_token=access_token, + github_token=github_user_token, ) try: repos: list[GitHubRepository] = await client.search_repositories( @@ -133,9 +141,9 @@ async def search_github_repositories( @app.get('/suggested-tasks', response_model=list[SuggestedTask]) async def get_suggested_tasks( - github_user_id: str | None = Depends(get_user_id), + github_user_id: str | None = Depends(get_github_user_id), github_user_token: SecretStr | None = Depends(get_github_token), - idp_token: SecretStr | None = Depends(get_idp_token), + access_token: SecretStr | None = Depends(get_access_token), ): """Get suggested tasks for the authenticated user across their most recently pushed repositories. @@ -144,7 +152,9 @@ async def get_suggested_tasks( - Issues assigned to the user. """ client = GithubServiceImpl( - user_id=github_user_id, idp_token=idp_token, token=github_user_token + user_id=github_user_id, + external_auth_token=access_token, + github_token=github_user_token, ) try: tasks: list[SuggestedTask] = await client.get_suggested_tasks() diff --git a/openhands/server/routes/manage_conversations.py b/openhands/server/routes/manage_conversations.py index 77a35041db..4d991c13a4 100644 --- a/openhands/server/routes/manage_conversations.py +++ b/openhands/server/routes/manage_conversations.py @@ -9,7 +9,7 @@ from openhands.core.logger import openhands_logger as logger from openhands.events.action.message import MessageAction from openhands.integrations.github.github_service import GithubServiceImpl from openhands.runtime import get_runtime_cls -from openhands.server.auth import get_github_token, get_idp_token, get_user_id +from openhands.server.auth import get_access_token, get_github_token, get_github_user_id from openhands.server.data_models.conversation_info import ConversationInfo from openhands.server.data_models.conversation_info_result_set import ( ConversationInfoResultSet, @@ -136,11 +136,11 @@ async def new_conversation(request: Request, data: InitSessionRequest): using the returned conversation ID. """ logger.info('Initializing new conversation') - user_id = get_user_id(request) + user_id = get_github_user_id(request) gh_client = GithubServiceImpl( user_id=user_id, - idp_token=get_idp_token(request), - token=get_github_token(request), + external_auth_token=get_access_token(request), + github_token=get_github_token(request), ) github_token = await gh_client.get_latest_token() @@ -191,7 +191,7 @@ async def search_conversations( limit: int = 20, ) -> ConversationInfoResultSet: conversation_store = await ConversationStoreImpl.get_instance( - config, get_user_id(request) + config, get_github_user_id(request) ) conversation_metadata_result_set = await conversation_store.search(page_id, limit) @@ -210,7 +210,7 @@ async def search_conversations( conversation.conversation_id for conversation in filtered_results ) running_conversations = await conversation_manager.get_running_agent_loops( - get_user_id(request), set(conversation_ids) + get_github_user_id(request), set(conversation_ids) ) result = ConversationInfoResultSet( results=await wait_all( @@ -230,7 +230,7 @@ async def get_conversation( conversation_id: str, request: Request ) -> ConversationInfo | None: conversation_store = await ConversationStoreImpl.get_instance( - config, get_user_id(request) + config, get_github_user_id(request) ) try: metadata = await conversation_store.get_metadata(conversation_id) @@ -246,7 +246,7 @@ async def update_conversation( request: Request, conversation_id: str, title: str = Body(embed=True) ) -> bool: conversation_store = await ConversationStoreImpl.get_instance( - config, get_user_id(request) + config, get_github_user_id(request) ) metadata = await conversation_store.get_metadata(conversation_id) if not metadata: @@ -262,7 +262,7 @@ async def delete_conversation( request: Request, ) -> bool: conversation_store = await ConversationStoreImpl.get_instance( - config, get_user_id(request) + config, get_github_user_id(request) ) try: await conversation_store.get_metadata(conversation_id) diff --git a/openhands/server/routes/settings.py b/openhands/server/routes/settings.py index 58b25a5d7c..37f067f55a 100644 --- a/openhands/server/routes/settings.py +++ b/openhands/server/routes/settings.py @@ -4,7 +4,7 @@ from pydantic import SecretStr from openhands.core.logger import openhands_logger as logger from openhands.integrations.github.github_service import GithubServiceImpl -from openhands.server.auth import get_github_token, get_user_id +from openhands.server.auth import get_github_token, get_github_user_id from openhands.server.settings import GETSettingsModel, POSTSettingsModel, Settings from openhands.server.shared import SettingsStoreImpl, config @@ -14,7 +14,7 @@ app = APIRouter(prefix='/api') @app.get('/settings', response_model=GETSettingsModel) async def load_settings(request: Request) -> GETSettingsModel | JSONResponse: try: - user_id = get_user_id(request) + user_id = get_github_user_id(request) settings_store = await SettingsStoreImpl.get_instance(config, user_id) settings = await settings_store.load() if not settings: @@ -51,7 +51,9 @@ async def store_settings( # We check if the token is valid by getting the user # If the token is invalid, this will raise an exception github = GithubServiceImpl( - user_id=None, idp_token=None, token=SecretStr(settings.github_token) + user_id=None, + external_auth_token=None, + github_token=SecretStr(settings.github_token), ) await github.get_user() @@ -66,7 +68,7 @@ async def store_settings( try: settings_store = await SettingsStoreImpl.get_instance( - config, get_user_id(request) + config, get_github_user_id(request) ) existing_settings = await settings_store.load() diff --git a/tests/unit/test_github_service.py b/tests/unit/test_github_service.py index 222be16767..627c60121f 100644 --- a/tests/unit/test_github_service.py +++ b/tests/unit/test_github_service.py @@ -12,9 +12,9 @@ from openhands.integrations.github.github_types import GhAuthenticationError async def test_github_service_token_handling(): # Test initialization with SecretStr token token = SecretStr('test-token') - service = GitHubService(user_id=None, token=token) - assert service.token == token - assert service.token.get_secret_value() == 'test-token' + service = GitHubService(user_id=None, github_token=token) + assert service.github_token == token + assert service.github_token.get_secret_value() == 'test-token' # Test headers contain the token correctly headers = await service._get_github_headers() @@ -23,14 +23,14 @@ async def test_github_service_token_handling(): # Test initialization without token service = GitHubService(user_id='test-user') - assert service.token == SecretStr('') + assert service.github_token == SecretStr('') @pytest.mark.asyncio async def test_github_service_token_refresh(): # Test that token refresh is only attempted when refresh=True token = SecretStr('test-token') - service = GitHubService(user_id=None, token=token) + service = GitHubService(user_id=None, github_token=token) assert not service.refresh # Test token expiry detection @@ -58,7 +58,7 @@ async def test_github_service_fetch_data(): mock_client.__aexit__.return_value = None with patch('httpx.AsyncClient', return_value=mock_client): - service = GitHubService(user_id=None, token=SecretStr('test-token')) + service = GitHubService(user_id=None, github_token=SecretStr('test-token')) _ = await service._fetch_data('https://api.github.com/user') # Verify the request was made with correct headers