mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Refactor authentication error handling with global FastAPI exception handler (#10403)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
parent
d9bc5824a0
commit
bb6cf5a816
@ -10,10 +10,13 @@ with warnings.catch_warnings():
|
||||
|
||||
from fastapi import (
|
||||
FastAPI,
|
||||
Request,
|
||||
)
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
import openhands.agenthub # noqa F401 (we import this to get the agents registered)
|
||||
from openhands import __version__
|
||||
from openhands.integrations.service_types import AuthenticationError
|
||||
from openhands.server.routes.conversation import app as conversation_api_router
|
||||
from openhands.server.routes.feedback import app as feedback_api_router
|
||||
from openhands.server.routes.files import app as files_api_router
|
||||
@ -61,6 +64,14 @@ app = FastAPI(
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(AuthenticationError)
|
||||
async def authentication_error_handler(request: Request, exc: AuthenticationError):
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content=str(exc),
|
||||
)
|
||||
|
||||
|
||||
app.include_router(public_api_router)
|
||||
app.include_router(files_api_router)
|
||||
app.include_router(security_api_router)
|
||||
|
||||
@ -58,10 +58,7 @@ async def get_user_installations(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
content='Git provider token required. (such as GitHub).',
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
raise AuthenticationError('Git provider token required. (such as GitHub).')
|
||||
|
||||
|
||||
@app.get('/repositories', response_model=list[Repository])
|
||||
@ -92,15 +89,6 @@ async def get_user_repositories(
|
||||
installation_id,
|
||||
)
|
||||
|
||||
except AuthenticationError as e:
|
||||
logger.info(
|
||||
f'Returning 401 Unauthorized - Authentication error for user_id: {user_id}, error: {str(e)}'
|
||||
)
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
except UnknownException as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
@ -110,10 +98,7 @@ async def get_user_repositories(
|
||||
logger.info(
|
||||
f'Returning 401 Unauthorized - Git provider token required for user_id: {user_id}'
|
||||
)
|
||||
return JSONResponse(
|
||||
content='Git provider token required. (such as GitHub).',
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
raise AuthenticationError('Git provider token required. (such as GitHub).')
|
||||
|
||||
|
||||
@app.get('/info', response_model=User)
|
||||
@ -131,15 +116,6 @@ async def get_user(
|
||||
user: User = await client.get_user()
|
||||
return user
|
||||
|
||||
except AuthenticationError as e:
|
||||
logger.info(
|
||||
f'Returning 401 Unauthorized - Authentication error for user_id: {user_id}, error: {str(e)}'
|
||||
)
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
except UnknownException as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
@ -149,10 +125,7 @@ async def get_user(
|
||||
logger.info(
|
||||
f'Returning 401 Unauthorized - Git provider token required for user_id: {user_id}'
|
||||
)
|
||||
return JSONResponse(
|
||||
content='Git provider token required. (such as GitHub).',
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
raise AuthenticationError('Git provider token required. (such as GitHub).')
|
||||
|
||||
|
||||
@app.get('/search/repositories', response_model=list[Repository])
|
||||
@ -178,12 +151,6 @@ async def search_repositories(
|
||||
)
|
||||
return repos
|
||||
|
||||
except AuthenticationError as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
except UnknownException as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
@ -193,10 +160,7 @@ async def search_repositories(
|
||||
logger.info(
|
||||
f'Returning 401 Unauthorized - Git provider token required for user_id: {user_id}'
|
||||
)
|
||||
return JSONResponse(
|
||||
content='Git provider token required.',
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
raise AuthenticationError('Git provider token required.')
|
||||
|
||||
|
||||
@app.get('/suggested-tasks', response_model=list[SuggestedTask])
|
||||
@ -219,23 +183,13 @@ async def get_suggested_tasks(
|
||||
tasks: list[SuggestedTask] = await client.get_suggested_tasks()
|
||||
return tasks
|
||||
|
||||
except AuthenticationError as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
except UnknownException as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
logger.info(f'Returning 401 Unauthorized - No providers set for user_id: {user_id}')
|
||||
|
||||
return JSONResponse(
|
||||
content='No providers set.',
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
raise AuthenticationError('No providers set.')
|
||||
|
||||
|
||||
@app.get('/repository/branches', response_model=list[Branch])
|
||||
@ -261,12 +215,6 @@ async def get_repository_branches(
|
||||
branches: list[Branch] = await client.get_branches(repository)
|
||||
return branches
|
||||
|
||||
except AuthenticationError as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
except UnknownException as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
@ -276,11 +224,7 @@ async def get_repository_branches(
|
||||
logger.info(
|
||||
f'Returning 401 Unauthorized - Git provider token required for user_id: {user_id}'
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
content='Git provider token required. (such as GitHub).',
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
raise AuthenticationError('Git provider token required. (such as GitHub).')
|
||||
|
||||
|
||||
def _extract_repo_name(repository_name: str) -> str:
|
||||
@ -339,14 +283,8 @@ async def get_repository_microagents(
|
||||
logger.info(f'Found {len(microagents)} microagents in {repository_name}')
|
||||
return microagents
|
||||
|
||||
except AuthenticationError as e:
|
||||
logger.info(
|
||||
f'Returning 401 Unauthorized - Authentication error for user_id: {user_id}, error: {str(e)}'
|
||||
)
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
except AuthenticationError:
|
||||
raise
|
||||
|
||||
except RuntimeError as e:
|
||||
return JSONResponse(
|
||||
@ -412,14 +350,8 @@ async def get_repository_microagent_content(
|
||||
|
||||
return response
|
||||
|
||||
except AuthenticationError as e:
|
||||
logger.info(
|
||||
f'Returning 401 Unauthorized - Authentication error for user_id: {user_id}, error: {str(e)}'
|
||||
)
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
except AuthenticationError:
|
||||
raise
|
||||
|
||||
except RuntimeError as e:
|
||||
return JSONResponse(
|
||||
|
||||
@ -28,7 +28,6 @@ from openhands.integrations.provider import (
|
||||
ProviderHandler,
|
||||
)
|
||||
from openhands.integrations.service_types import (
|
||||
AuthenticationError,
|
||||
CreateMicroagent,
|
||||
ProviderType,
|
||||
SuggestedTask,
|
||||
@ -210,16 +209,6 @@ async def new_conversation(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
except AuthenticationError as e:
|
||||
return JSONResponse(
|
||||
content={
|
||||
'status': 'error',
|
||||
'message': str(e),
|
||||
'msg_id': RuntimeStatus.GIT_PROVIDER_AUTHENTICATION_ERROR.value,
|
||||
},
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
|
||||
@app.get('/conversations')
|
||||
async def search_conversations(
|
||||
|
||||
@ -991,16 +991,8 @@ async def test_new_conversation_with_provider_authentication_error(
|
||||
)
|
||||
|
||||
# Call new_conversation
|
||||
response = await create_new_test_conversation(test_request)
|
||||
|
||||
# Verify the response
|
||||
assert isinstance(response, JSONResponse)
|
||||
assert response.status_code == 400
|
||||
assert json.loads(response.body.decode('utf-8')) == {
|
||||
'status': 'error',
|
||||
'message': 'auth error',
|
||||
'msg_id': RuntimeStatus.GIT_PROVIDER_AUTHENTICATION_ERROR.value,
|
||||
}
|
||||
with pytest.raises(AuthenticationError):
|
||||
await create_new_test_conversation(test_request)
|
||||
|
||||
# Verify that verify_repo_provider was called with the repository
|
||||
provider_handler_mock.verify_repo_provider.assert_called_once_with(
|
||||
|
||||
@ -4,7 +4,9 @@ from urllib.parse import quote
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.testclient import TestClient
|
||||
from httpcore import Request
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.integrations.provider import ProviderToken, ProviderType
|
||||
@ -28,6 +30,13 @@ def test_client():
|
||||
app = FastAPI()
|
||||
app.include_router(git_app)
|
||||
|
||||
@app.exception_handler(AuthenticationError)
|
||||
async def authentication_error_handler(request: Request, exc: AuthenticationError):
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content=str(exc),
|
||||
)
|
||||
|
||||
# Override the FastAPI dependencies directly
|
||||
def mock_get_provider_tokens():
|
||||
return MappingProxyType(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user