From da23189e4c1de94240c4246b1ea7b0d4594e855b Mon Sep 17 00:00:00 2001 From: Xingyao Wang Date: Tue, 15 Oct 2024 14:31:49 -0500 Subject: [PATCH] refactor: move get_pairs from memory to shared utils (#4411) --- openhands/events/utils.py | 56 +++++++++++++++++++++++++++++++++++++ openhands/memory/history.py | 54 +++-------------------------------- tests/unit/test_is_stuck.py | 23 +++++++++++++-- 3 files changed, 81 insertions(+), 52 deletions(-) create mode 100644 openhands/events/utils.py diff --git a/openhands/events/utils.py b/openhands/events/utils.py new file mode 100644 index 0000000000..6c8cc415f6 --- /dev/null +++ b/openhands/events/utils.py @@ -0,0 +1,56 @@ +from openhands.core.logger import openhands_logger as logger +from openhands.events.action.action import Action +from openhands.events.action.empty import NullAction +from openhands.events.event import Event +from openhands.events.observation.commands import CmdOutputObservation +from openhands.events.observation.empty import NullObservation +from openhands.events.observation.observation import Observation + + +def get_pairs_from_events(events: list[Event]) -> list[tuple[Action, Observation]]: + """Return the history as a list of tuples (action, observation).""" + tuples: list[tuple[Action, Observation]] = [] + action_map: dict[int, Action] = {} + observation_map: dict[int, Observation] = {} + + # runnable actions are set as cause of observations + # (MessageAction, NullObservation) for source=USER + # (MessageAction, NullObservation) for source=AGENT + # (other_action?, NullObservation) + # (NullAction, CmdOutputObservation) background CmdOutputObservations + + for event in events: + if event.id is None or event.id == -1: + logger.debug(f'Event {event} has no ID') + + if isinstance(event, Action): + action_map[event.id] = event + + if isinstance(event, Observation): + if event.cause is None or event.cause == -1: + logger.debug(f'Observation {event} has no cause') + + if event.cause is None: + # runnable actions are set as cause of observations + # NullObservations have no cause + continue + + observation_map[event.cause] = event + + for action_id, action in action_map.items(): + observation = observation_map.get(action_id) + if observation: + # observation with a cause + tuples.append((action, observation)) + else: + tuples.append((action, NullObservation(''))) + + for cause_id, observation in observation_map.items(): + if cause_id not in action_map: + if isinstance(observation, NullObservation): + continue + if not isinstance(observation, CmdOutputObservation): + logger.debug(f'Observation {observation} has no cause') + tuples.append((NullAction(), observation)) + + return tuples.copy() diff --git a/openhands/memory/history.py b/openhands/memory/history.py index 89e50d67e4..1e4cfb8b5f 100644 --- a/openhands/memory/history.py +++ b/openhands/memory/history.py @@ -10,12 +10,12 @@ from openhands.events.action.empty import NullAction from openhands.events.action.message import MessageAction from openhands.events.event import Event, EventSource from openhands.events.observation.agent import AgentStateChangedObservation -from openhands.events.observation.commands import CmdOutputObservation from openhands.events.observation.delegate import AgentDelegateObservation from openhands.events.observation.empty import NullObservation from openhands.events.observation.observation import Observation from openhands.events.serialization.event import event_to_dict from openhands.events.stream import EventStream +from openhands.events.utils import get_pairs_from_events class ShortTermHistory(list[Event]): @@ -216,55 +216,9 @@ class ShortTermHistory(list[Event]): def compatibility_for_eval_history_pairs(self) -> list[tuple[dict, dict]]: history_pairs = [] - for action, observation in self.get_pairs(): + for action, observation in get_pairs_from_events( + self.get_events_as_list(include_delegates=True) + ): history_pairs.append((event_to_dict(action), event_to_dict(observation))) return history_pairs - - def get_pairs(self) -> list[tuple[Action, Observation]]: - """Return the history as a list of tuples (action, observation).""" - tuples: list[tuple[Action, Observation]] = [] - action_map: dict[int, Action] = {} - observation_map: dict[int, Observation] = {} - - # runnable actions are set as cause of observations - # (MessageAction, NullObservation) for source=USER - # (MessageAction, NullObservation) for source=AGENT - # (other_action?, NullObservation) - # (NullAction, CmdOutputObservation) background CmdOutputObservations - - for event in self.get_events_as_list(include_delegates=True): - if event.id is None or event.id == -1: - logger.debug(f'Event {event} has no ID') - - if isinstance(event, Action): - action_map[event.id] = event - - if isinstance(event, Observation): - if event.cause is None or event.cause == -1: - logger.debug(f'Observation {event} has no cause') - - if event.cause is None: - # runnable actions are set as cause of observations - # NullObservations have no cause - continue - - observation_map[event.cause] = event - - for action_id, action in action_map.items(): - observation = observation_map.get(action_id) - if observation: - # observation with a cause - tuples.append((action, observation)) - else: - tuples.append((action, NullObservation(''))) - - for cause_id, observation in observation_map.items(): - if cause_id not in action_map: - if isinstance(observation, NullObservation): - continue - if not isinstance(observation, CmdOutputObservation): - logger.debug(f'Observation {observation} has no cause') - tuples.append((NullAction(), observation)) - - return tuples.copy() diff --git a/tests/unit/test_is_stuck.py b/tests/unit/test_is_stuck.py index 5e23a84928..4a13307521 100644 --- a/tests/unit/test_is_stuck.py +++ b/tests/unit/test_is_stuck.py @@ -17,6 +17,7 @@ from openhands.events.observation.commands import IPythonRunCellObservation from openhands.events.observation.empty import NullObservation from openhands.events.observation.error import ErrorObservation from openhands.events.stream import EventSource, EventStream +from openhands.events.utils import get_pairs_from_events from openhands.memory.history import ShortTermHistory from openhands.storage import get_file_store @@ -170,7 +171,16 @@ class TestStuckDetector: assert len(collect_events(event_stream)) == 10 assert len(list(stuck_detector.state.history.get_events())) == 8 - assert len(stuck_detector.state.history.get_pairs()) == 5 + assert ( + len( + get_pairs_from_events( + stuck_detector.state.history.get_events_as_list( + include_delegates=True + ) + ) + ) + == 5 + ) assert stuck_detector.is_stuck() is False assert stuck_detector.state.almost_stuck == 1 @@ -186,7 +196,16 @@ class TestStuckDetector: assert len(collect_events(event_stream)) == 12 assert len(list(stuck_detector.state.history.get_events())) == 10 - assert len(stuck_detector.state.history.get_pairs()) == 6 + assert ( + len( + get_pairs_from_events( + stuck_detector.state.history.get_events_as_list( + include_delegates=True + ) + ) + ) + == 6 + ) with patch('logging.Logger.warning') as mock_warning: assert stuck_detector.is_stuck() is True