mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 13:47:19 +08:00
(fix) CodeActAgent/LLM: react on should_exit flag (user cancellation) (#3968)
This commit is contained in:
@@ -5,6 +5,7 @@ from agenthub.codeact_agent.action_parser import CodeActResponseParser
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config import AgentConfig
|
||||
from openhands.core.exceptions import OperationCancelled
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.message import ImageContent, Message, TextContent
|
||||
from openhands.events.action import (
|
||||
@@ -211,8 +212,11 @@ class CodeActAgent(Agent):
|
||||
'anthropic-beta': 'prompt-caching-2024-07-31',
|
||||
}
|
||||
|
||||
# TODO: move exception handling to agent_controller
|
||||
try:
|
||||
response = self.llm.completion(**params)
|
||||
except OperationCancelled as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.error(f'{e}')
|
||||
error_message = '{}: {}'.format(type(e).__name__, str(e).split('\n')[0])
|
||||
|
||||
@@ -77,3 +77,10 @@ class UserCancelledError(Exception):
|
||||
class MicroAgentValidationError(Exception):
|
||||
def __init__(self, message='Micro agent validation failed'):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class OperationCancelled(Exception):
|
||||
"""Exception raised when an operation is cancelled (e.g. by a keyboard interrupt)."""
|
||||
|
||||
def __init__(self, message='Operation was cancelled'):
|
||||
super().__init__(message)
|
||||
|
||||
@@ -24,15 +24,21 @@ from litellm.types.utils import CostPerToken
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
retry_if_not_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from openhands.core.exceptions import LLMResponseError, UserCancelledError
|
||||
from openhands.core.exceptions import (
|
||||
LLMResponseError,
|
||||
OperationCancelled,
|
||||
UserCancelledError,
|
||||
)
|
||||
from openhands.core.logger import llm_prompt_logger, llm_response_logger
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.message import Message
|
||||
from openhands.core.metrics import Metrics
|
||||
from openhands.runtime.utils.shutdown_listener import should_exit
|
||||
|
||||
__all__ = ['LLM']
|
||||
|
||||
@@ -169,13 +175,18 @@ class LLM:
|
||||
|
||||
completion_unwrapped = self._completion
|
||||
|
||||
def attempt_on_error(retry_state):
|
||||
"""Custom attempt function for litellm completion."""
|
||||
def log_retry_attempt(retry_state):
|
||||
"""With before_sleep, this is called before `custom_completion_wait` and
|
||||
ONLY if the retry is triggered by an exception."""
|
||||
if should_exit():
|
||||
raise OperationCancelled(
|
||||
'Operation cancelled.'
|
||||
) # exits the @retry loop
|
||||
exception = retry_state.outcome.exception()
|
||||
logger.error(
|
||||
f'{retry_state.outcome.exception()}. Attempt #{retry_state.attempt_number} | You can customize retry values in the configuration.',
|
||||
f'{exception}. Attempt #{retry_state.attempt_number} | You can customize retry values in the configuration.',
|
||||
exc_info=False,
|
||||
)
|
||||
return None
|
||||
|
||||
def custom_completion_wait(retry_state):
|
||||
"""Custom wait function for litellm completion."""
|
||||
@@ -211,10 +222,13 @@ class LLM:
|
||||
return exponential_wait(retry_state)
|
||||
|
||||
@retry(
|
||||
after=attempt_on_error,
|
||||
before_sleep=log_retry_attempt,
|
||||
stop=stop_after_attempt(self.config.num_retries),
|
||||
reraise=True,
|
||||
retry=retry_if_exception_type(self.retry_exceptions),
|
||||
retry=(
|
||||
retry_if_exception_type(self.retry_exceptions)
|
||||
& retry_if_not_exception_type(OperationCancelled)
|
||||
),
|
||||
wait=custom_completion_wait,
|
||||
)
|
||||
def wrapper(*args, **kwargs):
|
||||
@@ -278,10 +292,13 @@ class LLM:
|
||||
async_completion_unwrapped = self._async_completion
|
||||
|
||||
@retry(
|
||||
after=attempt_on_error,
|
||||
before_sleep=log_retry_attempt,
|
||||
stop=stop_after_attempt(self.config.num_retries),
|
||||
reraise=True,
|
||||
retry=retry_if_exception_type(self.retry_exceptions),
|
||||
retry=(
|
||||
retry_if_exception_type(self.retry_exceptions)
|
||||
& retry_if_not_exception_type(OperationCancelled)
|
||||
),
|
||||
wait=custom_completion_wait,
|
||||
)
|
||||
async def async_completion_wrapper(*args, **kwargs):
|
||||
@@ -351,10 +368,13 @@ class LLM:
|
||||
pass
|
||||
|
||||
@retry(
|
||||
after=attempt_on_error,
|
||||
before_sleep=log_retry_attempt,
|
||||
stop=stop_after_attempt(self.config.num_retries),
|
||||
reraise=True,
|
||||
retry=retry_if_exception_type(self.retry_exceptions),
|
||||
retry=(
|
||||
retry_if_exception_type(self.retry_exceptions)
|
||||
& retry_if_not_exception_type(OperationCancelled)
|
||||
),
|
||||
wait=custom_completion_wait,
|
||||
)
|
||||
async def async_acompletion_stream_wrapper(*args, **kwargs):
|
||||
@@ -448,6 +468,9 @@ class LLM:
|
||||
return str(element)
|
||||
|
||||
async def _call_acompletion(self, *args, **kwargs):
|
||||
"""This is a wrapper for the litellm acompletion function which
|
||||
makes it mockable for testing.
|
||||
"""
|
||||
return await litellm.acompletion(*args, **kwargs)
|
||||
|
||||
@property
|
||||
@@ -528,10 +551,15 @@ class LLM:
|
||||
output_tokens = usage.get('completion_tokens')
|
||||
|
||||
if input_tokens:
|
||||
stats += 'Input tokens: ' + str(input_tokens) + '\n'
|
||||
stats += 'Input tokens: ' + str(input_tokens)
|
||||
|
||||
if output_tokens:
|
||||
stats += 'Output tokens: ' + str(output_tokens) + '\n'
|
||||
stats += (
|
||||
(' | ' if input_tokens else '')
|
||||
+ 'Output tokens: '
|
||||
+ str(output_tokens)
|
||||
+ '\n'
|
||||
)
|
||||
|
||||
model_extra = usage.get('model_extra', {})
|
||||
|
||||
|
||||
@@ -1,15 +1,38 @@
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from litellm.exceptions import (
|
||||
APIConnectionError,
|
||||
ContentPolicyViolationError,
|
||||
InternalServerError,
|
||||
OpenAIError,
|
||||
RateLimitError,
|
||||
)
|
||||
|
||||
from openhands.core.config import LLMConfig
|
||||
from openhands.core.exceptions import OperationCancelled
|
||||
from openhands.core.metrics import Metrics
|
||||
from openhands.llm.llm import LLM
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_logger(monkeypatch):
|
||||
# suppress logging of completion data to file
|
||||
mock_logger = MagicMock()
|
||||
monkeypatch.setattr('openhands.llm.llm.llm_prompt_logger', mock_logger)
|
||||
monkeypatch.setattr('openhands.llm.llm.llm_response_logger', mock_logger)
|
||||
return mock_logger
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_config():
|
||||
return LLMConfig(model='gpt-4o', api_key='test_key')
|
||||
return LLMConfig(
|
||||
model='gpt-4o',
|
||||
api_key='test_key',
|
||||
num_retries=2,
|
||||
retry_min_wait=1,
|
||||
retry_max_wait=2,
|
||||
)
|
||||
|
||||
|
||||
def test_llm_init_with_default_config(default_config):
|
||||
@@ -64,7 +87,7 @@ def test_llm_init_with_metrics():
|
||||
|
||||
|
||||
def test_llm_reset():
|
||||
llm = LLM(LLMConfig(model='gpt-3.5-turbo', api_key='test_key'))
|
||||
llm = LLM(LLMConfig(model='gpt-4o-mini', api_key='test_key'))
|
||||
initial_metrics = llm.metrics
|
||||
llm.reset()
|
||||
assert llm.metrics is not initial_metrics
|
||||
@@ -73,7 +96,7 @@ def test_llm_reset():
|
||||
|
||||
@patch('openhands.llm.llm.litellm.get_model_info')
|
||||
def test_llm_init_with_openrouter_model(mock_get_model_info, default_config):
|
||||
default_config.model = 'openrouter:gpt-3.5-turbo'
|
||||
default_config.model = 'openrouter:gpt-4o-mini'
|
||||
mock_get_model_info.return_value = {
|
||||
'max_input_tokens': 7000,
|
||||
'max_output_tokens': 1500,
|
||||
@@ -81,4 +104,197 @@ def test_llm_init_with_openrouter_model(mock_get_model_info, default_config):
|
||||
llm = LLM(default_config)
|
||||
assert llm.config.max_input_tokens == 7000
|
||||
assert llm.config.max_output_tokens == 1500
|
||||
mock_get_model_info.assert_called_once_with('openrouter:gpt-3.5-turbo')
|
||||
mock_get_model_info.assert_called_once_with('openrouter:gpt-4o-mini')
|
||||
|
||||
|
||||
# Tests involving completion and retries
|
||||
|
||||
|
||||
@patch('openhands.llm.llm.litellm_completion')
|
||||
def test_completion_with_mocked_logger(
|
||||
mock_litellm_completion, default_config, mock_logger
|
||||
):
|
||||
mock_litellm_completion.return_value = {
|
||||
'choices': [{'message': {'content': 'Test response'}}]
|
||||
}
|
||||
|
||||
llm = LLM(config=default_config)
|
||||
response = llm.completion(
|
||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert response['choices'][0]['message']['content'] == 'Test response'
|
||||
assert mock_litellm_completion.call_count == 1
|
||||
|
||||
mock_logger.debug.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'exception_class,extra_args,expected_retries',
|
||||
[
|
||||
(
|
||||
APIConnectionError,
|
||||
{'llm_provider': 'test_provider', 'model': 'test_model'},
|
||||
2,
|
||||
),
|
||||
(
|
||||
ContentPolicyViolationError,
|
||||
{'model': 'test_model', 'llm_provider': 'test_provider'},
|
||||
2,
|
||||
),
|
||||
(
|
||||
InternalServerError,
|
||||
{'llm_provider': 'test_provider', 'model': 'test_model'},
|
||||
2,
|
||||
),
|
||||
(OpenAIError, {}, 2),
|
||||
(RateLimitError, {'llm_provider': 'test_provider', 'model': 'test_model'}, 2),
|
||||
],
|
||||
)
|
||||
@patch('openhands.llm.llm.litellm_completion')
|
||||
def test_completion_retries(
|
||||
mock_litellm_completion,
|
||||
default_config,
|
||||
exception_class,
|
||||
extra_args,
|
||||
expected_retries,
|
||||
):
|
||||
mock_litellm_completion.side_effect = [
|
||||
exception_class('Test error message', **extra_args),
|
||||
{'choices': [{'message': {'content': 'Retry successful'}}]},
|
||||
]
|
||||
|
||||
llm = LLM(config=default_config)
|
||||
response = llm.completion(
|
||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert response['choices'][0]['message']['content'] == 'Retry successful'
|
||||
assert mock_litellm_completion.call_count == expected_retries
|
||||
|
||||
|
||||
@patch('openhands.llm.llm.litellm_completion')
|
||||
def test_completion_rate_limit_wait_time(mock_litellm_completion, default_config):
|
||||
with patch('time.sleep') as mock_sleep:
|
||||
mock_litellm_completion.side_effect = [
|
||||
RateLimitError(
|
||||
'Rate limit exceeded', llm_provider='test_provider', model='test_model'
|
||||
),
|
||||
{'choices': [{'message': {'content': 'Retry successful'}}]},
|
||||
]
|
||||
|
||||
llm = LLM(config=default_config)
|
||||
response = llm.completion(
|
||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert response['choices'][0]['message']['content'] == 'Retry successful'
|
||||
assert mock_litellm_completion.call_count == 2
|
||||
|
||||
mock_sleep.assert_called_once()
|
||||
wait_time = mock_sleep.call_args[0][0]
|
||||
assert (
|
||||
60 <= wait_time <= 240
|
||||
), f'Expected wait time between 60 and 240 seconds, but got {wait_time}'
|
||||
|
||||
|
||||
@patch('openhands.llm.llm.litellm_completion')
|
||||
def test_completion_exhausts_retries(mock_litellm_completion, default_config):
|
||||
mock_litellm_completion.side_effect = APIConnectionError(
|
||||
'Persistent error', llm_provider='test_provider', model='test_model'
|
||||
)
|
||||
|
||||
llm = LLM(config=default_config)
|
||||
with pytest.raises(APIConnectionError):
|
||||
llm.completion(
|
||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert mock_litellm_completion.call_count == llm.config.num_retries
|
||||
|
||||
|
||||
@patch('openhands.llm.llm.litellm_completion')
|
||||
def test_completion_operation_cancelled(mock_litellm_completion, default_config):
|
||||
mock_litellm_completion.side_effect = OperationCancelled('Operation cancelled')
|
||||
|
||||
llm = LLM(config=default_config)
|
||||
with pytest.raises(OperationCancelled):
|
||||
llm.completion(
|
||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert mock_litellm_completion.call_count == 1
|
||||
|
||||
|
||||
@patch('openhands.llm.llm.litellm_completion')
|
||||
def test_completion_keyboard_interrupt(mock_litellm_completion, default_config):
|
||||
def side_effect(*args, **kwargs):
|
||||
raise KeyboardInterrupt('Simulated KeyboardInterrupt')
|
||||
|
||||
mock_litellm_completion.side_effect = side_effect
|
||||
|
||||
llm = LLM(config=default_config)
|
||||
with pytest.raises(OperationCancelled):
|
||||
try:
|
||||
llm.completion(
|
||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||
stream=False,
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
raise OperationCancelled('Operation cancelled due to KeyboardInterrupt')
|
||||
|
||||
assert mock_litellm_completion.call_count == 1
|
||||
|
||||
|
||||
@patch('openhands.llm.llm.litellm_completion')
|
||||
def test_completion_keyboard_interrupt_handler(mock_litellm_completion, default_config):
|
||||
global _should_exit
|
||||
|
||||
def side_effect(*args, **kwargs):
|
||||
global _should_exit
|
||||
_should_exit = True
|
||||
return {'choices': [{'message': {'content': 'Simulated interrupt response'}}]}
|
||||
|
||||
mock_litellm_completion.side_effect = side_effect
|
||||
|
||||
llm = LLM(config=default_config)
|
||||
result = llm.completion(
|
||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert mock_litellm_completion.call_count == 1
|
||||
assert result['choices'][0]['message']['content'] == 'Simulated interrupt response'
|
||||
assert _should_exit
|
||||
|
||||
_should_exit = False
|
||||
|
||||
|
||||
@patch('openhands.llm.llm.litellm_completion')
|
||||
def test_completion_with_litellm_mock(mock_litellm_completion, default_config):
|
||||
mock_response = {
|
||||
'choices': [{'message': {'content': 'This is a mocked response.'}}]
|
||||
}
|
||||
mock_litellm_completion.return_value = mock_response
|
||||
|
||||
test_llm = LLM(config=default_config)
|
||||
response = test_llm.completion(
|
||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||
stream=False,
|
||||
drop_params=True,
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert response['choices'][0]['message']['content'] == 'This is a mocked response.'
|
||||
mock_litellm_completion.assert_called_once()
|
||||
|
||||
# Check if the correct arguments were passed to litellm_completion
|
||||
call_args = mock_litellm_completion.call_args[1] # Get keyword arguments
|
||||
assert call_args['model'] == default_config.model
|
||||
assert call_args['messages'] == [{'role': 'user', 'content': 'Hello!'}]
|
||||
assert not call_args['stream']
|
||||
|
||||
Reference in New Issue
Block a user