diff --git a/openhands/server/app.py b/openhands/server/app.py index bae0e70fde..7b81a29ab9 100644 --- a/openhands/server/app.py +++ b/openhands/server/app.py @@ -10,10 +10,13 @@ with warnings.catch_warnings(): from fastapi import ( FastAPI, + Request, ) +from fastapi.responses import JSONResponse import openhands.agenthub # noqa F401 (we import this to get the agents registered) from openhands import __version__ +from openhands.integrations.service_types import AuthenticationError from openhands.server.routes.conversation import app as conversation_api_router from openhands.server.routes.feedback import app as feedback_api_router from openhands.server.routes.files import app as files_api_router @@ -61,6 +64,14 @@ app = FastAPI( ) +@app.exception_handler(AuthenticationError) +async def authentication_error_handler(request: Request, exc: AuthenticationError): + return JSONResponse( + status_code=401, + content=str(exc), + ) + + app.include_router(public_api_router) app.include_router(files_api_router) app.include_router(security_api_router) diff --git a/openhands/server/routes/git.py b/openhands/server/routes/git.py index 3bf9c00e54..024190c88c 100644 --- a/openhands/server/routes/git.py +++ b/openhands/server/routes/git.py @@ -58,10 +58,7 @@ async def get_user_installations( status_code=status.HTTP_400_BAD_REQUEST, ) - return JSONResponse( - content='Git provider token required. (such as GitHub).', - status_code=status.HTTP_401_UNAUTHORIZED, - ) + raise AuthenticationError('Git provider token required. (such as GitHub).') @app.get('/repositories', response_model=list[Repository]) @@ -92,15 +89,6 @@ async def get_user_repositories( installation_id, ) - except AuthenticationError as e: - logger.info( - f'Returning 401 Unauthorized - Authentication error for user_id: {user_id}, error: {str(e)}' - ) - return JSONResponse( - content=str(e), - status_code=status.HTTP_401_UNAUTHORIZED, - ) - except UnknownException as e: return JSONResponse( content=str(e), @@ -110,10 +98,7 @@ async def get_user_repositories( logger.info( f'Returning 401 Unauthorized - Git provider token required for user_id: {user_id}' ) - return JSONResponse( - content='Git provider token required. (such as GitHub).', - status_code=status.HTTP_401_UNAUTHORIZED, - ) + raise AuthenticationError('Git provider token required. (such as GitHub).') @app.get('/info', response_model=User) @@ -131,15 +116,6 @@ async def get_user( user: User = await client.get_user() return user - except AuthenticationError as e: - logger.info( - f'Returning 401 Unauthorized - Authentication error for user_id: {user_id}, error: {str(e)}' - ) - return JSONResponse( - content=str(e), - status_code=status.HTTP_401_UNAUTHORIZED, - ) - except UnknownException as e: return JSONResponse( content=str(e), @@ -149,10 +125,7 @@ async def get_user( logger.info( f'Returning 401 Unauthorized - Git provider token required for user_id: {user_id}' ) - return JSONResponse( - content='Git provider token required. (such as GitHub).', - status_code=status.HTTP_401_UNAUTHORIZED, - ) + raise AuthenticationError('Git provider token required. (such as GitHub).') @app.get('/search/repositories', response_model=list[Repository]) @@ -178,12 +151,6 @@ async def search_repositories( ) return repos - except AuthenticationError as e: - return JSONResponse( - content=str(e), - status_code=status.HTTP_401_UNAUTHORIZED, - ) - except UnknownException as e: return JSONResponse( content=str(e), @@ -193,10 +160,7 @@ async def search_repositories( logger.info( f'Returning 401 Unauthorized - Git provider token required for user_id: {user_id}' ) - return JSONResponse( - content='Git provider token required.', - status_code=status.HTTP_401_UNAUTHORIZED, - ) + raise AuthenticationError('Git provider token required.') @app.get('/suggested-tasks', response_model=list[SuggestedTask]) @@ -219,23 +183,13 @@ async def get_suggested_tasks( tasks: list[SuggestedTask] = await client.get_suggested_tasks() return tasks - except AuthenticationError as e: - return JSONResponse( - content=str(e), - status_code=status.HTTP_401_UNAUTHORIZED, - ) - except UnknownException as e: return JSONResponse( content=str(e), status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, ) logger.info(f'Returning 401 Unauthorized - No providers set for user_id: {user_id}') - - return JSONResponse( - content='No providers set.', - status_code=status.HTTP_401_UNAUTHORIZED, - ) + raise AuthenticationError('No providers set.') @app.get('/repository/branches', response_model=list[Branch]) @@ -261,12 +215,6 @@ async def get_repository_branches( branches: list[Branch] = await client.get_branches(repository) return branches - except AuthenticationError as e: - return JSONResponse( - content=str(e), - status_code=status.HTTP_401_UNAUTHORIZED, - ) - except UnknownException as e: return JSONResponse( content=str(e), @@ -276,11 +224,7 @@ async def get_repository_branches( logger.info( f'Returning 401 Unauthorized - Git provider token required for user_id: {user_id}' ) - - return JSONResponse( - content='Git provider token required. (such as GitHub).', - status_code=status.HTTP_401_UNAUTHORIZED, - ) + raise AuthenticationError('Git provider token required. (such as GitHub).') def _extract_repo_name(repository_name: str) -> str: @@ -339,14 +283,8 @@ async def get_repository_microagents( logger.info(f'Found {len(microagents)} microagents in {repository_name}') return microagents - except AuthenticationError as e: - logger.info( - f'Returning 401 Unauthorized - Authentication error for user_id: {user_id}, error: {str(e)}' - ) - return JSONResponse( - content=str(e), - status_code=status.HTTP_401_UNAUTHORIZED, - ) + except AuthenticationError: + raise except RuntimeError as e: return JSONResponse( @@ -412,14 +350,8 @@ async def get_repository_microagent_content( return response - except AuthenticationError as e: - logger.info( - f'Returning 401 Unauthorized - Authentication error for user_id: {user_id}, error: {str(e)}' - ) - return JSONResponse( - content=str(e), - status_code=status.HTTP_401_UNAUTHORIZED, - ) + except AuthenticationError: + raise except RuntimeError as e: return JSONResponse( diff --git a/openhands/server/routes/manage_conversations.py b/openhands/server/routes/manage_conversations.py index a4c5fad291..77d8f42449 100644 --- a/openhands/server/routes/manage_conversations.py +++ b/openhands/server/routes/manage_conversations.py @@ -28,7 +28,6 @@ from openhands.integrations.provider import ( ProviderHandler, ) from openhands.integrations.service_types import ( - AuthenticationError, CreateMicroagent, ProviderType, SuggestedTask, @@ -210,16 +209,6 @@ async def new_conversation( status_code=status.HTTP_400_BAD_REQUEST, ) - except AuthenticationError as e: - return JSONResponse( - content={ - 'status': 'error', - 'message': str(e), - 'msg_id': RuntimeStatus.GIT_PROVIDER_AUTHENTICATION_ERROR.value, - }, - status_code=status.HTTP_400_BAD_REQUEST, - ) - @app.get('/conversations') async def search_conversations( diff --git a/tests/unit/server/data_models/test_conversation.py b/tests/unit/server/data_models/test_conversation.py index 339ac717b1..424085582e 100644 --- a/tests/unit/server/data_models/test_conversation.py +++ b/tests/unit/server/data_models/test_conversation.py @@ -991,16 +991,8 @@ async def test_new_conversation_with_provider_authentication_error( ) # Call new_conversation - response = await create_new_test_conversation(test_request) - - # Verify the response - assert isinstance(response, JSONResponse) - assert response.status_code == 400 - assert json.loads(response.body.decode('utf-8')) == { - 'status': 'error', - 'message': 'auth error', - 'msg_id': RuntimeStatus.GIT_PROVIDER_AUTHENTICATION_ERROR.value, - } + with pytest.raises(AuthenticationError): + await create_new_test_conversation(test_request) # Verify that verify_repo_provider was called with the repository provider_handler_mock.verify_repo_provider.assert_called_once_with( diff --git a/tests/unit/server/routes/test_get_repository_microagents.py b/tests/unit/server/routes/test_get_repository_microagents.py index d4a87f3eed..b68931d983 100644 --- a/tests/unit/server/routes/test_get_repository_microagents.py +++ b/tests/unit/server/routes/test_get_repository_microagents.py @@ -4,7 +4,9 @@ from urllib.parse import quote import pytest from fastapi import FastAPI +from fastapi.responses import JSONResponse from fastapi.testclient import TestClient +from httpcore import Request from pydantic import SecretStr from openhands.integrations.provider import ProviderToken, ProviderType @@ -28,6 +30,13 @@ def test_client(): app = FastAPI() app.include_router(git_app) + @app.exception_handler(AuthenticationError) + async def authentication_error_handler(request: Request, exc: AuthenticationError): + return JSONResponse( + status_code=401, + content=str(exc), + ) + # Override the FastAPI dependencies directly def mock_get_provider_tokens(): return MappingProxyType(