mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 05:37:20 +08:00
fix: Set correct user context in webhook callbacks based on sandbox owner (#13340)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
@@ -6,7 +6,7 @@ import logging
|
||||
import pkgutil
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||
from fastapi.security import APIKeyHeader
|
||||
from jwt import InvalidTokenError
|
||||
from pydantic import SecretStr
|
||||
@@ -23,61 +23,87 @@ from openhands.app_server.config import (
|
||||
depends_app_conversation_info_service,
|
||||
depends_event_service,
|
||||
depends_jwt_service,
|
||||
depends_sandbox_service,
|
||||
get_event_callback_service,
|
||||
get_global_config,
|
||||
get_sandbox_service,
|
||||
)
|
||||
from openhands.app_server.errors import AuthError
|
||||
from openhands.app_server.event.event_service import EventService
|
||||
from openhands.app_server.sandbox.sandbox_models import SandboxInfo
|
||||
from openhands.app_server.sandbox.sandbox_service import SandboxService
|
||||
from openhands.app_server.services.injector import InjectorState
|
||||
from openhands.app_server.services.jwt_service import JwtService
|
||||
from openhands.app_server.user.auth_user_context import AuthUserContext
|
||||
from openhands.app_server.user.specifiy_user_context import (
|
||||
ADMIN,
|
||||
USER_CONTEXT_ATTR,
|
||||
SpecifyUserContext,
|
||||
as_admin,
|
||||
)
|
||||
from openhands.app_server.user.user_context import UserContext
|
||||
from openhands.integrations.provider import ProviderType
|
||||
from openhands.sdk import ConversationExecutionStatus, Event
|
||||
from openhands.sdk.event import ConversationStateUpdateEvent
|
||||
from openhands.server.types import AppMode
|
||||
from openhands.server.user_auth.default_user_auth import DefaultUserAuth
|
||||
from openhands.server.user_auth.user_auth import (
|
||||
get_for_user as get_user_auth_for_user,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix='/webhooks', tags=['Webhooks'])
|
||||
sandbox_service_dependency = depends_sandbox_service()
|
||||
event_service_dependency = depends_event_service()
|
||||
app_conversation_info_service_dependency = depends_app_conversation_info_service()
|
||||
jwt_dependency = depends_jwt_service()
|
||||
app_mode = get_global_config().app_mode
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def valid_sandbox(
|
||||
user_context: UserContext = Depends(as_admin),
|
||||
request: Request,
|
||||
session_api_key: str = Depends(
|
||||
APIKeyHeader(name='X-Session-API-Key', auto_error=False)
|
||||
),
|
||||
sandbox_service: SandboxService = sandbox_service_dependency,
|
||||
) -> SandboxInfo:
|
||||
"""Use a session api key for validation, and get a sandbox. Subsequent actions
|
||||
are executed in the context of the owner of the sandbox"""
|
||||
if not session_api_key:
|
||||
raise HTTPException(
|
||||
status.HTTP_401_UNAUTHORIZED, detail='X-Session-API-Key header is required'
|
||||
)
|
||||
|
||||
sandbox_info = await sandbox_service.get_sandbox_by_session_api_key(session_api_key)
|
||||
if sandbox_info is None:
|
||||
raise HTTPException(
|
||||
status.HTTP_401_UNAUTHORIZED, detail='Invalid session API key'
|
||||
# Create a state which will be used internally only for this operation
|
||||
state = InjectorState()
|
||||
|
||||
# Since we need access to all sandboxes, this is executed in the context of the admin.
|
||||
setattr(state, USER_CONTEXT_ATTR, ADMIN)
|
||||
async with get_sandbox_service(state) as sandbox_service:
|
||||
sandbox_info = await sandbox_service.get_sandbox_by_session_api_key(
|
||||
session_api_key
|
||||
)
|
||||
return sandbox_info
|
||||
if sandbox_info is None:
|
||||
raise HTTPException(
|
||||
status.HTTP_401_UNAUTHORIZED, detail='Invalid session API key'
|
||||
)
|
||||
|
||||
# In SAAS Mode there is always a user, so we set the owner of the sandbox
|
||||
# as the current user (Validated by the session_api_key they provided)
|
||||
if sandbox_info.created_by_user_id:
|
||||
setattr(
|
||||
request.state,
|
||||
USER_CONTEXT_ATTR,
|
||||
SpecifyUserContext(sandbox_info.created_by_user_id),
|
||||
)
|
||||
elif app_mode == AppMode.SAAS:
|
||||
_logger.error(
|
||||
'Sandbox had no user specified', extra={'sandbox_id': sandbox_info.id}
|
||||
)
|
||||
raise HTTPException(
|
||||
status.HTTP_401_UNAUTHORIZED, detail='Sandbox had no user specified'
|
||||
)
|
||||
|
||||
return sandbox_info
|
||||
|
||||
|
||||
async def valid_conversation(
|
||||
conversation_id: UUID,
|
||||
sandbox_info: SandboxInfo,
|
||||
sandbox_info: SandboxInfo = Depends(valid_sandbox),
|
||||
app_conversation_info_service: AppConversationInfoService = app_conversation_info_service_dependency,
|
||||
) -> AppConversationInfo:
|
||||
app_conversation_info = (
|
||||
@@ -90,9 +116,11 @@ async def valid_conversation(
|
||||
sandbox_id=sandbox_info.id,
|
||||
created_by_user_id=sandbox_info.created_by_user_id,
|
||||
)
|
||||
|
||||
# Sanity check - Make sure that the conversation and sandbox were created by the same user
|
||||
if app_conversation_info.created_by_user_id != sandbox_info.created_by_user_id:
|
||||
# Make sure that the conversation and sandbox were created by the same user
|
||||
raise AuthError()
|
||||
|
||||
return app_conversation_info
|
||||
|
||||
|
||||
@@ -100,12 +128,10 @@ async def valid_conversation(
|
||||
async def on_conversation_update(
|
||||
conversation_info: ConversationInfo,
|
||||
sandbox_info: SandboxInfo = Depends(valid_sandbox),
|
||||
existing: AppConversationInfo = Depends(valid_conversation),
|
||||
app_conversation_info_service: AppConversationInfoService = app_conversation_info_service_dependency,
|
||||
) -> Success:
|
||||
"""Webhook callback for when a conversation starts, pauses, resumes, or deletes."""
|
||||
existing = await valid_conversation(
|
||||
conversation_info.id, sandbox_info, app_conversation_info_service
|
||||
)
|
||||
|
||||
# If the conversation is being deleted, no action is required...
|
||||
# Later we may consider deleting the conversation if it exists...
|
||||
@@ -139,15 +165,11 @@ async def on_conversation_update(
|
||||
async def on_event(
|
||||
events: list[Event],
|
||||
conversation_id: UUID,
|
||||
sandbox_info: SandboxInfo = Depends(valid_sandbox),
|
||||
app_conversation_info: AppConversationInfo = Depends(valid_conversation),
|
||||
app_conversation_info_service: AppConversationInfoService = app_conversation_info_service_dependency,
|
||||
event_service: EventService = event_service_dependency,
|
||||
) -> Success:
|
||||
"""Webhook callback for when event stream events occur."""
|
||||
app_conversation_info = await valid_conversation(
|
||||
conversation_id, sandbox_info, app_conversation_info_service
|
||||
)
|
||||
|
||||
try:
|
||||
# Save events...
|
||||
await asyncio.gather(
|
||||
|
||||
@@ -3,7 +3,8 @@
|
||||
This module tests the webhook authentication and authorization logic.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
import contextlib
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
@@ -14,7 +15,49 @@ from openhands.app_server.event_callback.webhook_router import (
|
||||
valid_sandbox,
|
||||
)
|
||||
from openhands.app_server.sandbox.sandbox_models import SandboxInfo, SandboxStatus
|
||||
from openhands.app_server.user.specifiy_user_context import ADMIN
|
||||
from openhands.app_server.user.specifiy_user_context import (
|
||||
USER_CONTEXT_ATTR,
|
||||
SpecifyUserContext,
|
||||
)
|
||||
from openhands.server.types import AppMode
|
||||
|
||||
|
||||
class MockRequestState:
|
||||
"""A mock request state that tracks attribute assignments."""
|
||||
|
||||
def __init__(self):
|
||||
self._state = {}
|
||||
self._attributes = {}
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if name.startswith('_'):
|
||||
super().__setattr__(name, value)
|
||||
else:
|
||||
self._attributes[name] = value
|
||||
|
||||
def __getattr__(self, name):
|
||||
if name in self._attributes:
|
||||
return self._attributes[name]
|
||||
raise AttributeError(
|
||||
f"'{type(self).__name__}' object has no attribute '{name}'"
|
||||
)
|
||||
|
||||
|
||||
def create_mock_request():
|
||||
"""Create a mock FastAPI Request object with proper state."""
|
||||
request = MagicMock()
|
||||
request.state = MockRequestState()
|
||||
return request
|
||||
|
||||
|
||||
def create_sandbox_service_context_manager(sandbox_service):
|
||||
"""Create an async context manager that yields the given sandbox service."""
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _context_manager(state, request=None):
|
||||
yield sandbox_service
|
||||
|
||||
return _context_manager
|
||||
|
||||
|
||||
class TestValidSandbox:
|
||||
@@ -22,14 +65,15 @@ class TestValidSandbox:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_sandbox_with_valid_api_key(self):
|
||||
"""Test that valid API key returns sandbox info."""
|
||||
"""Test that valid API key returns sandbox info and sets user_context."""
|
||||
# Arrange
|
||||
session_api_key = 'valid-api-key-123'
|
||||
user_id = 'user-123'
|
||||
expected_sandbox = SandboxInfo(
|
||||
id='sandbox-123',
|
||||
status=SandboxStatus.RUNNING,
|
||||
session_api_key=session_api_key,
|
||||
created_by_user_id='user-123',
|
||||
created_by_user_id=user_id,
|
||||
sandbox_spec_id='spec-123',
|
||||
)
|
||||
|
||||
@@ -38,12 +82,17 @@ class TestValidSandbox:
|
||||
return_value=expected_sandbox
|
||||
)
|
||||
|
||||
mock_request = create_mock_request()
|
||||
|
||||
# Act
|
||||
result = await valid_sandbox(
|
||||
user_context=ADMIN,
|
||||
session_api_key=session_api_key,
|
||||
sandbox_service=mock_sandbox_service,
|
||||
)
|
||||
with patch(
|
||||
'openhands.app_server.event_callback.webhook_router.get_sandbox_service',
|
||||
create_sandbox_service_context_manager(mock_sandbox_service),
|
||||
):
|
||||
result = await valid_sandbox(
|
||||
request=mock_request,
|
||||
session_api_key=session_api_key,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == expected_sandbox
|
||||
@@ -51,18 +100,136 @@ class TestValidSandbox:
|
||||
session_api_key
|
||||
)
|
||||
|
||||
# Verify user_context is set correctly on request.state
|
||||
assert USER_CONTEXT_ATTR in mock_request.state._attributes
|
||||
user_context = mock_request.state._attributes[USER_CONTEXT_ATTR]
|
||||
assert isinstance(user_context, SpecifyUserContext)
|
||||
assert user_context.user_id == user_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_sandbox_sets_user_context_to_sandbox_owner(self):
|
||||
"""Test that user_context is set to the sandbox owner's user ID."""
|
||||
# Arrange
|
||||
session_api_key = 'valid-api-key'
|
||||
sandbox_owner_id = 'sandbox-owner-user-id'
|
||||
expected_sandbox = SandboxInfo(
|
||||
id='sandbox-456',
|
||||
status=SandboxStatus.RUNNING,
|
||||
session_api_key=session_api_key,
|
||||
created_by_user_id=sandbox_owner_id,
|
||||
sandbox_spec_id='spec-456',
|
||||
)
|
||||
|
||||
mock_sandbox_service = AsyncMock()
|
||||
mock_sandbox_service.get_sandbox_by_session_api_key = AsyncMock(
|
||||
return_value=expected_sandbox
|
||||
)
|
||||
|
||||
mock_request = create_mock_request()
|
||||
|
||||
# Act
|
||||
with patch(
|
||||
'openhands.app_server.event_callback.webhook_router.get_sandbox_service',
|
||||
create_sandbox_service_context_manager(mock_sandbox_service),
|
||||
):
|
||||
await valid_sandbox(
|
||||
request=mock_request,
|
||||
session_api_key=session_api_key,
|
||||
)
|
||||
|
||||
# Assert - user_context should be set to the sandbox owner
|
||||
assert USER_CONTEXT_ATTR in mock_request.state._attributes
|
||||
user_context = mock_request.state._attributes[USER_CONTEXT_ATTR]
|
||||
assert isinstance(user_context, SpecifyUserContext)
|
||||
assert user_context.user_id == sandbox_owner_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_sandbox_no_user_context_when_no_user_id(self):
|
||||
"""Test that user_context is not set when sandbox has no created_by_user_id."""
|
||||
# Arrange
|
||||
session_api_key = 'valid-api-key'
|
||||
expected_sandbox = SandboxInfo(
|
||||
id='sandbox-789',
|
||||
status=SandboxStatus.RUNNING,
|
||||
session_api_key=session_api_key,
|
||||
created_by_user_id=None, # No user ID
|
||||
sandbox_spec_id='spec-789',
|
||||
)
|
||||
|
||||
mock_sandbox_service = AsyncMock()
|
||||
mock_sandbox_service.get_sandbox_by_session_api_key = AsyncMock(
|
||||
return_value=expected_sandbox
|
||||
)
|
||||
|
||||
mock_request = create_mock_request()
|
||||
|
||||
# Act
|
||||
with patch(
|
||||
'openhands.app_server.event_callback.webhook_router.get_sandbox_service',
|
||||
create_sandbox_service_context_manager(mock_sandbox_service),
|
||||
):
|
||||
result = await valid_sandbox(
|
||||
request=mock_request,
|
||||
session_api_key=session_api_key,
|
||||
)
|
||||
|
||||
# Assert - sandbox is returned but user_context should NOT be set
|
||||
assert result == expected_sandbox
|
||||
|
||||
# Verify user_context is NOT set on request.state
|
||||
assert USER_CONTEXT_ATTR not in mock_request.state._attributes
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_sandbox_no_user_context_when_no_user_id_raises_401_in_saas_mode(
|
||||
self,
|
||||
):
|
||||
"""Test that user_context is not set when sandbox has no created_by_user_id."""
|
||||
# Arrange
|
||||
session_api_key = 'valid-api-key'
|
||||
expected_sandbox = SandboxInfo(
|
||||
id='sandbox-789',
|
||||
status=SandboxStatus.RUNNING,
|
||||
session_api_key=session_api_key,
|
||||
created_by_user_id=None, # No user ID
|
||||
sandbox_spec_id='spec-789',
|
||||
)
|
||||
|
||||
mock_sandbox_service = AsyncMock()
|
||||
mock_sandbox_service.get_sandbox_by_session_api_key = AsyncMock(
|
||||
return_value=expected_sandbox
|
||||
)
|
||||
|
||||
mock_request = create_mock_request()
|
||||
|
||||
# Act
|
||||
with (
|
||||
patch(
|
||||
'openhands.app_server.event_callback.webhook_router.get_sandbox_service',
|
||||
create_sandbox_service_context_manager(mock_sandbox_service),
|
||||
),
|
||||
patch(
|
||||
'openhands.app_server.event_callback.webhook_router.app_mode',
|
||||
AppMode.SAAS,
|
||||
),
|
||||
):
|
||||
with pytest.raises(HTTPException) as excinfo:
|
||||
await valid_sandbox(
|
||||
request=mock_request,
|
||||
session_api_key=session_api_key,
|
||||
)
|
||||
assert excinfo.value.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_sandbox_without_api_key_raises_401(self):
|
||||
"""Test that missing API key raises 401 error."""
|
||||
# Arrange
|
||||
mock_sandbox_service = AsyncMock()
|
||||
mock_request = create_mock_request()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await valid_sandbox(
|
||||
user_context=ADMIN,
|
||||
request=mock_request,
|
||||
session_api_key=None,
|
||||
sandbox_service=mock_sandbox_service,
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
@@ -78,13 +245,18 @@ class TestValidSandbox:
|
||||
return_value=None
|
||||
)
|
||||
|
||||
mock_request = create_mock_request()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await valid_sandbox(
|
||||
user_context=ADMIN,
|
||||
session_api_key=session_api_key,
|
||||
sandbox_service=mock_sandbox_service,
|
||||
)
|
||||
with patch(
|
||||
'openhands.app_server.event_callback.webhook_router.get_sandbox_service',
|
||||
create_sandbox_service_context_manager(mock_sandbox_service),
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await valid_sandbox(
|
||||
request=mock_request,
|
||||
session_api_key=session_api_key,
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
assert 'Invalid session API key' in exc_info.value.detail
|
||||
@@ -95,13 +267,13 @@ class TestValidSandbox:
|
||||
# Arrange - empty string is falsy, so it gets rejected at the check
|
||||
session_api_key = ''
|
||||
mock_sandbox_service = AsyncMock()
|
||||
mock_request = create_mock_request()
|
||||
|
||||
# Act & Assert - should raise 401 because empty string fails the truth check
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await valid_sandbox(
|
||||
user_context=ADMIN,
|
||||
request=mock_request,
|
||||
session_api_key=session_api_key,
|
||||
sandbox_service=mock_sandbox_service,
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
@@ -263,12 +435,17 @@ class TestWebhookAuthenticationIntegration:
|
||||
return_value=conversation_info
|
||||
)
|
||||
|
||||
mock_request = create_mock_request()
|
||||
|
||||
# Act - Call valid_sandbox first
|
||||
sandbox_result = await valid_sandbox(
|
||||
user_context=ADMIN,
|
||||
session_api_key=session_api_key,
|
||||
sandbox_service=mock_sandbox_service,
|
||||
)
|
||||
with patch(
|
||||
'openhands.app_server.event_callback.webhook_router.get_sandbox_service',
|
||||
create_sandbox_service_context_manager(mock_sandbox_service),
|
||||
):
|
||||
sandbox_result = await valid_sandbox(
|
||||
request=mock_request,
|
||||
session_api_key=session_api_key,
|
||||
)
|
||||
|
||||
# Then call valid_conversation
|
||||
conversation_result = await valid_conversation(
|
||||
@@ -291,13 +468,18 @@ class TestWebhookAuthenticationIntegration:
|
||||
return_value=None
|
||||
)
|
||||
|
||||
mock_request = create_mock_request()
|
||||
|
||||
# Act & Assert - Should fail at valid_sandbox
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await valid_sandbox(
|
||||
user_context=ADMIN,
|
||||
session_api_key=session_api_key,
|
||||
sandbox_service=mock_sandbox_service,
|
||||
)
|
||||
with patch(
|
||||
'openhands.app_server.event_callback.webhook_router.get_sandbox_service',
|
||||
create_sandbox_service_context_manager(mock_sandbox_service),
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await valid_sandbox(
|
||||
request=mock_request,
|
||||
session_api_key=session_api_key,
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
@@ -328,12 +510,17 @@ class TestWebhookAuthenticationIntegration:
|
||||
return_value=different_user_info
|
||||
)
|
||||
|
||||
mock_request = create_mock_request()
|
||||
|
||||
# Act - valid_sandbox succeeds
|
||||
sandbox_result = await valid_sandbox(
|
||||
user_context=ADMIN,
|
||||
session_api_key=session_api_key,
|
||||
sandbox_service=mock_sandbox_service,
|
||||
)
|
||||
with patch(
|
||||
'openhands.app_server.event_callback.webhook_router.get_sandbox_service',
|
||||
create_sandbox_service_context_manager(mock_sandbox_service),
|
||||
):
|
||||
sandbox_result = await valid_sandbox(
|
||||
request=mock_request,
|
||||
session_api_key=session_api_key,
|
||||
)
|
||||
|
||||
# But valid_conversation fails
|
||||
from openhands.app_server.errors import AuthError
|
||||
|
||||
@@ -5,7 +5,7 @@ conversations are updated via the on_conversation_update webhook endpoint.
|
||||
"""
|
||||
|
||||
from typing import AsyncGenerator
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
@@ -137,17 +137,13 @@ class TestOnConversationUpdateParentConversationId:
|
||||
parent_conversation_id=parent_id,
|
||||
)
|
||||
|
||||
# Mock valid_conversation to return existing conversation
|
||||
with patch(
|
||||
'openhands.app_server.event_callback.webhook_router.valid_conversation',
|
||||
return_value=existing_conv,
|
||||
):
|
||||
# Act
|
||||
result = await on_conversation_update(
|
||||
conversation_info=mock_conversation_info,
|
||||
sandbox_info=sandbox_info,
|
||||
app_conversation_info_service=app_conversation_info_service,
|
||||
)
|
||||
# Act - call on_conversation_update directly with dependencies
|
||||
result = await on_conversation_update(
|
||||
conversation_info=mock_conversation_info,
|
||||
sandbox_info=sandbox_info,
|
||||
existing=existing_conv,
|
||||
app_conversation_info_service=app_conversation_info_service,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, Success)
|
||||
@@ -191,17 +187,13 @@ class TestOnConversationUpdateParentConversationId:
|
||||
parent_conversation_id=None,
|
||||
)
|
||||
|
||||
# Mock valid_conversation to return existing conversation
|
||||
with patch(
|
||||
'openhands.app_server.event_callback.webhook_router.valid_conversation',
|
||||
return_value=existing_conv,
|
||||
):
|
||||
# Act
|
||||
result = await on_conversation_update(
|
||||
conversation_info=mock_conversation_info,
|
||||
sandbox_info=sandbox_info,
|
||||
app_conversation_info_service=app_conversation_info_service,
|
||||
)
|
||||
# Act - call on_conversation_update directly with dependencies
|
||||
result = await on_conversation_update(
|
||||
conversation_info=mock_conversation_info,
|
||||
sandbox_info=sandbox_info,
|
||||
existing=existing_conv,
|
||||
app_conversation_info_service=app_conversation_info_service,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, Success)
|
||||
@@ -242,17 +234,13 @@ class TestOnConversationUpdateParentConversationId:
|
||||
created_by_user_id=sandbox_info.created_by_user_id,
|
||||
)
|
||||
|
||||
# Mock valid_conversation to return stub (as it would for new conversation)
|
||||
with patch(
|
||||
'openhands.app_server.event_callback.webhook_router.valid_conversation',
|
||||
return_value=stub_conv,
|
||||
):
|
||||
# Act
|
||||
result = await on_conversation_update(
|
||||
conversation_info=mock_conversation_info,
|
||||
sandbox_info=sandbox_info,
|
||||
app_conversation_info_service=app_conversation_info_service,
|
||||
)
|
||||
# Act - call on_conversation_update directly with dependencies
|
||||
result = await on_conversation_update(
|
||||
conversation_info=mock_conversation_info,
|
||||
sandbox_info=sandbox_info,
|
||||
existing=stub_conv,
|
||||
app_conversation_info_service=app_conversation_info_service,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, Success)
|
||||
@@ -302,17 +290,13 @@ class TestOnConversationUpdateParentConversationId:
|
||||
parent_conversation_id=parent_id,
|
||||
)
|
||||
|
||||
# Mock valid_conversation to return existing conversation
|
||||
with patch(
|
||||
'openhands.app_server.event_callback.webhook_router.valid_conversation',
|
||||
return_value=existing_conv,
|
||||
):
|
||||
# Act
|
||||
result = await on_conversation_update(
|
||||
conversation_info=mock_conversation_info,
|
||||
sandbox_info=sandbox_info,
|
||||
app_conversation_info_service=app_conversation_info_service,
|
||||
)
|
||||
# Act - call on_conversation_update directly with dependencies
|
||||
result = await on_conversation_update(
|
||||
conversation_info=mock_conversation_info,
|
||||
sandbox_info=sandbox_info,
|
||||
existing=existing_conv,
|
||||
app_conversation_info_service=app_conversation_info_service,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, Success)
|
||||
@@ -366,9 +350,8 @@ class TestOnConversationUpdateParentConversationId:
|
||||
parent_conversation_id=parent_id,
|
||||
)
|
||||
|
||||
# Mock valid_conversation to return conversation with parent
|
||||
# In real scenario, this would be retrieved from DB after first save
|
||||
async def mock_valid_conv(*args, **kwargs):
|
||||
# Act - Update multiple times, simulating what valid_conversation would return
|
||||
for _ in range(3):
|
||||
# After first save, get from DB with parent preserved
|
||||
saved = await app_conversation_info_service.get_app_conversation_info(
|
||||
conversation_id
|
||||
@@ -376,21 +359,17 @@ class TestOnConversationUpdateParentConversationId:
|
||||
if saved:
|
||||
# Override created_by_user_id for auth check
|
||||
saved.created_by_user_id = 'user_123'
|
||||
return saved
|
||||
return initial_conv
|
||||
existing = saved
|
||||
else:
|
||||
existing = initial_conv
|
||||
|
||||
with patch(
|
||||
'openhands.app_server.event_callback.webhook_router.valid_conversation',
|
||||
side_effect=mock_valid_conv,
|
||||
):
|
||||
# Act - Update multiple times
|
||||
for _ in range(3):
|
||||
result = await on_conversation_update(
|
||||
conversation_info=mock_conversation_info,
|
||||
sandbox_info=sandbox_info,
|
||||
app_conversation_info_service=app_conversation_info_service,
|
||||
)
|
||||
assert isinstance(result, Success)
|
||||
result = await on_conversation_update(
|
||||
conversation_info=mock_conversation_info,
|
||||
sandbox_info=sandbox_info,
|
||||
existing=existing,
|
||||
app_conversation_info_service=app_conversation_info_service,
|
||||
)
|
||||
assert isinstance(result, Success)
|
||||
|
||||
# Assert
|
||||
saved_conv = await app_conversation_info_service.get_app_conversation_info(
|
||||
@@ -441,17 +420,13 @@ class TestOnConversationUpdateParentConversationId:
|
||||
# Set conversation to DELETING status
|
||||
mock_conversation_info.execution_status = ConversationExecutionStatus.DELETING
|
||||
|
||||
# Mock valid_conversation (though it won't be called for DELETING status)
|
||||
with patch(
|
||||
'openhands.app_server.event_callback.webhook_router.valid_conversation',
|
||||
return_value=existing_conv,
|
||||
):
|
||||
# Act
|
||||
result = await on_conversation_update(
|
||||
conversation_info=mock_conversation_info,
|
||||
sandbox_info=sandbox_info,
|
||||
app_conversation_info_service=app_conversation_info_service,
|
||||
)
|
||||
# Act - call on_conversation_update directly with dependencies
|
||||
result = await on_conversation_update(
|
||||
conversation_info=mock_conversation_info,
|
||||
sandbox_info=sandbox_info,
|
||||
existing=existing_conv,
|
||||
app_conversation_info_service=app_conversation_info_service,
|
||||
)
|
||||
|
||||
# Assert - Function returns success but doesn't update
|
||||
assert isinstance(result, Success)
|
||||
@@ -498,17 +473,13 @@ class TestOnConversationUpdateParentConversationId:
|
||||
parent_conversation_id=parent_id,
|
||||
)
|
||||
|
||||
# Mock valid_conversation to return existing conversation
|
||||
with patch(
|
||||
'openhands.app_server.event_callback.webhook_router.valid_conversation',
|
||||
return_value=existing_conv,
|
||||
):
|
||||
# Act
|
||||
result = await on_conversation_update(
|
||||
conversation_info=mock_conversation_info,
|
||||
sandbox_info=sandbox_info,
|
||||
app_conversation_info_service=app_conversation_info_service,
|
||||
)
|
||||
# Act - call on_conversation_update directly with dependencies
|
||||
result = await on_conversation_update(
|
||||
conversation_info=mock_conversation_info,
|
||||
sandbox_info=sandbox_info,
|
||||
existing=existing_conv,
|
||||
app_conversation_info_service=app_conversation_info_service,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, Success)
|
||||
|
||||
@@ -451,11 +451,9 @@ class TestOnEventStatsProcessing:
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_event_processes_stats_events(self):
|
||||
"""Test that on_event processes stats events."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from openhands.app_server.event_callback.webhook_router import on_event
|
||||
from openhands.app_server.sandbox.sandbox_models import (
|
||||
SandboxInfo,
|
||||
SandboxStatus,
|
||||
)
|
||||
|
||||
conversation_id = uuid4()
|
||||
sandbox_id = 'sandbox_123'
|
||||
@@ -482,15 +480,6 @@ class TestOnEventStatsProcessing:
|
||||
|
||||
events = [stats_event, other_event]
|
||||
|
||||
# Mock dependencies
|
||||
mock_sandbox = SandboxInfo(
|
||||
id=sandbox_id,
|
||||
status=SandboxStatus.RUNNING,
|
||||
session_api_key='test_key',
|
||||
created_by_user_id='user_123',
|
||||
sandbox_spec_id='spec_123',
|
||||
)
|
||||
|
||||
mock_app_conversation_info = AppConversationInfo(
|
||||
id=conversation_id,
|
||||
sandbox_id=sandbox_id,
|
||||
@@ -499,9 +488,6 @@ class TestOnEventStatsProcessing:
|
||||
|
||||
mock_event_service = AsyncMock()
|
||||
mock_app_conversation_info_service = AsyncMock()
|
||||
mock_app_conversation_info_service.get_app_conversation_info.return_value = (
|
||||
mock_app_conversation_info
|
||||
)
|
||||
|
||||
# Set up process_stats_event to call update_conversation_statistics
|
||||
async def process_stats_event_side_effect(event, conversation_id):
|
||||
@@ -519,44 +505,33 @@ class TestOnEventStatsProcessing:
|
||||
process_stats_event_side_effect
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
'openhands.app_server.event_callback.webhook_router.valid_sandbox',
|
||||
return_value=mock_sandbox,
|
||||
),
|
||||
patch(
|
||||
'openhands.app_server.event_callback.webhook_router.valid_conversation',
|
||||
return_value=mock_app_conversation_info,
|
||||
),
|
||||
patch(
|
||||
'openhands.app_server.event_callback.webhook_router._run_callbacks_in_bg_and_close'
|
||||
) as mock_callbacks,
|
||||
):
|
||||
with patch(
|
||||
'openhands.app_server.event_callback.webhook_router._run_callbacks_in_bg_and_close'
|
||||
) as mock_callbacks:
|
||||
# Call on_event directly with dependencies
|
||||
await on_event(
|
||||
events=events,
|
||||
conversation_id=conversation_id,
|
||||
sandbox_info=mock_sandbox,
|
||||
app_conversation_info=mock_app_conversation_info,
|
||||
app_conversation_info_service=mock_app_conversation_info_service,
|
||||
event_service=mock_event_service,
|
||||
)
|
||||
|
||||
# Verify events were saved
|
||||
assert mock_event_service.save_event.call_count == 2
|
||||
# Verify events were saved
|
||||
assert mock_event_service.save_event.call_count == 2
|
||||
|
||||
# Verify stats event was processed
|
||||
mock_app_conversation_info_service.update_conversation_statistics.assert_called_once()
|
||||
# Verify stats event was processed
|
||||
mock_app_conversation_info_service.update_conversation_statistics.assert_called_once()
|
||||
|
||||
# Verify callbacks were scheduled
|
||||
mock_callbacks.assert_called_once()
|
||||
# Verify callbacks were scheduled
|
||||
mock_callbacks.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_event_skips_non_stats_events(self):
|
||||
"""Test that on_event skips non-stats events."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from openhands.app_server.event_callback.webhook_router import on_event
|
||||
from openhands.app_server.sandbox.sandbox_models import (
|
||||
SandboxInfo,
|
||||
SandboxStatus,
|
||||
)
|
||||
from openhands.events.action.message import MessageAction
|
||||
|
||||
conversation_id = uuid4()
|
||||
@@ -568,14 +543,6 @@ class TestOnEventStatsProcessing:
|
||||
MessageAction(content='test'),
|
||||
]
|
||||
|
||||
mock_sandbox = SandboxInfo(
|
||||
id=sandbox_id,
|
||||
status=SandboxStatus.RUNNING,
|
||||
session_api_key='test_key',
|
||||
created_by_user_id='user_123',
|
||||
sandbox_spec_id='spec_123',
|
||||
)
|
||||
|
||||
mock_app_conversation_info = AppConversationInfo(
|
||||
id=conversation_id,
|
||||
sandbox_id=sandbox_id,
|
||||
@@ -584,30 +551,18 @@ class TestOnEventStatsProcessing:
|
||||
|
||||
mock_event_service = AsyncMock()
|
||||
mock_app_conversation_info_service = AsyncMock()
|
||||
mock_app_conversation_info_service.get_app_conversation_info.return_value = (
|
||||
mock_app_conversation_info
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
'openhands.app_server.event_callback.webhook_router.valid_sandbox',
|
||||
return_value=mock_sandbox,
|
||||
),
|
||||
patch(
|
||||
'openhands.app_server.event_callback.webhook_router.valid_conversation',
|
||||
return_value=mock_app_conversation_info,
|
||||
),
|
||||
patch(
|
||||
'openhands.app_server.event_callback.webhook_router._run_callbacks_in_bg_and_close'
|
||||
),
|
||||
with patch(
|
||||
'openhands.app_server.event_callback.webhook_router._run_callbacks_in_bg_and_close'
|
||||
):
|
||||
# Call on_event directly with dependencies
|
||||
await on_event(
|
||||
events=events,
|
||||
conversation_id=conversation_id,
|
||||
sandbox_info=mock_sandbox,
|
||||
app_conversation_info=mock_app_conversation_info,
|
||||
app_conversation_info_service=mock_app_conversation_info_service,
|
||||
event_service=mock_event_service,
|
||||
)
|
||||
|
||||
# Verify stats update was NOT called
|
||||
mock_app_conversation_info_service.update_conversation_statistics.assert_not_called()
|
||||
# Verify stats update was NOT called
|
||||
mock_app_conversation_info_service.update_conversation_statistics.assert_not_called()
|
||||
|
||||
Reference in New Issue
Block a user