mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Reset a failed tool call (#5666)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
parent
7ae1f768fc
commit
a2e9e206e8
@ -482,18 +482,7 @@ class CodeActAgent(Agent):
|
||||
if message:
|
||||
if message.role == 'user':
|
||||
self.prompt_manager.enhance_message(message)
|
||||
# handle error if the message is the SAME role as the previous message
|
||||
# litellm.exceptions.BadRequestError: litellm.BadRequestError: OpenAIException - Error code: 400 - {'detail': 'Only supports u/a/u/a/u...'}
|
||||
# there shouldn't be two consecutive messages from the same role
|
||||
# NOTE: we shouldn't combine tool messages because each of them has a different tool_call_id
|
||||
if (
|
||||
messages
|
||||
and messages[-1].role == message.role
|
||||
and message.role != 'tool'
|
||||
):
|
||||
messages[-1].content.extend(message.content)
|
||||
else:
|
||||
messages.append(message)
|
||||
messages.append(message)
|
||||
|
||||
if self.llm.is_caching_prompt_active():
|
||||
# NOTE: this is only needed for anthropic
|
||||
|
||||
@ -335,6 +335,28 @@ class AgentController:
|
||||
def _reset(self) -> None:
|
||||
"""Resets the agent controller"""
|
||||
|
||||
# make sure there is an Observation with the tool call metadata to be recognized by the agent
|
||||
# otherwise the pending action is found in history, but it's incomplete without an obs with tool result
|
||||
if self._pending_action and hasattr(self._pending_action, 'tool_call_metadata'):
|
||||
# find out if there already is an observation with the same tool call metadata
|
||||
found_observation = False
|
||||
for event in self.state.history:
|
||||
if (
|
||||
isinstance(event, Observation)
|
||||
and event.tool_call_metadata
|
||||
== self._pending_action.tool_call_metadata
|
||||
):
|
||||
found_observation = True
|
||||
break
|
||||
|
||||
# make a new ErrorObservation with the tool call metadata
|
||||
if not found_observation:
|
||||
obs = ErrorObservation(content='The action has not been executed.')
|
||||
obs.tool_call_metadata = self._pending_action.tool_call_metadata
|
||||
obs._cause = self._pending_action.id # type: ignore[attr-defined]
|
||||
self.event_stream.add_event(obs, EventSource.AGENT)
|
||||
|
||||
# reset the pending action, this will be called when the agent is STOPPED or ERROR
|
||||
self._pending_action = None
|
||||
self.agent.reset()
|
||||
|
||||
|
||||
@ -13,8 +13,8 @@ with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore')
|
||||
import litellm
|
||||
|
||||
from litellm import ChatCompletionMessageToolCall, ModelInfo, PromptTokensDetails
|
||||
from litellm import Message as LiteLLMMessage
|
||||
from litellm import ModelInfo, PromptTokensDetails
|
||||
from litellm import completion as litellm_completion
|
||||
from litellm import completion_cost as litellm_completion_cost
|
||||
from litellm.exceptions import (
|
||||
@ -246,7 +246,9 @@ class LLM(RetryMixin, DebugMixin):
|
||||
resp.choices[0].message = fn_call_response_message
|
||||
|
||||
message_back: str = resp['choices'][0]['message']['content'] or ''
|
||||
tool_calls = resp['choices'][0]['message'].get('tool_calls', [])
|
||||
tool_calls: list[ChatCompletionMessageToolCall] = resp['choices'][0][
|
||||
'message'
|
||||
].get('tool_calls', [])
|
||||
if tool_calls:
|
||||
for tool_call in tool_calls:
|
||||
fn_name = tool_call.function.name
|
||||
|
||||
@ -387,3 +387,152 @@ async def test_step_max_budget_headless(mock_agent, mock_event_stream):
|
||||
# In headless mode, throttling results in an error
|
||||
assert controller.state.agent_state == AgentState.ERROR
|
||||
await controller.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_with_pending_action_no_observation(mock_agent, mock_event_stream):
|
||||
"""Test reset() when there's a pending action with tool call metadata but no observation."""
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
|
||||
# Create a pending action with tool call metadata
|
||||
pending_action = CmdRunAction(command='test')
|
||||
pending_action.tool_call_metadata = {
|
||||
'function': 'test_function',
|
||||
'args': {'arg1': 'value1'},
|
||||
}
|
||||
controller._pending_action = pending_action
|
||||
|
||||
# Call reset
|
||||
controller._reset()
|
||||
|
||||
# Verify that an ErrorObservation was added to the event stream
|
||||
mock_event_stream.add_event.assert_called_once()
|
||||
args, kwargs = mock_event_stream.add_event.call_args
|
||||
error_obs, source = args
|
||||
assert isinstance(error_obs, ErrorObservation)
|
||||
assert error_obs.content == 'The action has not been executed.'
|
||||
assert error_obs.tool_call_metadata == pending_action.tool_call_metadata
|
||||
assert error_obs._cause == pending_action.id
|
||||
assert source == EventSource.AGENT
|
||||
|
||||
# Verify that pending action was reset
|
||||
assert controller._pending_action is None
|
||||
|
||||
# Verify that agent.reset() was called
|
||||
mock_agent.reset.assert_called_once()
|
||||
await controller.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_with_pending_action_existing_observation(
|
||||
mock_agent, mock_event_stream
|
||||
):
|
||||
"""Test reset() when there's a pending action with tool call metadata and an existing observation."""
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
|
||||
# Create a pending action with tool call metadata
|
||||
pending_action = CmdRunAction(command='test')
|
||||
pending_action.tool_call_metadata = {
|
||||
'function': 'test_function',
|
||||
'args': {'arg1': 'value1'},
|
||||
}
|
||||
controller._pending_action = pending_action
|
||||
|
||||
# Add an existing observation to the history
|
||||
existing_obs = ErrorObservation(content='Previous error')
|
||||
existing_obs.tool_call_metadata = pending_action.tool_call_metadata
|
||||
controller.state.history.append(existing_obs)
|
||||
|
||||
# Call reset
|
||||
controller._reset()
|
||||
|
||||
# Verify that no new ErrorObservation was added to the event stream
|
||||
mock_event_stream.add_event.assert_not_called()
|
||||
|
||||
# Verify that pending action was reset
|
||||
assert controller._pending_action is None
|
||||
|
||||
# Verify that agent.reset() was called
|
||||
mock_agent.reset.assert_called_once()
|
||||
await controller.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_without_pending_action(mock_agent, mock_event_stream):
|
||||
"""Test reset() when there's no pending action."""
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
|
||||
# Call reset
|
||||
controller._reset()
|
||||
|
||||
# Verify that no ErrorObservation was added to the event stream
|
||||
mock_event_stream.add_event.assert_not_called()
|
||||
|
||||
# Verify that pending action is None
|
||||
assert controller._pending_action is None
|
||||
|
||||
# Verify that agent.reset() was called
|
||||
mock_agent.reset.assert_called_once()
|
||||
await controller.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_with_pending_action_no_metadata(
|
||||
mock_agent, mock_event_stream, monkeypatch
|
||||
):
|
||||
"""Test reset() when there's a pending action without tool call metadata."""
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
|
||||
# Create a pending action without tool call metadata
|
||||
pending_action = CmdRunAction(command='test')
|
||||
# Mock hasattr to return False for tool_call_metadata
|
||||
original_hasattr = hasattr
|
||||
|
||||
def mock_hasattr(obj, name):
|
||||
if obj == pending_action and name == 'tool_call_metadata':
|
||||
return False
|
||||
return original_hasattr(obj, name)
|
||||
|
||||
monkeypatch.setattr('builtins.hasattr', mock_hasattr)
|
||||
controller._pending_action = pending_action
|
||||
|
||||
# Call reset
|
||||
controller._reset()
|
||||
|
||||
# Verify that no ErrorObservation was added to the event stream
|
||||
mock_event_stream.add_event.assert_not_called()
|
||||
|
||||
# Verify that pending action was reset
|
||||
assert controller._pending_action is None
|
||||
|
||||
# Verify that agent.reset() was called
|
||||
mock_agent.reset.assert_called_once()
|
||||
await controller.close()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user