mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 05:37:20 +08:00
Guard User Creation with Redis based Lock (#12381)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import binascii
|
||||
import hashlib
|
||||
import json
|
||||
@@ -34,6 +35,13 @@ from openhands.storage.settings.settings_store import SettingsStore
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.utils.http_session import httpx_verify_option
|
||||
|
||||
# The max possible time to wait for another process to finish creating a user before retrying
|
||||
_REDIS_CREATE_TIMEOUT_SECONDS = 30
|
||||
# The delay to wait for another process to finish creating a user before trying to load again
|
||||
_RETRY_LOAD_DELAY_SECONDS = 2
|
||||
# Redis key prefix for user creation locks
|
||||
_REDIS_USER_CREATION_KEY_PREFIX = 'create_user:'
|
||||
|
||||
|
||||
@dataclass
|
||||
class SaasSettingsStore(SettingsStore):
|
||||
@@ -131,6 +139,32 @@ class SaasSettingsStore(SettingsStore):
|
||||
session.add(settings)
|
||||
session.commit()
|
||||
|
||||
def _get_redis_client(self):
|
||||
"""Get the Redis client from the Socket.IO manager."""
|
||||
from openhands.server.shared import sio
|
||||
|
||||
return getattr(sio.manager, 'redis', None)
|
||||
|
||||
async def _acquire_user_creation_lock(self) -> bool:
|
||||
"""Attempt to acquire a distributed lock for user creation.
|
||||
|
||||
Returns True if the lock was acquired or if Redis is unavailable (fallback to no locking).
|
||||
Returns False if another process holds the lock.
|
||||
"""
|
||||
redis_client = self._get_redis_client()
|
||||
if redis_client is None:
|
||||
logger.warning(
|
||||
'saas_settings_store:_acquire_user_creation_lock:no_redis_client',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
return True # Proceed without locking if Redis is unavailable
|
||||
|
||||
user_key = f'{_REDIS_USER_CREATION_KEY_PREFIX}{self.user_id}'
|
||||
lock_acquired = await redis_client.set(
|
||||
user_key, 1, nx=True, ex=_REDIS_CREATE_TIMEOUT_SECONDS
|
||||
)
|
||||
return bool(lock_acquired)
|
||||
|
||||
async def create_default_settings(self, user_settings: UserSettings | None):
|
||||
logger.info(
|
||||
'saas_settings_store:create_default_settings:start',
|
||||
@@ -140,6 +174,16 @@ class SaasSettingsStore(SettingsStore):
|
||||
if not self.user_id:
|
||||
return None
|
||||
|
||||
# Prevent duplicate settings creation using distributed lock
|
||||
if not await self._acquire_user_creation_lock():
|
||||
# The user is already being created in another thread / process
|
||||
logger.info(
|
||||
'saas_settings_store:create_default_settings:waiting_for_lock',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
await asyncio.sleep(_RETRY_LOAD_DELAY_SECONDS)
|
||||
return await self.load()
|
||||
|
||||
# Only users that have specified a payment method get default settings
|
||||
if REQUIRE_PAYMENT and not await stripe_service.has_payment_method(
|
||||
self.user_id
|
||||
|
||||
@@ -65,6 +65,15 @@ def mock_stripe():
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis_client():
|
||||
"""Mock Redis client for testing create_default_settings locking."""
|
||||
mock_redis = AsyncMock()
|
||||
# By default, allow proceeding with create (lock acquired successfully)
|
||||
mock_redis.set = AsyncMock(return_value=True)
|
||||
return mock_redis
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_github_user():
|
||||
with patch(
|
||||
@@ -200,7 +209,12 @@ async def test_store_and_load_keycloak_user(settings_store):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_returns_default_when_not_found(
|
||||
settings_store, mock_litellm_api, mock_stripe, mock_github_user, session_maker
|
||||
settings_store,
|
||||
mock_litellm_api,
|
||||
mock_stripe,
|
||||
mock_github_user,
|
||||
session_maker,
|
||||
mock_redis_client,
|
||||
):
|
||||
file_store = MagicMock()
|
||||
file_store.read.side_effect = FileNotFoundError()
|
||||
@@ -211,6 +225,9 @@ async def test_load_returns_default_when_not_found(
|
||||
MagicMock(return_value=file_store),
|
||||
),
|
||||
patch('storage.saas_settings_store.session_maker', session_maker),
|
||||
patch.object(
|
||||
settings_store, '_get_redis_client', return_value=mock_redis_client
|
||||
),
|
||||
):
|
||||
loaded_settings = await settings_store.load()
|
||||
assert loaded_settings is not None
|
||||
@@ -263,7 +280,7 @@ async def test_create_default_settings_no_user_id():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_default_settings_require_payment_enabled(
|
||||
settings_store, mock_stripe
|
||||
settings_store, mock_stripe, mock_redis_client
|
||||
):
|
||||
# Mock stripe_service.has_payment_method to return False
|
||||
with (
|
||||
@@ -275,6 +292,9 @@ async def test_create_default_settings_require_payment_enabled(
|
||||
patch(
|
||||
'integrations.stripe_service.session_maker', settings_store.session_maker
|
||||
),
|
||||
patch.object(
|
||||
settings_store, '_get_redis_client', return_value=mock_redis_client
|
||||
),
|
||||
):
|
||||
settings = await settings_store.create_default_settings(None)
|
||||
assert settings is None
|
||||
@@ -282,7 +302,12 @@ async def test_create_default_settings_require_payment_enabled(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_default_settings_require_payment_disabled(
|
||||
settings_store, mock_stripe, mock_github_user, mock_litellm_api, session_maker
|
||||
settings_store,
|
||||
mock_stripe,
|
||||
mock_github_user,
|
||||
mock_litellm_api,
|
||||
session_maker,
|
||||
mock_redis_client,
|
||||
):
|
||||
# Even without payment method, should get default settings when REQUIRE_PAYMENT is False
|
||||
file_store = MagicMock()
|
||||
@@ -298,12 +323,60 @@ async def test_create_default_settings_require_payment_disabled(
|
||||
MagicMock(return_value=file_store),
|
||||
),
|
||||
patch('storage.saas_settings_store.session_maker', session_maker),
|
||||
patch.object(
|
||||
settings_store, '_get_redis_client', return_value=mock_redis_client
|
||||
),
|
||||
):
|
||||
settings = await settings_store.create_default_settings(None)
|
||||
assert settings is not None
|
||||
assert settings.language == 'en'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_default_settings_waits_when_lock_held(
|
||||
settings_store, mock_stripe, mock_github_user, mock_litellm_api, session_maker
|
||||
):
|
||||
"""Test that create_default_settings waits and retries when another process holds the lock."""
|
||||
file_store = MagicMock()
|
||||
file_store.read.side_effect = FileNotFoundError()
|
||||
|
||||
# Create a mock Redis client that fails to acquire lock on first attempt, succeeds on second
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.set = AsyncMock(side_effect=[False, True])
|
||||
|
||||
# Track if sleep was called
|
||||
sleep_called = False
|
||||
|
||||
async def mock_sleep(delay):
|
||||
nonlocal sleep_called
|
||||
sleep_called = True
|
||||
# Don't actually sleep - just verify it was called with correct delay
|
||||
from storage.saas_settings_store import _RETRY_LOAD_DELAY_SECONDS
|
||||
|
||||
assert delay == _RETRY_LOAD_DELAY_SECONDS
|
||||
|
||||
with (
|
||||
patch('storage.saas_settings_store.REQUIRE_PAYMENT', False),
|
||||
patch(
|
||||
'stripe.Customer.list_payment_methods_async',
|
||||
AsyncMock(return_value=MagicMock(data=[])),
|
||||
),
|
||||
patch(
|
||||
'storage.saas_settings_store.get_file_store',
|
||||
MagicMock(return_value=file_store),
|
||||
),
|
||||
patch('storage.saas_settings_store.session_maker', session_maker),
|
||||
patch.object(settings_store, '_get_redis_client', return_value=mock_redis),
|
||||
patch('storage.saas_settings_store.asyncio.sleep', mock_sleep),
|
||||
):
|
||||
settings = await settings_store.create_default_settings(None)
|
||||
# Should have called sleep while waiting for lock
|
||||
assert sleep_called
|
||||
# Should eventually succeed and return settings
|
||||
assert settings is not None
|
||||
assert settings.language == 'en'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_default_lite_llm_settings_no_api_config(settings_store):
|
||||
with (
|
||||
|
||||
Reference in New Issue
Block a user