(fix) CodeActAgent/LLM: react on should_exit flag (user cancellation) (#3968)

This commit is contained in:
tobitege
2024-09-20 23:49:45 +02:00
committed by GitHub
parent ebd93977cd
commit 01462e11d7
4 changed files with 273 additions and 18 deletions

View File

@@ -5,6 +5,7 @@ from agenthub.codeact_agent.action_parser import CodeActResponseParser
from openhands.controller.agent import Agent
from openhands.controller.state.state import State
from openhands.core.config import AgentConfig
from openhands.core.exceptions import OperationCancelled
from openhands.core.logger import openhands_logger as logger
from openhands.core.message import ImageContent, Message, TextContent
from openhands.events.action import (
@@ -211,8 +212,11 @@ class CodeActAgent(Agent):
'anthropic-beta': 'prompt-caching-2024-07-31',
}
# TODO: move exception handling to agent_controller
try:
response = self.llm.completion(**params)
except OperationCancelled as e:
raise e
except Exception as e:
logger.error(f'{e}')
error_message = '{}: {}'.format(type(e).__name__, str(e).split('\n')[0])

View File

@@ -77,3 +77,10 @@ class UserCancelledError(Exception):
class MicroAgentValidationError(Exception):
def __init__(self, message='Micro agent validation failed'):
super().__init__(message)
class OperationCancelled(Exception):
"""Exception raised when an operation is cancelled (e.g. by a keyboard interrupt)."""
def __init__(self, message='Operation was cancelled'):
super().__init__(message)

View File

@@ -24,15 +24,21 @@ from litellm.types.utils import CostPerToken
from tenacity import (
retry,
retry_if_exception_type,
retry_if_not_exception_type,
stop_after_attempt,
wait_exponential,
)
from openhands.core.exceptions import LLMResponseError, UserCancelledError
from openhands.core.exceptions import (
LLMResponseError,
OperationCancelled,
UserCancelledError,
)
from openhands.core.logger import llm_prompt_logger, llm_response_logger
from openhands.core.logger import openhands_logger as logger
from openhands.core.message import Message
from openhands.core.metrics import Metrics
from openhands.runtime.utils.shutdown_listener import should_exit
__all__ = ['LLM']
@@ -169,13 +175,18 @@ class LLM:
completion_unwrapped = self._completion
def attempt_on_error(retry_state):
"""Custom attempt function for litellm completion."""
def log_retry_attempt(retry_state):
"""With before_sleep, this is called before `custom_completion_wait` and
ONLY if the retry is triggered by an exception."""
if should_exit():
raise OperationCancelled(
'Operation cancelled.'
) # exits the @retry loop
exception = retry_state.outcome.exception()
logger.error(
f'{retry_state.outcome.exception()}. Attempt #{retry_state.attempt_number} | You can customize retry values in the configuration.',
f'{exception}. Attempt #{retry_state.attempt_number} | You can customize retry values in the configuration.',
exc_info=False,
)
return None
def custom_completion_wait(retry_state):
"""Custom wait function for litellm completion."""
@@ -211,10 +222,13 @@ class LLM:
return exponential_wait(retry_state)
@retry(
after=attempt_on_error,
before_sleep=log_retry_attempt,
stop=stop_after_attempt(self.config.num_retries),
reraise=True,
retry=retry_if_exception_type(self.retry_exceptions),
retry=(
retry_if_exception_type(self.retry_exceptions)
& retry_if_not_exception_type(OperationCancelled)
),
wait=custom_completion_wait,
)
def wrapper(*args, **kwargs):
@@ -278,10 +292,13 @@ class LLM:
async_completion_unwrapped = self._async_completion
@retry(
after=attempt_on_error,
before_sleep=log_retry_attempt,
stop=stop_after_attempt(self.config.num_retries),
reraise=True,
retry=retry_if_exception_type(self.retry_exceptions),
retry=(
retry_if_exception_type(self.retry_exceptions)
& retry_if_not_exception_type(OperationCancelled)
),
wait=custom_completion_wait,
)
async def async_completion_wrapper(*args, **kwargs):
@@ -351,10 +368,13 @@ class LLM:
pass
@retry(
after=attempt_on_error,
before_sleep=log_retry_attempt,
stop=stop_after_attempt(self.config.num_retries),
reraise=True,
retry=retry_if_exception_type(self.retry_exceptions),
retry=(
retry_if_exception_type(self.retry_exceptions)
& retry_if_not_exception_type(OperationCancelled)
),
wait=custom_completion_wait,
)
async def async_acompletion_stream_wrapper(*args, **kwargs):
@@ -448,6 +468,9 @@ class LLM:
return str(element)
async def _call_acompletion(self, *args, **kwargs):
"""This is a wrapper for the litellm acompletion function which
makes it mockable for testing.
"""
return await litellm.acompletion(*args, **kwargs)
@property
@@ -528,10 +551,15 @@ class LLM:
output_tokens = usage.get('completion_tokens')
if input_tokens:
stats += 'Input tokens: ' + str(input_tokens) + '\n'
stats += 'Input tokens: ' + str(input_tokens)
if output_tokens:
stats += 'Output tokens: ' + str(output_tokens) + '\n'
stats += (
(' | ' if input_tokens else '')
+ 'Output tokens: '
+ str(output_tokens)
+ '\n'
)
model_extra = usage.get('model_extra', {})

View File

@@ -1,15 +1,38 @@
from unittest.mock import patch
from unittest.mock import MagicMock, patch
import pytest
from litellm.exceptions import (
APIConnectionError,
ContentPolicyViolationError,
InternalServerError,
OpenAIError,
RateLimitError,
)
from openhands.core.config import LLMConfig
from openhands.core.exceptions import OperationCancelled
from openhands.core.metrics import Metrics
from openhands.llm.llm import LLM
@pytest.fixture(autouse=True)
def mock_logger(monkeypatch):
# suppress logging of completion data to file
mock_logger = MagicMock()
monkeypatch.setattr('openhands.llm.llm.llm_prompt_logger', mock_logger)
monkeypatch.setattr('openhands.llm.llm.llm_response_logger', mock_logger)
return mock_logger
@pytest.fixture
def default_config():
return LLMConfig(model='gpt-4o', api_key='test_key')
return LLMConfig(
model='gpt-4o',
api_key='test_key',
num_retries=2,
retry_min_wait=1,
retry_max_wait=2,
)
def test_llm_init_with_default_config(default_config):
@@ -64,7 +87,7 @@ def test_llm_init_with_metrics():
def test_llm_reset():
llm = LLM(LLMConfig(model='gpt-3.5-turbo', api_key='test_key'))
llm = LLM(LLMConfig(model='gpt-4o-mini', api_key='test_key'))
initial_metrics = llm.metrics
llm.reset()
assert llm.metrics is not initial_metrics
@@ -73,7 +96,7 @@ def test_llm_reset():
@patch('openhands.llm.llm.litellm.get_model_info')
def test_llm_init_with_openrouter_model(mock_get_model_info, default_config):
default_config.model = 'openrouter:gpt-3.5-turbo'
default_config.model = 'openrouter:gpt-4o-mini'
mock_get_model_info.return_value = {
'max_input_tokens': 7000,
'max_output_tokens': 1500,
@@ -81,4 +104,197 @@ def test_llm_init_with_openrouter_model(mock_get_model_info, default_config):
llm = LLM(default_config)
assert llm.config.max_input_tokens == 7000
assert llm.config.max_output_tokens == 1500
mock_get_model_info.assert_called_once_with('openrouter:gpt-3.5-turbo')
mock_get_model_info.assert_called_once_with('openrouter:gpt-4o-mini')
# Tests involving completion and retries
@patch('openhands.llm.llm.litellm_completion')
def test_completion_with_mocked_logger(
mock_litellm_completion, default_config, mock_logger
):
mock_litellm_completion.return_value = {
'choices': [{'message': {'content': 'Test response'}}]
}
llm = LLM(config=default_config)
response = llm.completion(
messages=[{'role': 'user', 'content': 'Hello!'}],
stream=False,
)
assert response['choices'][0]['message']['content'] == 'Test response'
assert mock_litellm_completion.call_count == 1
mock_logger.debug.assert_called()
@pytest.mark.parametrize(
'exception_class,extra_args,expected_retries',
[
(
APIConnectionError,
{'llm_provider': 'test_provider', 'model': 'test_model'},
2,
),
(
ContentPolicyViolationError,
{'model': 'test_model', 'llm_provider': 'test_provider'},
2,
),
(
InternalServerError,
{'llm_provider': 'test_provider', 'model': 'test_model'},
2,
),
(OpenAIError, {}, 2),
(RateLimitError, {'llm_provider': 'test_provider', 'model': 'test_model'}, 2),
],
)
@patch('openhands.llm.llm.litellm_completion')
def test_completion_retries(
mock_litellm_completion,
default_config,
exception_class,
extra_args,
expected_retries,
):
mock_litellm_completion.side_effect = [
exception_class('Test error message', **extra_args),
{'choices': [{'message': {'content': 'Retry successful'}}]},
]
llm = LLM(config=default_config)
response = llm.completion(
messages=[{'role': 'user', 'content': 'Hello!'}],
stream=False,
)
assert response['choices'][0]['message']['content'] == 'Retry successful'
assert mock_litellm_completion.call_count == expected_retries
@patch('openhands.llm.llm.litellm_completion')
def test_completion_rate_limit_wait_time(mock_litellm_completion, default_config):
with patch('time.sleep') as mock_sleep:
mock_litellm_completion.side_effect = [
RateLimitError(
'Rate limit exceeded', llm_provider='test_provider', model='test_model'
),
{'choices': [{'message': {'content': 'Retry successful'}}]},
]
llm = LLM(config=default_config)
response = llm.completion(
messages=[{'role': 'user', 'content': 'Hello!'}],
stream=False,
)
assert response['choices'][0]['message']['content'] == 'Retry successful'
assert mock_litellm_completion.call_count == 2
mock_sleep.assert_called_once()
wait_time = mock_sleep.call_args[0][0]
assert (
60 <= wait_time <= 240
), f'Expected wait time between 60 and 240 seconds, but got {wait_time}'
@patch('openhands.llm.llm.litellm_completion')
def test_completion_exhausts_retries(mock_litellm_completion, default_config):
mock_litellm_completion.side_effect = APIConnectionError(
'Persistent error', llm_provider='test_provider', model='test_model'
)
llm = LLM(config=default_config)
with pytest.raises(APIConnectionError):
llm.completion(
messages=[{'role': 'user', 'content': 'Hello!'}],
stream=False,
)
assert mock_litellm_completion.call_count == llm.config.num_retries
@patch('openhands.llm.llm.litellm_completion')
def test_completion_operation_cancelled(mock_litellm_completion, default_config):
mock_litellm_completion.side_effect = OperationCancelled('Operation cancelled')
llm = LLM(config=default_config)
with pytest.raises(OperationCancelled):
llm.completion(
messages=[{'role': 'user', 'content': 'Hello!'}],
stream=False,
)
assert mock_litellm_completion.call_count == 1
@patch('openhands.llm.llm.litellm_completion')
def test_completion_keyboard_interrupt(mock_litellm_completion, default_config):
def side_effect(*args, **kwargs):
raise KeyboardInterrupt('Simulated KeyboardInterrupt')
mock_litellm_completion.side_effect = side_effect
llm = LLM(config=default_config)
with pytest.raises(OperationCancelled):
try:
llm.completion(
messages=[{'role': 'user', 'content': 'Hello!'}],
stream=False,
)
except KeyboardInterrupt:
raise OperationCancelled('Operation cancelled due to KeyboardInterrupt')
assert mock_litellm_completion.call_count == 1
@patch('openhands.llm.llm.litellm_completion')
def test_completion_keyboard_interrupt_handler(mock_litellm_completion, default_config):
global _should_exit
def side_effect(*args, **kwargs):
global _should_exit
_should_exit = True
return {'choices': [{'message': {'content': 'Simulated interrupt response'}}]}
mock_litellm_completion.side_effect = side_effect
llm = LLM(config=default_config)
result = llm.completion(
messages=[{'role': 'user', 'content': 'Hello!'}],
stream=False,
)
assert mock_litellm_completion.call_count == 1
assert result['choices'][0]['message']['content'] == 'Simulated interrupt response'
assert _should_exit
_should_exit = False
@patch('openhands.llm.llm.litellm_completion')
def test_completion_with_litellm_mock(mock_litellm_completion, default_config):
mock_response = {
'choices': [{'message': {'content': 'This is a mocked response.'}}]
}
mock_litellm_completion.return_value = mock_response
test_llm = LLM(config=default_config)
response = test_llm.completion(
messages=[{'role': 'user', 'content': 'Hello!'}],
stream=False,
drop_params=True,
)
# Assertions
assert response['choices'][0]['message']['content'] == 'This is a mocked response.'
mock_litellm_completion.assert_called_once()
# Check if the correct arguments were passed to litellm_completion
call_args = mock_litellm_completion.call_args[1] # Get keyword arguments
assert call_args['model'] == default_config.model
assert call_args['messages'] == [{'role': 'user', 'content': 'Hello!'}]
assert not call_args['stream']