diff --git a/enterprise/server/routes/org_models.py b/enterprise/server/routes/org_models.py index c9825f8ddb..ed2b47e016 100644 --- a/enterprise/server/routes/org_models.py +++ b/enterprise/server/routes/org_models.py @@ -214,6 +214,7 @@ class OrgPage(BaseModel): items: list[OrgResponse] next_page_id: str | None = None + current_org_id: str | None = None class OrgUpdate(BaseModel): diff --git a/enterprise/server/routes/orgs.py b/enterprise/server/routes/orgs.py index 258f633d2d..ddef0d6b00 100644 --- a/enterprise/server/routes/orgs.py +++ b/enterprise/server/routes/orgs.py @@ -32,6 +32,7 @@ from server.routes.org_models import ( ) from server.services.org_member_service import OrgMemberService from storage.org_service import OrgService +from storage.user_store import UserStore from openhands.core.logger import openhands_logger as logger from openhands.server.user_auth import get_user_id @@ -78,6 +79,12 @@ async def list_user_orgs( ) try: + # Fetch user to get current_org_id + user = await UserStore.get_user_by_id_async(user_id) + current_org_id = ( + str(user.current_org_id) if user and user.current_org_id else None + ) + # Fetch organizations from service layer orgs, next_page_id = OrgService.get_user_orgs_paginated( user_id=user_id, @@ -99,7 +106,11 @@ async def list_user_orgs( }, ) - return OrgPage(items=org_responses, next_page_id=next_page_id) + return OrgPage( + items=org_responses, + next_page_id=next_page_id, + current_org_id=current_org_id, + ) except Exception as e: logger.exception( diff --git a/enterprise/tests/unit/server/routes/test_orgs.py b/enterprise/tests/unit/server/routes/test_orgs.py index 9faa2f9693..c375d0bdfe 100644 --- a/enterprise/tests/unit/server/routes/test_orgs.py +++ b/enterprise/tests/unit/server/routes/test_orgs.py @@ -513,10 +513,18 @@ async def test_list_user_orgs_success(mock_app_list): org_version=5, default_llm_model='claude-opus-4-5-20251101', ) + mock_user = MagicMock() + mock_user.current_org_id = org_id - with patch( - 'server.routes.orgs.OrgService.get_user_orgs_paginated', - return_value=([mock_org], None), + with ( + patch( + 'server.routes.orgs.UserStore.get_user_by_id_async', + AsyncMock(return_value=mock_user), + ), + patch( + 'server.routes.orgs.OrgService.get_user_orgs_paginated', + return_value=([mock_org], None), + ), ): client = TestClient(mock_app_list) @@ -536,6 +544,54 @@ async def test_list_user_orgs_success(mock_app_list): assert response_data['items'][0]['credits'] is None +@pytest.mark.asyncio +async def test_list_user_orgs_returns_current_org_id(mock_app_list): + """ + GIVEN: User has a current organization set + WHEN: GET /api/organizations is called + THEN: Response includes current_org_id matching the user's current org + """ + # Arrange + current_org_id = uuid.uuid4() + other_org_id = uuid.uuid4() + + current_org = Org( + id=current_org_id, + name='Current Organization', + contact_name='John Doe', + contact_email='john@example.com', + ) + other_org = Org( + id=other_org_id, + name='Other Organization', + contact_name='Jane Doe', + contact_email='jane@example.com', + ) + mock_user = MagicMock() + mock_user.current_org_id = current_org_id + + with ( + patch( + 'server.routes.orgs.UserStore.get_user_by_id_async', + AsyncMock(return_value=mock_user), + ), + patch( + 'server.routes.orgs.OrgService.get_user_orgs_paginated', + return_value=([current_org, other_org], None), + ), + ): + client = TestClient(mock_app_list) + + # Act + response = client.get('/api/organizations') + + # Assert + assert response.status_code == status.HTTP_200_OK + response_data = response.json() + assert 'current_org_id' in response_data + assert response_data['current_org_id'] == str(current_org_id) + + @pytest.mark.asyncio async def test_list_user_orgs_with_pagination(mock_app_list): """ @@ -556,10 +612,18 @@ async def test_list_user_orgs_with_pagination(mock_app_list): contact_name='Jane Doe', contact_email='jane@example.com', ) + mock_user = MagicMock() + mock_user.current_org_id = org1.id - with patch( - 'server.routes.orgs.OrgService.get_user_orgs_paginated', - return_value=([org1, org2], '2'), + with ( + patch( + 'server.routes.orgs.UserStore.get_user_by_id_async', + AsyncMock(return_value=mock_user), + ), + patch( + 'server.routes.orgs.OrgService.get_user_orgs_paginated', + return_value=([org1, org2], '2'), + ), ): client = TestClient(mock_app_list) @@ -583,9 +647,18 @@ async def test_list_user_orgs_empty(mock_app_list): THEN: Empty list is returned with 200 status """ # Arrange - with patch( - 'server.routes.orgs.OrgService.get_user_orgs_paginated', - return_value=([], None), + mock_user = MagicMock() + mock_user.current_org_id = uuid.uuid4() + + with ( + patch( + 'server.routes.orgs.UserStore.get_user_by_id_async', + AsyncMock(return_value=mock_user), + ), + patch( + 'server.routes.orgs.OrgService.get_user_orgs_paginated', + return_value=([], None), + ), ): client = TestClient(mock_app_list) @@ -641,9 +714,18 @@ async def test_list_user_orgs_service_error(mock_app_list): THEN: 500 Internal Server Error is returned """ # Arrange - with patch( - 'server.routes.orgs.OrgService.get_user_orgs_paginated', - side_effect=Exception('Database error'), + mock_user = MagicMock() + mock_user.current_org_id = uuid.uuid4() + + with ( + patch( + 'server.routes.orgs.UserStore.get_user_by_id_async', + AsyncMock(return_value=mock_user), + ), + patch( + 'server.routes.orgs.OrgService.get_user_orgs_paginated', + side_effect=Exception('Database error'), + ), ): client = TestClient(mock_app_list) @@ -698,10 +780,18 @@ async def test_list_user_orgs_personal_org_identified(mock_app_list): contact_name='John Doe', contact_email='john@example.com', ) + mock_user = MagicMock() + mock_user.current_org_id = personal_org_id - with patch( - 'server.routes.orgs.OrgService.get_user_orgs_paginated', - return_value=([personal_org], None), + with ( + patch( + 'server.routes.orgs.UserStore.get_user_by_id_async', + AsyncMock(return_value=mock_user), + ), + patch( + 'server.routes.orgs.OrgService.get_user_orgs_paginated', + return_value=([personal_org], None), + ), ): client = TestClient(mock_app_list) @@ -729,10 +819,18 @@ async def test_list_user_orgs_team_org_identified(mock_app_list): contact_name='John Doe', contact_email='john@example.com', ) + mock_user = MagicMock() + mock_user.current_org_id = team_org.id - with patch( - 'server.routes.orgs.OrgService.get_user_orgs_paginated', - return_value=([team_org], None), + with ( + patch( + 'server.routes.orgs.UserStore.get_user_by_id_async', + AsyncMock(return_value=mock_user), + ), + patch( + 'server.routes.orgs.OrgService.get_user_orgs_paginated', + return_value=([team_org], None), + ), ): client = TestClient(mock_app_list) @@ -770,10 +868,18 @@ async def test_list_user_orgs_mixed_personal_and_team(mock_app_list): contact_name='Jane Doe', contact_email='jane@example.com', ) + mock_user = MagicMock() + mock_user.current_org_id = personal_org_id - with patch( - 'server.routes.orgs.OrgService.get_user_orgs_paginated', - return_value=([personal_org, team_org], None), + with ( + patch( + 'server.routes.orgs.UserStore.get_user_by_id_async', + AsyncMock(return_value=mock_user), + ), + patch( + 'server.routes.orgs.OrgService.get_user_orgs_paginated', + return_value=([personal_org, team_org], None), + ), ): client = TestClient(mock_app_list) @@ -834,10 +940,18 @@ async def test_list_user_orgs_all_fields_present(mock_app_list): enable_solvability_analysis=True, v1_enabled=True, ) + mock_user = MagicMock() + mock_user.current_org_id = org_id - with patch( - 'server.routes.orgs.OrgService.get_user_orgs_paginated', - return_value=([mock_org], None), + with ( + patch( + 'server.routes.orgs.UserStore.get_user_by_id_async', + AsyncMock(return_value=mock_user), + ), + patch( + 'server.routes.orgs.OrgService.get_user_orgs_paginated', + return_value=([mock_org], None), + ), ): client = TestClient(mock_app_list)