mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Show LLM retries and allow resume from rate-limit state (#6438)
Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
This commit is contained in:
parent
1bccfb3492
commit
fd73f4210e
@ -180,8 +180,7 @@ export function ChatInterface() {
|
||||
onStop={handleStop}
|
||||
isDisabled={
|
||||
curAgentState === AgentState.LOADING ||
|
||||
curAgentState === AgentState.AWAITING_USER_CONFIRMATION ||
|
||||
curAgentState === AgentState.RATE_LIMITED
|
||||
curAgentState === AgentState.AWAITING_USER_CONFIRMATION
|
||||
}
|
||||
mode={curAgentState === AgentState.RUNNING ? "stop" : "submit"}
|
||||
value={messageToSend ?? undefined}
|
||||
|
||||
@ -3816,6 +3816,9 @@
|
||||
"es": "Hubo un error al conectar con el entorno de ejecución. Por favor, actualice la página.",
|
||||
"tr": "Çalışma zamanına bağlanırken bir hata oluştu. Lütfen sayfayı yenileyin."
|
||||
},
|
||||
"STATUS$LLM_RETRY": {
|
||||
"en": "Retrying LLM request"
|
||||
},
|
||||
"AGENT_ERROR$BAD_ACTION": {
|
||||
"en": "Agent tried to execute a malformed action.",
|
||||
"zh-CN": "错误的操作",
|
||||
|
||||
@ -235,8 +235,10 @@ class AgentController:
|
||||
f'report this error to the developers. Your session ID is {self.id}. '
|
||||
f'Error type: {e.__class__.__name__}'
|
||||
)
|
||||
if isinstance(e, litellm.AuthenticationError) or isinstance(
|
||||
e, litellm.BadRequestError
|
||||
if (
|
||||
isinstance(e, litellm.AuthenticationError)
|
||||
or isinstance(e, litellm.BadRequestError)
|
||||
or isinstance(e, RateLimitError)
|
||||
):
|
||||
reported = e
|
||||
await self._react_to_exception(reported)
|
||||
@ -530,7 +532,7 @@ class AgentController:
|
||||
agent_cls: Type[Agent] = Agent.get_cls(action.agent)
|
||||
agent_config = self.agent_configs.get(action.agent, self.agent.config)
|
||||
llm_config = self.agent_to_llm_config.get(action.agent, self.agent.llm.config)
|
||||
llm = LLM(config=llm_config)
|
||||
llm = LLM(config=llm_config, retry_listener=self._notify_on_llm_retry)
|
||||
delegate_agent = agent_cls(llm=llm, config=agent_config)
|
||||
state = State(
|
||||
inputs=action.inputs or {},
|
||||
@ -725,6 +727,13 @@ class AgentController:
|
||||
log_level = 'info' if LOG_ALL_EVENTS else 'debug'
|
||||
self.log(log_level, str(action), extra={'msg_type': 'ACTION'})
|
||||
|
||||
def _notify_on_llm_retry(self, retries: int, max: int) -> None:
|
||||
if self.status_callback is not None:
|
||||
msg_id = 'STATUS$LLM_RETRY'
|
||||
self.status_callback(
|
||||
'info', msg_id, f'Retrying LLM request, {retries} / {max}'
|
||||
)
|
||||
|
||||
async def _handle_traffic_control(
|
||||
self, limit_type: str, current_value: float, max_value: float
|
||||
) -> bool:
|
||||
|
||||
@ -3,7 +3,7 @@ import os
|
||||
import time
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
from typing import Any, Callable
|
||||
|
||||
import requests
|
||||
|
||||
@ -94,6 +94,7 @@ class LLM(RetryMixin, DebugMixin):
|
||||
self,
|
||||
config: LLMConfig,
|
||||
metrics: Metrics | None = None,
|
||||
retry_listener: Callable[[int, int], None] | None = None,
|
||||
):
|
||||
"""Initializes the LLM. If LLMConfig is passed, its values will be the fallback.
|
||||
|
||||
@ -111,7 +112,7 @@ class LLM(RetryMixin, DebugMixin):
|
||||
self.config: LLMConfig = copy.deepcopy(config)
|
||||
|
||||
self.model_info: ModelInfo | None = None
|
||||
|
||||
self.retry_listener = retry_listener
|
||||
if self.config.log_completions:
|
||||
if self.config.log_completions_folder is None:
|
||||
raise RuntimeError(
|
||||
@ -168,6 +169,7 @@ class LLM(RetryMixin, DebugMixin):
|
||||
retry_min_wait=self.config.retry_min_wait,
|
||||
retry_max_wait=self.config.retry_max_wait,
|
||||
retry_multiplier=self.config.retry_multiplier,
|
||||
retry_listener=self.retry_listener,
|
||||
)
|
||||
def wrapper(*args, **kwargs):
|
||||
"""Wrapper for the litellm completion function. Logs the input and output of the completion function."""
|
||||
|
||||
@ -28,9 +28,15 @@ class RetryMixin:
|
||||
retry_min_wait = kwargs.get('retry_min_wait')
|
||||
retry_max_wait = kwargs.get('retry_max_wait')
|
||||
retry_multiplier = kwargs.get('retry_multiplier')
|
||||
retry_listener = kwargs.get('retry_listener')
|
||||
|
||||
def before_sleep(retry_state):
|
||||
self.log_retry_attempt(retry_state)
|
||||
if retry_listener:
|
||||
retry_listener(retry_state.attempt_number, num_retries)
|
||||
|
||||
return retry(
|
||||
before_sleep=self.log_retry_attempt,
|
||||
before_sleep=before_sleep,
|
||||
stop=stop_after_attempt(num_retries) | stop_if_should_exit(),
|
||||
reraise=True,
|
||||
retry=(retry_if_exception_type(retry_exceptions)),
|
||||
|
||||
@ -104,7 +104,7 @@ class Session:
|
||||
|
||||
# TODO: override other LLM config & agent config groups (#2075)
|
||||
|
||||
llm = LLM(config=self.config.get_llm_config_from_agent(agent_cls))
|
||||
llm = self._create_llm(agent_cls)
|
||||
agent_config = self.config.get_agent_config(agent_cls)
|
||||
|
||||
if settings.enable_default_condenser:
|
||||
@ -142,6 +142,21 @@ class Session:
|
||||
)
|
||||
return
|
||||
|
||||
def _create_llm(self, agent_cls: str | None) -> LLM:
|
||||
"""
|
||||
Initialize LLM, extracted for testing.
|
||||
"""
|
||||
return LLM(
|
||||
config=self.config.get_llm_config_from_agent(agent_cls),
|
||||
retry_listener=self._notify_on_llm_retry,
|
||||
)
|
||||
|
||||
def _notify_on_llm_retry(self, retries: int, max: int) -> None:
|
||||
msg_id = 'STATUS$LLM_RETRY'
|
||||
self.queue_status_message(
|
||||
'info', msg_id, f'Retrying LLM request, {retries} / {max}'
|
||||
)
|
||||
|
||||
def on_event(self, event: Event):
|
||||
asyncio.get_event_loop().run_until_complete(self._on_event(event))
|
||||
|
||||
@ -220,7 +235,6 @@ class Session:
|
||||
"""Sends a status message to the client."""
|
||||
if msg_type == 'error':
|
||||
await self.agent_session.stop_agent_loop_for_error()
|
||||
|
||||
await self.send(
|
||||
{'status_update': True, 'type': msg_type, 'id': id, 'message': message}
|
||||
)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import ANY, AsyncMock, MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
@ -564,6 +564,22 @@ async def test_run_controller_max_iterations_has_metrics():
|
||||
), f'Expected accumulated cost to be 30.0, but got {state.metrics.accumulated_cost}'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_notify_on_llm_retry(mock_agent, mock_event_stream, mock_status_callback):
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
status_callback=mock_status_callback,
|
||||
max_iterations=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
controller._notify_on_llm_retry(1, 2)
|
||||
controller.status_callback.assert_called_once_with('info', 'STATUS$LLM_RETRY', ANY)
|
||||
await controller.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_window_exceeded_error_handling(mock_agent, mock_event_stream):
|
||||
"""Test that context window exceeded errors are handled correctly by truncating history."""
|
||||
|
||||
69
tests/unit/test_session.py
Normal file
69
tests/unit/test_session.py
Normal file
@ -0,0 +1,69 @@
|
||||
from unittest.mock import ANY, AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from litellm.exceptions import (
|
||||
RateLimitError,
|
||||
)
|
||||
|
||||
from openhands.core.config.app_config import AppConfig
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.server.session.session import Session
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_status_callback():
|
||||
return AsyncMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_sio():
|
||||
return AsyncMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_llm_config():
|
||||
return LLMConfig(
|
||||
model='gpt-4o',
|
||||
api_key='test_key',
|
||||
num_retries=2,
|
||||
retry_min_wait=1,
|
||||
retry_max_wait=2,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('openhands.llm.llm.litellm_completion')
|
||||
async def test_notify_on_llm_retry(
|
||||
mock_litellm_completion, mock_sio, default_llm_config
|
||||
):
|
||||
config = AppConfig()
|
||||
config.set_llm_config(default_llm_config)
|
||||
session = Session(
|
||||
sid='..sid..',
|
||||
file_store=InMemoryFileStore({}),
|
||||
config=config,
|
||||
sio=mock_sio,
|
||||
user_id='..uid..',
|
||||
)
|
||||
session.queue_status_message = AsyncMock()
|
||||
|
||||
with patch('time.sleep') as _mock_sleep:
|
||||
mock_litellm_completion.side_effect = [
|
||||
RateLimitError(
|
||||
'Rate limit exceeded', llm_provider='test_provider', model='test_model'
|
||||
),
|
||||
{'choices': [{'message': {'content': 'Retry successful'}}]},
|
||||
]
|
||||
llm = session._create_llm('..cls..')
|
||||
|
||||
llm.completion(
|
||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert mock_litellm_completion.call_count == 2
|
||||
session.queue_status_message.assert_called_once_with(
|
||||
'info', 'STATUS$LLM_RETRY', ANY
|
||||
)
|
||||
await session.close()
|
||||
Loading…
x
Reference in New Issue
Block a user