mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Refactor history/event stream (#3808)
This commit is contained in:
parent
edfba4618a
commit
eeb2342509
@ -161,7 +161,7 @@ Pour créer un workflow d'évaluation pour votre benchmark, suivez ces étapes :
|
||||
instruction=instruction,
|
||||
test_result=evaluation_result,
|
||||
metadata=metadata,
|
||||
history=state.history.compatibility_for_eval_history_pairs(),
|
||||
history=compatibility_for_eval_history_pairs(state.history),
|
||||
metrics=state.metrics.get() if state.metrics else None,
|
||||
error=state.last_error if state and state.last_error else None,
|
||||
)
|
||||
@ -260,7 +260,7 @@ def codeact_user_response(state: State | None) -> str:
|
||||
# vérifier si l'agent a essayé de parler à l'utilisateur 3 fois, si oui, faire savoir à l'agent qu'il peut abandonner
|
||||
user_msgs = [
|
||||
event
|
||||
for event in state.history.get_events()
|
||||
for event in state.history
|
||||
if isinstance(event, MessageAction) and event.source == 'user'
|
||||
]
|
||||
if len(user_msgs) >= 2:
|
||||
|
||||
@ -158,7 +158,7 @@ OpenHands 的主要入口点在 `openhands/core/main.py` 中。以下是它工
|
||||
instruction=instruction,
|
||||
test_result=evaluation_result,
|
||||
metadata=metadata,
|
||||
history=state.history.compatibility_for_eval_history_pairs(),
|
||||
history=compatibility_for_eval_history_pairs(state.history),
|
||||
metrics=state.metrics.get() if state.metrics else None,
|
||||
error=state.last_error if state and state.last_error else None,
|
||||
)
|
||||
@ -257,7 +257,7 @@ def codeact_user_response(state: State | None) -> str:
|
||||
# 检查代理是否已尝试与用户对话 3 次,如果是,让代理知道它可以放弃
|
||||
user_msgs = [
|
||||
event
|
||||
for event in state.history.get_events()
|
||||
for event in state.history
|
||||
if isinstance(event, MessageAction) and event.source == 'user'
|
||||
]
|
||||
if len(user_msgs) >= 2:
|
||||
|
||||
@ -158,7 +158,7 @@ To create an evaluation workflow for your benchmark, follow these steps:
|
||||
instruction=instruction,
|
||||
test_result=evaluation_result,
|
||||
metadata=metadata,
|
||||
history=state.history.compatibility_for_eval_history_pairs(),
|
||||
history=compatibility_for_eval_history_pairs(state.history),
|
||||
metrics=state.metrics.get() if state.metrics else None,
|
||||
error=state.last_error if state and state.last_error else None,
|
||||
)
|
||||
@ -257,7 +257,7 @@ def codeact_user_response(state: State | None) -> str:
|
||||
# check if the agent has tried to talk to the user 3 times, if so, let the agent know it can give up
|
||||
user_msgs = [
|
||||
event
|
||||
for event in state.history.get_events()
|
||||
for event in state.history
|
||||
if isinstance(event, MessageAction) and event.source == 'user'
|
||||
]
|
||||
if len(user_msgs) >= 2:
|
||||
|
||||
@ -8,6 +8,7 @@ from evaluation.EDA.game import Q20Game, Q20GameCelebrity
|
||||
from evaluation.utils.shared import (
|
||||
EvalMetadata,
|
||||
EvalOutput,
|
||||
compatibility_for_eval_history_pairs,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
@ -34,7 +35,7 @@ def codeact_user_response_eda(state: State) -> str:
|
||||
|
||||
# retrieve the latest model message from history
|
||||
if state.history:
|
||||
model_guess = state.history.get_last_agent_message()
|
||||
model_guess = state.get_last_agent_message()
|
||||
|
||||
assert game is not None, 'Game is not initialized.'
|
||||
msg = game.generate_user_response(model_guess)
|
||||
@ -139,7 +140,7 @@ def process_instance(
|
||||
if state is None:
|
||||
raise ValueError('State should not be None.')
|
||||
|
||||
final_message = state.history.get_last_agent_message()
|
||||
final_message = state.get_last_agent_message()
|
||||
|
||||
logger.info(f'Final message: {final_message} | Ground truth: {instance["text"]}')
|
||||
test_result = game.reward()
|
||||
@ -148,7 +149,7 @@ def process_instance(
|
||||
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
|
||||
# for compatibility with the existing output format, we can remake the pairs here
|
||||
# remove when it becomes unnecessary
|
||||
histories = state.history.compatibility_for_eval_history_pairs()
|
||||
histories = compatibility_for_eval_history_pairs(state.history)
|
||||
|
||||
# Save the output
|
||||
output = EvalOutput(
|
||||
|
||||
@ -16,6 +16,7 @@ from evaluation.agent_bench.helper import (
|
||||
from evaluation.utils.shared import (
|
||||
EvalMetadata,
|
||||
EvalOutput,
|
||||
compatibility_for_eval_history_pairs,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
@ -242,7 +243,7 @@ def process_instance(
|
||||
raw_ans = ''
|
||||
|
||||
# retrieve the last agent message or thought
|
||||
for event in state.history.get_events(reverse=True):
|
||||
for event in reversed(state.history):
|
||||
if event.source == 'agent':
|
||||
if isinstance(event, AgentFinishAction):
|
||||
raw_ans = event.thought
|
||||
@ -271,7 +272,7 @@ def process_instance(
|
||||
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
|
||||
# for compatibility with the existing output format, we can remake the pairs here
|
||||
# remove when it becomes unnecessary
|
||||
histories = state.history.compatibility_for_eval_history_pairs()
|
||||
histories = compatibility_for_eval_history_pairs(state.history)
|
||||
|
||||
metrics = state.metrics.get() if state.metrics else None
|
||||
|
||||
|
||||
@ -15,6 +15,7 @@ from evaluation.aider_bench.helper import (
|
||||
from evaluation.utils.shared import (
|
||||
EvalMetadata,
|
||||
EvalOutput,
|
||||
compatibility_for_eval_history_pairs,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
@ -250,7 +251,7 @@ def process_instance(
|
||||
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
|
||||
# for compatibility with the existing output format, we can remake the pairs here
|
||||
# remove when it becomes unnecessary
|
||||
histories = state.history.compatibility_for_eval_history_pairs()
|
||||
histories = compatibility_for_eval_history_pairs(state.history)
|
||||
metrics = state.metrics.get() if state.metrics else None
|
||||
|
||||
# Save the output
|
||||
|
||||
@ -13,6 +13,7 @@ from evaluation.utils.shared import (
|
||||
EvalMetadata,
|
||||
EvalOutput,
|
||||
codeact_user_response,
|
||||
compatibility_for_eval_history_pairs,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
@ -299,7 +300,7 @@ def process_instance(
|
||||
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
|
||||
# for compatibility with the existing output format, we can remake the pairs here
|
||||
# remove when it becomes unnecessary
|
||||
histories = state.history.compatibility_for_eval_history_pairs()
|
||||
histories = compatibility_for_eval_history_pairs(state.history)
|
||||
|
||||
test_result['generated'] = test_result['metadata']['1_copy_change_code']
|
||||
|
||||
|
||||
@ -16,6 +16,7 @@ from tqdm import tqdm
|
||||
from evaluation.utils.shared import (
|
||||
EvalMetadata,
|
||||
EvalOutput,
|
||||
compatibility_for_eval_history_pairs,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
@ -46,7 +47,7 @@ def codeact_user_response(state: State) -> str:
|
||||
# check if the agent has tried to talk to the user 3 times, if so, let the agent know it can give up
|
||||
user_msgs = [
|
||||
event
|
||||
for event in state.history.get_events()
|
||||
for event in state.history
|
||||
if isinstance(event, MessageAction) and event.source == 'user'
|
||||
]
|
||||
if len(user_msgs) > 2:
|
||||
@ -431,7 +432,7 @@ def process_instance(
|
||||
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
|
||||
# for compatibility with the existing output format, we can remake the pairs here
|
||||
# remove when it becomes unnecessary
|
||||
histories = state.history.compatibility_for_eval_history_pairs()
|
||||
histories = compatibility_for_eval_history_pairs(state.history)
|
||||
|
||||
# Save the output
|
||||
output = EvalOutput(
|
||||
|
||||
@ -9,6 +9,7 @@ from datasets import load_dataset
|
||||
from evaluation.utils.shared import (
|
||||
EvalMetadata,
|
||||
EvalOutput,
|
||||
compatibility_for_eval_history_pairs,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
@ -89,7 +90,7 @@ def process_instance(
|
||||
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
|
||||
# for compatibility with the existing output format, we can remake the pairs here
|
||||
# remove when it becomes unnecessary
|
||||
histories = state.history.compatibility_for_eval_history_pairs()
|
||||
histories = compatibility_for_eval_history_pairs(state.history)
|
||||
|
||||
# find the last delegate action
|
||||
last_delegate_action = None
|
||||
|
||||
@ -15,6 +15,7 @@ from evaluation.utils.shared import (
|
||||
EvalMetadata,
|
||||
EvalOutput,
|
||||
codeact_user_response,
|
||||
compatibility_for_eval_history_pairs,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
@ -173,14 +174,14 @@ def initialize_runtime(runtime: Runtime, data_files: list[str]):
|
||||
|
||||
|
||||
def get_last_agent_finish_action(state: State) -> AgentFinishAction:
|
||||
for event in state.history.get_events(reverse=True):
|
||||
for event in reversed(state.history):
|
||||
if isinstance(event, AgentFinishAction):
|
||||
return event
|
||||
return None
|
||||
|
||||
|
||||
def get_last_message_action(state: State) -> MessageAction:
|
||||
for event in state.history.get_events(reverse=True):
|
||||
for event in reversed(state.history):
|
||||
if isinstance(event, MessageAction):
|
||||
return event
|
||||
return None
|
||||
@ -307,7 +308,7 @@ def process_instance(
|
||||
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
|
||||
# for compatibility with the existing output format, we can remake the pairs here
|
||||
# remove when it becomes unnecessary
|
||||
histories = state.history.compatibility_for_eval_history_pairs()
|
||||
histories = compatibility_for_eval_history_pairs(state.history)
|
||||
|
||||
# DiscoveryBench Evaluation
|
||||
eval_rec = run_eval_gold_vs_gen_NL_hypo_workflow(
|
||||
|
||||
@ -12,6 +12,7 @@ from evaluation.utils.shared import (
|
||||
EvalMetadata,
|
||||
EvalOutput,
|
||||
codeact_user_response,
|
||||
compatibility_for_eval_history_pairs,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
@ -166,7 +167,7 @@ def process_instance(
|
||||
|
||||
model_answer_raw = ''
|
||||
# get the last message or thought from the agent
|
||||
for event in state.history.get_events(reverse=True):
|
||||
for event in reversed(state.history):
|
||||
if event.source == 'agent':
|
||||
if isinstance(event, AgentFinishAction):
|
||||
model_answer_raw = event.thought
|
||||
@ -203,7 +204,7 @@ def process_instance(
|
||||
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
|
||||
# for compatibility with the existing output format, we can remake the pairs here
|
||||
# remove when it becomes unnecessary
|
||||
histories = state.history.compatibility_for_eval_history_pairs()
|
||||
histories = compatibility_for_eval_history_pairs(state.history)
|
||||
|
||||
# Save the output
|
||||
output = EvalOutput(
|
||||
|
||||
@ -10,6 +10,7 @@ from evaluation.utils.shared import (
|
||||
EvalMetadata,
|
||||
EvalOutput,
|
||||
codeact_user_response,
|
||||
compatibility_for_eval_history_pairs,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
@ -101,7 +102,7 @@ def process_instance(
|
||||
raise ValueError('State should not be None.')
|
||||
|
||||
# retrieve the last message from the agent
|
||||
model_answer_raw = state.history.get_last_agent_message()
|
||||
model_answer_raw = state.get_last_agent_message()
|
||||
|
||||
# attempt to parse model_answer
|
||||
ast_eval_fn = instance['ast_eval']
|
||||
@ -114,7 +115,7 @@ def process_instance(
|
||||
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
|
||||
# for compatibility with the existing output format, we can remake the pairs here
|
||||
# remove when it becomes unnecessary
|
||||
histories = state.history.compatibility_for_eval_history_pairs()
|
||||
histories = compatibility_for_eval_history_pairs(state.history)
|
||||
|
||||
output = EvalOutput(
|
||||
instance_id=instance_id,
|
||||
|
||||
@ -28,6 +28,7 @@ from datasets import load_dataset
|
||||
from evaluation.utils.shared import (
|
||||
EvalMetadata,
|
||||
EvalOutput,
|
||||
compatibility_for_eval_history_pairs,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
@ -244,7 +245,7 @@ Ok now its time to start solving the question. Good luck!
|
||||
'C': False,
|
||||
'D': False,
|
||||
}
|
||||
for event in state.history.get_events(reverse=True):
|
||||
for event in reversed(state.history):
|
||||
if (
|
||||
isinstance(event, AgentFinishAction)
|
||||
and event.source != 'user'
|
||||
@ -300,7 +301,7 @@ Ok now its time to start solving the question. Good luck!
|
||||
instance_id=str(instance.instance_id),
|
||||
instruction=instruction,
|
||||
metadata=metadata,
|
||||
history=state.history.compatibility_for_eval_history_pairs(),
|
||||
history=compatibility_for_eval_history_pairs(state.history),
|
||||
metrics=metrics,
|
||||
error=state.last_error if state and state.last_error else None,
|
||||
test_result={
|
||||
|
||||
@ -21,6 +21,7 @@ from evaluation.utils.shared import (
|
||||
EvalMetadata,
|
||||
EvalOutput,
|
||||
codeact_user_response,
|
||||
compatibility_for_eval_history_pairs,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
@ -255,7 +256,7 @@ def process_instance(
|
||||
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
|
||||
# for compatibility with the existing output format, we can remake the pairs here
|
||||
# remove when it becomes unnecessary
|
||||
histories = state.history.compatibility_for_eval_history_pairs()
|
||||
histories = compatibility_for_eval_history_pairs(state.history)
|
||||
|
||||
# Save the output
|
||||
output = EvalOutput(
|
||||
|
||||
@ -129,7 +129,7 @@ def process_instance(
|
||||
# # result evaluation
|
||||
# # =============================================
|
||||
|
||||
histories = [event_to_dict(event) for event in state.history.get_events()]
|
||||
histories = [event_to_dict(event) for event in state.history]
|
||||
test_result: TestResult = test_class.verify_result(runtime, histories)
|
||||
metrics = state.metrics.get() if state.metrics else None
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@ from evaluation.utils.shared import (
|
||||
EvalMetadata,
|
||||
EvalOutput,
|
||||
codeact_user_response,
|
||||
compatibility_for_eval_history_pairs,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
@ -225,7 +226,7 @@ def process_instance(
|
||||
raise ValueError('State should not be None.')
|
||||
|
||||
final_message = ''
|
||||
for event in state.history.get_events(reverse=True):
|
||||
for event in reversed(state.history):
|
||||
if isinstance(event, AgentFinishAction):
|
||||
final_message = event.thought
|
||||
break
|
||||
@ -247,7 +248,7 @@ def process_instance(
|
||||
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
|
||||
# for compatibility with the existing output format, we can remake the pairs here
|
||||
# remove when it becomes unnecessary
|
||||
histories = state.history.compatibility_for_eval_history_pairs()
|
||||
histories = compatibility_for_eval_history_pairs(state.history)
|
||||
|
||||
# Save the output
|
||||
output = EvalOutput(
|
||||
|
||||
@ -11,6 +11,7 @@ from evaluation.utils.shared import (
|
||||
EvalMetadata,
|
||||
EvalOutput,
|
||||
codeact_user_response,
|
||||
compatibility_for_eval_history_pairs,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
@ -182,7 +183,7 @@ def process_instance(
|
||||
|
||||
# Instruction is the first message from the USER
|
||||
instruction = ''
|
||||
for event in state.history.get_events():
|
||||
for event in state.history:
|
||||
if isinstance(event, MessageAction):
|
||||
instruction = event.content
|
||||
break
|
||||
@ -194,7 +195,7 @@ def process_instance(
|
||||
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
|
||||
# for compatibility with the existing output format, we can remake the pairs here
|
||||
# remove when it becomes unnecessary
|
||||
histories = state.history.compatibility_for_eval_history_pairs()
|
||||
histories = compatibility_for_eval_history_pairs(state.history)
|
||||
|
||||
# Save the output
|
||||
output = EvalOutput(
|
||||
|
||||
@ -13,6 +13,7 @@ from evaluation.mint.tasks import Task
|
||||
from evaluation.utils.shared import (
|
||||
EvalMetadata,
|
||||
EvalOutput,
|
||||
compatibility_for_eval_history_pairs,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
@ -28,6 +29,7 @@ from openhands.core.config import (
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.main import create_runtime, run_controller
|
||||
from openhands.events.action import (
|
||||
Action,
|
||||
CmdRunAction,
|
||||
MessageAction,
|
||||
)
|
||||
@ -45,7 +47,10 @@ def codeact_user_response_mint(state: State, task: Task, task_config: dict[str,
|
||||
task=task,
|
||||
task_config=task_config,
|
||||
)
|
||||
last_action = state.history.get_last_action()
|
||||
last_action = next(
|
||||
(event for event in reversed(state.history) if isinstance(event, Action)),
|
||||
None,
|
||||
)
|
||||
result_state: TaskState = env.step(last_action.message or '')
|
||||
|
||||
state.extra_data['task_state'] = result_state
|
||||
@ -202,7 +207,7 @@ def process_instance(
|
||||
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
|
||||
# for compatibility with the existing output format, we can remake the pairs here
|
||||
# remove when it becomes unnecessary
|
||||
histories = state.history.compatibility_for_eval_history_pairs()
|
||||
histories = compatibility_for_eval_history_pairs(state.history)
|
||||
|
||||
# Save the output
|
||||
output = EvalOutput(
|
||||
|
||||
@ -24,6 +24,7 @@ from evaluation.utils.shared import (
|
||||
EvalMetadata,
|
||||
EvalOutput,
|
||||
codeact_user_response,
|
||||
compatibility_for_eval_history_pairs,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
@ -256,7 +257,7 @@ def process_instance(instance: Any, metadata: EvalMetadata, reset_logger: bool =
|
||||
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
|
||||
# for compatibility with the existing output format, we can remake the pairs here
|
||||
# remove when it becomes unnecessary
|
||||
histories = state.history.compatibility_for_eval_history_pairs()
|
||||
histories = compatibility_for_eval_history_pairs(state.history)
|
||||
|
||||
# Save the output
|
||||
output = EvalOutput(
|
||||
|
||||
@ -10,6 +10,7 @@ from evaluation.utils.shared import (
|
||||
EvalMetadata,
|
||||
EvalOutput,
|
||||
codeact_user_response,
|
||||
compatibility_for_eval_history_pairs,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
@ -232,7 +233,7 @@ If the program uses some packages that are incompatible, please figure out alter
|
||||
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
|
||||
# for compatibility with the existing output format, we can remake the pairs here
|
||||
# remove when it becomes unnecessary
|
||||
histories = state.history.compatibility_for_eval_history_pairs()
|
||||
histories = compatibility_for_eval_history_pairs(state.history)
|
||||
|
||||
# Save the output
|
||||
output = EvalOutput(
|
||||
|
||||
@ -443,7 +443,8 @@ def process_instance(
|
||||
if state is None:
|
||||
raise ValueError('State should not be None.')
|
||||
|
||||
histories = [event_to_dict(event) for event in state.history.get_events()]
|
||||
# NOTE: this is NO LONGER the event stream, but an agent history that includes delegate agent's events
|
||||
histories = [event_to_dict(event) for event in state.history]
|
||||
metrics = state.metrics.get() if state.metrics else None
|
||||
|
||||
# Save the output
|
||||
|
||||
@ -9,6 +9,7 @@ from evaluation.utils.shared import (
|
||||
EvalMetadata,
|
||||
EvalOutput,
|
||||
codeact_user_response,
|
||||
compatibility_for_eval_history_pairs,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
@ -126,7 +127,7 @@ def process_instance(instance: Any, metadata: EvalMetadata, reset_logger: bool =
|
||||
raise ValueError('State should not be None.')
|
||||
|
||||
# retrieve the last message from the agent
|
||||
model_answer_raw = state.history.get_last_agent_message()
|
||||
model_answer_raw = state.get_last_agent_message()
|
||||
|
||||
# attempt to parse model_answer
|
||||
correct = eval_answer(str(model_answer_raw), str(answer))
|
||||
@ -137,7 +138,7 @@ def process_instance(instance: Any, metadata: EvalMetadata, reset_logger: bool =
|
||||
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
|
||||
# for compatibility with the existing output format, we can remake the pairs here
|
||||
# remove when it becomes unnecessary
|
||||
histories = state.history.compatibility_for_eval_history_pairs()
|
||||
histories = compatibility_for_eval_history_pairs(state.history)
|
||||
|
||||
# Save the output
|
||||
output = EvalOutput(
|
||||
|
||||
@ -18,6 +18,9 @@ from openhands.core.logger import get_console_handler
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action import Action
|
||||
from openhands.events.action.message import MessageAction
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.serialization.event import event_to_dict
|
||||
from openhands.events.utils import get_pairs_from_events
|
||||
|
||||
|
||||
class EvalMetadata(BaseModel):
|
||||
@ -112,7 +115,14 @@ def codeact_user_response(
|
||||
if state.history:
|
||||
# check if the last action has an answer, if so, early exit
|
||||
if try_parse is not None:
|
||||
last_action = state.history.get_last_action()
|
||||
last_action = next(
|
||||
(
|
||||
event
|
||||
for event in reversed(state.history)
|
||||
if isinstance(event, Action)
|
||||
),
|
||||
None,
|
||||
)
|
||||
ans = try_parse(last_action)
|
||||
if ans is not None:
|
||||
return '/exit'
|
||||
@ -120,7 +130,7 @@ def codeact_user_response(
|
||||
# check if the agent has tried to talk to the user 3 times, if so, let the agent know it can give up
|
||||
user_msgs = [
|
||||
event
|
||||
for event in state.history.get_events()
|
||||
for event in state.history
|
||||
if isinstance(event, MessageAction) and event.source == 'user'
|
||||
]
|
||||
if len(user_msgs) >= 2:
|
||||
@ -428,3 +438,18 @@ def update_llm_config_for_completions_logging(
|
||||
f'{llm_config.log_completions_folder}'
|
||||
)
|
||||
return llm_config
|
||||
|
||||
|
||||
# history is now available as a filtered stream of events, rather than list of pairs of (Action, Observation)
|
||||
# we rebuild the pairs here
|
||||
# for compatibility with the existing output format in evaluations
|
||||
# remove this when it's no longer necessary
|
||||
def compatibility_for_eval_history_pairs(
|
||||
history: list[Event],
|
||||
) -> list[tuple[dict, dict]]:
|
||||
history_pairs = []
|
||||
|
||||
for action, observation in get_pairs_from_events(history):
|
||||
history_pairs.append((event_to_dict(action), event_to_dict(observation)))
|
||||
|
||||
return history_pairs
|
||||
|
||||
@ -10,6 +10,7 @@ import pandas as pd
|
||||
from evaluation.utils.shared import (
|
||||
EvalMetadata,
|
||||
EvalOutput,
|
||||
compatibility_for_eval_history_pairs,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
@ -166,7 +167,7 @@ def process_instance(
|
||||
|
||||
# Instruction is the first message from the USER
|
||||
instruction = ''
|
||||
for event in state.history.get_events():
|
||||
for event in state.history:
|
||||
if isinstance(event, MessageAction):
|
||||
instruction = event.content
|
||||
break
|
||||
@ -178,7 +179,7 @@ def process_instance(
|
||||
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
|
||||
# for compatibility with the existing output format, we can remake the pairs here
|
||||
# remove when it becomes unnecessary
|
||||
histories = state.history.compatibility_for_eval_history_pairs()
|
||||
histories = compatibility_for_eval_history_pairs(state.history)
|
||||
|
||||
# Save the output
|
||||
output = EvalOutput(
|
||||
|
||||
@ -150,13 +150,13 @@ class BrowsingAgent(Agent):
|
||||
last_obs = None
|
||||
last_action = None
|
||||
|
||||
if EVAL_MODE and len(state.history.get_events_as_list()) == 1:
|
||||
if EVAL_MODE and len(state.history) == 1:
|
||||
# for webarena and miniwob++ eval, we need to retrieve the initial observation already in browser env
|
||||
# initialize and retrieve the first observation by issuing an noop OP
|
||||
# For non-benchmark browsing, the browser env starts with a blank page, and the agent is expected to first navigate to desired websites
|
||||
return BrowseInteractiveAction(browser_actions='noop()')
|
||||
|
||||
for event in state.history.get_events():
|
||||
for event in state.history:
|
||||
if isinstance(event, BrowseInteractiveAction):
|
||||
prev_actions.append(event.browser_actions)
|
||||
last_action = event
|
||||
|
||||
@ -337,8 +337,8 @@ class CodeActAgent(Agent):
|
||||
return self.pending_actions.popleft()
|
||||
|
||||
# if we're done, go back
|
||||
latest_user_message = state.history.get_last_user_message()
|
||||
if latest_user_message and latest_user_message.strip() == '/exit':
|
||||
last_user_message = state.get_last_user_message()
|
||||
if last_user_message and last_user_message.strip() == '/exit':
|
||||
return AgentFinishAction()
|
||||
|
||||
# prepare what we want to send to the LLM
|
||||
@ -419,7 +419,7 @@ class CodeActAgent(Agent):
|
||||
|
||||
pending_tool_call_action_messages: dict[str, Message] = {}
|
||||
tool_call_id_to_message: dict[str, Message] = {}
|
||||
events = list(state.history.get_events())
|
||||
events = list(state.history)
|
||||
for event in events:
|
||||
# create a regular message from an event
|
||||
if isinstance(event, Action):
|
||||
|
||||
@ -154,8 +154,8 @@ class CodeActSWEAgent(Agent):
|
||||
- AgentFinishAction() - end the interaction
|
||||
"""
|
||||
# if we're done, go back
|
||||
latest_user_message = state.history.get_last_user_message()
|
||||
if latest_user_message and latest_user_message.strip() == '/exit':
|
||||
last_user_message = state.get_last_user_message()
|
||||
if last_user_message and last_user_message.strip() == '/exit':
|
||||
return AgentFinishAction()
|
||||
|
||||
# prepare what we want to send to the LLM
|
||||
@ -176,7 +176,7 @@ class CodeActSWEAgent(Agent):
|
||||
Message(role='user', content=[TextContent(text=self.in_context_example)]),
|
||||
]
|
||||
|
||||
for event in state.history.get_events():
|
||||
for event in state.history:
|
||||
# create a regular message from an event
|
||||
if isinstance(event, Action):
|
||||
message = self.get_action_message(event)
|
||||
|
||||
@ -2,7 +2,7 @@ from openhands.controller.agent import Agent
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config import AgentConfig
|
||||
from openhands.events.action import Action, AgentDelegateAction, AgentFinishAction
|
||||
from openhands.events.observation import AgentDelegateObservation
|
||||
from openhands.events.observation import AgentDelegateObservation, Observation
|
||||
from openhands.llm.llm import LLM
|
||||
|
||||
|
||||
@ -41,7 +41,11 @@ class DelegatorAgent(Agent):
|
||||
)
|
||||
|
||||
# last observation in history should be from the delegate
|
||||
last_observation = state.history.get_last_observation()
|
||||
last_observation = None
|
||||
for event in reversed(state.history):
|
||||
if isinstance(event, Observation):
|
||||
last_observation = event
|
||||
break
|
||||
|
||||
if not isinstance(last_observation, AgentDelegateObservation):
|
||||
raise Exception('Last observation is not an AgentDelegateObservation')
|
||||
|
||||
@ -164,7 +164,7 @@ class DummyAgent(Agent):
|
||||
|
||||
if 'observations' in prev_step and prev_step['observations']:
|
||||
expected_observations = prev_step['observations']
|
||||
hist_events = state.history.get_last_events(len(expected_observations))
|
||||
hist_events = state.history[-len(expected_observations) :]
|
||||
|
||||
if len(hist_events) < len(expected_observations):
|
||||
print(
|
||||
|
||||
@ -8,10 +8,10 @@ from openhands.core.config import AgentConfig
|
||||
from openhands.core.message import ImageContent, Message, TextContent
|
||||
from openhands.core.utils import json
|
||||
from openhands.events.action import Action
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.serialization.action import action_from_dict
|
||||
from openhands.events.serialization.event import event_to_memory
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.memory.history import ShortTermHistory
|
||||
|
||||
|
||||
def parse_response(orig_response: str) -> Action:
|
||||
@ -32,16 +32,14 @@ class MicroAgent(Agent):
|
||||
prompt = ''
|
||||
agent_definition: dict = {}
|
||||
|
||||
def history_to_json(
|
||||
self, history: ShortTermHistory, max_events: int = 20, **kwargs
|
||||
):
|
||||
def history_to_json(self, history: list[Event], max_events: int = 20, **kwargs):
|
||||
"""
|
||||
Serialize and simplify history to str format
|
||||
"""
|
||||
processed_history = []
|
||||
event_count = 0
|
||||
|
||||
for event in history.get_events(reverse=True):
|
||||
for event in reversed(history):
|
||||
if event_count >= max_events:
|
||||
break
|
||||
processed_history.append(
|
||||
|
||||
@ -117,7 +117,7 @@ def get_hint(latest_action_id: str) -> str:
|
||||
|
||||
def get_prompt_and_images(
|
||||
state: State, max_message_chars: int
|
||||
) -> tuple[str, list[str]]:
|
||||
) -> tuple[str, list[str] | None]:
|
||||
"""Gets the prompt for the planner agent.
|
||||
|
||||
Formatted with the most recent action-observation pairs, current task, and hint based on last action
|
||||
@ -136,7 +136,7 @@ def get_prompt_and_images(
|
||||
latest_action: Action = NullAction()
|
||||
|
||||
# retrieve the latest HISTORY_SIZE events
|
||||
for event_count, event in enumerate(state.history.get_events(reverse=True)):
|
||||
for event_count, event in enumerate(reversed(state.history)):
|
||||
if event_count >= HISTORY_SIZE:
|
||||
break
|
||||
if latest_action == NullAction() and isinstance(event, Action):
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import traceback
|
||||
from typing import Callable, Type
|
||||
from typing import Callable, ClassVar, Type
|
||||
|
||||
import litellm
|
||||
|
||||
@ -36,6 +36,7 @@ from openhands.events.observation import (
|
||||
AgentDelegateObservation,
|
||||
AgentStateChangedObservation,
|
||||
ErrorObservation,
|
||||
NullObservation,
|
||||
Observation,
|
||||
)
|
||||
from openhands.events.serialization.event import truncate_content
|
||||
@ -61,6 +62,12 @@ class AgentController:
|
||||
parent: 'AgentController | None' = None
|
||||
delegate: 'AgentController | None' = None
|
||||
_pending_action: Action | None = None
|
||||
filter_out: ClassVar[tuple[type[Event], ...]] = (
|
||||
NullAction,
|
||||
NullObservation,
|
||||
ChangeAgentStateAction,
|
||||
AgentStateChangedObservation,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -121,8 +128,34 @@ class AgentController:
|
||||
self.status_callback = status_callback
|
||||
|
||||
async def close(self):
|
||||
"""Closes the agent controller, canceling any ongoing tasks and unsubscribing from the event stream."""
|
||||
"""Closes the agent controller, canceling any ongoing tasks and unsubscribing from the event stream.
|
||||
|
||||
Note that it's fairly important that this closes properly, otherwise the state is incomplete."""
|
||||
await self.set_agent_state_to(AgentState.STOPPED)
|
||||
|
||||
# we made history, now is the time to rewrite it!
|
||||
# the final state.history will be used by external scripts like evals, tests, etc.
|
||||
# history will need to be complete WITH delegates events
|
||||
# like the regular agent history, it does not include:
|
||||
# - 'hidden' events, events with hidden=True
|
||||
# - backend events (the default 'filtered out' types, types in self.filter_out)
|
||||
start_id = self.state.start_id if self.state.start_id >= 0 else 0
|
||||
end_id = (
|
||||
self.state.end_id
|
||||
if self.state.end_id >= 0
|
||||
else self.event_stream.get_latest_event_id()
|
||||
)
|
||||
self.state.history = list(
|
||||
self.event_stream.get_events(
|
||||
start_id=start_id,
|
||||
end_id=end_id,
|
||||
reverse=False,
|
||||
filter_out_type=self.filter_out,
|
||||
filter_hidden=True,
|
||||
)
|
||||
)
|
||||
|
||||
# unsubscribe from the event stream
|
||||
self.event_stream.unsubscribe(EventStreamSubscriber.AGENT_CONTROLLER)
|
||||
|
||||
def log(self, level: str, message: str, extra: dict | None = None):
|
||||
@ -178,6 +211,11 @@ class AgentController:
|
||||
"""
|
||||
if hasattr(event, 'hidden') and event.hidden:
|
||||
return
|
||||
|
||||
# if the event is not filtered out, add it to the history
|
||||
if not any(isinstance(event, filter_type) for filter_type in self.filter_out):
|
||||
self.state.history.append(event)
|
||||
|
||||
if isinstance(event, Action):
|
||||
await self._handle_action(event)
|
||||
elif isinstance(event, Observation):
|
||||
@ -233,9 +271,6 @@ class AgentController:
|
||||
if self.state.agent_state == AgentState.USER_REJECTED:
|
||||
await self.set_agent_state_to(AgentState.AWAITING_USER_INPUT)
|
||||
return
|
||||
|
||||
if isinstance(observation, AgentDelegateObservation):
|
||||
self.state.history.on_event(observation)
|
||||
elif isinstance(observation, ErrorObservation):
|
||||
if self.state.agent_state == AgentState.ERROR:
|
||||
self.state.metrics.merge(self.state.local_metrics)
|
||||
@ -362,6 +397,8 @@ class AgentController:
|
||||
delegate_level=self.state.delegate_level + 1,
|
||||
# global metrics should be shared between parent and child
|
||||
metrics=self.state.metrics,
|
||||
# start on top of the stream
|
||||
start_id=self.event_stream.get_latest_event_id() + 1,
|
||||
)
|
||||
self.log(
|
||||
'debug',
|
||||
@ -463,9 +500,7 @@ class AgentController:
|
||||
|
||||
async def _delegate_step(self):
|
||||
"""Executes a single step of the delegate agent."""
|
||||
self.log('debug', 'Delegate not none, awaiting...')
|
||||
await self.delegate._step() # type: ignore[union-attr]
|
||||
self.log('debug', 'Delegate step done')
|
||||
assert self.delegate is not None
|
||||
delegate_state = self.delegate.get_agent_state()
|
||||
self.log('debug', f'Delegate state: {delegate_state}')
|
||||
@ -473,7 +508,7 @@ class AgentController:
|
||||
# update iteration that shall be shared across agents
|
||||
self.state.iteration = self.delegate.state.iteration
|
||||
|
||||
# emit AgentDelegateObservation when the delegate terminates due to error
|
||||
# emit AgentDelegateObservation to mark delegate termination due to error
|
||||
delegate_outputs = (
|
||||
self.delegate.state.outputs if self.delegate.state else {}
|
||||
)
|
||||
@ -488,10 +523,6 @@ class AgentController:
|
||||
self.delegate = None
|
||||
self.delegateAction = None
|
||||
|
||||
self.event_stream.add_event(
|
||||
ErrorObservation('Delegate agent encountered an error'),
|
||||
EventSource.AGENT,
|
||||
)
|
||||
elif delegate_state in (AgentState.FINISHED, AgentState.REJECTED):
|
||||
self.log('debug', 'Delegate agent has finished execution')
|
||||
# retrieve delegate result
|
||||
@ -574,8 +605,10 @@ class AgentController:
|
||||
max_iterations: The maximum number of iterations allowed for the task.
|
||||
confirmation_mode: Whether to enable confirmation mode.
|
||||
"""
|
||||
# state from the previous session, state from a parent agent, or a new state
|
||||
# note that this is called twice when restoring a previous session, first with state=None
|
||||
# state can come from:
|
||||
# - the previous session, in which case it has history
|
||||
# - from a parent agent, in which case it has no history
|
||||
# - None / a new state
|
||||
if state is None:
|
||||
self.state = State(
|
||||
inputs={},
|
||||
@ -585,27 +618,109 @@ class AgentController:
|
||||
else:
|
||||
self.state = state
|
||||
|
||||
# when restored from a previous session, the State object will have history, start_id, and end_id
|
||||
# connect it to the event stream
|
||||
self.state.history.set_event_stream(self.event_stream)
|
||||
if self.state.start_id <= -1:
|
||||
self.state.start_id = 0
|
||||
|
||||
# if start_id was not set in State, we're starting fresh, at the top of the stream
|
||||
start_id = self.state.start_id
|
||||
if start_id == -1:
|
||||
start_id = self.event_stream.get_latest_event_id() + 1
|
||||
else:
|
||||
self.log(
|
||||
'debug', f'AgentController {self.id} restoring from event {start_id}'
|
||||
'debug',
|
||||
f'AgentController {self.id} initializing history from event {self.state.start_id}',
|
||||
)
|
||||
|
||||
self._init_history()
|
||||
|
||||
def _init_history(self):
|
||||
"""Initializes the agent's history from the event stream.
|
||||
|
||||
The history is a list of events that:
|
||||
- Excludes events of types listed in self.filter_out
|
||||
- Excludes events with hidden=True attribute
|
||||
- For delegate events (between AgentDelegateAction and AgentDelegateObservation):
|
||||
- Excludes all events between the action and observation
|
||||
- Includes the delegate action and observation themselves
|
||||
"""
|
||||
|
||||
# define range of events to fetch
|
||||
# delegates start with a start_id and initially won't find any events
|
||||
# otherwise we're restoring a previous session
|
||||
start_id = self.state.start_id if self.state.start_id >= 0 else 0
|
||||
end_id = (
|
||||
self.state.end_id
|
||||
if self.state.end_id >= 0
|
||||
else self.event_stream.get_latest_event_id()
|
||||
)
|
||||
|
||||
# sanity check
|
||||
if start_id > end_id + 1:
|
||||
self.log(
|
||||
'debug',
|
||||
f'start_id {start_id} is greater than end_id + 1 ({end_id + 1}). History will be empty.',
|
||||
)
|
||||
self.state.history = []
|
||||
return
|
||||
|
||||
# Get all events, filtering out backend events and hidden events
|
||||
events = list(
|
||||
self.event_stream.get_events(
|
||||
start_id=start_id,
|
||||
end_id=end_id,
|
||||
reverse=False,
|
||||
filter_out_type=self.filter_out,
|
||||
filter_hidden=True,
|
||||
)
|
||||
)
|
||||
|
||||
# Find all delegate action/observation pairs
|
||||
delegate_ranges: list[tuple[int, int]] = []
|
||||
delegate_action_ids: list[int] = [] # stack of unmatched delegate action IDs
|
||||
|
||||
for event in events:
|
||||
if isinstance(event, AgentDelegateAction):
|
||||
delegate_action_ids.append(event.id)
|
||||
# Note: we can get agent=event.agent and task=event.inputs.get('task','')
|
||||
# if we need to track these in the future
|
||||
|
||||
elif isinstance(event, AgentDelegateObservation):
|
||||
# Match with most recent unmatched delegate action
|
||||
if not delegate_action_ids:
|
||||
self.log(
|
||||
'error',
|
||||
f'Found AgentDelegateObservation without matching action at id={event.id}',
|
||||
)
|
||||
continue
|
||||
|
||||
action_id = delegate_action_ids.pop()
|
||||
delegate_ranges.append((action_id, event.id))
|
||||
|
||||
# Filter out events between delegate action/observation pairs
|
||||
if delegate_ranges:
|
||||
filtered_events: list[Event] = []
|
||||
current_idx = 0
|
||||
|
||||
for start_id, end_id in sorted(delegate_ranges):
|
||||
# Add events before delegate range
|
||||
filtered_events.extend(
|
||||
event for event in events[current_idx:] if event.id < start_id
|
||||
)
|
||||
|
||||
# Add delegate action and observation
|
||||
filtered_events.extend(
|
||||
event for event in events if event.id in (start_id, end_id)
|
||||
)
|
||||
|
||||
# Update index to after delegate range
|
||||
current_idx = next(
|
||||
(i for i, e in enumerate(events) if e.id > end_id), len(events)
|
||||
)
|
||||
|
||||
# Add any remaining events after last delegate range
|
||||
filtered_events.extend(events[current_idx:])
|
||||
|
||||
self.state.history = filtered_events
|
||||
else:
|
||||
self.state.history = events
|
||||
|
||||
# make sure history is in sync
|
||||
self.state.start_id = start_id
|
||||
self.state.history.start_id = start_id
|
||||
|
||||
# if there was an end_id saved in State, set it in history
|
||||
# currently not used, later useful for delegates
|
||||
if self.state.end_id > -1:
|
||||
self.state.history.end_id = self.state.end_id
|
||||
|
||||
def _is_stuck(self):
|
||||
"""Checks if the agent or its delegate is stuck in a loop.
|
||||
|
||||
@ -11,9 +11,8 @@ from openhands.events.action import (
|
||||
MessageAction,
|
||||
)
|
||||
from openhands.events.action.agent import AgentFinishAction
|
||||
from openhands.events.observation import ErrorObservation
|
||||
from openhands.events.event import Event, EventSource
|
||||
from openhands.llm.metrics import Metrics
|
||||
from openhands.memory.history import ShortTermHistory
|
||||
from openhands.storage.files import FileStore
|
||||
|
||||
|
||||
@ -78,7 +77,7 @@ class State:
|
||||
# max number of iterations for the current task
|
||||
max_iterations: int = 100
|
||||
confirmation_mode: bool = False
|
||||
history: ShortTermHistory = field(default_factory=ShortTermHistory)
|
||||
history: list[Event] = field(default_factory=list)
|
||||
inputs: dict = field(default_factory=dict)
|
||||
outputs: dict = field(default_factory=dict)
|
||||
agent_state: AgentState = AgentState.LOADING
|
||||
@ -94,6 +93,7 @@ class State:
|
||||
start_id: int = -1
|
||||
end_id: int = -1
|
||||
almost_stuck: int = 0
|
||||
delegates: dict[tuple[int, int], tuple[str, str]] = field(default_factory=dict)
|
||||
# NOTE: This will never be used by the controller, but it can be used by different
|
||||
# evaluation tasks to store extra data needed to track the progress/state of the task.
|
||||
extra_data: dict[str, Any] = field(default_factory=dict)
|
||||
@ -116,7 +116,7 @@ class State:
|
||||
pickled = base64.b64decode(encoded)
|
||||
state = pickle.loads(pickled)
|
||||
except Exception as e:
|
||||
logger.warning(f'Failed to restore state from session: {e}')
|
||||
logger.warning(f'Could not restore state from session: {e}')
|
||||
raise e
|
||||
|
||||
# update state
|
||||
@ -130,39 +130,40 @@ class State:
|
||||
return state
|
||||
|
||||
def __getstate__(self):
|
||||
# don't pickle history, it will be restored from the event stream
|
||||
state = self.__dict__.copy()
|
||||
|
||||
# save the relevant data from recent history
|
||||
# so that we can restore it when the state is restored
|
||||
if 'history' in state:
|
||||
state['start_id'] = state['history'].start_id
|
||||
state['end_id'] = state['history'].end_id
|
||||
|
||||
# don't save history object itself
|
||||
state.pop('history', None)
|
||||
state['history'] = []
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__.update(state)
|
||||
|
||||
# recreate the history object
|
||||
# make sure we always have the attribute history
|
||||
if not hasattr(self, 'history'):
|
||||
self.history = ShortTermHistory()
|
||||
self.history = []
|
||||
|
||||
self.history.start_id = self.start_id
|
||||
self.history.end_id = self.end_id
|
||||
|
||||
|
||||
def get_current_user_intent(self):
|
||||
def get_current_user_intent(self) -> tuple[str | None, list[str] | None]:
|
||||
"""Returns the latest user message and image(if provided) that appears after a FinishAction, or the first (the task) if nothing was finished yet."""
|
||||
last_user_message = None
|
||||
last_user_message_image_urls: list[str] | None = []
|
||||
for event in self.history.get_events(reverse=True):
|
||||
for event in reversed(self.history):
|
||||
if isinstance(event, MessageAction) and event.source == 'user':
|
||||
last_user_message = event.content
|
||||
last_user_message_image_urls = event.images_urls
|
||||
elif isinstance(event, AgentFinishAction):
|
||||
if last_user_message is not None:
|
||||
return last_user_message
|
||||
return last_user_message, None
|
||||
|
||||
return last_user_message, last_user_message_image_urls
|
||||
|
||||
def get_last_agent_message(self) -> str | None:
|
||||
for event in reversed(self.history):
|
||||
if isinstance(event, MessageAction) and event.source == EventSource.AGENT:
|
||||
return event.content
|
||||
return None
|
||||
|
||||
def get_last_user_message(self) -> str | None:
|
||||
for event in reversed(self.history):
|
||||
if isinstance(event, MessageAction) and event.source == EventSource.USER:
|
||||
return event.content
|
||||
return None
|
||||
|
||||
@ -28,7 +28,7 @@ class StuckDetector:
|
||||
# filter out MessageAction with source='user' from history
|
||||
filtered_history = [
|
||||
event
|
||||
for event in self.state.history.get_events()
|
||||
for event in self.state.history
|
||||
if not (
|
||||
(isinstance(event, MessageAction) and event.source == EventSource.USER)
|
||||
or
|
||||
|
||||
@ -38,7 +38,6 @@ class AppConfig:
|
||||
e2b_api_key: The E2B API key.
|
||||
disable_color: Whether to disable color. For terminals that don't support color.
|
||||
debug: Whether to enable debugging.
|
||||
enable_cli_session: Whether to enable saving and restoring the session when run from CLI.
|
||||
file_uploads_max_file_size_mb: Maximum file size for uploads in megabytes. 0 means no limit.
|
||||
file_uploads_restrict_file_types: Whether to restrict file types for file uploads. Defaults to False.
|
||||
file_uploads_allowed_extensions: List of allowed file extensions for uploads. ['.*'] means all extensions are allowed.
|
||||
@ -67,7 +66,6 @@ class AppConfig:
|
||||
disable_color: bool = False
|
||||
jwt_secret: str = uuid.uuid4().hex
|
||||
debug: bool = False
|
||||
enable_cli_session: bool = False
|
||||
file_uploads_max_file_size_mb: int = 0
|
||||
file_uploads_restrict_file_types: bool = False
|
||||
file_uploads_allowed_extensions: list[str] = field(default_factory=lambda: ['.*'])
|
||||
|
||||
@ -125,16 +125,18 @@ async def run_controller(
|
||||
runtime = create_runtime(config, sid=sid)
|
||||
|
||||
event_stream = runtime.event_stream
|
||||
# restore cli session if enabled
|
||||
|
||||
# restore cli session if available
|
||||
initial_state = None
|
||||
if config.enable_cli_session:
|
||||
try:
|
||||
logger.debug(f'Restoring agent state from cli session {event_stream.sid}')
|
||||
initial_state = State.restore_from_session(
|
||||
event_stream.sid, event_stream.file_store
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f'Error restoring state: {e}')
|
||||
try:
|
||||
logger.debug(
|
||||
f'Trying to restore agent state from cli session {event_stream.sid} if available'
|
||||
)
|
||||
initial_state = State.restore_from_session(
|
||||
event_stream.sid, event_stream.file_store
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f'Cannot restore agent state: {e}')
|
||||
|
||||
# init controller with this initial state
|
||||
controller = AgentController(
|
||||
@ -157,7 +159,7 @@ async def run_controller(
|
||||
)
|
||||
|
||||
# start event is a MessageAction with the task, either resumed or new
|
||||
if config.enable_cli_session and initial_state is not None:
|
||||
if initial_state is not None:
|
||||
# we're resuming the previous session
|
||||
event_stream.add_event(
|
||||
MessageAction(
|
||||
@ -168,7 +170,7 @@ async def run_controller(
|
||||
),
|
||||
EventSource.USER,
|
||||
)
|
||||
elif initial_state is None:
|
||||
else:
|
||||
# init with the provided actions
|
||||
event_stream.add_event(initial_user_action, EventSource.USER)
|
||||
|
||||
@ -202,8 +204,9 @@ async def run_controller(
|
||||
logger.error(f'Exception in main loop: {e}')
|
||||
|
||||
# save session when we're about to close
|
||||
if config.enable_cli_session:
|
||||
if config.file_store is not None and config.file_store != 'memory':
|
||||
end_state = controller.get_state()
|
||||
# NOTE: the saved state does not include delegates events
|
||||
end_state.save_to_session(event_stream.sid, event_stream.file_store)
|
||||
|
||||
state = controller.get_state()
|
||||
@ -212,10 +215,7 @@ async def run_controller(
|
||||
if config.trajectories_path is not None:
|
||||
file_path = os.path.join(config.trajectories_path, sid + '.json')
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
histories = [
|
||||
event_to_trajectory(event)
|
||||
for event in state.history.get_events(include_delegates=True)
|
||||
]
|
||||
histories = [event_to_trajectory(event) for event in state.history]
|
||||
with open(file_path, 'w') as f:
|
||||
json.dump(histories, f)
|
||||
|
||||
|
||||
@ -7,7 +7,7 @@ from openhands.events.action.action import Action, ActionSecurityRisk
|
||||
@dataclass
|
||||
class MessageAction(Action):
|
||||
content: str
|
||||
images_urls: list | None = None
|
||||
images_urls: list[str] | None = None
|
||||
wait_for_response: bool = False
|
||||
action: str = ActionType.MESSAGE
|
||||
security_risk: ActionSecurityRisk | None = None
|
||||
|
||||
@ -83,12 +83,27 @@ class EventStream:
|
||||
|
||||
def get_events(
|
||||
self,
|
||||
start_id=0,
|
||||
end_id=None,
|
||||
reverse=False,
|
||||
start_id: int = 0,
|
||||
end_id: int | None = None,
|
||||
reverse: bool = False,
|
||||
filter_out_type: tuple[type[Event], ...] | None = None,
|
||||
filter_hidden=False,
|
||||
) -> Iterable[Event]:
|
||||
"""
|
||||
Retrieve events from the event stream, optionally filtering out events of a given type
|
||||
and events marked as hidden.
|
||||
|
||||
Args:
|
||||
start_id: The ID of the first event to retrieve. Defaults to 0.
|
||||
end_id: The ID of the last event to retrieve. Defaults to the last event in the stream.
|
||||
reverse: Whether to retrieve events in reverse order. Defaults to False.
|
||||
filter_out_type: A tuple of event types to filter out. Typically used to filter out backend events from the agent.
|
||||
filter_hidden: If True, filters out events with the 'hidden' attribute set to True.
|
||||
|
||||
Yields:
|
||||
Events from the stream that match the criteria.
|
||||
"""
|
||||
|
||||
def should_filter(event: Event):
|
||||
if filter_hidden and hasattr(event, 'hidden') and event.hidden:
|
||||
return True
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from openhands.memory.condenser import MemoryCondenser
|
||||
from openhands.memory.history import ShortTermHistory
|
||||
from openhands.memory.memory import LongTermMemory
|
||||
|
||||
__all__ = ['LongTermMemory', 'ShortTermHistory', 'MemoryCondenser']
|
||||
__all__ = ['LongTermMemory', 'MemoryCondenser']
|
||||
|
||||
@ -1,224 +0,0 @@
|
||||
from typing import ClassVar, Iterable
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action.action import Action
|
||||
from openhands.events.action.agent import (
|
||||
AgentDelegateAction,
|
||||
ChangeAgentStateAction,
|
||||
)
|
||||
from openhands.events.action.empty import NullAction
|
||||
from openhands.events.action.message import MessageAction
|
||||
from openhands.events.event import Event, EventSource
|
||||
from openhands.events.observation.agent import AgentStateChangedObservation
|
||||
from openhands.events.observation.delegate import AgentDelegateObservation
|
||||
from openhands.events.observation.empty import NullObservation
|
||||
from openhands.events.observation.observation import Observation
|
||||
from openhands.events.serialization.event import event_to_dict
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.events.utils import get_pairs_from_events
|
||||
|
||||
|
||||
class ShortTermHistory(list[Event]):
|
||||
"""A list of events that represents the short-term memory of the agent.
|
||||
|
||||
This class provides methods to retrieve and filter the events in the history of the running agent from the event stream.
|
||||
"""
|
||||
|
||||
start_id: int
|
||||
end_id: int
|
||||
_event_stream: EventStream
|
||||
delegates: dict[tuple[int, int], tuple[str, str]]
|
||||
filter_out: ClassVar[tuple[type[Event], ...]] = (
|
||||
NullAction,
|
||||
NullObservation,
|
||||
ChangeAgentStateAction,
|
||||
AgentStateChangedObservation,
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.start_id = -1
|
||||
self.end_id = -1
|
||||
self.delegates = {}
|
||||
|
||||
def set_event_stream(self, event_stream: EventStream):
|
||||
self._event_stream = event_stream
|
||||
|
||||
def get_events_as_list(self, include_delegates: bool = False) -> list[Event]:
|
||||
"""Return the history as a list of Event objects."""
|
||||
return list(self.get_events(include_delegates=include_delegates))
|
||||
|
||||
def get_events(
|
||||
self,
|
||||
reverse: bool = False,
|
||||
include_delegates: bool = False,
|
||||
include_hidden=False,
|
||||
) -> Iterable[Event]:
|
||||
"""Return the events as a stream of Event objects."""
|
||||
# TODO handle AgentRejectAction, if it's not part of a chunk ending with an AgentDelegateObservation
|
||||
# or even if it is, because currently we don't add it to the summary
|
||||
|
||||
# iterate from start_id to end_id, or reverse
|
||||
start_id = self.start_id if self.start_id != -1 else 0
|
||||
end_id = (
|
||||
self.end_id
|
||||
if self.end_id != -1
|
||||
else self._event_stream.get_latest_event_id()
|
||||
)
|
||||
|
||||
for event in self._event_stream.get_events(
|
||||
start_id=start_id,
|
||||
end_id=end_id,
|
||||
reverse=reverse,
|
||||
filter_out_type=self.filter_out,
|
||||
):
|
||||
if not include_hidden and hasattr(event, 'hidden') and event.hidden:
|
||||
continue
|
||||
# TODO add summaries
|
||||
# and filter out events that were included in a summary
|
||||
|
||||
# filter out the events from a delegate of the current agent
|
||||
if not include_delegates and not any(
|
||||
# except for the delegate action and observation themselves, currently
|
||||
# AgentDelegateAction has id = delegate_start
|
||||
# AgentDelegateObservation has id = delegate_end
|
||||
delegate_start < event.id < delegate_end
|
||||
for delegate_start, delegate_end in self.delegates.keys()
|
||||
):
|
||||
yield event
|
||||
elif include_delegates:
|
||||
yield event
|
||||
|
||||
def get_last_action(self, end_id: int = -1) -> Action | None:
|
||||
"""Return the last action from the event stream, filtered to exclude unwanted events."""
|
||||
# from end_id in reverse, find the first action
|
||||
end_id = self._event_stream.get_latest_event_id() if end_id == -1 else end_id
|
||||
|
||||
last_action = next(
|
||||
(
|
||||
event
|
||||
for event in self._event_stream.get_events(
|
||||
end_id=end_id, reverse=True, filter_out_type=self.filter_out
|
||||
)
|
||||
if isinstance(event, Action)
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
return last_action
|
||||
|
||||
def get_last_observation(self, end_id: int = -1) -> Observation | None:
|
||||
"""Return the last observation from the event stream, filtered to exclude unwanted events."""
|
||||
# from end_id in reverse, find the first observation
|
||||
end_id = self._event_stream.get_latest_event_id() if end_id == -1 else end_id
|
||||
|
||||
last_observation = next(
|
||||
(
|
||||
event
|
||||
for event in self._event_stream.get_events(
|
||||
end_id=end_id, reverse=True, filter_out_type=self.filter_out
|
||||
)
|
||||
if isinstance(event, Observation)
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
return last_observation
|
||||
|
||||
def get_last_user_message(self) -> str:
|
||||
"""Return the content of the last user message from the event stream."""
|
||||
last_user_message = next(
|
||||
(
|
||||
event.content
|
||||
for event in self._event_stream.get_events(reverse=True)
|
||||
if isinstance(event, MessageAction) and event.source == EventSource.USER
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
return last_user_message if last_user_message is not None else ''
|
||||
|
||||
def get_last_agent_message(self) -> str:
|
||||
"""Return the content of the last agent message from the event stream."""
|
||||
last_agent_message = next(
|
||||
(
|
||||
event.content
|
||||
for event in self._event_stream.get_events(reverse=True)
|
||||
if isinstance(event, MessageAction)
|
||||
and event.source == EventSource.AGENT
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
return last_agent_message if last_agent_message is not None else ''
|
||||
|
||||
def get_last_events(self, n: int) -> list[Event]:
|
||||
"""Return the last n events from the event stream."""
|
||||
# dummy agent is using this
|
||||
# it should work, but it's not great to store temporary lists now just for a test
|
||||
end_id = self._event_stream.get_latest_event_id()
|
||||
start_id = max(0, end_id - n + 1)
|
||||
|
||||
return list(
|
||||
event
|
||||
for event in self._event_stream.get_events(
|
||||
start_id=start_id,
|
||||
end_id=end_id,
|
||||
filter_out_type=self.filter_out,
|
||||
)
|
||||
)
|
||||
|
||||
def has_delegation(self) -> bool:
|
||||
for event in self._event_stream.get_events():
|
||||
if isinstance(event, AgentDelegateObservation):
|
||||
return True
|
||||
return False
|
||||
|
||||
def on_event(self, event: Event):
|
||||
if not isinstance(event, AgentDelegateObservation):
|
||||
return
|
||||
|
||||
logger.debug('AgentDelegateObservation received')
|
||||
|
||||
# figure out what this delegate's actions were
|
||||
# from the last AgentDelegateAction to this AgentDelegateObservation
|
||||
# and save their ids as start and end ids
|
||||
# in order to use later to exclude them from parent stream
|
||||
# or summarize them
|
||||
delegate_end = event.id
|
||||
delegate_start = -1
|
||||
delegate_agent: str = ''
|
||||
delegate_task: str = ''
|
||||
for prev_event in self._event_stream.get_events(
|
||||
end_id=event.id - 1, reverse=True
|
||||
):
|
||||
if isinstance(prev_event, AgentDelegateAction):
|
||||
delegate_start = prev_event.id
|
||||
delegate_agent = prev_event.agent
|
||||
delegate_task = prev_event.inputs.get('task', '')
|
||||
break
|
||||
|
||||
if delegate_start == -1:
|
||||
logger.error(
|
||||
f'No AgentDelegateAction found for AgentDelegateObservation with id={delegate_end}'
|
||||
)
|
||||
return
|
||||
|
||||
self.delegates[(delegate_start, delegate_end)] = (delegate_agent, delegate_task)
|
||||
logger.debug(
|
||||
f'Delegate {delegate_agent} with task {delegate_task} ran from id={delegate_start} to id={delegate_end}'
|
||||
)
|
||||
|
||||
# TODO remove me when unnecessary
|
||||
# history is now available as a filtered stream of events, rather than list of pairs of (Action, Observation)
|
||||
# we rebuild the pairs here
|
||||
# for compatibility with the existing output format in evaluations
|
||||
def compatibility_for_eval_history_pairs(self) -> list[tuple[dict, dict]]:
|
||||
history_pairs = []
|
||||
|
||||
for action, observation in get_pairs_from_events(
|
||||
self.get_events_as_list(include_delegates=True)
|
||||
):
|
||||
history_pairs.append((event_to_dict(action), event_to_dict(observation)))
|
||||
|
||||
return history_pairs
|
||||
@ -213,7 +213,9 @@ class FileEditRuntimeMixin(FileEditRuntimeInterface):
|
||||
if isinstance(obs, ErrorObservation):
|
||||
return obs
|
||||
if not isinstance(obs, FileWriteObservation):
|
||||
raise ValueError(f'Expected FileWriteObservation, got {type(obs)}: {str(obs)}')
|
||||
raise ValueError(
|
||||
f'Expected FileWriteObservation, got {type(obs)}: {str(obs)}'
|
||||
)
|
||||
return FileEditObservation(
|
||||
content=get_diff('', action.content, action.path),
|
||||
path=action.path,
|
||||
@ -222,7 +224,9 @@ class FileEditRuntimeMixin(FileEditRuntimeInterface):
|
||||
new_content=action.content,
|
||||
)
|
||||
if not isinstance(obs, FileReadObservation):
|
||||
raise ValueError(f'Expected FileReadObservation, got {type(obs)}: {str(obs)}')
|
||||
raise ValueError(
|
||||
f'Expected FileReadObservation, got {type(obs)}: {str(obs)}'
|
||||
)
|
||||
|
||||
original_file_content = obs.content
|
||||
old_file_lines = original_file_content.split('\n')
|
||||
|
||||
@ -181,7 +181,7 @@ def process_instance(
|
||||
test_result = {}
|
||||
if state is None:
|
||||
raise ValueError('State should not be None.')
|
||||
histories = [event_to_dict(event) for event in state.history.get_events()]
|
||||
histories = [event_to_dict(event) for event in state.history]
|
||||
metrics = state.metrics.get() if state.metrics else None
|
||||
|
||||
# Save the output
|
||||
|
||||
@ -104,5 +104,5 @@ def test_error_observation_message(agent: CodeActAgent):
|
||||
def test_unknown_observation_message(agent: CodeActAgent):
|
||||
obs = Mock()
|
||||
|
||||
with pytest.raises(ValueError, match='Unknown observation type:'):
|
||||
with pytest.raises(ValueError, match='Unknown observation type'):
|
||||
agent.get_observation_message(obs, tool_call_id_to_message={})
|
||||
|
||||
@ -17,8 +17,6 @@ from openhands.events.observation.commands import IPythonRunCellObservation
|
||||
from openhands.events.observation.empty import NullObservation
|
||||
from openhands.events.observation.error import ErrorObservation
|
||||
from openhands.events.stream import EventSource, EventStream
|
||||
from openhands.events.utils import get_pairs_from_events
|
||||
from openhands.memory.history import ShortTermHistory
|
||||
from openhands.storage import get_file_store
|
||||
|
||||
|
||||
@ -55,22 +53,21 @@ def event_stream(temp_dir):
|
||||
|
||||
class TestStuckDetector:
|
||||
@pytest.fixture
|
||||
def stuck_detector(self, event_stream):
|
||||
def stuck_detector(self):
|
||||
state = State(inputs={}, max_iterations=50)
|
||||
state.history.set_event_stream(event_stream)
|
||||
|
||||
state.history = [] # Initialize history as an empty list
|
||||
return StuckDetector(state)
|
||||
|
||||
def _impl_syntax_error_events(
|
||||
self,
|
||||
event_stream: EventStream,
|
||||
state: State,
|
||||
error_message: str,
|
||||
random_line: bool,
|
||||
incidents: int = 4,
|
||||
):
|
||||
for i in range(incidents):
|
||||
ipython_action = IPythonRunCellAction(code=code_snippet)
|
||||
event_stream.add_event(ipython_action, EventSource.AGENT)
|
||||
state.history.append(ipython_action)
|
||||
extra_number = (i + 1) * 10 if random_line else '42'
|
||||
extra_line = '\n' * (i + 1) if random_line else ''
|
||||
ipython_observation = IPythonRunCellObservation(
|
||||
@ -79,15 +76,15 @@ class TestStuckDetector:
|
||||
f'{error_message}{extra_line}' + jupyter_line_1 + jupyter_line_2,
|
||||
code=code_snippet,
|
||||
)
|
||||
ipython_observation._cause = ipython_action._id
|
||||
event_stream.add_event(ipython_observation, EventSource.ENVIRONMENT)
|
||||
# ipython_observation._cause = ipython_action._id
|
||||
state.history.append(ipython_observation)
|
||||
|
||||
def _impl_unterminated_string_error_events(
|
||||
self, event_stream: EventStream, random_line: bool, incidents: int = 4
|
||||
self, state: State, random_line: bool, incidents: int = 4
|
||||
):
|
||||
for i in range(incidents):
|
||||
ipython_action = IPythonRunCellAction(code=code_snippet)
|
||||
event_stream.add_event(ipython_action, EventSource.AGENT)
|
||||
state.history.append(ipython_action)
|
||||
line_number = (i + 1) * 10 if random_line else '1'
|
||||
ipython_observation = IPythonRunCellObservation(
|
||||
content=f'print(" Cell In[1], line {line_number}\nhello\n ^\nSyntaxError: unterminated string literal (detected at line {line_number})'
|
||||
@ -95,34 +92,30 @@ class TestStuckDetector:
|
||||
+ jupyter_line_2,
|
||||
code=code_snippet,
|
||||
)
|
||||
ipython_observation._cause = ipython_action._id
|
||||
event_stream.add_event(ipython_observation, EventSource.ENVIRONMENT)
|
||||
# ipython_observation._cause = ipython_action._
|
||||
state.history.append(ipython_observation)
|
||||
|
||||
def test_history_too_short(
|
||||
self, stuck_detector: StuckDetector, event_stream: EventStream
|
||||
):
|
||||
def test_history_too_short(self, stuck_detector: StuckDetector):
|
||||
state = stuck_detector.state
|
||||
message_action = MessageAction(content='Hello', wait_for_response=False)
|
||||
message_action._source = EventSource.USER
|
||||
observation = NullObservation(content='')
|
||||
observation._cause = message_action.id
|
||||
event_stream.add_event(message_action, EventSource.USER)
|
||||
event_stream.add_event(observation, EventSource.ENVIRONMENT)
|
||||
# observation._cause = message_action.id
|
||||
state.history.append(message_action)
|
||||
state.history.append(observation)
|
||||
|
||||
cmd_action = CmdRunAction(command='ls')
|
||||
event_stream.add_event(cmd_action, EventSource.AGENT)
|
||||
state.history.append(cmd_action)
|
||||
cmd_observation = CmdOutputObservation(
|
||||
command_id=1, command='ls', content='file1.txt\nfile2.txt'
|
||||
)
|
||||
cmd_observation._cause = cmd_action._id
|
||||
event_stream.add_event(cmd_observation, EventSource.ENVIRONMENT)
|
||||
|
||||
# stuck_detector.state.history.set_event_stream(event_stream)
|
||||
# cmd_observation._cause = cmd_action._id
|
||||
state.history.append(cmd_observation)
|
||||
|
||||
assert stuck_detector.is_stuck() is False
|
||||
|
||||
def test_is_stuck_repeating_action_observation(
|
||||
self, stuck_detector: StuckDetector, event_stream: EventStream
|
||||
):
|
||||
def test_is_stuck_repeating_action_observation(self, stuck_detector: StuckDetector):
|
||||
state = stuck_detector.state
|
||||
message_action = MessageAction(content='Done', wait_for_response=False)
|
||||
message_action._source = EventSource.USER
|
||||
|
||||
@ -130,135 +123,125 @@ class TestStuckDetector:
|
||||
hello_observation = NullObservation('')
|
||||
|
||||
# 2 events
|
||||
event_stream.add_event(hello_action, EventSource.USER)
|
||||
event_stream.add_event(hello_observation, EventSource.ENVIRONMENT)
|
||||
state.history.append(hello_action)
|
||||
state.history.append(hello_observation)
|
||||
|
||||
cmd_action_1 = CmdRunAction(command='ls')
|
||||
event_stream.add_event(cmd_action_1, EventSource.AGENT)
|
||||
cmd_observation_1 = CmdOutputObservation(
|
||||
content='', command='ls', command_id=cmd_action_1._id
|
||||
)
|
||||
cmd_action_1._id = 1
|
||||
state.history.append(cmd_action_1)
|
||||
cmd_observation_1 = CmdOutputObservation(content='', command='ls', command_id=1)
|
||||
cmd_observation_1._cause = cmd_action_1._id
|
||||
event_stream.add_event(cmd_observation_1, EventSource.ENVIRONMENT)
|
||||
state.history.append(cmd_observation_1)
|
||||
# 4 events
|
||||
|
||||
cmd_action_2 = CmdRunAction(command='ls')
|
||||
event_stream.add_event(cmd_action_2, EventSource.AGENT)
|
||||
cmd_observation_2 = CmdOutputObservation(
|
||||
content='', command='ls', command_id=cmd_action_2._id
|
||||
)
|
||||
cmd_action_2._id = 2
|
||||
state.history.append(cmd_action_2)
|
||||
cmd_observation_2 = CmdOutputObservation(content='', command='ls', command_id=2)
|
||||
cmd_observation_2._cause = cmd_action_2._id
|
||||
event_stream.add_event(cmd_observation_2, EventSource.ENVIRONMENT)
|
||||
state.history.append(cmd_observation_2)
|
||||
# 6 events
|
||||
|
||||
# random user message just because we can
|
||||
message_null_observation = NullObservation(content='')
|
||||
event_stream.add_event(message_action, EventSource.USER)
|
||||
event_stream.add_event(message_null_observation, EventSource.ENVIRONMENT)
|
||||
state.history.append(message_action)
|
||||
state.history.append(message_null_observation)
|
||||
# 8 events
|
||||
|
||||
assert stuck_detector.is_stuck() is False
|
||||
assert stuck_detector.state.almost_stuck == 2
|
||||
|
||||
cmd_action_3 = CmdRunAction(command='ls')
|
||||
event_stream.add_event(cmd_action_3, EventSource.AGENT)
|
||||
cmd_observation_3 = CmdOutputObservation(
|
||||
content='', command='ls', command_id=cmd_action_3._id
|
||||
)
|
||||
cmd_action_3._id = 3
|
||||
state.history.append(cmd_action_3)
|
||||
cmd_observation_3 = CmdOutputObservation(content='', command='ls', command_id=3)
|
||||
cmd_observation_3._cause = cmd_action_3._id
|
||||
event_stream.add_event(cmd_observation_3, EventSource.ENVIRONMENT)
|
||||
state.history.append(cmd_observation_3)
|
||||
# 10 events
|
||||
|
||||
assert len(collect_events(event_stream)) == 10
|
||||
assert len(list(stuck_detector.state.history.get_events())) == 8
|
||||
assert len(state.history) == 10
|
||||
assert (
|
||||
len(
|
||||
get_pairs_from_events(
|
||||
stuck_detector.state.history.get_events_as_list(
|
||||
include_delegates=True
|
||||
)
|
||||
)
|
||||
)
|
||||
== 5
|
||||
)
|
||||
len(state.history) == 10
|
||||
) # Adjusted since history is a list and the controller is not running
|
||||
|
||||
# FIXME are we still testing this without this test?
|
||||
# assert (
|
||||
# len(
|
||||
# get_pairs_from_events(state.history)
|
||||
# )
|
||||
# == 5
|
||||
# )
|
||||
|
||||
assert stuck_detector.is_stuck() is False
|
||||
assert stuck_detector.state.almost_stuck == 1
|
||||
|
||||
cmd_action_4 = CmdRunAction(command='ls')
|
||||
event_stream.add_event(cmd_action_4, EventSource.AGENT)
|
||||
cmd_observation_4 = CmdOutputObservation(
|
||||
content='', command='ls', command_id=cmd_action_4._id
|
||||
)
|
||||
cmd_action_4._id = 4
|
||||
state.history.append(cmd_action_4)
|
||||
cmd_observation_4 = CmdOutputObservation(content='', command='ls', command_id=4)
|
||||
cmd_observation_4._cause = cmd_action_4._id
|
||||
event_stream.add_event(cmd_observation_4, EventSource.ENVIRONMENT)
|
||||
state.history.append(cmd_observation_4)
|
||||
# 12 events
|
||||
|
||||
assert len(collect_events(event_stream)) == 12
|
||||
assert len(list(stuck_detector.state.history.get_events())) == 10
|
||||
assert (
|
||||
len(
|
||||
get_pairs_from_events(
|
||||
stuck_detector.state.history.get_events_as_list(
|
||||
include_delegates=True
|
||||
)
|
||||
)
|
||||
)
|
||||
== 6
|
||||
)
|
||||
assert len(state.history) == 12
|
||||
# assert (
|
||||
# len(
|
||||
# get_pairs_from_events(state.history)
|
||||
# )
|
||||
# == 6
|
||||
# )
|
||||
|
||||
with patch('logging.Logger.warning') as mock_warning:
|
||||
assert stuck_detector.is_stuck() is True
|
||||
assert stuck_detector.state.almost_stuck == 0
|
||||
mock_warning.assert_called_once_with('Action, Observation loop detected')
|
||||
|
||||
def test_is_stuck_repeating_action_error(
|
||||
self, stuck_detector: StuckDetector, event_stream: EventStream
|
||||
):
|
||||
def test_is_stuck_repeating_action_error(self, stuck_detector: StuckDetector):
|
||||
state = stuck_detector.state
|
||||
# (action, error_observation), not necessarily the same error
|
||||
message_action = MessageAction(content='Done', wait_for_response=False)
|
||||
message_action._source = EventSource.USER
|
||||
|
||||
hello_action = MessageAction(content='Hello', wait_for_response=False)
|
||||
hello_observation = NullObservation(content='')
|
||||
event_stream.add_event(hello_action, EventSource.USER)
|
||||
hello_observation._cause = hello_action._id
|
||||
event_stream.add_event(hello_observation, EventSource.ENVIRONMENT)
|
||||
state.history.append(hello_action)
|
||||
# hello_observation._cause = hello_action._id
|
||||
state.history.append(hello_observation)
|
||||
# 2 events
|
||||
|
||||
cmd_action_1 = CmdRunAction(command='invalid_command')
|
||||
event_stream.add_event(cmd_action_1, EventSource.AGENT)
|
||||
state.history.append(cmd_action_1)
|
||||
error_observation_1 = ErrorObservation(content='Command not found')
|
||||
error_observation_1._cause = cmd_action_1._id
|
||||
event_stream.add_event(error_observation_1, EventSource.ENVIRONMENT)
|
||||
# error_observation_1._cause = cmd_action_1._id
|
||||
state.history.append(error_observation_1)
|
||||
# 4 events
|
||||
|
||||
cmd_action_2 = CmdRunAction(command='invalid_command')
|
||||
event_stream.add_event(cmd_action_2, EventSource.AGENT)
|
||||
state.history.append(cmd_action_2)
|
||||
error_observation_2 = ErrorObservation(
|
||||
content='Command still not found or another error'
|
||||
)
|
||||
error_observation_2._cause = cmd_action_2._id
|
||||
event_stream.add_event(error_observation_2, EventSource.ENVIRONMENT)
|
||||
# error_observation_2._cause = cmd_action_2._id
|
||||
state.history.append(error_observation_2)
|
||||
# 6 events
|
||||
|
||||
message_null_observation = NullObservation(content='')
|
||||
event_stream.add_event(message_action, EventSource.USER)
|
||||
event_stream.add_event(message_null_observation, EventSource.ENVIRONMENT)
|
||||
state.history.append(message_action)
|
||||
state.history.append(message_null_observation)
|
||||
# 8 events
|
||||
|
||||
cmd_action_3 = CmdRunAction(command='invalid_command')
|
||||
event_stream.add_event(cmd_action_3, EventSource.AGENT)
|
||||
state.history.append(cmd_action_3)
|
||||
error_observation_3 = ErrorObservation(content='Different error')
|
||||
error_observation_3._cause = cmd_action_3._id
|
||||
event_stream.add_event(error_observation_3, EventSource.ENVIRONMENT)
|
||||
# error_observation_3._cause = cmd_action_3._id
|
||||
state.history.append(error_observation_3)
|
||||
# 10 events
|
||||
|
||||
cmd_action_4 = CmdRunAction(command='invalid_command')
|
||||
event_stream.add_event(cmd_action_4, EventSource.AGENT)
|
||||
state.history.append(cmd_action_4)
|
||||
error_observation_4 = ErrorObservation(content='Command not found')
|
||||
error_observation_4._cause = cmd_action_4._id
|
||||
event_stream.add_event(error_observation_4, EventSource.ENVIRONMENT)
|
||||
# error_observation_4._cause = cmd_action_4._id
|
||||
state.history.append(error_observation_4)
|
||||
# 12 events
|
||||
|
||||
with patch('logging.Logger.warning') as mock_warning:
|
||||
@ -267,11 +250,10 @@ class TestStuckDetector:
|
||||
'Action, ErrorObservation loop detected'
|
||||
)
|
||||
|
||||
def test_is_stuck_invalid_syntax_error(
|
||||
self, stuck_detector: StuckDetector, event_stream: EventStream
|
||||
):
|
||||
def test_is_stuck_invalid_syntax_error(self, stuck_detector: StuckDetector):
|
||||
state = stuck_detector.state
|
||||
self._impl_syntax_error_events(
|
||||
event_stream,
|
||||
state,
|
||||
error_message='SyntaxError: invalid syntax. Perhaps you forgot a comma?',
|
||||
random_line=False,
|
||||
)
|
||||
@ -280,10 +262,11 @@ class TestStuckDetector:
|
||||
assert stuck_detector.is_stuck() is True
|
||||
|
||||
def test_is_not_stuck_invalid_syntax_error_random_lines(
|
||||
self, stuck_detector: StuckDetector, event_stream: EventStream
|
||||
self, stuck_detector: StuckDetector
|
||||
):
|
||||
state = stuck_detector.state
|
||||
self._impl_syntax_error_events(
|
||||
event_stream,
|
||||
state,
|
||||
error_message='SyntaxError: invalid syntax. Perhaps you forgot a comma?',
|
||||
random_line=True,
|
||||
)
|
||||
@ -292,10 +275,11 @@ class TestStuckDetector:
|
||||
assert stuck_detector.is_stuck() is False
|
||||
|
||||
def test_is_not_stuck_invalid_syntax_error_only_three_incidents(
|
||||
self, stuck_detector: StuckDetector, event_stream: EventStream
|
||||
self, stuck_detector: StuckDetector
|
||||
):
|
||||
state = stuck_detector.state
|
||||
self._impl_syntax_error_events(
|
||||
event_stream,
|
||||
state,
|
||||
error_message='SyntaxError: invalid syntax. Perhaps you forgot a comma?',
|
||||
random_line=True,
|
||||
incidents=3,
|
||||
@ -304,11 +288,10 @@ class TestStuckDetector:
|
||||
with patch('logging.Logger.warning'):
|
||||
assert stuck_detector.is_stuck() is False
|
||||
|
||||
def test_is_stuck_incomplete_input_error(
|
||||
self, stuck_detector: StuckDetector, event_stream: EventStream
|
||||
):
|
||||
def test_is_stuck_incomplete_input_error(self, stuck_detector: StuckDetector):
|
||||
state = stuck_detector.state
|
||||
self._impl_syntax_error_events(
|
||||
event_stream,
|
||||
state,
|
||||
error_message='SyntaxError: incomplete input',
|
||||
random_line=False,
|
||||
)
|
||||
@ -316,11 +299,10 @@ class TestStuckDetector:
|
||||
with patch('logging.Logger.warning'):
|
||||
assert stuck_detector.is_stuck() is True
|
||||
|
||||
def test_is_not_stuck_incomplete_input_error(
|
||||
self, stuck_detector: StuckDetector, event_stream: EventStream
|
||||
):
|
||||
def test_is_not_stuck_incomplete_input_error(self, stuck_detector: StuckDetector):
|
||||
state = stuck_detector.state
|
||||
self._impl_syntax_error_events(
|
||||
event_stream,
|
||||
state,
|
||||
error_message='SyntaxError: incomplete input',
|
||||
random_line=True,
|
||||
)
|
||||
@ -329,239 +311,241 @@ class TestStuckDetector:
|
||||
assert stuck_detector.is_stuck() is False
|
||||
|
||||
def test_is_not_stuck_ipython_unterminated_string_error_random_lines(
|
||||
self, stuck_detector: StuckDetector, event_stream: EventStream
|
||||
self, stuck_detector: StuckDetector
|
||||
):
|
||||
self._impl_unterminated_string_error_events(event_stream, random_line=True)
|
||||
state = stuck_detector.state
|
||||
self._impl_unterminated_string_error_events(state, random_line=True)
|
||||
|
||||
with patch('logging.Logger.warning'):
|
||||
assert stuck_detector.is_stuck() is False
|
||||
|
||||
def test_is_not_stuck_ipython_unterminated_string_error_only_three_incidents(
|
||||
self, stuck_detector: StuckDetector, event_stream: EventStream
|
||||
self, stuck_detector: StuckDetector
|
||||
):
|
||||
state = stuck_detector.state
|
||||
self._impl_unterminated_string_error_events(
|
||||
event_stream, random_line=False, incidents=3
|
||||
state, random_line=False, incidents=3
|
||||
)
|
||||
|
||||
with patch('logging.Logger.warning'):
|
||||
assert stuck_detector.is_stuck() is False
|
||||
|
||||
def test_is_stuck_ipython_unterminated_string_error(
|
||||
self, stuck_detector: StuckDetector, event_stream: EventStream
|
||||
self, stuck_detector: StuckDetector
|
||||
):
|
||||
self._impl_unterminated_string_error_events(event_stream, random_line=False)
|
||||
state = stuck_detector.state
|
||||
self._impl_unterminated_string_error_events(state, random_line=False)
|
||||
|
||||
with patch('logging.Logger.warning'):
|
||||
assert stuck_detector.is_stuck() is True
|
||||
|
||||
def test_is_not_stuck_ipython_syntax_error_not_at_end(
|
||||
self, stuck_detector: StuckDetector, event_stream: EventStream
|
||||
self, stuck_detector: StuckDetector
|
||||
):
|
||||
state = stuck_detector.state
|
||||
# this test is to make sure we don't get false positives
|
||||
# since the "at line x" is changing in between!
|
||||
ipython_action_1 = IPythonRunCellAction(code='print("hello')
|
||||
event_stream.add_event(ipython_action_1, EventSource.AGENT)
|
||||
state.history.append(ipython_action_1)
|
||||
ipython_observation_1 = IPythonRunCellObservation(
|
||||
content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 1)\nThis is some additional output',
|
||||
code='print("hello',
|
||||
)
|
||||
ipython_observation_1._cause = ipython_action_1._id
|
||||
event_stream.add_event(ipython_observation_1, EventSource.ENVIRONMENT)
|
||||
# ipython_observation_1._cause = ipython_action_1._id
|
||||
state.history.append(ipython_observation_1)
|
||||
|
||||
ipython_action_2 = IPythonRunCellAction(code='print("hello')
|
||||
event_stream.add_event(ipython_action_2, EventSource.AGENT)
|
||||
state.history.append(ipython_action_2)
|
||||
ipython_observation_2 = IPythonRunCellObservation(
|
||||
content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 1)\nToo much output here on and on',
|
||||
code='print("hello',
|
||||
)
|
||||
ipython_observation_2._cause = ipython_action_2._id
|
||||
event_stream.add_event(ipython_observation_2, EventSource.ENVIRONMENT)
|
||||
# ipython_observation_2._cause = ipython_action_2._id
|
||||
state.history.append(ipython_observation_2)
|
||||
|
||||
ipython_action_3 = IPythonRunCellAction(code='print("hello')
|
||||
event_stream.add_event(ipython_action_3, EventSource.AGENT)
|
||||
state.history.append(ipython_action_3)
|
||||
ipython_observation_3 = IPythonRunCellObservation(
|
||||
content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 3)\nEnough',
|
||||
code='print("hello',
|
||||
)
|
||||
ipython_observation_3._cause = ipython_action_3._id
|
||||
event_stream.add_event(ipython_observation_3, EventSource.ENVIRONMENT)
|
||||
# ipython_observation_3._cause = ipython_action_3._id
|
||||
state.history.append(ipython_observation_3)
|
||||
|
||||
ipython_action_4 = IPythonRunCellAction(code='print("hello')
|
||||
event_stream.add_event(ipython_action_4, EventSource.AGENT)
|
||||
state.history.append(ipython_action_4)
|
||||
ipython_observation_4 = IPythonRunCellObservation(
|
||||
content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 2)\nLast line of output',
|
||||
code='print("hello',
|
||||
)
|
||||
ipython_observation_4._cause = ipython_action_4._id
|
||||
event_stream.add_event(ipython_observation_4, EventSource.ENVIRONMENT)
|
||||
# ipython_observation_4._cause = ipython_action_4._id
|
||||
state.history.append(ipython_observation_4)
|
||||
|
||||
with patch('logging.Logger.warning') as mock_warning:
|
||||
assert stuck_detector.is_stuck() is False
|
||||
mock_warning.assert_not_called()
|
||||
|
||||
def test_is_stuck_repeating_action_observation_pattern(
|
||||
self, stuck_detector: StuckDetector, event_stream: EventStream
|
||||
self, stuck_detector: StuckDetector
|
||||
):
|
||||
state = stuck_detector.state
|
||||
message_action = MessageAction(content='Come on', wait_for_response=False)
|
||||
message_action._source = EventSource.USER
|
||||
event_stream.add_event(message_action, EventSource.USER)
|
||||
state.history.append(message_action)
|
||||
message_observation = NullObservation(content='')
|
||||
event_stream.add_event(message_observation, EventSource.ENVIRONMENT)
|
||||
state.history.append(message_observation)
|
||||
|
||||
cmd_action_1 = CmdRunAction(command='ls')
|
||||
event_stream.add_event(cmd_action_1, EventSource.AGENT)
|
||||
state.history.append(cmd_action_1)
|
||||
cmd_observation_1 = CmdOutputObservation(
|
||||
command_id=1, command='ls', content='file1.txt\nfile2.txt'
|
||||
)
|
||||
cmd_observation_1._cause = cmd_action_1._id
|
||||
event_stream.add_event(cmd_observation_1, EventSource.ENVIRONMENT)
|
||||
# cmd_observation_1._cause = cmd_action_1._id
|
||||
state.history.append(cmd_observation_1)
|
||||
|
||||
read_action_1 = FileReadAction(path='file1.txt')
|
||||
event_stream.add_event(read_action_1, EventSource.AGENT)
|
||||
state.history.append(read_action_1)
|
||||
read_observation_1 = FileReadObservation(
|
||||
content='File content', path='file1.txt'
|
||||
)
|
||||
read_observation_1._cause = read_action_1._id
|
||||
event_stream.add_event(read_observation_1, EventSource.ENVIRONMENT)
|
||||
# read_observation_1._cause = read_action_1._id
|
||||
state.history.append(read_observation_1)
|
||||
|
||||
cmd_action_2 = CmdRunAction(command='ls')
|
||||
event_stream.add_event(cmd_action_2, EventSource.AGENT)
|
||||
state.history.append(cmd_action_2)
|
||||
cmd_observation_2 = CmdOutputObservation(
|
||||
command_id=2, command='ls', content='file1.txt\nfile2.txt'
|
||||
)
|
||||
cmd_observation_2._cause = cmd_action_2._id
|
||||
event_stream.add_event(cmd_observation_2, EventSource.ENVIRONMENT)
|
||||
# cmd_observation_2._cause = cmd_action_2._id
|
||||
state.history.append(cmd_observation_2)
|
||||
|
||||
read_action_2 = FileReadAction(path='file1.txt')
|
||||
event_stream.add_event(read_action_2, EventSource.AGENT)
|
||||
state.history.append(read_action_2)
|
||||
read_observation_2 = FileReadObservation(
|
||||
content='File content', path='file1.txt'
|
||||
)
|
||||
read_observation_2._cause = read_action_2._id
|
||||
event_stream.add_event(read_observation_2, EventSource.ENVIRONMENT)
|
||||
# read_observation_2._cause = read_action_2._id
|
||||
state.history.append(read_observation_2)
|
||||
|
||||
message_action = MessageAction(content='Come on', wait_for_response=False)
|
||||
event_stream.add_event(message_action, EventSource.USER)
|
||||
message_action._source = EventSource.USER
|
||||
state.history.append(message_action)
|
||||
|
||||
message_null_observation = NullObservation(content='')
|
||||
event_stream.add_event(message_null_observation, EventSource.ENVIRONMENT)
|
||||
state.history.append(message_null_observation)
|
||||
|
||||
cmd_action_3 = CmdRunAction(command='ls')
|
||||
event_stream.add_event(cmd_action_3, EventSource.AGENT)
|
||||
state.history.append(cmd_action_3)
|
||||
cmd_observation_3 = CmdOutputObservation(
|
||||
command_id=3, command='ls', content='file1.txt\nfile2.txt'
|
||||
)
|
||||
cmd_observation_3._cause = cmd_action_3._id
|
||||
event_stream.add_event(cmd_observation_3, EventSource.ENVIRONMENT)
|
||||
# cmd_observation_3._cause = cmd_action_3._id
|
||||
state.history.append(cmd_observation_3)
|
||||
|
||||
read_action_3 = FileReadAction(path='file1.txt')
|
||||
event_stream.add_event(read_action_3, EventSource.AGENT)
|
||||
state.history.append(read_action_3)
|
||||
read_observation_3 = FileReadObservation(
|
||||
content='File content', path='file1.txt'
|
||||
)
|
||||
read_observation_3._cause = read_action_3._id
|
||||
event_stream.add_event(read_observation_3, EventSource.ENVIRONMENT)
|
||||
# read_observation_3._cause = read_action_3._id
|
||||
state.history.append(read_observation_3)
|
||||
|
||||
with patch('logging.Logger.warning') as mock_warning:
|
||||
assert stuck_detector.is_stuck() is True
|
||||
mock_warning.assert_called_once_with('Action, Observation pattern detected')
|
||||
|
||||
def test_is_stuck_not_stuck(
|
||||
self, stuck_detector: StuckDetector, event_stream: EventStream
|
||||
):
|
||||
def test_is_stuck_not_stuck(self, stuck_detector: StuckDetector):
|
||||
state = stuck_detector.state
|
||||
message_action = MessageAction(content='Done', wait_for_response=False)
|
||||
message_action._source = EventSource.USER
|
||||
|
||||
hello_action = MessageAction(content='Hello', wait_for_response=False)
|
||||
event_stream.add_event(hello_action, EventSource.USER)
|
||||
state.history.append(hello_action)
|
||||
hello_observation = NullObservation(content='')
|
||||
hello_observation._cause = hello_action._id
|
||||
event_stream.add_event(hello_observation, EventSource.ENVIRONMENT)
|
||||
# hello_observation._cause = hello_action._id
|
||||
state.history.append(hello_observation)
|
||||
|
||||
cmd_action_1 = CmdRunAction(command='ls')
|
||||
event_stream.add_event(cmd_action_1, EventSource.AGENT)
|
||||
state.history.append(cmd_action_1)
|
||||
cmd_observation_1 = CmdOutputObservation(
|
||||
command_id=cmd_action_1.id, command='ls', content='file1.txt\nfile2.txt'
|
||||
)
|
||||
cmd_observation_1._cause = cmd_action_1._id
|
||||
event_stream.add_event(cmd_observation_1, EventSource.ENVIRONMENT)
|
||||
# cmd_observation_1._cause = cmd_action_1._id
|
||||
state.history.append(cmd_observation_1)
|
||||
|
||||
read_action_1 = FileReadAction(path='file1.txt')
|
||||
event_stream.add_event(read_action_1, EventSource.AGENT)
|
||||
state.history.append(read_action_1)
|
||||
read_observation_1 = FileReadObservation(
|
||||
content='File content', path='file1.txt'
|
||||
)
|
||||
read_observation_1._cause = read_action_1._id
|
||||
event_stream.add_event(read_observation_1, EventSource.ENVIRONMENT)
|
||||
# read_observation_1._cause = read_action_1._id
|
||||
state.history.append(read_observation_1)
|
||||
|
||||
cmd_action_2 = CmdRunAction(command='pwd')
|
||||
event_stream.add_event(cmd_action_2, EventSource.AGENT)
|
||||
state.history.append(cmd_action_2)
|
||||
cmd_observation_2 = CmdOutputObservation(
|
||||
command_id=2, command='pwd', content='/home/user'
|
||||
)
|
||||
cmd_observation_2._cause = cmd_action_2._id
|
||||
event_stream.add_event(cmd_observation_2, EventSource.ENVIRONMENT)
|
||||
# cmd_observation_2._cause = cmd_action_2._id
|
||||
state.history.append(cmd_observation_2)
|
||||
|
||||
read_action_2 = FileReadAction(path='file2.txt')
|
||||
event_stream.add_event(read_action_2, EventSource.AGENT)
|
||||
state.history.append(read_action_2)
|
||||
read_observation_2 = FileReadObservation(
|
||||
content='Another file content', path='file2.txt'
|
||||
)
|
||||
read_observation_2._cause = read_action_2._id
|
||||
event_stream.add_event(read_observation_2, EventSource.ENVIRONMENT)
|
||||
# read_observation_2._cause = read_action_2._id
|
||||
state.history.append(read_observation_2)
|
||||
|
||||
message_null_observation = NullObservation(content='')
|
||||
event_stream.add_event(message_action, EventSource.USER)
|
||||
event_stream.add_event(message_null_observation, EventSource.ENVIRONMENT)
|
||||
state.history.append(message_action)
|
||||
state.history.append(message_null_observation)
|
||||
|
||||
cmd_action_3 = CmdRunAction(command='pwd')
|
||||
event_stream.add_event(cmd_action_3, EventSource.AGENT)
|
||||
state.history.append(cmd_action_3)
|
||||
cmd_observation_3 = CmdOutputObservation(
|
||||
command_id=cmd_action_3.id, command='pwd', content='/home/user'
|
||||
)
|
||||
cmd_observation_3._cause = cmd_action_3._id
|
||||
event_stream.add_event(cmd_observation_3, EventSource.ENVIRONMENT)
|
||||
# cmd_observation_3._cause = cmd_action_3._id
|
||||
state.history.append(cmd_observation_3)
|
||||
|
||||
read_action_3 = FileReadAction(path='file2.txt')
|
||||
event_stream.add_event(read_action_3, EventSource.AGENT)
|
||||
state.history.append(read_action_3)
|
||||
read_observation_3 = FileReadObservation(
|
||||
content='Another file content', path='file2.txt'
|
||||
)
|
||||
read_observation_3._cause = read_action_3._id
|
||||
event_stream.add_event(read_observation_3, EventSource.ENVIRONMENT)
|
||||
# read_observation_3._cause = read_action_3._id
|
||||
state.history.append(read_observation_3)
|
||||
|
||||
assert stuck_detector.is_stuck() is False
|
||||
|
||||
def test_is_stuck_monologue(self, stuck_detector, event_stream):
|
||||
# Add events to the event stream
|
||||
def test_is_stuck_monologue(self, stuck_detector):
|
||||
state = stuck_detector.state
|
||||
# Add events to the history list directly
|
||||
message_action_1 = MessageAction(content='Hi there!')
|
||||
event_stream.add_event(message_action_1, EventSource.USER)
|
||||
message_action_1._source = EventSource.USER
|
||||
|
||||
state.history.append(message_action_1)
|
||||
message_action_2 = MessageAction(content='Hi there!')
|
||||
event_stream.add_event(message_action_2, EventSource.AGENT)
|
||||
message_action_2._source = EventSource.AGENT
|
||||
|
||||
state.history.append(message_action_2)
|
||||
message_action_3 = MessageAction(content='How are you?')
|
||||
event_stream.add_event(message_action_3, EventSource.USER)
|
||||
message_action_3._source = EventSource.USER
|
||||
state.history.append(message_action_3)
|
||||
|
||||
cmd_kill_action = CmdRunAction(
|
||||
command='echo 42', thought="I'm not stuck, he's stuck"
|
||||
)
|
||||
event_stream.add_event(cmd_kill_action, EventSource.AGENT)
|
||||
state.history.append(cmd_kill_action)
|
||||
|
||||
message_action_4 = MessageAction(content="I'm doing well, thanks for asking.")
|
||||
event_stream.add_event(message_action_4, EventSource.AGENT)
|
||||
message_action_4._source = EventSource.AGENT
|
||||
|
||||
state.history.append(message_action_4)
|
||||
message_action_5 = MessageAction(content="I'm doing well, thanks for asking.")
|
||||
event_stream.add_event(message_action_5, EventSource.AGENT)
|
||||
message_action_5._source = EventSource.AGENT
|
||||
|
||||
state.history.append(message_action_5)
|
||||
message_action_6 = MessageAction(content="I'm doing well, thanks for asking.")
|
||||
event_stream.add_event(message_action_6, EventSource.AGENT)
|
||||
message_action_6._source = EventSource.AGENT
|
||||
state.history.append(message_action_6)
|
||||
|
||||
assert stuck_detector.is_stuck()
|
||||
|
||||
@ -572,16 +556,15 @@ class TestStuckDetector:
|
||||
command='storybook',
|
||||
exit_code=0,
|
||||
)
|
||||
cmd_output_observation._cause = cmd_kill_action._id
|
||||
event_stream.add_event(cmd_output_observation, EventSource.ENVIRONMENT)
|
||||
# cmd_output_observation._cause = cmd_kill_action._id
|
||||
state.history.append(cmd_output_observation)
|
||||
|
||||
message_action_7 = MessageAction(content="I'm doing well, thanks for asking.")
|
||||
event_stream.add_event(message_action_7, EventSource.AGENT)
|
||||
message_action_7._source = EventSource.AGENT
|
||||
|
||||
state.history.append(message_action_7)
|
||||
message_action_8 = MessageAction(content="I'm doing well, thanks for asking.")
|
||||
event_stream.add_event(message_action_8, EventSource.AGENT)
|
||||
message_action_8._source = EventSource.AGENT
|
||||
state.history.append(message_action_8)
|
||||
|
||||
with patch('logging.Logger.warning'):
|
||||
assert not stuck_detector.is_stuck()
|
||||
@ -596,7 +579,6 @@ class TestAgentController:
|
||||
)
|
||||
controller.delegate = None
|
||||
controller.state = Mock()
|
||||
controller.state.history = ShortTermHistory()
|
||||
return controller
|
||||
|
||||
def test_is_stuck_delegate_stuck(self, controller: AgentController):
|
||||
|
||||
@ -10,10 +10,8 @@ from openhands.agenthub.micro.registry import all_microagents
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config import AgentConfig
|
||||
from openhands.events import EventSource
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.memory.history import ShortTermHistory
|
||||
from openhands.storage import get_file_store
|
||||
|
||||
|
||||
@ -74,10 +72,10 @@ def test_coder_agent_with_summary(event_stream: EventStream, agent_configs: dict
|
||||
)
|
||||
assert coder_agent is not None
|
||||
|
||||
# give it some history
|
||||
task = 'This is a dummy task'
|
||||
history = ShortTermHistory()
|
||||
history.set_event_stream(event_stream)
|
||||
event_stream.add_event(MessageAction(content=task), EventSource.USER)
|
||||
history = list()
|
||||
history.append(MessageAction(content=task))
|
||||
|
||||
summary = 'This is a dummy summary about this repo'
|
||||
state = State(history=history, inputs={'summary': summary})
|
||||
@ -119,10 +117,10 @@ def test_coder_agent_without_summary(event_stream: EventStream, agent_configs: d
|
||||
)
|
||||
assert coder_agent is not None
|
||||
|
||||
# give it some history
|
||||
task = 'This is a dummy task'
|
||||
history = ShortTermHistory()
|
||||
history.set_event_stream(event_stream)
|
||||
event_stream.add_event(MessageAction(content=task), EventSource.USER)
|
||||
history = list()
|
||||
history.append(MessageAction(content=task))
|
||||
|
||||
# set state without codebase summary
|
||||
state = State(history=history)
|
||||
|
||||
@ -1,14 +1,12 @@
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.agenthub.codeact_agent.codeact_agent import CodeActAgent
|
||||
from openhands.core.config import AgentConfig, LLMConfig
|
||||
from openhands.events import EventSource, EventStream
|
||||
from openhands.events.action import CmdRunAction, MessageAction
|
||||
from openhands.events.observation import CmdOutputObservation
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.storage import get_file_store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -19,12 +17,6 @@ def mock_llm():
|
||||
return llm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_event_stream(tmp_path):
|
||||
file_store = get_file_store('local', str(tmp_path))
|
||||
return EventStream('test_session', file_store)
|
||||
|
||||
|
||||
@pytest.fixture(params=[False, True])
|
||||
def codeact_agent(mock_llm, request):
|
||||
config = AgentConfig()
|
||||
@ -57,17 +49,28 @@ def response_mock(content: str):
|
||||
return MockModelResponse(content)
|
||||
|
||||
|
||||
def test_get_messages_with_reminder(codeact_agent, mock_event_stream):
|
||||
# Add some events to the stream
|
||||
mock_event_stream.add_event(MessageAction('Initial user message'), EventSource.USER)
|
||||
mock_event_stream.add_event(MessageAction('Sure!'), EventSource.AGENT)
|
||||
mock_event_stream.add_event(MessageAction('Hello, agent!'), EventSource.USER)
|
||||
mock_event_stream.add_event(MessageAction('Hello, user!'), EventSource.AGENT)
|
||||
mock_event_stream.add_event(MessageAction('Laaaaaaaast!'), EventSource.USER)
|
||||
def test_get_messages_with_reminder(codeact_agent: CodeActAgent):
|
||||
# Add some events to history
|
||||
history = list()
|
||||
message_action_1 = MessageAction('Initial user message')
|
||||
message_action_1._source = 'user'
|
||||
history.append(message_action_1)
|
||||
message_action_2 = MessageAction('Sure!')
|
||||
message_action_2._source = 'assistant'
|
||||
history.append(message_action_2)
|
||||
message_action_3 = MessageAction('Hello, agent!')
|
||||
message_action_3._source = 'user'
|
||||
history.append(message_action_3)
|
||||
message_action_4 = MessageAction('Hello, user!')
|
||||
message_action_4._source = 'assistant'
|
||||
history.append(message_action_4)
|
||||
message_action_5 = MessageAction('Laaaaaaaast!')
|
||||
message_action_5._source = 'user'
|
||||
history.append(message_action_5)
|
||||
|
||||
codeact_agent.reset()
|
||||
messages = codeact_agent._get_messages(
|
||||
Mock(history=mock_event_stream, max_iterations=5, iteration=0)
|
||||
Mock(history=history, max_iterations=5, iteration=0)
|
||||
)
|
||||
|
||||
assert (
|
||||
@ -102,19 +105,20 @@ def test_get_messages_with_reminder(codeact_agent, mock_event_stream):
|
||||
)
|
||||
|
||||
|
||||
def test_get_messages_prompt_caching(codeact_agent, mock_event_stream):
|
||||
def test_get_messages_prompt_caching(codeact_agent: CodeActAgent):
|
||||
history = list()
|
||||
# Add multiple user and agent messages
|
||||
for i in range(15):
|
||||
mock_event_stream.add_event(
|
||||
MessageAction(f'User message {i}'), EventSource.USER
|
||||
)
|
||||
mock_event_stream.add_event(
|
||||
MessageAction(f'Agent message {i}'), EventSource.AGENT
|
||||
)
|
||||
message_action_user = MessageAction(f'User message {i}')
|
||||
message_action_user._source = 'user'
|
||||
history.append(message_action_user)
|
||||
message_action_agent = MessageAction(f'Agent message {i}')
|
||||
message_action_agent._source = 'assistant'
|
||||
history.append(message_action_agent)
|
||||
|
||||
codeact_agent.reset()
|
||||
messages = codeact_agent._get_messages(
|
||||
Mock(history=mock_event_stream, max_iterations=10, iteration=5)
|
||||
Mock(history=history, max_iterations=10, iteration=5)
|
||||
)
|
||||
|
||||
# Check that only the last two user messages have cache_prompt=True
|
||||
@ -136,18 +140,23 @@ def test_get_messages_prompt_caching(codeact_agent, mock_event_stream):
|
||||
assert cached_user_messages[3].content[0].text.startswith('User message 1')
|
||||
|
||||
|
||||
def test_get_messages_with_cmd_action(codeact_agent, mock_event_stream):
|
||||
def test_get_messages_with_cmd_action(codeact_agent: CodeActAgent):
|
||||
if codeact_agent.config.function_calling:
|
||||
pytest.skip('Skipping this test for function calling')
|
||||
|
||||
history = list()
|
||||
|
||||
# Add a mix of actions and observations
|
||||
message_action_1 = MessageAction(
|
||||
"Let's list the contents of the current directory."
|
||||
)
|
||||
mock_event_stream.add_event(message_action_1, EventSource.USER)
|
||||
message_action_1._source = 'user'
|
||||
history.append(message_action_1)
|
||||
|
||||
cmd_action_1 = CmdRunAction('ls -l', thought='List files in current directory')
|
||||
mock_event_stream.add_event(cmd_action_1, EventSource.AGENT)
|
||||
cmd_action_1._source = 'agent'
|
||||
cmd_action_1._id = 'cmd_1'
|
||||
history.append(cmd_action_1)
|
||||
|
||||
cmd_observation_1 = CmdOutputObservation(
|
||||
content='total 0\n-rw-r--r-- 1 user group 0 Jan 1 00:00 file1.txt\n-rw-r--r-- 1 user group 0 Jan 1 00:00 file2.txt',
|
||||
@ -155,13 +164,17 @@ def test_get_messages_with_cmd_action(codeact_agent, mock_event_stream):
|
||||
command='ls -l',
|
||||
exit_code=0,
|
||||
)
|
||||
mock_event_stream.add_event(cmd_observation_1, EventSource.ENVIRONMENT)
|
||||
cmd_observation_1._source = 'user'
|
||||
history.append(cmd_observation_1)
|
||||
|
||||
message_action_2 = MessageAction("Now, let's create a new directory.")
|
||||
mock_event_stream.add_event(message_action_2, EventSource.AGENT)
|
||||
message_action_2._source = 'agent'
|
||||
history.append(message_action_2)
|
||||
|
||||
cmd_action_2 = CmdRunAction('mkdir new_directory', thought='Create a new directory')
|
||||
mock_event_stream.add_event(cmd_action_2, EventSource.AGENT)
|
||||
cmd_action_2._source = 'agent'
|
||||
cmd_action_2._id = 'cmd_2'
|
||||
history.append(cmd_action_2)
|
||||
|
||||
cmd_observation_2 = CmdOutputObservation(
|
||||
content='',
|
||||
@ -169,11 +182,12 @@ def test_get_messages_with_cmd_action(codeact_agent, mock_event_stream):
|
||||
command='mkdir new_directory',
|
||||
exit_code=0,
|
||||
)
|
||||
mock_event_stream.add_event(cmd_observation_2, EventSource.ENVIRONMENT)
|
||||
cmd_observation_2._source = 'user'
|
||||
history.append(cmd_observation_2)
|
||||
|
||||
codeact_agent.reset()
|
||||
messages = codeact_agent._get_messages(
|
||||
Mock(history=mock_event_stream, max_iterations=5, iteration=0)
|
||||
Mock(history=history, max_iterations=5, iteration=0)
|
||||
)
|
||||
|
||||
# Assert the presence of key elements in the messages
|
||||
@ -218,19 +232,17 @@ def test_get_messages_with_cmd_action(codeact_agent, mock_event_stream):
|
||||
assert 'ENVIRONMENT REMINDER: You have 5 turns' in messages[5].content[1].text
|
||||
|
||||
|
||||
def test_prompt_caching_headers(codeact_agent, mock_event_stream):
|
||||
def test_prompt_caching_headers(codeact_agent: CodeActAgent):
|
||||
history = list()
|
||||
if codeact_agent.config.function_calling:
|
||||
pytest.skip('Skipping this test for function calling')
|
||||
|
||||
# Setup
|
||||
mock_event_stream.add_event(MessageAction('Hello, agent!'), EventSource.USER)
|
||||
mock_event_stream.add_event(MessageAction('Hello, user!'), EventSource.AGENT)
|
||||
|
||||
mock_short_term_history = MagicMock()
|
||||
mock_short_term_history.get_last_user_message.return_value = 'Hello, agent!'
|
||||
history.append(MessageAction('Hello, agent!'))
|
||||
history.append(MessageAction('Hello, user!'))
|
||||
|
||||
mock_state = Mock()
|
||||
mock_state.history = mock_short_term_history
|
||||
mock_state.history = history
|
||||
mock_state.max_iterations = 5
|
||||
mock_state.iteration = 0
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user