fix(backend): organization members now see correct shared credit balance (#12942)

This commit is contained in:
Hiep Le
2026-02-20 01:34:53 +07:00
committed by GitHub
parent f3429e33ca
commit 8927ac2230
6 changed files with 264 additions and 53 deletions

View File

@@ -93,9 +93,9 @@ async def get_credits(user_id: str = Depends(get_user_id)) -> GetCreditsResponse
user_team_info = await LiteLlmManager.get_user_team_info(
user_id, str(user.current_org_id)
)
# Update to use calculate_credits
spend = user_team_info.get('spend', 0)
max_budget = (user_team_info.get('litellm_budget_table') or {}).get('max_budget', 0)
max_budget, spend = LiteLlmManager.get_budget_from_team_info(
user_team_info, user_id, str(user.current_org_id)
)
credits = max(max_budget - spend, 0)
return GetCreditsResponse(credits=Decimal('{:.2f}'.format(credits)))
@@ -249,8 +249,8 @@ async def success_callback(session_id: str, request: Request):
)
amount_subtotal = stripe_session.amount_subtotal or 0
add_credits = amount_subtotal / 100
max_budget = (user_team_info.get('litellm_budget_table') or {}).get(
'max_budget', 0
max_budget, _ = LiteLlmManager.get_budget_from_team_info(
user_team_info, billing_session.user_id, str(user.current_org_id)
)
org = session.query(Org).filter(Org.id == user.current_org_id).first()

View File

@@ -43,6 +43,34 @@ def get_byor_key_alias(keycloak_user_id: str, org_id: str) -> str:
class LiteLlmManager:
"""Manage LiteLLM interactions."""
@staticmethod
def get_budget_from_team_info(
user_team_info: dict | None, user_id: str, org_id: str
) -> tuple[float, float]:
"""Extract max_budget and spend from user team info.
For personal orgs (user_id == org_id), uses litellm_budget_table.max_budget.
For team orgs, uses max_budget_in_team (populated by get_user_team_info).
Args:
user_team_info: The response from get_user_team_info
user_id: The user's ID
org_id: The organization's ID
Returns:
Tuple of (max_budget, spend)
"""
if not user_team_info:
return 0, 0
spend = user_team_info.get('spend', 0)
if user_id == org_id:
max_budget = (user_team_info.get('litellm_budget_table') or {}).get(
'max_budget', 0
)
else:
max_budget = user_team_info.get('max_budget_in_team') or 0
return max_budget, spend
@staticmethod
async def create_entries(
org_id: str,
@@ -71,8 +99,34 @@ class LiteLlmManager:
'x-goog-api-key': LITE_LLM_API_KEY,
}
) as client:
# New users start with $0 budget - they must purchase credits
await LiteLlmManager._create_team(client, keycloak_user_id, org_id, 0)
# Check if team already exists and get its budget
# New users joining existing orgs should inherit the team's budget
team_budget = 0.0
try:
existing_team = await LiteLlmManager._get_team(client, org_id)
if existing_team:
team_info = existing_team.get('team_info', {})
team_budget = team_info.get('max_budget', 0.0) or 0.0
logger.info(
'LiteLlmManager:create_entries:existing_team_budget',
extra={
'org_id': org_id,
'user_id': keycloak_user_id,
'team_budget': team_budget,
},
)
except httpx.HTTPStatusError as e:
# Team doesn't exist yet (404) - this is expected for first user
if e.response.status_code != 404:
raise
logger.info(
'LiteLlmManager:create_entries:no_existing_team',
extra={'org_id': org_id, 'user_id': keycloak_user_id},
)
await LiteLlmManager._create_team(
client, keycloak_user_id, org_id, team_budget
)
if create_user:
await LiteLlmManager._create_user(
@@ -80,7 +134,7 @@ class LiteLlmManager:
)
await LiteLlmManager._add_user_to_team(
client, keycloak_user_id, org_id, 0
client, keycloak_user_id, org_id, team_budget
)
key = await LiteLlmManager._generate_key(
@@ -892,21 +946,31 @@ class LiteLlmManager:
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
logger.warning('LiteLLM API configuration not found')
return None
team_info = await LiteLlmManager._get_team(client, team_id)
if not team_info:
team_response = await LiteLlmManager._get_team(client, team_id)
if not team_response:
return None
# Filter team_memberships based on team_id and keycloak_user_id
user_membership = next(
(
membership
for membership in team_info.get('team_memberships', [])
for membership in team_response.get('team_memberships', [])
if membership.get('user_id') == keycloak_user_id
and membership.get('team_id') == team_id
),
None,
)
if not user_membership:
return None
# For team orgs (user_id != team_id), include team-level budget info
# The team's max_budget and spend are shared across all members
if keycloak_user_id != team_id:
team_info = team_response.get('team_info', {})
user_membership['max_budget_in_team'] = team_info.get('max_budget')
user_membership['spend'] = team_info.get('spend', 0)
return user_membership
@staticmethod

View File

@@ -656,10 +656,9 @@ class OrgService:
)
return None
max_budget = (user_team_info.get('litellm_budget_table') or {}).get(
'max_budget', 0
max_budget, spend = LiteLlmManager.get_budget_from_team_info(
user_team_info, user_id, str(org_id)
)
spend = user_team_info.get('spend', 0)
credits = max(max_budget - spend, 0)
logger.debug(

View File

@@ -101,7 +101,7 @@ async def test_get_credits_success():
json={
'user_info': {
'spend': 25.50,
'litellm_budget_table': {'max_budget': 100.00},
'max_budget_in_team': 100.00,
}
},
request=MagicMock(),
@@ -121,7 +121,7 @@ async def test_get_credits_success():
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
return_value={
'spend': 25.50,
'litellm_budget_table': {'max_budget': 100.00},
'max_budget_in_team': 100.00,
},
),
):
@@ -313,7 +313,7 @@ async def test_success_callback_success():
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
return_value={
'spend': 25.50,
'litellm_budget_table': {'max_budget': 100.00},
'max_budget_in_team': 100.00,
},
),
patch(
@@ -430,7 +430,7 @@ async def test_success_callback_lite_llm_update_budget_error_rollback():
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
return_value={
'spend': 0,
'litellm_budget_table': {'max_budget': 0},
'max_budget_in_team': 0,
},
),
patch(

View File

@@ -142,44 +142,192 @@ class TestLiteLlmManager:
@pytest.mark.asyncio
async def test_create_entries_cloud_deployment(self, mock_settings, mock_response):
"""Test create_entries in cloud deployment mode."""
with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}):
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
with patch(
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'
):
with patch(
'storage.lite_llm_manager.TokenManager'
) as mock_token_manager:
mock_token_manager.return_value.get_user_info_from_user_id = (
AsyncMock(return_value={'email': 'test@example.com'})
)
mock_404_response = MagicMock()
mock_404_response.status_code = 404
mock_404_response.is_success = False
with patch('httpx.AsyncClient') as mock_client_class:
mock_client = AsyncMock()
mock_client_class.return_value.__aenter__.return_value = (
mock_client
)
mock_client.post.return_value = mock_response
mock_token_manager = MagicMock()
mock_token_manager.return_value.get_user_info_from_user_id = AsyncMock(
return_value={'email': 'test@example.com'}
)
result = await LiteLlmManager.create_entries(
'test-org-id',
'test-user-id',
mock_settings,
create_user=False,
)
mock_client = AsyncMock()
mock_client.get.return_value = mock_404_response
mock_client.get.return_value.raise_for_status.side_effect = (
httpx.HTTPStatusError(
message='Not Found', request=MagicMock(), response=mock_404_response
)
)
mock_client.post.return_value = mock_response
assert result is not None
assert result.agent == 'CodeActAgent'
assert result.llm_model == get_default_litellm_model()
assert (
result.llm_api_key.get_secret_value() == 'test-api-key'
)
assert result.llm_base_url == 'http://test.com'
mock_client_class = MagicMock()
mock_client_class.return_value.__aenter__.return_value = mock_client
# Verify API calls were made
assert (
mock_client.post.call_count == 3
) # create_team, create_user, add_user_to_team, generate_key
with (
patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}),
patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'),
patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'),
patch('storage.lite_llm_manager.TokenManager', mock_token_manager),
patch('httpx.AsyncClient', mock_client_class),
):
result = await LiteLlmManager.create_entries(
'test-org-id', 'test-user-id', mock_settings, create_user=False
)
assert result is not None
assert result.agent == 'CodeActAgent'
assert result.llm_model == get_default_litellm_model()
assert result.llm_api_key.get_secret_value() == 'test-api-key'
assert result.llm_base_url == 'http://test.com'
# Verify API calls were made (get_team + 3 posts)
assert mock_client.get.call_count == 1 # get_team
assert (
mock_client.post.call_count == 3
) # create_team, add_user_to_team, generate_key
@pytest.mark.asyncio
async def test_create_entries_inherits_existing_team_budget(
self, mock_settings, mock_response
):
"""Test that create_entries inherits budget from existing team."""
mock_team_response = MagicMock()
mock_team_response.is_success = True
mock_team_response.status_code = 200
mock_team_response.json.return_value = {
'team_info': {'max_budget': 30.0, 'spend': 5.0},
'team_memberships': [],
}
mock_team_response.raise_for_status = MagicMock()
mock_token_manager = MagicMock()
mock_token_manager.return_value.get_user_info_from_user_id = AsyncMock(
return_value={'email': 'test@example.com'}
)
mock_client = AsyncMock()
mock_client.get.return_value = mock_team_response
mock_client.post.return_value = mock_response
mock_client_class = MagicMock()
mock_client_class.return_value.__aenter__.return_value = mock_client
with (
patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}),
patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'),
patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'),
patch('storage.lite_llm_manager.TokenManager', mock_token_manager),
patch('httpx.AsyncClient', mock_client_class),
):
result = await LiteLlmManager.create_entries(
'test-org-id', 'test-user-id', mock_settings, create_user=False
)
assert result is not None
# Verify _get_team was called first
mock_client.get.assert_called_once()
get_call_url = mock_client.get.call_args[0][0]
assert 'team/info' in get_call_url
assert 'test-org-id' in get_call_url
# Verify _create_team was called with inherited budget (30.0)
create_team_call = mock_client.post.call_args_list[0]
assert 'team/new' in create_team_call[0][0]
assert create_team_call[1]['json']['max_budget'] == 30.0
# Verify _add_user_to_team was called with inherited budget (30.0)
add_user_call = mock_client.post.call_args_list[1]
assert 'team/member_add' in add_user_call[0][0]
assert add_user_call[1]['json']['max_budget_in_team'] == 30.0
@pytest.mark.asyncio
async def test_create_entries_new_org_uses_zero_budget(
self, mock_settings, mock_response
):
"""Test that create_entries uses budget=0 for new org (team doesn't exist)."""
mock_404_response = MagicMock()
mock_404_response.status_code = 404
mock_404_response.is_success = False
mock_token_manager = MagicMock()
mock_token_manager.return_value.get_user_info_from_user_id = AsyncMock(
return_value={'email': 'test@example.com'}
)
mock_client = AsyncMock()
mock_client.get.return_value = mock_404_response
mock_client.get.return_value.raise_for_status.side_effect = (
httpx.HTTPStatusError(
message='Not Found', request=MagicMock(), response=mock_404_response
)
)
mock_client.post.return_value = mock_response
mock_client_class = MagicMock()
mock_client_class.return_value.__aenter__.return_value = mock_client
with (
patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}),
patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'),
patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'),
patch('storage.lite_llm_manager.TokenManager', mock_token_manager),
patch('httpx.AsyncClient', mock_client_class),
):
result = await LiteLlmManager.create_entries(
'test-org-id', 'test-user-id', mock_settings, create_user=False
)
assert result is not None
# Verify _create_team was called with budget=0
create_team_call = mock_client.post.call_args_list[0]
assert 'team/new' in create_team_call[0][0]
assert create_team_call[1]['json']['max_budget'] == 0.0
# Verify _add_user_to_team was called with budget=0
add_user_call = mock_client.post.call_args_list[1]
assert 'team/member_add' in add_user_call[0][0]
assert add_user_call[1]['json']['max_budget_in_team'] == 0.0
@pytest.mark.asyncio
async def test_create_entries_propagates_non_404_errors(self, mock_settings):
"""Test that create_entries propagates non-404 errors from _get_team."""
mock_500_response = MagicMock()
mock_500_response.status_code = 500
mock_500_response.is_success = False
mock_token_manager = MagicMock()
mock_token_manager.return_value.get_user_info_from_user_id = AsyncMock(
return_value={'email': 'test@example.com'}
)
mock_client = AsyncMock()
mock_client.get.return_value = mock_500_response
mock_client.get.return_value.raise_for_status.side_effect = (
httpx.HTTPStatusError(
message='Internal Server Error',
request=MagicMock(),
response=mock_500_response,
)
)
mock_client_class = MagicMock()
mock_client_class.return_value.__aenter__.return_value = mock_client
with (
patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}),
patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'),
patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'),
patch('storage.lite_llm_manager.TokenManager', mock_token_manager),
patch('httpx.AsyncClient', mock_client_class),
):
with pytest.raises(httpx.HTTPStatusError) as exc_info:
await LiteLlmManager.create_entries(
'test-org-id', 'test-user-id', mock_settings, create_user=False
)
assert exc_info.value.response.status_code == 500
@pytest.mark.asyncio
async def test_migrate_entries_missing_config(self, mock_user_settings):

View File

@@ -482,7 +482,7 @@ async def test_get_org_credits_success(mock_litellm_api):
spend = 25.0
mock_team_info = {
'litellm_budget_table': {'max_budget': max_budget},
'max_budget_in_team': max_budget,
'spend': spend,
}