Reset a failed tool call (#5666)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Engel Nyst 2024-12-31 21:21:32 +01:00 committed by GitHub
parent 7ae1f768fc
commit a2e9e206e8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 176 additions and 14 deletions

View File

@ -482,18 +482,7 @@ class CodeActAgent(Agent):
if message:
if message.role == 'user':
self.prompt_manager.enhance_message(message)
# handle error if the message is the SAME role as the previous message
# litellm.exceptions.BadRequestError: litellm.BadRequestError: OpenAIException - Error code: 400 - {'detail': 'Only supports u/a/u/a/u...'}
# there shouldn't be two consecutive messages from the same role
# NOTE: we shouldn't combine tool messages because each of them has a different tool_call_id
if (
messages
and messages[-1].role == message.role
and message.role != 'tool'
):
messages[-1].content.extend(message.content)
else:
messages.append(message)
messages.append(message)
if self.llm.is_caching_prompt_active():
# NOTE: this is only needed for anthropic

View File

@ -335,6 +335,28 @@ class AgentController:
def _reset(self) -> None:
"""Resets the agent controller"""
# make sure there is an Observation with the tool call metadata to be recognized by the agent
# otherwise the pending action is found in history, but it's incomplete without an obs with tool result
if self._pending_action and hasattr(self._pending_action, 'tool_call_metadata'):
# find out if there already is an observation with the same tool call metadata
found_observation = False
for event in self.state.history:
if (
isinstance(event, Observation)
and event.tool_call_metadata
== self._pending_action.tool_call_metadata
):
found_observation = True
break
# make a new ErrorObservation with the tool call metadata
if not found_observation:
obs = ErrorObservation(content='The action has not been executed.')
obs.tool_call_metadata = self._pending_action.tool_call_metadata
obs._cause = self._pending_action.id # type: ignore[attr-defined]
self.event_stream.add_event(obs, EventSource.AGENT)
# reset the pending action, this will be called when the agent is STOPPED or ERROR
self._pending_action = None
self.agent.reset()

View File

@ -13,8 +13,8 @@ with warnings.catch_warnings():
warnings.simplefilter('ignore')
import litellm
from litellm import ChatCompletionMessageToolCall, ModelInfo, PromptTokensDetails
from litellm import Message as LiteLLMMessage
from litellm import ModelInfo, PromptTokensDetails
from litellm import completion as litellm_completion
from litellm import completion_cost as litellm_completion_cost
from litellm.exceptions import (
@ -246,7 +246,9 @@ class LLM(RetryMixin, DebugMixin):
resp.choices[0].message = fn_call_response_message
message_back: str = resp['choices'][0]['message']['content'] or ''
tool_calls = resp['choices'][0]['message'].get('tool_calls', [])
tool_calls: list[ChatCompletionMessageToolCall] = resp['choices'][0][
'message'
].get('tool_calls', [])
if tool_calls:
for tool_call in tool_calls:
fn_name = tool_call.function.name

View File

@ -387,3 +387,152 @@ async def test_step_max_budget_headless(mock_agent, mock_event_stream):
# In headless mode, throttling results in an error
assert controller.state.agent_state == AgentState.ERROR
await controller.close()
@pytest.mark.asyncio
async def test_reset_with_pending_action_no_observation(mock_agent, mock_event_stream):
"""Test reset() when there's a pending action with tool call metadata but no observation."""
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
)
# Create a pending action with tool call metadata
pending_action = CmdRunAction(command='test')
pending_action.tool_call_metadata = {
'function': 'test_function',
'args': {'arg1': 'value1'},
}
controller._pending_action = pending_action
# Call reset
controller._reset()
# Verify that an ErrorObservation was added to the event stream
mock_event_stream.add_event.assert_called_once()
args, kwargs = mock_event_stream.add_event.call_args
error_obs, source = args
assert isinstance(error_obs, ErrorObservation)
assert error_obs.content == 'The action has not been executed.'
assert error_obs.tool_call_metadata == pending_action.tool_call_metadata
assert error_obs._cause == pending_action.id
assert source == EventSource.AGENT
# Verify that pending action was reset
assert controller._pending_action is None
# Verify that agent.reset() was called
mock_agent.reset.assert_called_once()
await controller.close()
@pytest.mark.asyncio
async def test_reset_with_pending_action_existing_observation(
mock_agent, mock_event_stream
):
"""Test reset() when there's a pending action with tool call metadata and an existing observation."""
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
)
# Create a pending action with tool call metadata
pending_action = CmdRunAction(command='test')
pending_action.tool_call_metadata = {
'function': 'test_function',
'args': {'arg1': 'value1'},
}
controller._pending_action = pending_action
# Add an existing observation to the history
existing_obs = ErrorObservation(content='Previous error')
existing_obs.tool_call_metadata = pending_action.tool_call_metadata
controller.state.history.append(existing_obs)
# Call reset
controller._reset()
# Verify that no new ErrorObservation was added to the event stream
mock_event_stream.add_event.assert_not_called()
# Verify that pending action was reset
assert controller._pending_action is None
# Verify that agent.reset() was called
mock_agent.reset.assert_called_once()
await controller.close()
@pytest.mark.asyncio
async def test_reset_without_pending_action(mock_agent, mock_event_stream):
"""Test reset() when there's no pending action."""
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
)
# Call reset
controller._reset()
# Verify that no ErrorObservation was added to the event stream
mock_event_stream.add_event.assert_not_called()
# Verify that pending action is None
assert controller._pending_action is None
# Verify that agent.reset() was called
mock_agent.reset.assert_called_once()
await controller.close()
@pytest.mark.asyncio
async def test_reset_with_pending_action_no_metadata(
mock_agent, mock_event_stream, monkeypatch
):
"""Test reset() when there's a pending action without tool call metadata."""
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
)
# Create a pending action without tool call metadata
pending_action = CmdRunAction(command='test')
# Mock hasattr to return False for tool_call_metadata
original_hasattr = hasattr
def mock_hasattr(obj, name):
if obj == pending_action and name == 'tool_call_metadata':
return False
return original_hasattr(obj, name)
monkeypatch.setattr('builtins.hasattr', mock_hasattr)
controller._pending_action = pending_action
# Call reset
controller._reset()
# Verify that no ErrorObservation was added to the event stream
mock_event_stream.add_event.assert_not_called()
# Verify that pending action was reset
assert controller._pending_action is None
# Verify that agent.reset() was called
mock_agent.reset.assert_called_once()
await controller.close()