Refactor history/event stream (#3808)

This commit is contained in:
Engel Nyst 2024-11-05 03:36:14 +01:00 committed by GitHub
parent edfba4618a
commit eeb2342509
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
46 changed files with 559 additions and 609 deletions

View File

@ -161,7 +161,7 @@ Pour créer un workflow d'évaluation pour votre benchmark, suivez ces étapes :
instruction=instruction,
test_result=evaluation_result,
metadata=metadata,
history=state.history.compatibility_for_eval_history_pairs(),
history=compatibility_for_eval_history_pairs(state.history),
metrics=state.metrics.get() if state.metrics else None,
error=state.last_error if state and state.last_error else None,
)
@ -260,7 +260,7 @@ def codeact_user_response(state: State | None) -> str:
# vérifier si l'agent a essayé de parler à l'utilisateur 3 fois, si oui, faire savoir à l'agent qu'il peut abandonner
user_msgs = [
event
for event in state.history.get_events()
for event in state.history
if isinstance(event, MessageAction) and event.source == 'user'
]
if len(user_msgs) >= 2:

View File

@ -158,7 +158,7 @@ OpenHands 的主要入口点在 `openhands/core/main.py` 中。以下是它工
instruction=instruction,
test_result=evaluation_result,
metadata=metadata,
history=state.history.compatibility_for_eval_history_pairs(),
history=compatibility_for_eval_history_pairs(state.history),
metrics=state.metrics.get() if state.metrics else None,
error=state.last_error if state and state.last_error else None,
)
@ -257,7 +257,7 @@ def codeact_user_response(state: State | None) -> str:
# 检查代理是否已尝试与用户对话 3 次,如果是,让代理知道它可以放弃
user_msgs = [
event
for event in state.history.get_events()
for event in state.history
if isinstance(event, MessageAction) and event.source == 'user'
]
if len(user_msgs) >= 2:

View File

@ -158,7 +158,7 @@ To create an evaluation workflow for your benchmark, follow these steps:
instruction=instruction,
test_result=evaluation_result,
metadata=metadata,
history=state.history.compatibility_for_eval_history_pairs(),
history=compatibility_for_eval_history_pairs(state.history),
metrics=state.metrics.get() if state.metrics else None,
error=state.last_error if state and state.last_error else None,
)
@ -257,7 +257,7 @@ def codeact_user_response(state: State | None) -> str:
# check if the agent has tried to talk to the user 3 times, if so, let the agent know it can give up
user_msgs = [
event
for event in state.history.get_events()
for event in state.history
if isinstance(event, MessageAction) and event.source == 'user'
]
if len(user_msgs) >= 2:

View File

