From d76e83b55ef4bb05b2bab33338716960508e5f9a Mon Sep 17 00:00:00 2001 From: tofarr Date: Mon, 16 Dec 2024 15:59:41 -0700 Subject: [PATCH] Fix: Mocking LLM proxy in unit tests (#5639) --- tests/unit/test_acompletion.py | 32 +++++++++++++++++++++++++------- tests/unit/test_manager.py | 4 ++-- 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/tests/unit/test_acompletion.py b/tests/unit/test_acompletion.py index cd5197ebb1..b6753759be 100644 --- a/tests/unit/test_acompletion.py +++ b/tests/unit/test_acompletion.py @@ -1,5 +1,7 @@ import asyncio -from unittest.mock import AsyncMock, patch +from contextlib import contextmanager +from typing import Type +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -14,8 +16,12 @@ config = load_app_config() @pytest.fixture def test_llm(): - # Create a mock config for testing - return LLM(config=config.get_llm_config()) + return _get_llm(LLM) + + +def _get_llm(type_: Type[LLM]): + with _patch_http(): + return type_(config=config.get_llm_config()) @pytest.fixture @@ -39,6 +45,18 @@ def mock_response(): ] +@contextmanager +def _patch_http(): + with patch('openhands.llm.llm.requests.get', MagicMock()) as mock_http: + mock_http.json.return_value = { + 'data': [ + {'model_name': 'some_model'}, + {'model_name': 'another_model'}, + ] + } + yield + + @pytest.mark.asyncio async def test_acompletion_non_streaming(): with patch.object(AsyncLLM, '_call_acompletion') as mock_call_acompletion: @@ -46,7 +64,7 @@ async def test_acompletion_non_streaming(): 'choices': [{'message': {'content': 'This is a test message.'}}] } mock_call_acompletion.return_value = mock_response - test_llm = AsyncLLM(config=config.get_llm_config()) + test_llm = _get_llm(AsyncLLM) response = await test_llm.async_completion( messages=[{'role': 'user', 'content': 'Hello!'}], stream=False, @@ -60,7 +78,7 @@ async def test_acompletion_non_streaming(): async def test_acompletion_streaming(mock_response): with patch.object(StreamingLLM, '_call_acompletion') as mock_call_acompletion: mock_call_acompletion.return_value.__aiter__.return_value = iter(mock_response) - test_llm = StreamingLLM(config=config.get_llm_config()) + test_llm = _get_llm(StreamingLLM) async for chunk in test_llm.async_streaming_completion( messages=[{'role': 'user', 'content': 'Hello!'}], stream=True ): @@ -109,7 +127,7 @@ async def test_async_completion_with_user_cancellation(cancel_delay): AsyncLLM, '_call_acompletion', new_callable=AsyncMock ) as mock_call_acompletion: mock_call_acompletion.side_effect = mock_acompletion - test_llm = AsyncLLM(config=config.get_llm_config()) + test_llm = _get_llm(AsyncLLM) async def cancel_after_delay(): print(f'Starting cancel_after_delay with delay {cancel_delay}') @@ -171,7 +189,7 @@ async def test_async_streaming_completion_with_user_cancellation(cancel_after_ch AsyncLLM, '_call_acompletion', new_callable=AsyncMock ) as mock_call_acompletion: mock_call_acompletion.return_value = mock_acompletion() - test_llm = StreamingLLM(config=config.get_llm_config()) + test_llm = _get_llm(StreamingLLM) received_chunks = [] with pytest.raises(UserCancelledError): diff --git a/tests/unit/test_manager.py b/tests/unit/test_manager.py index 9ec9e4ac31..43e8672ef1 100644 --- a/tests/unit/test_manager.py +++ b/tests/unit/test_manager.py @@ -60,7 +60,7 @@ async def test_session_is_running_in_cluster(): ) ) with ( - patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.05), + patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.1), ): async with SessionManager( sio, AppConfig(), InMemoryFileStore() @@ -87,7 +87,7 @@ async def test_init_new_local_session(): is_session_running_in_cluster_mock.return_value = False with ( patch('openhands.server.session.manager.Session', mock_session), - patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01), + patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.1), patch( 'openhands.server.session.manager.SessionManager._redis_subscribe', AsyncMock(),