Guard User Creation with Redis based Lock (#12381)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Tim O'Farrell
2026-01-12 15:03:42 -07:00
committed by GitHub
parent 92baebc4bd
commit 9cf7d64bfe
2 changed files with 120 additions and 3 deletions

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import asyncio
import binascii
import hashlib
import json
@@ -34,6 +35,13 @@ from openhands.storage.settings.settings_store import SettingsStore
from openhands.utils.async_utils import call_sync_from_async
from openhands.utils.http_session import httpx_verify_option
# The max possible time to wait for another process to finish creating a user before retrying
_REDIS_CREATE_TIMEOUT_SECONDS = 30
# The delay to wait for another process to finish creating a user before trying to load again
_RETRY_LOAD_DELAY_SECONDS = 2
# Redis key prefix for user creation locks
_REDIS_USER_CREATION_KEY_PREFIX = 'create_user:'
@dataclass
class SaasSettingsStore(SettingsStore):
@@ -131,6 +139,32 @@ class SaasSettingsStore(SettingsStore):
session.add(settings)
session.commit()
def _get_redis_client(self):
"""Get the Redis client from the Socket.IO manager."""
from openhands.server.shared import sio
return getattr(sio.manager, 'redis', None)
async def _acquire_user_creation_lock(self) -> bool:
"""Attempt to acquire a distributed lock for user creation.
Returns True if the lock was acquired or if Redis is unavailable (fallback to no locking).
Returns False if another process holds the lock.
"""
redis_client = self._get_redis_client()
if redis_client is None:
logger.warning(
'saas_settings_store:_acquire_user_creation_lock:no_redis_client',
extra={'user_id': self.user_id},
)
return True # Proceed without locking if Redis is unavailable
user_key = f'{_REDIS_USER_CREATION_KEY_PREFIX}{self.user_id}'
lock_acquired = await redis_client.set(
user_key, 1, nx=True, ex=_REDIS_CREATE_TIMEOUT_SECONDS
)
return bool(lock_acquired)
async def create_default_settings(self, user_settings: UserSettings | None):
logger.info(
'saas_settings_store:create_default_settings:start',
@@ -140,6 +174,16 @@ class SaasSettingsStore(SettingsStore):
if not self.user_id:
return None
# Prevent duplicate settings creation using distributed lock
if not await self._acquire_user_creation_lock():
# The user is already being created in another thread / process
logger.info(
'saas_settings_store:create_default_settings:waiting_for_lock',
extra={'user_id': self.user_id},
)
await asyncio.sleep(_RETRY_LOAD_DELAY_SECONDS)
return await self.load()
# Only users that have specified a payment method get default settings
if REQUIRE_PAYMENT and not await stripe_service.has_payment_method(
self.user_id

View File

@@ -65,6 +65,15 @@ def mock_stripe():
yield
@pytest.fixture
def mock_redis_client():
"""Mock Redis client for testing create_default_settings locking."""
mock_redis = AsyncMock()
# By default, allow proceeding with create (lock acquired successfully)
mock_redis.set = AsyncMock(return_value=True)
return mock_redis
@pytest.fixture
def mock_github_user():
with patch(
@@ -200,7 +209,12 @@ async def test_store_and_load_keycloak_user(settings_store):
@pytest.mark.asyncio
async def test_load_returns_default_when_not_found(
settings_store, mock_litellm_api, mock_stripe, mock_github_user, session_maker
settings_store,
mock_litellm_api,
mock_stripe,
mock_github_user,
session_maker,
mock_redis_client,
):
file_store = MagicMock()
file_store.read.side_effect = FileNotFoundError()
@@ -211,6 +225,9 @@ async def test_load_returns_default_when_not_found(
MagicMock(return_value=file_store),
),
patch('storage.saas_settings_store.session_maker', session_maker),
patch.object(
settings_store, '_get_redis_client', return_value=mock_redis_client
),
):
loaded_settings = await settings_store.load()
assert loaded_settings is not None
@@ -263,7 +280,7 @@ async def test_create_default_settings_no_user_id():
@pytest.mark.asyncio
async def test_create_default_settings_require_payment_enabled(
settings_store, mock_stripe
settings_store, mock_stripe, mock_redis_client
):
# Mock stripe_service.has_payment_method to return False
with (
@@ -275,6 +292,9 @@ async def test_create_default_settings_require_payment_enabled(
patch(
'integrations.stripe_service.session_maker', settings_store.session_maker
),
patch.object(
settings_store, '_get_redis_client', return_value=mock_redis_client
),
):
settings = await settings_store.create_default_settings(None)
assert settings is None
@@ -282,7 +302,12 @@ async def test_create_default_settings_require_payment_enabled(
@pytest.mark.asyncio
async def test_create_default_settings_require_payment_disabled(
settings_store, mock_stripe, mock_github_user, mock_litellm_api, session_maker
settings_store,
mock_stripe,
mock_github_user,
mock_litellm_api,
session_maker,
mock_redis_client,
):
# Even without payment method, should get default settings when REQUIRE_PAYMENT is False
file_store = MagicMock()
@@ -298,12 +323,60 @@ async def test_create_default_settings_require_payment_disabled(
MagicMock(return_value=file_store),
),
patch('storage.saas_settings_store.session_maker', session_maker),
patch.object(
settings_store, '_get_redis_client', return_value=mock_redis_client
),
):
settings = await settings_store.create_default_settings(None)
assert settings is not None
assert settings.language == 'en'
@pytest.mark.asyncio
async def test_create_default_settings_waits_when_lock_held(
settings_store, mock_stripe, mock_github_user, mock_litellm_api, session_maker
):
"""Test that create_default_settings waits and retries when another process holds the lock."""
file_store = MagicMock()
file_store.read.side_effect = FileNotFoundError()
# Create a mock Redis client that fails to acquire lock on first attempt, succeeds on second
mock_redis = AsyncMock()
mock_redis.set = AsyncMock(side_effect=[False, True])
# Track if sleep was called
sleep_called = False
async def mock_sleep(delay):
nonlocal sleep_called
sleep_called = True
# Don't actually sleep - just verify it was called with correct delay
from storage.saas_settings_store import _RETRY_LOAD_DELAY_SECONDS
assert delay == _RETRY_LOAD_DELAY_SECONDS
with (
patch('storage.saas_settings_store.REQUIRE_PAYMENT', False),
patch(
'stripe.Customer.list_payment_methods_async',
AsyncMock(return_value=MagicMock(data=[])),
),
patch(
'storage.saas_settings_store.get_file_store',
MagicMock(return_value=file_store),
),
patch('storage.saas_settings_store.session_maker', session_maker),
patch.object(settings_store, '_get_redis_client', return_value=mock_redis),
patch('storage.saas_settings_store.asyncio.sleep', mock_sleep),
):
settings = await settings_store.create_default_settings(None)
# Should have called sleep while waiting for lock
assert sleep_called
# Should eventually succeed and return settings
assert settings is not None
assert settings.language == 'en'
@pytest.mark.asyncio
async def test_create_default_lite_llm_settings_no_api_config(settings_store):
with (