@ -8,6 +8,7 @@ from evaluation.EDA.game import Q20Game, Q20GameCelebrity
from evaluation.utils.shared import (
EvalMetadata,
EvalOutput,
compatibility_for_eval_history_pairs,
make_metadata,
prepare_dataset,
reset_logger_for_multiprocessing,
@ -34,7 +35,7 @@ def codeact_user_response_eda(state: State) -> str:
# retrieve the latest model message from history
if state.history:
model_guess = state.history.get_last_agent_message()
model_guess = state.get_last_agent_message()
assert game is not None, 'Game is not initialized.'
msg = game.generate_user_response(model_guess)
@ -139,7 +140,7 @@ def process_instance(
if state is None:
raise ValueError('State should not be None.')
final_message = state.history.get_last_agent_message()
final_message = state.get_last_agent_message()
logger.info(f'Final message: {final_message} | Ground truth: {instance["text"]}')
test_result = game.reward()
@ -148,7 +149,7 @@ def process_instance(
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
# for compatibility with the existing output format, we can remake the pairs here
# remove when it becomes unnecessary
histories = state.history.compatibility_for_eval_history_pairs()
histories = compatibility_for_eval_history_pairs(state.history)
# Save the output
output = EvalOutput(

View File

@ -16,6 +16,7 @@ from evaluation.agent_bench.helper import (
from evaluation.utils.shared import (
EvalMetadata,
EvalOutput,
compatibility_for_eval_history_pairs,
make_metadata,
prepare_dataset,
reset_logger_for_multiprocessing,
@ -242,7 +243,7 @@ def process_instance(
raw_ans = ''
# retrieve the last agent message or thought
for event in state.history.get_events(reverse=True):
for event in reversed(state.history):
if event.source == 'agent':
if isinstance(event, AgentFinishAction):
raw_ans = event.thought
@ -271,7 +272,7 @@ def process_instance(
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
# for compatibility with the existing output format, we can remake the pairs here
# remove when it becomes unnecessary
histories = state.history.compatibility_for_eval_history_pairs()
histories = compatibility_for_eval_history_pairs(state.history)
metrics = state.metrics.get() if state.metrics else None

View File

@ -15,6 +15,7 @@ from evaluation.aider_bench.helper import (
from evaluation.utils.shared import (
EvalMetadata,
EvalOutput,
compatibility_for_eval_history_pairs,
make_metadata,
prepare_dataset,
reset_logger_for_multiprocessing,
@ -250,7 +251,7 @@ def process_instance(
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
# for compatibility with the existing output format, we can remake the pairs here
# remove when it becomes unnecessary
histories = state.history.compatibility_for_eval_history_pairs()
histories = compatibility_for_eval_history_pairs(state.history)
metrics = state.metrics.get() if state.metrics else None
# Save the output

View File

@ -13,6 +13,7 @@ from evaluation.utils.shared import (
EvalMetadata,
EvalOutput,
codeact_user_response,
compatibility_for_eval_history_pairs,
make_metadata,
prepare_dataset,
reset_logger_for_multiprocessing,
@ -299,7 +300,7 @@ def process_instance(
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
# for compatibility with the existing output format, we can remake the pairs here
# remove when it becomes unnecessary
histories = state.history.compatibility_for_eval_history_pairs()
histories = compatibility_for_eval_history_pairs(state.history)
test_result['generated'] = test_result['metadata']['1_copy_change_code']

View File

@ -16,6 +16,7 @@ from tqdm import tqdm
from evaluation.utils.shared import (
EvalMetadata,
EvalOutput,
compatibility_for_eval_history_pairs,
make_metadata,
prepare_dataset,
reset_logger_for_multiprocessing,
@ -46,7 +47,7 @@ def codeact_user_response(state: State) -> str:
# check if the agent has tried to talk to the user 3 times, if so, let the agent know it can give up
user_msgs = [
event
for event in state.history.get_events()
for event in state.history
if isinstance(event, MessageAction) and event.source == 'user'
]
if len(user_msgs) > 2:
@ -431,7 +432,7 @@ def process_instance(
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
# for compatibility with the existing output format, we can remake the pairs here
# remove when it becomes unnecessary
histories = state.history.compatibility_for_eval_history_pairs()
histories = compatibility_for_eval_history_pairs(state.history)
# Save the output
output = EvalOutput(

View File

@ -9,6 +9,7 @@ from datasets import load_dataset
from evaluation.utils.shared import (
EvalMetadata,
EvalOutput,
compatibility_for_eval_history_pairs,
make_metadata,
prepare_dataset,
reset_logger_for_multiprocessing,
@ -89,7 +90,7 @@ def process_instance(
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
# for compatibility with the existing output format, we can remake the pairs here
# remove when it becomes unnecessary
histories = state.history.compatibility_for_eval_history_pairs()
histories = compatibility_for_eval_history_pairs(state.history)
# find the last delegate action
last_delegate_action = None

View File

@ -15,6 +15,7 @@ from evaluation.utils.shared import (
EvalMetadata,
EvalOutput,
codeact_user_response,
compatibility_for_eval_history_pairs,
make_metadata,
prepare_dataset,
reset_logger_for_multiprocessing,
@ -173,14 +174,14 @@ def initialize_runtime(runtime: Runtime, data_files: list[str]):
def get_last_agent_finish_action(state: State) -> AgentFinishAction:
for event in state.history.get_events(reverse=True):
for event in reversed(state.history):
if isinstance(event, AgentFinishAction):
return event
return None
def get_last_message_action(state: State) -> MessageAction:
for event in state.history.get_events(reverse=True):
for event in reversed(state.history):
if isinstance(event, MessageAction):
return event
return None
@ -307,7 +308,7 @@ def process_instance(
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
# for compatibility with the existing output format, we can remake the pairs here
# remove when it becomes unnecessary
histories = state.history.compatibility_for_eval_history_pairs()
histories = compatibility_for_eval_history_pairs(state.history)
# DiscoveryBench Evaluation
eval_rec = run_eval_gold_vs_gen_NL_hypo_workflow(

View File

@ -12,6 +12,7 @@ from evaluation.utils.shared import (
EvalMetadata,
EvalOutput,
codeact_user_response,
compatibility_for_eval_history_pairs,
make_metadata,
prepare_dataset,
reset_logger_for_multiprocessing,
@ -166,7 +167,7 @@ def process_instance(
model_answer_raw = ''
# get the last message or thought from the agent
for event in state.history.get_events(reverse=True):
for event in reversed(state.history):
if event.source == 'agent':
if isinstance(event, AgentFinishAction):
model_answer_raw = event.thought
@ -203,7 +204,7 @@ def process_instance(
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
# for compatibility with the existing output format, we can remake the pairs here
# remove when it becomes unnecessary
histories = state.history.compatibility_for_eval_history_pairs()
histories = compatibility_for_eval_history_pairs(state.history)
# Save the output
output = EvalOutput(

View File

@ -10,6 +10,7 @@ from evaluation.utils.shared import (
EvalMetadata,
EvalOutput,
codeact_user_response,
compatibility_for_eval_history_pairs,
make_metadata,
prepare_dataset,
reset_logger_for_multiprocessing,
@ -101,7 +102,7 @@ def process_instance(
raise ValueError('State should not be None.')
# retrieve the last message from the agent
model_answer_raw = state.history.get_last_agent_message()
model_answer_raw = state.get_last_agent_message()
# attempt to parse model_answer
ast_eval_fn = instance['ast_eval']
@ -114,7 +115,7 @@ def process_instance(
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
# for compatibility with the existing output format, we can remake the pairs here
# remove when it becomes unnecessary
histories = state.history.compatibility_for_eval_history_pairs()
histories = compatibility_for_eval_history_pairs(state.history)
output = EvalOutput(
instance_id=instance_id,

View File

@ -28,6 +28,7 @@ from datasets import load_dataset
from evaluation.utils.shared import (
EvalMetadata,
EvalOutput,
compatibility_for_eval_history_pairs,
make_metadata,
prepare_dataset,
reset_logger_for_multiprocessing,
@ -244,7 +245,7 @@ Ok now its time to start solving the question. Good luck!
'C': False,
'D': False,
}
for event in state.history.get_events(reverse=True):
for event in reversed(state.history):
if (
isinstance(event, AgentFinishAction)
and event.source != 'user'
@ -300,7 +301,7 @@ Ok now its time to start solving the question. Good luck!
instance_id=str(instance.instance_id),
instruction=instruction,
metadata=metadata,
history=state.history.compatibility_for_eval_history_pairs(),
history=compatibility_for_eval_history_pairs(state.history),
metrics=metrics,
error=state.last_error if state and state.last_error else None,
test_result={

View File

@ -21,6 +21,7 @@ from evaluation.utils.shared import (
EvalMetadata,
EvalOutput,
codeact_user_response,
compatibility_for_eval_history_pairs,
make_metadata,
prepare_dataset,
reset_logger_for_multiprocessing,
@ -255,7 +256,7 @@ def process_instance(
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
# for compatibility with the existing output format, we can remake the pairs here
# remove when it becomes unnecessary
histories = state.history.compatibility_for_eval_history_pairs()
histories = compatibility_for_eval_history_pairs(state.history)
# Save the output
output = EvalOutput(

View File

@ -129,7 +129,7 @@ def process_instance(
# # result evaluation
# # =============================================
histories = [event_to_dict(event) for event in state.history.get_events()]
histories = [event_to_dict(event) for event in state.history]
test_result: TestResult = test_class.verify_result(runtime, histories)
metrics = state.metrics.get() if state.metrics else None

View File

@ -8,6 +8,7 @@ from evaluation.utils.shared import (
EvalMetadata,
EvalOutput,
codeact_user_response,
compatibility_for_eval_history_pairs,
make_metadata,
prepare_dataset,
reset_logger_for_multiprocessing,
@ -225,7 +226,7 @@ def process_instance(
raise ValueError('State should not be None.')
final_message = ''
for event in state.history.get_events(reverse=True):
for event in reversed(state.history):
if isinstance(event, AgentFinishAction):
final_message = event.thought
break
@ -247,7 +248,7 @@ def process_instance(
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
# for compatibility with the existing output format, we can remake the pairs here
# remove when it becomes unnecessary
histories = state.history.compatibility_for_eval_history_pairs()
histories = compatibility_for_eval_history_pairs(state.history)
# Save the output
output = EvalOutput(

View File

@ -11,6 +11,7 @@ from evaluation.utils.shared import (
EvalMetadata,
EvalOutput,
codeact_user_response,
compatibility_for_eval_history_pairs,
make_metadata,
prepare_dataset,
reset_logger_for_multiprocessing,
@ -182,7 +183,7 @@ def process_instance(
# Instruction is the first message from the USER
instruction = ''
for event in state.history.get_events():
for event in state.history:
if isinstance(event, MessageAction):
instruction = event.content
break
@ -194,7 +195,7 @@ def process_instance(
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
# for compatibility with the existing output format, we can remake the pairs here
# remove when it becomes unnecessary
histories = state.history.compatibility_for_eval_history_pairs()
histories = compatibility_for_eval_history_pairs(state.history)
# Save the output
output = EvalOutput(

View File

@ -13,6 +13,7 @@ from evaluation.mint.tasks import Task
from evaluation.utils.shared import (
EvalMetadata,
EvalOutput,
compatibility_for_eval_history_pairs,
make_metadata,
prepare_dataset,
reset_logger_for_multiprocessing,
@ -28,6 +29,7 @@ from openhands.core.config import (
from openhands.core.logger import openhands_logger as logger
from openhands.core.main import create_runtime, run_controller
from openhands.events.action import (
Action,
CmdRunAction,
MessageAction,
)
@ -45,7 +47,10 @@ def codeact_user_response_mint(state: State, task: Task, task_config: dict[str,
task=task,
task_config=task_config,
)
last_action = state.history.get_last_action()
last_action = next(
(event for event in reversed(state.history) if isinstance(event, Action)),
None,
)
result_state: TaskState = env.step(last_action.message or '')
state.extra_data['task_state'] = result_state
@ -202,7 +207,7 @@ def process_instance(
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
# for compatibility with the existing output format, we can remake the pairs here
# remove when it becomes unnecessary
histories = state.history.compatibility_for_eval_history_pairs()
histories = compatibility_for_eval_history_pairs(state.history)
# Save the output
output = EvalOutput(

View File

@ -24,6 +24,7 @@ from evaluation.utils.shared import (
EvalMetadata,
EvalOutput,
codeact_user_response,
compatibility_for_eval_history_pairs,
make_metadata,
prepare_dataset,
reset_logger_for_multiprocessing,
@ -256,7 +257,7 @@ def process_instance(instance: Any, metadata: EvalMetadata, reset_logger: bool =
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
# for compatibility with the existing output format, we can remake the pairs here
# remove when it becomes unnecessary
histories = state.history.compatibility_for_eval_history_pairs()
histories = compatibility_for_eval_history_pairs(state.history)
# Save the output
output = EvalOutput(

View File

@ -10,6 +10,7 @@ from evaluation.utils.shared import (
EvalMetadata,
EvalOutput,
codeact_user_response,
compatibility_for_eval_history_pairs,
make_metadata,
prepare_dataset,
reset_logger_for_multiprocessing,
@ -232,7 +233,7 @@ If the program uses some packages that are incompatible, please figure out alter
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
# for compatibility with the existing output format, we can remake the pairs here
# remove when it becomes unnecessary
histories = state.history.compatibility_for_eval_history_pairs()
histories = compatibility_for_eval_history_pairs(state.history)
# Save the output
output = EvalOutput(

View File

@ -443,7 +443,8 @@ def process_instance(
if state is None:
raise ValueError('State should not be None.')
histories = [event_to_dict(event) for event in state.history.get_events()]
# NOTE: this is NO LONGER the event stream, but an agent history that includes delegate agent's events
histories = [event_to_dict(event) for event in state.history]
metrics = state.metrics.get() if state.metrics else None
# Save the output

View File

@ -9,6 +9,7 @@ from evaluation.utils.shared import (
EvalMetadata,
EvalOutput,
codeact_user_response,
compatibility_for_eval_history_pairs,
make_metadata,
prepare_dataset,
reset_logger_for_multiprocessing,
@ -126,7 +127,7 @@ def process_instance(instance: Any, metadata: EvalMetadata, reset_logger: bool =
raise ValueError('State should not be None.')
# retrieve the last message from the agent
model_answer_raw = state.history.get_last_agent_message()
model_answer_raw = state.get_last_agent_message()
# attempt to parse model_answer
correct = eval_answer(str(model_answer_raw), str(answer))
@ -137,7 +138,7 @@ def process_instance(instance: Any, metadata: EvalMetadata, reset_logger: bool =
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
# for compatibility with the existing output format, we can remake the pairs here
# remove when it becomes unnecessary
histories = state.history.compatibility_for_eval_history_pairs()
histories = compatibility_for_eval_history_pairs(state.history)
# Save the output
output = EvalOutput(

View File

@ -18,6 +18,9 @@ from openhands.core.logger import get_console_handler
from openhands.core.logger import openhands_logger as logger
from openhands.events.action import Action
from openhands.events.action.message import MessageAction
from openhands.events.event import Event
from openhands.events.serialization.event import event_to_dict
from openhands.events.utils import get_pairs_from_events
class EvalMetadata(BaseModel):
@ -112,7 +115,14 @@ def codeact_user_response(
if state.history:
# check if the last action has an answer, if so, early exit
if try_parse is not None:
last_action = state.history.get_last_action()
last_action = next(
(
event
for event in reversed(state.history)
if isinstance(event, Action)
),
None,
)
ans = try_parse(last_action)
if ans is not None:
return '/exit'
@ -120,7 +130,7 @@ def codeact_user_response(
# check if the agent has tried to talk to the user 3 times, if so, let the agent know it can give up
user_msgs = [
event
for event in state.history.get_events()
for event in state.history
if isinstance(event, MessageAction) and event.source == 'user'
]
if len(user_msgs) >= 2:
@ -428,3 +438,18 @@ def update_llm_config_for_completions_logging(
f'{llm_config.log_completions_folder}'
)
return llm_config
# history is now available as a filtered stream of events, rather than list of pairs of (Action, Observation)
# we rebuild the pairs here
# for compatibility with the existing output format in evaluations
# remove this when it's no longer necessary
def compatibility_for_eval_history_pairs(
history: list[Event],
) -> list[tuple[dict, dict]]:
history_pairs = []
for action, observation in get_pairs_from_events(history):
history_pairs.append((event_to_dict(action), event_to_dict(observation)))
return history_pairs

View File

@ -10,6 +10,7 @@ import pandas as pd
from evaluation.utils.shared import (
EvalMetadata,
EvalOutput,
compatibility_for_eval_history_pairs,
make_metadata,
prepare_dataset,
reset_logger_for_multiprocessing,
@ -166,7 +167,7 @@ def process_instance(
# Instruction is the first message from the USER
instruction = ''
for event in state.history.get_events():
for event in state.history:
if isinstance(event, MessageAction):
instruction = event.content
break
@ -178,7 +179,7 @@ def process_instance(
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
# for compatibility with the existing output format, we can remake the pairs here
# remove when it becomes unnecessary
histories = state.history.compatibility_for_eval_history_pairs()
histories = compatibility_for_eval_history_pairs(state.history)
# Save the output
output = EvalOutput(

View File

@ -150,13 +150,13 @@ class BrowsingAgent(Agent):
last_obs = None
last_action = None
if EVAL_MODE and len(state.history.get_events_as_list()) == 1:
if EVAL_MODE and len(state.history) == 1:
# for webarena and miniwob++ eval, we need to retrieve the initial observation already in browser env
# initialize and retrieve the first observation by issuing an noop OP
# For non-benchmark browsing, the browser env starts with a blank page, and the agent is expected to first navigate to desired websites
return BrowseInteractiveAction(browser_actions='noop()')
for event in state.history.get_events():
for event in state.history:
if isinstance(event, BrowseInteractiveAction):
prev_actions.append(event.browser_actions)
last_action = event

View File

@ -337,8 +337,8 @@ class CodeActAgent(Agent):
return self.pending_actions.popleft()
# if we're done, go back
latest_user_message = state.history.get_last_user_message()
if latest_user_message and latest_user_message.strip() == '/exit':
last_user_message = state.get_last_user_message()
if last_user_message and last_user_message.strip() == '/exit':
return AgentFinishAction()
# prepare what we want to send to the LLM
@ -419,7 +419,7 @@ class CodeActAgent(Agent):
pending_tool_call_action_messages: dict[str, Message] = {}
tool_call_id_to_message: dict[str, Message] = {}
events = list(state.history.get_events())
events = list(state.history)
for event in events:
# create a regular message from an event
if isinstance(event, Action):

View File

@ -154,8 +154,8 @@ class CodeActSWEAgent(Agent):
- AgentFinishAction() - end the interaction
"""
# if we're done, go back
latest_user_message = state.history.get_last_user_message()
if latest_user_message and latest_user_message.strip() == '/exit':
last_user_message = state.get_last_user_message()
if last_user_message and last_user_message.strip() == '/exit':
return AgentFinishAction()
# prepare what we want to send to the LLM
@ -176,7 +176,7 @@ class CodeActSWEAgent(Agent):
Message(role='user', content=[TextContent(text=self.in_context_example)]),
]
for event in state.history.get_events():
for event in state.history:
# create a regular message from an event
if isinstance(event, Action):
message = self.get_action_message(event)

View File

@ -2,7 +2,7 @@ from openhands.controller.agent import Agent
from openhands.controller.state.state import State
from openhands.core.config import AgentConfig
from openhands.events.action import Action, AgentDelegateAction, AgentFinishAction
from openhands.events.observation import AgentDelegateObservation
from openhands.events.observation import AgentDelegateObservation, Observation
from openhands.llm.llm import LLM
@ -41,7 +41,11 @@ class DelegatorAgent(Agent):
)
# last observation in history should be from the delegate
last_observation = state.history.get_last_observation()
last_observation = None
for event in reversed(state.history):
if isinstance(event, Observation):
last_observation = event
break
if not isinstance(last_observation, AgentDelegateObservation):
raise Exception('Last observation is not an AgentDelegateObservation')

View File

@ -164,7 +164,7 @@ class DummyAgent(Agent):
if 'observations' in prev_step and prev_step['observations']:
expected_observations = prev_step['observations']
hist_events = state.history.get_last_events(len(expected_observations))
hist_events = state.history[-len(expected_observations) :]
if len(hist_events) < len(expected_observations):
print(

View File

@ -8,10 +8,10 @@ from openhands.core.config import AgentConfig
from openhands.core.message import ImageContent, Message, TextContent
from openhands.core.utils import json
from openhands.events.action import Action
from openhands.events.event import Event
from openhands.events.serialization.action import action_from_dict
from openhands.events.serialization.event import event_to_memory
from openhands.llm.llm import LLM
from openhands.memory.history import ShortTermHistory
def parse_response(orig_response: str) -> Action:
@ -32,16 +32,14 @@ class MicroAgent(Agent):
prompt = ''
agent_definition: dict = {}
def history_to_json(
self, history: ShortTermHistory, max_events: int = 20, **kwargs
):
def history_to_json(self, history: list[Event], max_events: int = 20, **kwargs):
"""
Serialize and simplify history to str format
"""
processed_history = []
event_count = 0
for event in history.get_events(reverse=True):
for event in reversed(history):
if event_count >= max_events:
break
processed_history.append(

View File

@ -117,7 +117,7 @@ def get_hint(latest_action_id: str) -> str:
def get_prompt_and_images(
state: State, max_message_chars: int
) -> tuple[str, list[str]]:
) -> tuple[str, list[str] | None]:
"""Gets the prompt for the planner agent.
Formatted with the most recent action-observation pairs, current task, and hint based on last action
@ -136,7 +136,7 @@ def get_prompt_and_images(
latest_action: Action = NullAction()
# retrieve the latest HISTORY_SIZE events
for event_count, event in enumerate(state.history.get_events(reverse=True)):
for event_count, event in enumerate(reversed(state.history)):
if event_count >= HISTORY_SIZE:
break
if latest_action == NullAction() and isinstance(event, Action):

View File

@ -1,7 +1,7 @@
import asyncio
import copy
import traceback
from typing import Callable, Type
from typing import Callable, ClassVar, Type
import litellm
@ -36,6 +36,7 @@ from openhands.events.observation import (
AgentDelegateObservation,
AgentStateChangedObservation,
ErrorObservation,
NullObservation,
Observation,
)
from openhands.events.serialization.event import truncate_content
@ -61,6 +62,12 @@ class AgentController:
parent: 'AgentController | None' = None
delegate: 'AgentController | None' = None
_pending_action: Action | None = None
filter_out: ClassVar[tuple[type[Event], ...]] = (
NullAction,
NullObservation,
ChangeAgentStateAction,
AgentStateChangedObservation,
)
def __init__(
self,
@ -121,8 +128,34 @@ class AgentController:
self.status_callback = status_callback
async def close(self):
"""Closes the agent controller, canceling any ongoing tasks and unsubscribing from the event stream."""
"""Closes the agent controller, canceling any ongoing tasks and unsubscribing from the event stream.
Note that it's fairly important that this closes properly, otherwise the state is incomplete."""
await self.set_agent_state_to(AgentState.STOPPED)
# we made history, now is the time to rewrite it!
# the final state.history will be used by external scripts like evals, tests, etc.
# history will need to be complete WITH delegates events
# like the regular agent history, it does not include:
# - 'hidden' events, events with hidden=True
# - backend events (the default 'filtered out' types, types in self.filter_out)
start_id = self.state.start_id if self.state.start_id >= 0 else 0
end_id = (
self.state.end_id
if self.state.end_id >= 0
else self.event_stream.get_latest_event_id()
)
self.state.history = list(
self.event_stream.get_events(
start_id=start_id,
end_id=end_id,
reverse=False,
filter_out_type=self.filter_out,
filter_hidden=True,
)
)
# unsubscribe from the event stream
self.event_stream.unsubscribe(EventStreamSubscriber.AGENT_CONTROLLER)
def log(self, level: str, message: str, extra: dict | None = None):
@ -178,6 +211,11 @@ class AgentController:
"""
if hasattr(event, 'hidden') and event.hidden:
return
# if the event is not filtered out, add it to the history
if not any(isinstance(event, filter_type) for filter_type in self.filter_out):
self.state.history.append(event)
if isinstance(event, Action):
await self._handle_action(event)
elif isinstance(event, Observation):
@ -233,9 +271,6 @@ class AgentController:
if self.state.agent_state == AgentState.USER_REJECTED:
await self.set_agent_state_to(AgentState.AWAITING_USER_INPUT)
return
if isinstance(observation, AgentDelegateObservation):
self.state.history.on_event(observation)
elif isinstance(observation, ErrorObservation):
if self.state.agent_state == AgentState.ERROR:
self.state.metrics.merge(self.state.local_metrics)
@ -362,6 +397,8 @@ class AgentController:
delegate_level=self.state.delegate_level + 1,
# global metrics should be shared between parent and child
metrics=self.state.metrics,
# start on top of the stream
start_id=self.event_stream.get_latest_event_id() + 1,
)
self.log(
'debug',
@ -463,9 +500,7 @@ class AgentController:
async def _delegate_step(self):
"""Executes a single step of the delegate agent."""
self.log('debug', 'Delegate not none, awaiting...')
await self.delegate._step() # type: ignore[union-attr]
self.log('debug', 'Delegate step done')
assert self.delegate is not None
delegate_state = self.delegate.get_agent_state()
self.log('debug', f'Delegate state: {delegate_state}')
@ -473,7 +508,7 @@ class AgentController:
# update iteration that shall be shared across agents
self.state.iteration = self.delegate.state.iteration
# emit AgentDelegateObservation when the delegate terminates due to error
# emit AgentDelegateObservation to mark delegate termination due to error
delegate_outputs = (
self.delegate.state.outputs if self.delegate.state else {}
)
@ -488,10 +523,6 @@ class AgentController:
self.delegate = None
self.delegateAction = None
self.event_stream.add_event(
ErrorObservation('Delegate agent encountered an error'),
EventSource.AGENT,
)
elif delegate_state in (AgentState.FINISHED, AgentState.REJECTED):
self.log('debug', 'Delegate agent has finished execution')
# retrieve delegate result
@ -574,8 +605,10 @@ class AgentController:
max_iterations: The maximum number of iterations allowed for the task.
confirmation_mode: Whether to enable confirmation mode.
"""
# state from the previous session, state from a parent agent, or a new state
# note that this is called twice when restoring a previous session, first with state=None
# state can come from:
# - the previous session, in which case it has history
# - from a parent agent, in which case it has no history
# - None / a new state
if state is None:
self.state = State(
inputs={},
@ -585,27 +618,109 @@ class AgentController:
else:
self.state = state
# when restored from a previous session, the State object will have history, start_id, and end_id
# connect it to the event stream
self.state.history.set_event_stream(self.event_stream)
if self.state.start_id <= -1:
self.state.start_id = 0
# if start_id was not set in State, we're starting fresh, at the top of the stream
start_id = self.state.start_id
if start_id == -1:
start_id = self.event_stream.get_latest_event_id() + 1
else:
self.log(
'debug', f'AgentController {self.id} restoring from event {start_id}'
'debug',
f'AgentController {self.id} initializing history from event {self.state.start_id}',
)
self._init_history()
def _init_history(self):
"""Initializes the agent's history from the event stream.
The history is a list of events that:
- Excludes events of types listed in self.filter_out
- Excludes events with hidden=True attribute
- For delegate events (between AgentDelegateAction and AgentDelegateObservation):
- Excludes all events between the action and observation
- Includes the delegate action and observation themselves
"""
# define range of events to fetch
# delegates start with a start_id and initially won't find any events
# otherwise we're restoring a previous session
start_id = self.state.start_id if self.state.start_id >= 0 else 0
end_id = (
self.state.end_id
if self.state.end_id >= 0
else self.event_stream.get_latest_event_id()
)
# sanity check
if start_id > end_id + 1:
self.log(
'debug',
f'start_id {start_id} is greater than end_id + 1 ({end_id + 1}). History will be empty.',
)
self.state.history = []
return
# Get all events, filtering out backend events and hidden events
events = list(
self.event_stream.get_events(
start_id=start_id,
end_id=end_id,
reverse=False,
filter_out_type=self.filter_out,
filter_hidden=True,
)
)
# Find all delegate action/observation pairs
delegate_ranges: list[tuple[int, int]] = []
delegate_action_ids: list[int] = [] # stack of unmatched delegate action IDs
for event in events:
if isinstance(event, AgentDelegateAction):
delegate_action_ids.append(event.id)
# Note: we can get agent=event.agent and task=event.inputs.get('task','')
# if we need to track these in the future
elif isinstance(event, AgentDelegateObservation):
# Match with most recent unmatched delegate action
if not delegate_action_ids:
self.log(
'error',
f'Found AgentDelegateObservation without matching action at id={event.id}',
)
continue
action_id = delegate_action_ids.pop()
delegate_ranges.append((action_id, event.id))
# Filter out events between delegate action/observation pairs
if delegate_ranges:
filtered_events: list[Event] = []
current_idx = 0
for start_id, end_id in sorted(delegate_ranges):
# Add events before delegate range
filtered_events.extend(
event for event in events[current_idx:] if event.id < start_id
)
# Add delegate action and observation
filtered_events.extend(
event for event in events if event.id in (start_id, end_id)
)
# Update index to after delegate range
current_idx = next(
(i for i, e in enumerate(events) if e.id > end_id), len(events)
)
# Add any remaining events after last delegate range
filtered_events.extend(events[current_idx:])
self.state.history = filtered_events
else:
self.state.history = events
# make sure history is in sync
self.state.start_id = start_id
self.state.history.start_id = start_id
# if there was an end_id saved in State, set it in history
# currently not used, later useful for delegates
if self.state.end_id > -1:
self.state.history.end_id = self.state.end_id
def _is_stuck(self):
"""Checks if the agent or its delegate is stuck in a loop.

View File

@ -11,9 +11,8 @@ from openhands.events.action import (
MessageAction,
)
from openhands.events.action.agent import AgentFinishAction
from openhands.events.observation import ErrorObservation
from openhands.events.event import Event, EventSource
from openhands.llm.metrics import Metrics
from openhands.memory.history import ShortTermHistory
from openhands.storage.files import FileStore
@ -78,7 +77,7 @@ class State:
# max number of iterations for the current task
max_iterations: int = 100
confirmation_mode: bool = False
history: ShortTermHistory = field(default_factory=ShortTermHistory)
history: list[Event] = field(default_factory=list)
inputs: dict = field(default_factory=dict)
outputs: dict = field(default_factory=dict)
agent_state: AgentState = AgentState.LOADING
@ -94,6 +93,7 @@ class State:
start_id: int = -1
end_id: int = -1
almost_stuck: int = 0
delegates: dict[tuple[int, int], tuple[str, str]] = field(default_factory=dict)
# NOTE: This will never be used by the controller, but it can be used by different
# evaluation tasks to store extra data needed to track the progress/state of the task.
extra_data: dict[str, Any] = field(default_factory=dict)
@ -116,7 +116,7 @@ class State:
pickled = base64.b64decode(encoded)
state = pickle.loads(pickled)
except Exception as e:
logger.warning(f'Failed to restore state from session: {e}')
logger.warning(f'Could not restore state from session: {e}')
raise e
# update state
@ -130,39 +130,40 @@ class State:
return state
def __getstate__(self):
# don't pickle history, it will be restored from the event stream
state = self.__dict__.copy()
# save the relevant data from recent history
# so that we can restore it when the state is restored
if 'history' in state:
state['start_id'] = state['history'].start_id
state['end_id'] = state['history'].end_id
# don't save history object itself
state.pop('history', None)
state['history'] = []
return state
def __setstate__(self, state):
self.__dict__.update(state)
# recreate the history object
# make sure we always have the attribute history
if not hasattr(self, 'history'):
self.history = ShortTermHistory()
self.history = []
self.history.start_id = self.start_id
self.history.end_id = self.end_id
def get_current_user_intent(self):
def get_current_user_intent(self) -> tuple[str | None, list[str] | None]:
"""Returns the latest user message and image(if provided) that appears after a FinishAction, or the first (the task) if nothing was finished yet."""
last_user_message = None
last_user_message_image_urls: list[str] | None = []
for event in self.history.get_events(reverse=True):
for event in reversed(self.history):
if isinstance(event, MessageAction) and event.source == 'user':
last_user_message = event.content
last_user_message_image_urls = event.images_urls
elif isinstance(event, AgentFinishAction):
if last_user_message is not None:
return last_user_message
return last_user_message, None
return last_user_message, last_user_message_image_urls
def get_last_agent_message(self) -> str | None:
for event in reversed(self.history):
if isinstance(event, MessageAction) and event.source == EventSource.AGENT:
return event.content
return None
def get_last_user_message(self) -> str | None:
for event in reversed(self.history):
if isinstance(event, MessageAction) and event.source == EventSource.USER:
return event.content
return None

View File

@ -28,7 +28,7 @@ class StuckDetector:
# filter out MessageAction with source='user' from history
filtered_history = [
event
for event in self.state.history.get_events()
for event in self.state.history
if not (
(isinstance(event, MessageAction) and event.source == EventSource.USER)
or

View File

@ -38,7 +38,6 @@ class AppConfig:
e2b_api_key: The E2B API key.
disable_color: Whether to disable color. For terminals that don't support color.
debug: Whether to enable debugging.
enable_cli_session: Whether to enable saving and restoring the session when run from CLI.
file_uploads_max_file_size_mb: Maximum file size for uploads in megabytes. 0 means no limit.
file_uploads_restrict_file_types: Whether to restrict file types for file uploads. Defaults to False.
file_uploads_allowed_extensions: List of allowed file extensions for uploads. ['.*'] means all extensions are allowed.
@ -67,7 +66,6 @@ class AppConfig:
disable_color: bool = False
jwt_secret: str = uuid.uuid4().hex
debug: bool = False
enable_cli_session: bool = False
file_uploads_max_file_size_mb: int = 0
file_uploads_restrict_file_types: bool = False
file_uploads_allowed_extensions: list[str] = field(default_factory=lambda: ['.*'])

View File

@ -125,16 +125,18 @@ async def run_controller(
runtime = create_runtime(config, sid=sid)
event_stream = runtime.event_stream
# restore cli session if enabled
# restore cli session if available
initial_state = None
if config.enable_cli_session:
try:
logger.debug(f'Restoring agent state from cli session {event_stream.sid}')
initial_state = State.restore_from_session(
event_stream.sid, event_stream.file_store
)
except Exception as e:
logger.debug(f'Error restoring state: {e}')
try:
logger.debug(
f'Trying to restore agent state from cli session {event_stream.sid} if available'
)
initial_state = State.restore_from_session(
event_stream.sid, event_stream.file_store
)
except Exception as e:
logger.debug(f'Cannot restore agent state: {e}')
# init controller with this initial state
controller = AgentController(
@ -157,7 +159,7 @@ async def run_controller(
)
# start event is a MessageAction with the task, either resumed or new
if config.enable_cli_session and initial_state is not None:
if initial_state is not None:
# we're resuming the previous session
event_stream.add_event(
MessageAction(
@ -168,7 +170,7 @@ async def run_controller(
),
EventSource.USER,
)
elif initial_state is None:
else:
# init with the provided actions
event_stream.add_event(initial_user_action, EventSource.USER)
@ -202,8 +204,9 @@ async def run_controller(
logger.error(f'Exception in main loop: {e}')
# save session when we're about to close
if config.enable_cli_session:
if config.file_store is not None and config.file_store != 'memory':
end_state = controller.get_state()
# NOTE: the saved state does not include delegates events
end_state.save_to_session(event_stream.sid, event_stream.file_store)
state = controller.get_state()
@ -212,10 +215,7 @@ async def run_controller(
if config.trajectories_path is not None:
file_path = os.path.join(config.trajectories_path, sid + '.json')
os.makedirs(os.path.dirname(file_path), exist_ok=True)
histories = [
event_to_trajectory(event)
for event in state.history.get_events(include_delegates=True)
]
histories = [event_to_trajectory(event) for event in state.history]
with open(file_path, 'w') as f:
json.dump(histories, f)

View File

@ -7,7 +7,7 @@ from openhands.events.action.action import Action, ActionSecurityRisk
@dataclass
class MessageAction(Action):
content: str
images_urls: list | None = None
images_urls: list[str] | None = None
wait_for_response: bool = False
action: str = ActionType.MESSAGE
security_risk: ActionSecurityRisk | None = None

View File

@ -83,12 +83,27 @@ class EventStream:
def get_events(
self,
start_id=0,
end_id=None,
reverse=False,
start_id: int = 0,
end_id: int | None = None,
reverse: bool = False,
filter_out_type: tuple[type[Event], ...] | None = None,
filter_hidden=False,
) -> Iterable[Event]:
"""
Retrieve events from the event stream, optionally filtering out events of a given type
and events marked as hidden.
Args:
start_id: The ID of the first event to retrieve. Defaults to 0.
end_id: The ID of the last event to retrieve. Defaults to the last event in the stream.
reverse: Whether to retrieve events in reverse order. Defaults to False.
filter_out_type: A tuple of event types to filter out. Typically used to filter out backend events from the agent.
filter_hidden: If True, filters out events with the 'hidden' attribute set to True.
Yields:
Events from the stream that match the criteria.
"""
def should_filter(event: Event):
if filter_hidden and hasattr(event, 'hidden') and event.hidden:
return True

View File

@ -1,5 +1,4 @@
from openhands.memory.condenser import MemoryCondenser
from openhands.memory.history import ShortTermHistory
from openhands.memory.memory import LongTermMemory
__all__ = ['LongTermMemory', 'ShortTermHistory', 'MemoryCondenser']
__all__ = ['LongTermMemory', 'MemoryCondenser']

View File

@ -1,224 +0,0 @@
from typing import ClassVar, Iterable
from openhands.core.logger import openhands_logger as logger
from openhands.events.action.action import Action
from openhands.events.action.agent import (
AgentDelegateAction,
ChangeAgentStateAction,
)
from openhands.events.action.empty import NullAction
from openhands.events.action.message import MessageAction
from openhands.events.event import Event, EventSource
from openhands.events.observation.agent import AgentStateChangedObservation
from openhands.events.observation.delegate import AgentDelegateObservation
from openhands.events.observation.empty import NullObservation
from openhands.events.observation.observation import Observation
from openhands.events.serialization.event import event_to_dict
from openhands.events.stream import EventStream
from openhands.events.utils import get_pairs_from_events
class ShortTermHistory(list[Event]):
"""A list of events that represents the short-term memory of the agent.
This class provides methods to retrieve and filter the events in the history of the running agent from the event stream.
"""
start_id: int
end_id: int
_event_stream: EventStream
delegates: dict[tuple[int, int], tuple[str, str]]
filter_out: ClassVar[tuple[type[Event], ...]] = (
NullAction,
NullObservation,
ChangeAgentStateAction,
AgentStateChangedObservation,
)
def __init__(self):
super().__init__()
self.start_id = -1
self.end_id = -1
self.delegates = {}
def set_event_stream(self, event_stream: EventStream):
self._event_stream = event_stream
def get_events_as_list(self, include_delegates: bool = False) -> list[Event]:
"""Return the history as a list of Event objects."""
return list(self.get_events(include_delegates=include_delegates))
def get_events(
self,
reverse: bool = False,
include_delegates: bool = False,
include_hidden=False,
) -> Iterable[Event]:
"""Return the events as a stream of Event objects."""
# TODO handle AgentRejectAction, if it's not part of a chunk ending with an AgentDelegateObservation
# or even if it is, because currently we don't add it to the summary
# iterate from start_id to end_id, or reverse
start_id = self.start_id if self.start_id != -1 else 0
end_id = (
self.end_id
if self.end_id != -1
else self._event_stream.get_latest_event_id()
)
for event in self._event_stream.get_events(
start_id=start_id,
end_id=end_id,
reverse=reverse,
filter_out_type=self.filter_out,
):
if not include_hidden and hasattr(event, 'hidden') and event.hidden:
continue
# TODO add summaries
# and filter out events that were included in a summary
# filter out the events from a delegate of the current agent
if not include_delegates and not any(
# except for the delegate action and observation themselves, currently
# AgentDelegateAction has id = delegate_start
# AgentDelegateObservation has id = delegate_end
delegate_start < event.id < delegate_end
for delegate_start, delegate_end in self.delegates.keys()
):
yield event
elif include_delegates:
yield event
def get_last_action(self, end_id: int = -1) -> Action | None:
"""Return the last action from the event stream, filtered to exclude unwanted events."""
# from end_id in reverse, find the first action
end_id = self._event_stream.get_latest_event_id() if end_id == -1 else end_id
last_action = next(
(
event
for event in self._event_stream.get_events(
end_id=end_id, reverse=True, filter_out_type=self.filter_out
)
if isinstance(event, Action)
),
None,
)
return last_action
def get_last_observation(self, end_id: int = -1) -> Observation | None:
"""Return the last observation from the event stream, filtered to exclude unwanted events."""
# from end_id in reverse, find the first observation
end_id = self._event_stream.get_latest_event_id() if end_id == -1 else end_id
last_observation = next(
(
event
for event in self._event_stream.get_events(
end_id=end_id, reverse=True, filter_out_type=self.filter_out
)
if isinstance(event, Observation)
),
None,
)
return last_observation
def get_last_user_message(self) -> str:
"""Return the content of the last user message from the event stream."""
last_user_message = next(
(
event.content
for event in self._event_stream.get_events(reverse=True)
if isinstance(event, MessageAction) and event.source == EventSource.USER
),
None,
)
return last_user_message if last_user_message is not None else ''
def get_last_agent_message(self) -> str:
"""Return the content of the last agent message from the event stream."""
last_agent_message = next(
(
event.content
for event in self._event_stream.get_events(reverse=True)
if isinstance(event, MessageAction)
and event.source == EventSource.AGENT
),
None,
)
return last_agent_message if last_agent_message is not None else ''
def get_last_events(self, n: int) -> list[Event]:
"""Return the last n events from the event stream."""
# dummy agent is using this
# it should work, but it's not great to store temporary lists now just for a test
end_id = self._event_stream.get_latest_event_id()
start_id = max(0, end_id - n + 1)
return list(
event
for event in self._event_stream.get_events(
start_id=start_id,
end_id=end_id,
filter_out_type=self.filter_out,
)
)
def has_delegation(self) -> bool:
for event in self._event_stream.get_events():
if isinstance(event, AgentDelegateObservation):
return True
return False
def on_event(self, event: Event):
if not isinstance(event, AgentDelegateObservation):
return
logger.debug('AgentDelegateObservation received')
# figure out what this delegate's actions were
# from the last AgentDelegateAction to this AgentDelegateObservation
# and save their ids as start and end ids
# in order to use later to exclude them from parent stream
# or summarize them
delegate_end = event.id
delegate_start = -1
delegate_agent: str = ''
delegate_task: str = ''
for prev_event in self._event_stream.get_events(
end_id=event.id - 1, reverse=True
):
if isinstance(prev_event, AgentDelegateAction):
delegate_start = prev_event.id
delegate_agent = prev_event.agent
delegate_task = prev_event.inputs.get('task', '')
break
if delegate_start == -1:
logger.error(
f'No AgentDelegateAction found for AgentDelegateObservation with id={delegate_end}'
)
return
self.delegates[(delegate_start, delegate_end)] = (delegate_agent, delegate_task)
logger.debug(
f'Delegate {delegate_agent} with task {delegate_task} ran from id={delegate_start} to id={delegate_end}'
)
# TODO remove me when unnecessary
# history is now available as a filtered stream of events, rather than list of pairs of (Action, Observation)
# we rebuild the pairs here
# for compatibility with the existing output format in evaluations
def compatibility_for_eval_history_pairs(self) -> list[tuple[dict, dict]]:
history_pairs = []
for action, observation in get_pairs_from_events(
self.get_events_as_list(include_delegates=True)
):
history_pairs.append((event_to_dict(action), event_to_dict(observation)))
return history_pairs

View File

@ -213,7 +213,9 @@ class FileEditRuntimeMixin(FileEditRuntimeInterface):
if isinstance(obs, ErrorObservation):
return obs
if not isinstance(obs, FileWriteObservation):
raise ValueError(f'Expected FileWriteObservation, got {type(obs)}: {str(obs)}')
raise ValueError(
f'Expected FileWriteObservation, got {type(obs)}: {str(obs)}'
)
return FileEditObservation(
content=get_diff('', action.content, action.path),
path=action.path,
@ -222,7 +224,9 @@ class FileEditRuntimeMixin(FileEditRuntimeInterface):
new_content=action.content,
)
if not isinstance(obs, FileReadObservation):
raise ValueError(f'Expected FileReadObservation, got {type(obs)}: {str(obs)}')
raise ValueError(
f'Expected FileReadObservation, got {type(obs)}: {str(obs)}'
)
original_file_content = obs.content
old_file_lines = original_file_content.split('\n')

View File

@ -181,7 +181,7 @@ def process_instance(
test_result = {}
if state is None:
raise ValueError('State should not be None.')
histories = [event_to_dict(event) for event in state.history.get_events()]
histories = [event_to_dict(event) for event in state.history]
metrics = state.metrics.get() if state.metrics else None
# Save the output

View File

@ -104,5 +104,5 @@ def test_error_observation_message(agent: CodeActAgent):
def test_unknown_observation_message(agent: CodeActAgent):
obs = Mock()
with pytest.raises(ValueError, match='Unknown observation type:'):
with pytest.raises(ValueError, match='Unknown observation type'):
agent.get_observation_message(obs, tool_call_id_to_message={})

View File

@ -17,8 +17,6 @@ from openhands.events.observation.commands import IPythonRunCellObservation
from openhands.events.observation.empty import NullObservation
from openhands.events.observation.error import ErrorObservation
from openhands.events.stream import EventSource, EventStream
from openhands.events.utils import get_pairs_from_events
from openhands.memory.history import ShortTermHistory
from openhands.storage import get_file_store
@ -55,22 +53,21 @@ def event_stream(temp_dir):
class TestStuckDetector:
@pytest.fixture
def stuck_detector(self, event_stream):
def stuck_detector(self):
state = State(inputs={}, max_iterations=50)
state.history.set_event_stream(event_stream)
state.history = [] # Initialize history as an empty list
return StuckDetector(state)
def _impl_syntax_error_events(
self,
event_stream: EventStream,
state: State,
error_message: str,
random_line: bool,
incidents: int = 4,
):
for i in range(incidents):
ipython_action = IPythonRunCellAction(code=code_snippet)
event_stream.add_event(ipython_action, EventSource.AGENT)
state.history.append(ipython_action)
extra_number = (i + 1) * 10 if random_line else '42'
extra_line = '\n' * (i + 1) if random_line else ''
ipython_observation = IPythonRunCellObservation(
@ -79,15 +76,15 @@ class TestStuckDetector:
f'{error_message}{extra_line}' + jupyter_line_1 + jupyter_line_2,
code=code_snippet,
)
ipython_observation._cause = ipython_action._id
event_stream.add_event(ipython_observation, EventSource.ENVIRONMENT)
# ipython_observation._cause = ipython_action._id
state.history.append(ipython_observation)
def _impl_unterminated_string_error_events(
self, event_stream: EventStream, random_line: bool, incidents: int = 4
self, state: State, random_line: bool, incidents: int = 4
):
for i in range(incidents):
ipython_action = IPythonRunCellAction(code=code_snippet)
event_stream.add_event(ipython_action, EventSource.AGENT)
state.history.append(ipython_action)
line_number = (i + 1) * 10 if random_line else '1'
ipython_observation = IPythonRunCellObservation(
content=f'print(" Cell In[1], line {line_number}\nhello\n ^\nSyntaxError: unterminated string literal (detected at line {line_number})'
@ -95,34 +92,30 @@ class TestStuckDetector:
+ jupyter_line_2,
code=code_snippet,
)
ipython_observation._cause = ipython_action._id
event_stream.add_event(ipython_observation, EventSource.ENVIRONMENT)
# ipython_observation._cause = ipython_action._
state.history.append(ipython_observation)
def test_history_too_short(
self, stuck_detector: StuckDetector, event_stream: EventStream
):
def test_history_too_short(self, stuck_detector: StuckDetector):
state = stuck_detector.state
message_action = MessageAction(content='Hello', wait_for_response=False)
message_action._source = EventSource.USER
observation = NullObservation(content='')
observation._cause = message_action.id
event_stream.add_event(message_action, EventSource.USER)
event_stream.add_event(observation, EventSource.ENVIRONMENT)
# observation._cause = message_action.id
state.history.append(message_action)
state.history.append(observation)
cmd_action = CmdRunAction(command='ls')
event_stream.add_event(cmd_action, EventSource.AGENT)
state.history.append(cmd_action)
cmd_observation = CmdOutputObservation(
command_id=1, command='ls', content='file1.txt\nfile2.txt'
)
cmd_observation._cause = cmd_action._id
event_stream.add_event(cmd_observation, EventSource.ENVIRONMENT)
# stuck_detector.state.history.set_event_stream(event_stream)
# cmd_observation._cause = cmd_action._id
state.history.append(cmd_observation)
assert stuck_detector.is_stuck() is False
def test_is_stuck_repeating_action_observation(
self, stuck_detector: StuckDetector, event_stream: EventStream
):
def test_is_stuck_repeating_action_observation(self, stuck_detector: StuckDetector):
state = stuck_detector.state
message_action = MessageAction(content='Done', wait_for_response=False)
message_action._source = EventSource.USER
@ -130,135 +123,125 @@ class TestStuckDetector:
hello_observation = NullObservation('')
# 2 events
event_stream.add_event(hello_action, EventSource.USER)
event_stream.add_event(hello_observation, EventSource.ENVIRONMENT)
state.history.append(hello_action)
state.history.append(hello_observation)
cmd_action_1 = CmdRunAction(command='ls')
event_stream.add_event(cmd_action_1, EventSource.AGENT)
cmd_observation_1 = CmdOutputObservation(
content='', command='ls', command_id=cmd_action_1._id
)
cmd_action_1._id = 1
state.history.append(cmd_action_1)
cmd_observation_1 = CmdOutputObservation(content='', command='ls', command_id=1)
cmd_observation_1._cause = cmd_action_1._id
event_stream.add_event(cmd_observation_1, EventSource.ENVIRONMENT)
state.history.append(cmd_observation_1)
# 4 events
cmd_action_2 = CmdRunAction(command='ls')
event_stream.add_event(cmd_action_2, EventSource.AGENT)
cmd_observation_2 = CmdOutputObservation(
content='', command='ls', command_id=cmd_action_2._id
)
cmd_action_2._id = 2
state.history.append(cmd_action_2)
cmd_observation_2 = CmdOutputObservation(content='', command='ls', command_id=2)
cmd_observation_2._cause = cmd_action_2._id
event_stream.add_event(cmd_observation_2, EventSource.ENVIRONMENT)
state.history.append(cmd_observation_2)
# 6 events
# random user message just because we can
message_null_observation = NullObservation(content='')
event_stream.add_event(message_action, EventSource.USER)
event_stream.add_event(message_null_observation, EventSource.ENVIRONMENT)
state.history.append(message_action)
state.history.append(message_null_observation)
# 8 events
assert stuck_detector.is_stuck() is False
assert stuck_detector.state.almost_stuck == 2
cmd_action_3 = CmdRunAction(command='ls')
event_stream.add_event(cmd_action_3, EventSource.AGENT)
cmd_observation_3 = CmdOutputObservation(
content='', command='ls', command_id=cmd_action_3._id
)
cmd_action_3._id = 3
state.history.append(cmd_action_3)
cmd_observation_3 = CmdOutputObservation(content='', command='ls', command_id=3)
cmd_observation_3._cause = cmd_action_3._id
event_stream.add_event(cmd_observation_3, EventSource.ENVIRONMENT)
state.history.append(cmd_observation_3)
# 10 events
assert len(collect_events(event_stream)) == 10
assert len(list(stuck_detector.state.history.get_events())) == 8
assert len(state.history) == 10
assert (
len(
get_pairs_from_events(
stuck_detector.state.history.get_events_as_list(
include_delegates=True
)
)
)
== 5
)
len(state.history) == 10
) # Adjusted since history is a list and the controller is not running
# FIXME are we still testing this without this test?
# assert (
# len(
# get_pairs_from_events(state.history)
# )
# == 5
# )
assert stuck_detector.is_stuck() is False
assert stuck_detector.state.almost_stuck == 1
cmd_action_4 = CmdRunAction(command='ls')
event_stream.add_event(cmd_action_4, EventSource.AGENT)
cmd_observation_4 = CmdOutputObservation(
content='', command='ls', command_id=cmd_action_4._id
)
cmd_action_4._id = 4
state.history.append(cmd_action_4)
cmd_observation_4 = CmdOutputObservation(content='', command='ls', command_id=4)
cmd_observation_4._cause = cmd_action_4._id
event_stream.add_event(cmd_observation_4, EventSource.ENVIRONMENT)
state.history.append(cmd_observation_4)
# 12 events
assert len(collect_events(event_stream)) == 12
assert len(list(stuck_detector.state.history.get_events())) == 10
assert (
len(
get_pairs_from_events(
stuck_detector.state.history.get_events_as_list(
include_delegates=True
)
)
)
== 6
)
assert len(state.history) == 12
# assert (
# len(
# get_pairs_from_events(state.history)
# )
# == 6
# )
with patch('logging.Logger.warning') as mock_warning:
assert stuck_detector.is_stuck() is True
assert stuck_detector.state.almost_stuck == 0
mock_warning.assert_called_once_with('Action, Observation loop detected')
def test_is_stuck_repeating_action_error(
self, stuck_detector: StuckDetector, event_stream: EventStream
):
def test_is_stuck_repeating_action_error(self, stuck_detector: StuckDetector):
state = stuck_detector.state
# (action, error_observation), not necessarily the same error
message_action = MessageAction(content='Done', wait_for_response=False)
message_action._source = EventSource.USER
hello_action = MessageAction(content='Hello', wait_for_response=False)
hello_observation = NullObservation(content='')
event_stream.add_event(hello_action, EventSource.USER)
hello_observation._cause = hello_action._id
event_stream.add_event(hello_observation, EventSource.ENVIRONMENT)
state.history.append(hello_action)
# hello_observation._cause = hello_action._id
state.history.append(hello_observation)
# 2 events
cmd_action_1 = CmdRunAction(command='invalid_command')
event_stream.add_event(cmd_action_1, EventSource.AGENT)
state.history.append(cmd_action_1)
error_observation_1 = ErrorObservation(content='Command not found')
error_observation_1._cause = cmd_action_1._id
event_stream.add_event(error_observation_1, EventSource.ENVIRONMENT)
# error_observation_1._cause = cmd_action_1._id
state.history.append(error_observation_1)
# 4 events
cmd_action_2 = CmdRunAction(command='invalid_command')
event_stream.add_event(cmd_action_2, EventSource.AGENT)
state.history.append(cmd_action_2)
error_observation_2 = ErrorObservation(
content='Command still not found or another error'
)
error_observation_2._cause = cmd_action_2._id
event_stream.add_event(error_observation_2, EventSource.ENVIRONMENT)
# error_observation_2._cause = cmd_action_2._id
state.history.append(error_observation_2)
# 6 events
message_null_observation = NullObservation(content='')
event_stream.add_event(message_action, EventSource.USER)
event_stream.add_event(message_null_observation, EventSource.ENVIRONMENT)
state.history.append(message_action)
state.history.append(message_null_observation)
# 8 events
cmd_action_3 = CmdRunAction(command='invalid_command')
event_stream.add_event(cmd_action_3, EventSource.AGENT)
state.history.append(cmd_action_3)
error_observation_3 = ErrorObservation(content='Different error')
error_observation_3._cause = cmd_action_3._id
event_stream.add_event(error_observation_3, EventSource.ENVIRONMENT)
# error_observation_3._cause = cmd_action_3._id
state.history.append(error_observation_3)
# 10 events
cmd_action_4 = CmdRunAction(command='invalid_command')
event_stream.add_event(cmd_action_4, EventSource.AGENT)
state.history.append(cmd_action_4)
error_observation_4 = ErrorObservation(content='Command not found')
error_observation_4._cause = cmd_action_4._id
event_stream.add_event(error_observation_4, EventSource.ENVIRONMENT)
# error_observation_4._cause = cmd_action_4._id
state.history.append(error_observation_4)
# 12 events
with patch('logging.Logger.warning') as mock_warning:
@ -267,11 +250,10 @@ class TestStuckDetector:
'Action, ErrorObservation loop detected'
)
def test_is_stuck_invalid_syntax_error(
self, stuck_detector: StuckDetector, event_stream: EventStream
):
def test_is_stuck_invalid_syntax_error(self, stuck_detector: StuckDetector):
state = stuck_detector.state
self._impl_syntax_error_events(
event_stream,
state,
error_message='SyntaxError: invalid syntax. Perhaps you forgot a comma?',
random_line=False,
)
@ -280,10 +262,11 @@ class TestStuckDetector:
assert stuck_detector.is_stuck() is True
def test_is_not_stuck_invalid_syntax_error_random_lines(
self, stuck_detector: StuckDetector, event_stream: EventStream
self, stuck_detector: StuckDetector
):
state = stuck_detector.state
self._impl_syntax_error_events(
event_stream,
state,
error_message='SyntaxError: invalid syntax. Perhaps you forgot a comma?',
random_line=True,
)
@ -292,10 +275,11 @@ class TestStuckDetector:
assert stuck_detector.is_stuck() is False
def test_is_not_stuck_invalid_syntax_error_only_three_incidents(
self, stuck_detector: StuckDetector, event_stream: EventStream
self, stuck_detector: StuckDetector
):
state = stuck_detector.state
self._impl_syntax_error_events(
event_stream,
state,
error_message='SyntaxError: invalid syntax. Perhaps you forgot a comma?',
random_line=True,
incidents=3,
@ -304,11 +288,10 @@ class TestStuckDetector:
with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is False
def test_is_stuck_incomplete_input_error(
self, stuck_detector: StuckDetector, event_stream: EventStream
):
def test_is_stuck_incomplete_input_error(self, stuck_detector: StuckDetector):
state = stuck_detector.state
self._impl_syntax_error_events(
event_stream,
state,
error_message='SyntaxError: incomplete input',
random_line=False,
)
@ -316,11 +299,10 @@ class TestStuckDetector:
with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is True
def test_is_not_stuck_incomplete_input_error(
self, stuck_detector: StuckDetector, event_stream: EventStream
):
def test_is_not_stuck_incomplete_input_error(self, stuck_detector: StuckDetector):
state = stuck_detector.state
self._impl_syntax_error_events(
event_stream,
state,
error_message='SyntaxError: incomplete input',
random_line=True,
)
@ -329,239 +311,241 @@ class TestStuckDetector:
assert stuck_detector.is_stuck() is False
def test_is_not_stuck_ipython_unterminated_string_error_random_lines(
self, stuck_detector: StuckDetector, event_stream: EventStream
self, stuck_detector: StuckDetector
):
self._impl_unterminated_string_error_events(event_stream, random_line=True)
state = stuck_detector.state
self._impl_unterminated_string_error_events(state, random_line=True)
with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is False
def test_is_not_stuck_ipython_unterminated_string_error_only_three_incidents(
self, stuck_detector: StuckDetector, event_stream: EventStream
self, stuck_detector: StuckDetector
):
state = stuck_detector.state
self._impl_unterminated_string_error_events(
event_stream, random_line=False, incidents=3
state, random_line=False, incidents=3
)
with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is False
def test_is_stuck_ipython_unterminated_string_error(
self, stuck_detector: StuckDetector, event_stream: EventStream
self, stuck_detector: StuckDetector
):
self._impl_unterminated_string_error_events(event_stream, random_line=False)
state = stuck_detector.state
self._impl_unterminated_string_error_events(state, random_line=False)
with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is True
def test_is_not_stuck_ipython_syntax_error_not_at_end(
self, stuck_detector: StuckDetector, event_stream: EventStream
self, stuck_detector: StuckDetector
):
state = stuck_detector.state
# this test is to make sure we don't get false positives
# since the "at line x" is changing in between!
ipython_action_1 = IPythonRunCellAction(code='print("hello')
event_stream.add_event(ipython_action_1, EventSource.AGENT)
state.history.append(ipython_action_1)
ipython_observation_1 = IPythonRunCellObservation(
content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 1)\nThis is some additional output',
code='print("hello',
)
ipython_observation_1._cause = ipython_action_1._id
event_stream.add_event(ipython_observation_1, EventSource.ENVIRONMENT)
# ipython_observation_1._cause = ipython_action_1._id
state.history.append(ipython_observation_1)
ipython_action_2 = IPythonRunCellAction(code='print("hello')
event_stream.add_event(ipython_action_2, EventSource.AGENT)
state.history.append(ipython_action_2)
ipython_observation_2 = IPythonRunCellObservation(
content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 1)\nToo much output here on and on',
code='print("hello',
)
ipython_observation_2._cause = ipython_action_2._id
event_stream.add_event(ipython_observation_2, EventSource.ENVIRONMENT)
# ipython_observation_2._cause = ipython_action_2._id
state.history.append(ipython_observation_2)
ipython_action_3 = IPythonRunCellAction(code='print("hello')
event_stream.add_event(ipython_action_3, EventSource.AGENT)
state.history.append(ipython_action_3)
ipython_observation_3 = IPythonRunCellObservation(
content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 3)\nEnough',
code='print("hello',
)
ipython_observation_3._cause = ipython_action_3._id
event_stream.add_event(ipython_observation_3, EventSource.ENVIRONMENT)
# ipython_observation_3._cause = ipython_action_3._id
state.history.append(ipython_observation_3)
ipython_action_4 = IPythonRunCellAction(code='print("hello')
event_stream.add_event(ipython_action_4, EventSource.AGENT)
state.history.append(ipython_action_4)
ipython_observation_4 = IPythonRunCellObservation(
content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 2)\nLast line of output',
code='print("hello',
)
ipython_observation_4._cause = ipython_action_4._id
event_stream.add_event(ipython_observation_4, EventSource.ENVIRONMENT)
# ipython_observation_4._cause = ipython_action_4._id
state.history.append(ipython_observation_4)
with patch('logging.Logger.warning') as mock_warning:
assert stuck_detector.is_stuck() is False
mock_warning.assert_not_called()
def test_is_stuck_repeating_action_observation_pattern(
self, stuck_detector: StuckDetector, event_stream: EventStream
self, stuck_detector: StuckDetector
):
state = stuck_detector.state
message_action = MessageAction(content='Come on', wait_for_response=False)
message_action._source = EventSource.USER
event_stream.add_event(message_action, EventSource.USER)
state.history.append(message_action)
message_observation = NullObservation(content='')
event_stream.add_event(message_observation, EventSource.ENVIRONMENT)
state.history.append(message_observation)
cmd_action_1 = CmdRunAction(command='ls')
event_stream.add_event(cmd_action_1, EventSource.AGENT)
state.history.append(cmd_action_1)
cmd_observation_1 = CmdOutputObservation(
command_id=1, command='ls', content='file1.txt\nfile2.txt'
)
cmd_observation_1._cause = cmd_action_1._id
event_stream.add_event(cmd_observation_1, EventSource.ENVIRONMENT)
# cmd_observation_1._cause = cmd_action_1._id
state.history.append(cmd_observation_1)
read_action_1 = FileReadAction(path='file1.txt')
event_stream.add_event(read_action_1, EventSource.AGENT)
state.history.append(read_action_1)
read_observation_1 = FileReadObservation(
content='File content', path='file1.txt'
)
read_observation_1._cause = read_action_1._id
event_stream.add_event(read_observation_1, EventSource.ENVIRONMENT)
# read_observation_1._cause = read_action_1._id
state.history.append(read_observation_1)
cmd_action_2 = CmdRunAction(command='ls')
event_stream.add_event(cmd_action_2, EventSource.AGENT)
state.history.append(cmd_action_2)
cmd_observation_2 = CmdOutputObservation(
command_id=2, command='ls', content='file1.txt\nfile2.txt'
)
cmd_observation_2._cause = cmd_action_2._id
event_stream.add_event(cmd_observation_2, EventSource.ENVIRONMENT)
# cmd_observation_2._cause = cmd_action_2._id
state.history.append(cmd_observation_2)
read_action_2 = FileReadAction(path='file1.txt')
event_stream.add_event(read_action_2, EventSource.AGENT)
state.history.append(read_action_2)
read_observation_2 = FileReadObservation(
content='File content', path='file1.txt'
)
read_observation_2._cause = read_action_2._id
event_stream.add_event(read_observation_2, EventSource.ENVIRONMENT)
# read_observation_2._cause = read_action_2._id
state.history.append(read_observation_2)
message_action = MessageAction(content='Come on', wait_for_response=False)
event_stream.add_event(message_action, EventSource.USER)
message_action._source = EventSource.USER
state.history.append(message_action)
message_null_observation = NullObservation(content='')
event_stream.add_event(message_null_observation, EventSource.ENVIRONMENT)
state.history.append(message_null_observation)
cmd_action_3 = CmdRunAction(command='ls')
event_stream.add_event(cmd_action_3, EventSource.AGENT)
state.history.append(cmd_action_3)
cmd_observation_3 = CmdOutputObservation(
command_id=3, command='ls', content='file1.txt\nfile2.txt'
)
cmd_observation_3._cause = cmd_action_3._id
event_stream.add_event(cmd_observation_3, EventSource.ENVIRONMENT)
# cmd_observation_3._cause = cmd_action_3._id
state.history.append(cmd_observation_3)
read_action_3 = FileReadAction(path='file1.txt')
event_stream.add_event(read_action_3, EventSource.AGENT)
state.history.append(read_action_3)
read_observation_3 = FileReadObservation(
content='File content', path='file1.txt'
)
read_observation_3._cause = read_action_3._id
event_stream.add_event(read_observation_3, EventSource.ENVIRONMENT)
# read_observation_3._cause = read_action_3._id
state.history.append(read_observation_3)
with patch('logging.Logger.warning') as mock_warning:
assert stuck_detector.is_stuck() is True
mock_warning.assert_called_once_with('Action, Observation pattern detected')
def test_is_stuck_not_stuck(
self, stuck_detector: StuckDetector, event_stream: EventStream
):
def test_is_stuck_not_stuck(self, stuck_detector: StuckDetector):
state = stuck_detector.state
message_action = MessageAction(content='Done', wait_for_response=False)
message_action._source = EventSource.USER
hello_action = MessageAction(content='Hello', wait_for_response=False)
event_stream.add_event(hello_action, EventSource.USER)
state.history.append(hello_action)
hello_observation = NullObservation(content='')
hello_observation._cause = hello_action._id
event_stream.add_event(hello_observation, EventSource.ENVIRONMENT)
# hello_observation._cause = hello_action._id
state.history.append(hello_observation)
cmd_action_1 = CmdRunAction(command='ls')
event_stream.add_event(cmd_action_1, EventSource.AGENT)
state.history.append(cmd_action_1)
cmd_observation_1 = CmdOutputObservation(
command_id=cmd_action_1.id, command='ls', content='file1.txt\nfile2.txt'
)
cmd_observation_1._cause = cmd_action_1._id
event_stream.add_event(cmd_observation_1, EventSource.ENVIRONMENT)
# cmd_observation_1._cause = cmd_action_1._id
state.history.append(cmd_observation_1)
read_action_1 = FileReadAction(path='file1.txt')
event_stream.add_event(read_action_1, EventSource.AGENT)
state.history.append(read_action_1)
read_observation_1 = FileReadObservation(
content='File content', path='file1.txt'
)
read_observation_1._cause = read_action_1._id
event_stream.add_event(read_observation_1, EventSource.ENVIRONMENT)
# read_observation_1._cause = read_action_1._id
state.history.append(read_observation_1)
cmd_action_2 = CmdRunAction(command='pwd')
event_stream.add_event(cmd_action_2, EventSource.AGENT)
state.history.append(cmd_action_2)
cmd_observation_2 = CmdOutputObservation(
command_id=2, command='pwd', content='/home/user'
)
cmd_observation_2._cause = cmd_action_2._id
event_stream.add_event(cmd_observation_2, EventSource.ENVIRONMENT)
# cmd_observation_2._cause = cmd_action_2._id
state.history.append(cmd_observation_2)
read_action_2 = FileReadAction(path='file2.txt')
event_stream.add_event(read_action_2, EventSource.AGENT)
state.history.append(read_action_2)
read_observation_2 = FileReadObservation(
content='Another file content', path='file2.txt'
)
read_observation_2._cause = read_action_2._id
event_stream.add_event(read_observation_2, EventSource.ENVIRONMENT)
# read_observation_2._cause = read_action_2._id
state.history.append(read_observation_2)
message_null_observation = NullObservation(content='')
event_stream.add_event(message_action, EventSource.USER)
event_stream.add_event(message_null_observation, EventSource.ENVIRONMENT)
state.history.append(message_action)
state.history.append(message_null_observation)
cmd_action_3 = CmdRunAction(command='pwd')
event_stream.add_event(cmd_action_3, EventSource.AGENT)
state.history.append(cmd_action_3)
cmd_observation_3 = CmdOutputObservation(
command_id=cmd_action_3.id, command='pwd', content='/home/user'
)
cmd_observation_3._cause = cmd_action_3._id
event_stream.add_event(cmd_observation_3, EventSource.ENVIRONMENT)
# cmd_observation_3._cause = cmd_action_3._id
state.history.append(cmd_observation_3)
read_action_3 = FileReadAction(path='file2.txt')
event_stream.add_event(read_action_3, EventSource.AGENT)
state.history.append(read_action_3)
read_observation_3 = FileReadObservation(
content='Another file content', path='file2.txt'
)
read_observation_3._cause = read_action_3._id
event_stream.add_event(read_observation_3, EventSource.ENVIRONMENT)
# read_observation_3._cause = read_action_3._id
state.history.append(read_observation_3)
assert stuck_detector.is_stuck() is False
def test_is_stuck_monologue(self, stuck_detector, event_stream):
# Add events to the event stream
def test_is_stuck_monologue(self, stuck_detector):
state = stuck_detector.state
# Add events to the history list directly
message_action_1 = MessageAction(content='Hi there!')
event_stream.add_event(message_action_1, EventSource.USER)
message_action_1._source = EventSource.USER
state.history.append(message_action_1)
message_action_2 = MessageAction(content='Hi there!')
event_stream.add_event(message_action_2, EventSource.AGENT)
message_action_2._source = EventSource.AGENT
state.history.append(message_action_2)
message_action_3 = MessageAction(content='How are you?')
event_stream.add_event(message_action_3, EventSource.USER)
message_action_3._source = EventSource.USER
state.history.append(message_action_3)
cmd_kill_action = CmdRunAction(
command='echo 42', thought="I'm not stuck, he's stuck"
)
event_stream.add_event(cmd_kill_action, EventSource.AGENT)
state.history.append(cmd_kill_action)
message_action_4 = MessageAction(content="I'm doing well, thanks for asking.")
event_stream.add_event(message_action_4, EventSource.AGENT)
message_action_4._source = EventSource.AGENT
state.history.append(message_action_4)
message_action_5 = MessageAction(content="I'm doing well, thanks for asking.")
event_stream.add_event(message_action_5, EventSource.AGENT)
message_action_5._source = EventSource.AGENT
state.history.append(message_action_5)
message_action_6 = MessageAction(content="I'm doing well, thanks for asking.")
event_stream.add_event(message_action_6, EventSource.AGENT)
message_action_6._source = EventSource.AGENT
state.history.append(message_action_6)
assert stuck_detector.is_stuck()
@ -572,16 +556,15 @@ class TestStuckDetector:
command='storybook',
exit_code=0,
)
cmd_output_observation._cause = cmd_kill_action._id
event_stream.add_event(cmd_output_observation, EventSource.ENVIRONMENT)
# cmd_output_observation._cause = cmd_kill_action._id
state.history.append(cmd_output_observation)
message_action_7 = MessageAction(content="I'm doing well, thanks for asking.")
event_stream.add_event(message_action_7, EventSource.AGENT)
message_action_7._source = EventSource.AGENT
state.history.append(message_action_7)
message_action_8 = MessageAction(content="I'm doing well, thanks for asking.")
event_stream.add_event(message_action_8, EventSource.AGENT)
message_action_8._source = EventSource.AGENT
state.history.append(message_action_8)
with patch('logging.Logger.warning'):
assert not stuck_detector.is_stuck()
@ -596,7 +579,6 @@ class TestAgentController:
)
controller.delegate = None
controller.state = Mock()
controller.state.history = ShortTermHistory()
return controller
def test_is_stuck_delegate_stuck(self, controller: AgentController):

View File

@ -10,10 +10,8 @@ from openhands.agenthub.micro.registry import all_microagents
from openhands.controller.agent import Agent
from openhands.controller.state.state import State
from openhands.core.config import AgentConfig
from openhands.events import EventSource
from openhands.events.action import MessageAction
from openhands.events.stream import EventStream
from openhands.memory.history import ShortTermHistory
from openhands.storage import get_file_store
@ -74,10 +72,10 @@ def test_coder_agent_with_summary(event_stream: EventStream, agent_configs: dict
)
assert coder_agent is not None
# give it some history
task = 'This is a dummy task'
history = ShortTermHistory()
history.set_event_stream(event_stream)
event_stream.add_event(MessageAction(content=task), EventSource.USER)
history = list()
history.append(MessageAction(content=task))
summary = 'This is a dummy summary about this repo'
state = State(history=history, inputs={'summary': summary})
@ -119,10 +117,10 @@ def test_coder_agent_without_summary(event_stream: EventStream, agent_configs: d
)
assert coder_agent is not None
# give it some history
task = 'This is a dummy task'
history = ShortTermHistory()
history.set_event_stream(event_stream)
event_stream.add_event(MessageAction(content=task), EventSource.USER)
history = list()
history.append(MessageAction(content=task))
# set state without codebase summary
state = State(history=history)

View File

@ -1,14 +1,12 @@
from unittest.mock import MagicMock, Mock, patch
from unittest.mock import Mock, patch
import pytest
from openhands.agenthub.codeact_agent.codeact_agent import CodeActAgent
from openhands.core.config import AgentConfig, LLMConfig
from openhands.events import EventSource, EventStream
from openhands.events.action import CmdRunAction, MessageAction
from openhands.events.observation import CmdOutputObservation
from openhands.llm.llm import LLM
from openhands.storage import get_file_store
@pytest.fixture
@ -19,12 +17,6 @@ def mock_llm():
return llm
@pytest.fixture
def mock_event_stream(tmp_path):
file_store = get_file_store('local', str(tmp_path))
return EventStream('test_session', file_store)
@pytest.fixture(params=[False, True])
def codeact_agent(mock_llm, request):
config = AgentConfig()
@ -57,17 +49,28 @@ def response_mock(content: str):
return MockModelResponse(content)
def test_get_messages_with_reminder(codeact_agent, mock_event_stream):
# Add some events to the stream
mock_event_stream.add_event(MessageAction('Initial user message'), EventSource.USER)
mock_event_stream.add_event(MessageAction('Sure!'), EventSource.AGENT)
mock_event_stream.add_event(MessageAction('Hello, agent!'), EventSource.USER)
mock_event_stream.add_event(MessageAction('Hello, user!'), EventSource.AGENT)
mock_event_stream.add_event(MessageAction('Laaaaaaaast!'), EventSource.USER)
def test_get_messages_with_reminder(codeact_agent: CodeActAgent):
# Add some events to history
history = list()
message_action_1 = MessageAction('Initial user message')
message_action_1._source = 'user'
history.append(message_action_1)
message_action_2 = MessageAction('Sure!')
message_action_2._source = 'assistant'
history.append(message_action_2)
message_action_3 = MessageAction('Hello, agent!')
message_action_3._source = 'user'
history.append(message_action_3)
message_action_4 = MessageAction('Hello, user!')
message_action_4._source = 'assistant'
history.append(message_action_4)
message_action_5 = MessageAction('Laaaaaaaast!')
message_action_5._source = 'user'
history.append(message_action_5)
codeact_agent.reset()
messages = codeact_agent._get_messages(
Mock(history=mock_event_stream, max_iterations=5, iteration=0)
Mock(history=history, max_iterations=5, iteration=0)
)
assert (
@ -102,19 +105,20 @@ def test_get_messages_with_reminder(codeact_agent, mock_event_stream):
)
def test_get_messages_prompt_caching(codeact_agent, mock_event_stream):
def test_get_messages_prompt_caching(codeact_agent: CodeActAgent):
history = list()
# Add multiple user and agent messages
for i in range(15):
mock_event_stream.add_event(
MessageAction(f'User message {i}'), EventSource.USER
)
mock_event_stream.add_event(
MessageAction(f'Agent message {i}'), EventSource.AGENT
)
message_action_user = MessageAction(f'User message {i}')
message_action_user._source = 'user'
history.append(message_action_user)
message_action_agent = MessageAction(f'Agent message {i}')
message_action_agent._source = 'assistant'
history.append(message_action_agent)
codeact_agent.reset()
messages = codeact_agent._get_messages(
Mock(history=mock_event_stream, max_iterations=10, iteration=5)
Mock(history=history, max_iterations=10, iteration=5)
)
# Check that only the last two user messages have cache_prompt=True
@ -136,18 +140,23 @@ def test_get_messages_prompt_caching(codeact_agent, mock_event_stream):
assert cached_user_messages[3].content[0].text.startswith('User message 1')
def test_get_messages_with_cmd_action(codeact_agent, mock_event_stream):
def test_get_messages_with_cmd_action(codeact_agent: CodeActAgent):
if codeact_agent.config.function_calling:
pytest.skip('Skipping this test for function calling')
history = list()
# Add a mix of actions and observations
message_action_1 = MessageAction(
"Let's list the contents of the current directory."
)
mock_event_stream.add_event(message_action_1, EventSource.USER)
message_action_1._source = 'user'
history.append(message_action_1)
cmd_action_1 = CmdRunAction('ls -l', thought='List files in current directory')
mock_event_stream.add_event(cmd_action_1, EventSource.AGENT)
cmd_action_1._source = 'agent'
cmd_action_1._id = 'cmd_1'
history.append(cmd_action_1)
cmd_observation_1 = CmdOutputObservation(
content='total 0\n-rw-r--r-- 1 user group 0 Jan 1 00:00 file1.txt\n-rw-r--r-- 1 user group 0 Jan 1 00:00 file2.txt',
@ -155,13 +164,17 @@ def test_get_messages_with_cmd_action(codeact_agent, mock_event_stream):
command='ls -l',
exit_code=0,
)
mock_event_stream.add_event(cmd_observation_1, EventSource.ENVIRONMENT)
cmd_observation_1._source = 'user'
history.append(cmd_observation_1)
message_action_2 = MessageAction("Now, let's create a new directory.")
mock_event_stream.add_event(message_action_2, EventSource.AGENT)
message_action_2._source = 'agent'
history.append(message_action_2)
cmd_action_2 = CmdRunAction('mkdir new_directory', thought='Create a new directory')
mock_event_stream.add_event(cmd_action_2, EventSource.AGENT)
cmd_action_2._source = 'agent'
cmd_action_2._id = 'cmd_2'
history.append(cmd_action_2)
cmd_observation_2 = CmdOutputObservation(
content='',
@ -169,11 +182,12 @@ def test_get_messages_with_cmd_action(codeact_agent, mock_event_stream):
command='mkdir new_directory',
exit_code=0,
)
mock_event_stream.add_event(cmd_observation_2, EventSource.ENVIRONMENT)
cmd_observation_2._source = 'user'
history.append(cmd_observation_2)
codeact_agent.reset()
messages = codeact_agent._get_messages(
Mock(history=mock_event_stream, max_iterations=5, iteration=0)
Mock(history=history, max_iterations=5, iteration=0)
)
# Assert the presence of key elements in the messages
@ -218,19 +232,17 @@ def test_get_messages_with_cmd_action(codeact_agent, mock_event_stream):
assert 'ENVIRONMENT REMINDER: You have 5 turns' in messages[5].content[1].text
def test_prompt_caching_headers(codeact_agent, mock_event_stream):
def test_prompt_caching_headers(codeact_agent: CodeActAgent):
history = list()
if codeact_agent.config.function_calling:
pytest.skip('Skipping this test for function calling')
# Setup
mock_event_stream.add_event(MessageAction('Hello, agent!'), EventSource.USER)
mock_event_stream.add_event(MessageAction('Hello, user!'), EventSource.AGENT)
mock_short_term_history = MagicMock()
mock_short_term_history.get_last_user_message.return_value = 'Hello, agent!'
history.append(MessageAction('Hello, agent!'))
history.append(MessageAction('Hello, user!'))
mock_state = Mock()
mock_state.history = mock_short_term_history
mock_state.history = history
mock_state.max_iterations = 5
mock_state.iteration = 0