[Arch proposal] ENVIRONMENT event source (#4584)

Co-authored-by: Xingyao Wang <xingyao@all-hands.dev>
This commit is contained in:
Engel Nyst 2024-10-31 19:33:13 +01:00 committed by GitHub
parent db4e1dbbec
commit 0687608feb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 70 additions and 52 deletions

View File

@ -156,7 +156,7 @@ class AgentController:
if exception is not None and isinstance(exception, litellm.AuthenticationError):
detail = 'Please check your credentials. Is your API key correct?'
self.event_stream.add_event(
ErrorObservation(f'{message}:{detail}'), EventSource.USER
ErrorObservation(f'{message}:{detail}'), EventSource.ENVIRONMENT
)
async def start_step_loop(self):
@ -346,7 +346,8 @@ class AgentController:
self.state.agent_state = new_state
self.event_stream.add_event(
AgentStateChangedObservation('', self.state.agent_state), EventSource.AGENT
AgentStateChangedObservation('', self.state.agent_state),
EventSource.ENVIRONMENT,
)
if new_state == AgentState.INIT and self.state.resume_state:
@ -423,7 +424,8 @@ class AgentController:
if self._is_stuck():
# This need to go BEFORE report_error to sync metrics
self.event_stream.add_event(
FatalErrorObservation('Agent got stuck in a loop'), EventSource.USER
FatalErrorObservation('Agent got stuck in a loop'),
EventSource.ENVIRONMENT,
)
return

View File

@ -61,7 +61,7 @@ def display_event(event: Event):
if hasattr(event, 'thought'):
display_message(event.thought)
if isinstance(event, MessageAction):
if event.source != EventSource.USER:
if event.source == EventSource.AGENT:
display_message(event.content)
if isinstance(event, CmdRunAction):
display_command(event.command)
@ -131,7 +131,7 @@ async def main():
next_message = input('How can I help? >> ')
if next_message == 'exit':
event_stream.add_event(
ChangeAgentStateAction(AgentState.STOPPED), EventSource.USER
ChangeAgentStateAction(AgentState.STOPPED), EventSource.ENVIRONMENT
)
return
action = MessageAction(content=next_message)

View File

@ -49,6 +49,8 @@ class ImageContent(Content):
class Message(BaseModel):
# NOTE: this is not the same as EventSource
# These are the roles in the LLM's APIs
role: Literal['user', 'system', 'assistant', 'tool']
content: list[TextContent | ImageContent] = Field(default_factory=list)
cache_enabled: bool = False

View File

@ -9,6 +9,7 @@ from openhands.llm.metrics import Metrics
class EventSource(str, Enum):
AGENT = 'agent'
USER = 'user'
ENVIRONMENT = 'environment'
@dataclass

View File

@ -136,6 +136,8 @@ class Runtime(FileEditRuntimeMixin):
)
observation._cause = event.id # type: ignore[attr-defined]
observation.tool_call_metadata = event.tool_call_metadata
# this might be unnecessary, since source should be set by the event stream when we're here
source = event.source if event.source else EventSource.AGENT
await self.event_stream.async_add_event(observation, source) # type: ignore[arg-type]

View File

@ -147,6 +147,7 @@ class InvariantAnalyzer(SecurityAnalyzer):
new_event = action_from_dict(
{'action': 'change_agent_state', 'args': {'agent_state': 'user_confirmed'}}
)
# we should confirm only on agent actions
event_source = event.source if event.source else EventSource.AGENT
await call_sync_from_async(self.event_stream.add_event, new_event, event_source)

View File

@ -118,7 +118,7 @@ class AgentSession:
agent_configs=agent_configs,
)
self.event_stream.add_event(
ChangeAgentStateAction(AgentState.INIT), EventSource.USER
ChangeAgentStateAction(AgentState.INIT), EventSource.ENVIRONMENT
)
if self.controller:
self.controller.agent_task = self.controller.start_step_loop()

View File

@ -73,10 +73,11 @@ class Session:
async def _initialize_agent(self, data: dict):
self.agent_session.event_stream.add_event(
ChangeAgentStateAction(AgentState.LOADING), EventSource.USER
ChangeAgentStateAction(AgentState.LOADING), EventSource.ENVIRONMENT
)
self.agent_session.event_stream.add_event(
AgentStateChangedObservation('', AgentState.LOADING), EventSource.AGENT
AgentStateChangedObservation('', AgentState.LOADING),
EventSource.ENVIRONMENT,
)
# Extract the agent-relevant arguments from the request
args = {key: value for key, value in data.get('args', {}).items()}
@ -138,12 +139,19 @@ class Session:
return
if event.source == EventSource.AGENT:
await self.send(event_to_dict(event))
elif event.source == EventSource.USER and isinstance(
# NOTE: ipython observations are not sent here currently
elif event.source == EventSource.ENVIRONMENT and isinstance(
event, CmdOutputObservation
):
await self.send(event_to_dict(event))
# feedback from the environment to agent actions is understood as agent events by the UI
event_dict = event_to_dict(event)
event_dict['source'] = EventSource.AGENT
await self.send(event_dict)
elif isinstance(event, ErrorObservation):
await self.send(event_to_dict(event))
# send error events as agent events to the UI
event_dict = event_to_dict(event)
event_dict['source'] = EventSource.AGENT
await self.send(event_dict)
async def dispatch(self, data: dict):
action = data.get('action', '')

View File

@ -207,7 +207,9 @@ async def test_run_controller_stop_with_stuck(mock_agent, mock_event_stream):
'Non fatal error here to trigger loop'
)
non_fatal_error_obs._cause = event.id
await event_stream.async_add_event(non_fatal_error_obs, EventSource.USER)
await event_stream.async_add_event(
non_fatal_error_obs, EventSource.ENVIRONMENT
)
event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event)
runtime.event_stream = event_stream

View File

@ -80,7 +80,7 @@ class TestStuckDetector:
code=code_snippet,
)
ipython_observation._cause = ipython_action._id
event_stream.add_event(ipython_observation, EventSource.USER)
event_stream.add_event(ipython_observation, EventSource.ENVIRONMENT)
def _impl_unterminated_string_error_events(
self, event_stream: EventStream, random_line: bool, incidents: int = 4
@ -96,7 +96,7 @@ class TestStuckDetector:
code=code_snippet,
)
ipython_observation._cause = ipython_action._id
event_stream.add_event(ipython_observation, EventSource.USER)
event_stream.add_event(ipython_observation, EventSource.ENVIRONMENT)
def test_history_too_short(
self, stuck_detector: StuckDetector, event_stream: EventStream
@ -106,7 +106,7 @@ class TestStuckDetector:
observation = NullObservation(content='')
observation._cause = message_action.id
event_stream.add_event(message_action, EventSource.USER)
event_stream.add_event(observation, EventSource.USER)
event_stream.add_event(observation, EventSource.ENVIRONMENT)
cmd_action = CmdRunAction(command='ls')
event_stream.add_event(cmd_action, EventSource.AGENT)
@ -114,7 +114,7 @@ class TestStuckDetector:
command_id=1, command='ls', content='file1.txt\nfile2.txt'
)
cmd_observation._cause = cmd_action._id
event_stream.add_event(cmd_observation, EventSource.USER)
event_stream.add_event(cmd_observation, EventSource.ENVIRONMENT)
# stuck_detector.state.history.set_event_stream(event_stream)
@ -131,7 +131,7 @@ class TestStuckDetector:
# 2 events
event_stream.add_event(hello_action, EventSource.USER)
event_stream.add_event(hello_observation, EventSource.USER)
event_stream.add_event(hello_observation, EventSource.ENVIRONMENT)
cmd_action_1 = CmdRunAction(command='ls')
event_stream.add_event(cmd_action_1, EventSource.AGENT)
@ -139,7 +139,7 @@ class TestStuckDetector:
content='', command='ls', command_id=cmd_action_1._id
)
cmd_observation_1._cause = cmd_action_1._id
event_stream.add_event(cmd_observation_1, EventSource.USER)
event_stream.add_event(cmd_observation_1, EventSource.ENVIRONMENT)
# 4 events
cmd_action_2 = CmdRunAction(command='ls')
@ -148,13 +148,13 @@ class TestStuckDetector:
content='', command='ls', command_id=cmd_action_2._id
)
cmd_observation_2._cause = cmd_action_2._id
event_stream.add_event(cmd_observation_2, EventSource.USER)
event_stream.add_event(cmd_observation_2, EventSource.ENVIRONMENT)
# 6 events
# random user message just because we can
message_null_observation = NullObservation(content='')
event_stream.add_event(message_action, EventSource.USER)
event_stream.add_event(message_null_observation, EventSource.USER)
event_stream.add_event(message_null_observation, EventSource.ENVIRONMENT)
# 8 events
assert stuck_detector.is_stuck() is False
@ -166,7 +166,7 @@ class TestStuckDetector:
content='', command='ls', command_id=cmd_action_3._id
)
cmd_observation_3._cause = cmd_action_3._id
event_stream.add_event(cmd_observation_3, EventSource.USER)
event_stream.add_event(cmd_observation_3, EventSource.ENVIRONMENT)
# 10 events
assert len(collect_events(event_stream)) == 10
@ -191,7 +191,7 @@ class TestStuckDetector:
content='', command='ls', command_id=cmd_action_4._id
)
cmd_observation_4._cause = cmd_action_4._id
event_stream.add_event(cmd_observation_4, EventSource.USER)
event_stream.add_event(cmd_observation_4, EventSource.ENVIRONMENT)
# 12 events
assert len(collect_events(event_stream)) == 12
@ -223,14 +223,14 @@ class TestStuckDetector:
hello_observation = NullObservation(content='')
event_stream.add_event(hello_action, EventSource.USER)
hello_observation._cause = hello_action._id
event_stream.add_event(hello_observation, EventSource.USER)
event_stream.add_event(hello_observation, EventSource.ENVIRONMENT)
# 2 events
cmd_action_1 = CmdRunAction(command='invalid_command')
event_stream.add_event(cmd_action_1, EventSource.AGENT)
error_observation_1 = ErrorObservation(content='Command not found')
error_observation_1._cause = cmd_action_1._id
event_stream.add_event(error_observation_1, EventSource.USER)
event_stream.add_event(error_observation_1, EventSource.ENVIRONMENT)
# 4 events
cmd_action_2 = CmdRunAction(command='invalid_command')
@ -239,26 +239,26 @@ class TestStuckDetector:
content='Command still not found or another error'
)
error_observation_2._cause = cmd_action_2._id
event_stream.add_event(error_observation_2, EventSource.USER)
event_stream.add_event(error_observation_2, EventSource.ENVIRONMENT)
# 6 events
message_null_observation = NullObservation(content='')
event_stream.add_event(message_action, EventSource.USER)
event_stream.add_event(message_null_observation, EventSource.USER)
event_stream.add_event(message_null_observation, EventSource.ENVIRONMENT)
# 8 events
cmd_action_3 = CmdRunAction(command='invalid_command')
event_stream.add_event(cmd_action_3, EventSource.AGENT)
error_observation_3 = ErrorObservation(content='Different error')
error_observation_3._cause = cmd_action_3._id
event_stream.add_event(error_observation_3, EventSource.USER)
event_stream.add_event(error_observation_3, EventSource.ENVIRONMENT)
# 10 events
cmd_action_4 = CmdRunAction(command='invalid_command')
event_stream.add_event(cmd_action_4, EventSource.AGENT)
error_observation_4 = ErrorObservation(content='Command not found')
error_observation_4._cause = cmd_action_4._id
event_stream.add_event(error_observation_4, EventSource.USER)
event_stream.add_event(error_observation_4, EventSource.ENVIRONMENT)
# 12 events
with patch('logging.Logger.warning') as mock_warning:
@ -366,7 +366,7 @@ class TestStuckDetector:
code='print("hello',
)
ipython_observation_1._cause = ipython_action_1._id
event_stream.add_event(ipython_observation_1, EventSource.USER)
event_stream.add_event(ipython_observation_1, EventSource.ENVIRONMENT)
ipython_action_2 = IPythonRunCellAction(code='print("hello')
event_stream.add_event(ipython_action_2, EventSource.AGENT)
@ -375,7 +375,7 @@ class TestStuckDetector:
code='print("hello',
)
ipython_observation_2._cause = ipython_action_2._id
event_stream.add_event(ipython_observation_2, EventSource.USER)
event_stream.add_event(ipython_observation_2, EventSource.ENVIRONMENT)
ipython_action_3 = IPythonRunCellAction(code='print("hello')
event_stream.add_event(ipython_action_3, EventSource.AGENT)
@ -384,7 +384,7 @@ class TestStuckDetector:
code='print("hello',
)
ipython_observation_3._cause = ipython_action_3._id
event_stream.add_event(ipython_observation_3, EventSource.USER)
event_stream.add_event(ipython_observation_3, EventSource.ENVIRONMENT)
ipython_action_4 = IPythonRunCellAction(code='print("hello')
event_stream.add_event(ipython_action_4, EventSource.AGENT)
@ -393,7 +393,7 @@ class TestStuckDetector:
code='print("hello',
)
ipython_observation_4._cause = ipython_action_4._id
event_stream.add_event(ipython_observation_4, EventSource.USER)
event_stream.add_event(ipython_observation_4, EventSource.ENVIRONMENT)
with patch('logging.Logger.warning') as mock_warning:
assert stuck_detector.is_stuck() is False
@ -406,7 +406,7 @@ class TestStuckDetector:
message_action._source = EventSource.USER
event_stream.add_event(message_action, EventSource.USER)
message_observation = NullObservation(content='')
event_stream.add_event(message_observation, EventSource.USER)
event_stream.add_event(message_observation, EventSource.ENVIRONMENT)
cmd_action_1 = CmdRunAction(command='ls')
event_stream.add_event(cmd_action_1, EventSource.AGENT)
@ -414,7 +414,7 @@ class TestStuckDetector:
command_id=1, command='ls', content='file1.txt\nfile2.txt'
)
cmd_observation_1._cause = cmd_action_1._id
event_stream.add_event(cmd_observation_1, EventSource.USER)
event_stream.add_event(cmd_observation_1, EventSource.ENVIRONMENT)
read_action_1 = FileReadAction(path='file1.txt')
event_stream.add_event(read_action_1, EventSource.AGENT)
@ -422,7 +422,7 @@ class TestStuckDetector:
content='File content', path='file1.txt'
)
read_observation_1._cause = read_action_1._id
event_stream.add_event(read_observation_1, EventSource.USER)
event_stream.add_event(read_observation_1, EventSource.ENVIRONMENT)
cmd_action_2 = CmdRunAction(command='ls')
event_stream.add_event(cmd_action_2, EventSource.AGENT)
@ -430,7 +430,7 @@ class TestStuckDetector:
command_id=2, command='ls', content='file1.txt\nfile2.txt'
)
cmd_observation_2._cause = cmd_action_2._id
event_stream.add_event(cmd_observation_2, EventSource.USER)
event_stream.add_event(cmd_observation_2, EventSource.ENVIRONMENT)
read_action_2 = FileReadAction(path='file1.txt')
event_stream.add_event(read_action_2, EventSource.AGENT)
@ -438,12 +438,12 @@ class TestStuckDetector:
content='File content', path='file1.txt'
)
read_observation_2._cause = read_action_2._id
event_stream.add_event(read_observation_2, EventSource.USER)
event_stream.add_event(read_observation_2, EventSource.ENVIRONMENT)
# one more message to break the pattern
message_null_observation = NullObservation(content='')
event_stream.add_event(message_action, EventSource.USER)
event_stream.add_event(message_null_observation, EventSource.USER)
event_stream.add_event(message_null_observation, EventSource.ENVIRONMENT)
cmd_action_3 = CmdRunAction(command='ls')
event_stream.add_event(cmd_action_3, EventSource.AGENT)
@ -451,7 +451,7 @@ class TestStuckDetector:
command_id=3, command='ls', content='file1.txt\nfile2.txt'
)
cmd_observation_3._cause = cmd_action_3._id
event_stream.add_event(cmd_observation_3, EventSource.USER)
event_stream.add_event(cmd_observation_3, EventSource.ENVIRONMENT)
read_action_3 = FileReadAction(path='file1.txt')
event_stream.add_event(read_action_3, EventSource.AGENT)
@ -459,7 +459,7 @@ class TestStuckDetector:
content='File content', path='file1.txt'
)
read_observation_3._cause = read_action_3._id
event_stream.add_event(read_observation_3, EventSource.USER)
event_stream.add_event(read_observation_3, EventSource.ENVIRONMENT)
with patch('logging.Logger.warning') as mock_warning:
assert stuck_detector.is_stuck() is True
@ -475,7 +475,7 @@ class TestStuckDetector:
event_stream.add_event(hello_action, EventSource.USER)
hello_observation = NullObservation(content='')
hello_observation._cause = hello_action._id
event_stream.add_event(hello_observation, EventSource.USER)
event_stream.add_event(hello_observation, EventSource.ENVIRONMENT)
cmd_action_1 = CmdRunAction(command='ls')
event_stream.add_event(cmd_action_1, EventSource.AGENT)
@ -483,7 +483,7 @@ class TestStuckDetector:
command_id=cmd_action_1.id, command='ls', content='file1.txt\nfile2.txt'
)
cmd_observation_1._cause = cmd_action_1._id
event_stream.add_event(cmd_observation_1, EventSource.USER)
event_stream.add_event(cmd_observation_1, EventSource.ENVIRONMENT)
read_action_1 = FileReadAction(path='file1.txt')
event_stream.add_event(read_action_1, EventSource.AGENT)
@ -491,7 +491,7 @@ class TestStuckDetector:
content='File content', path='file1.txt'
)
read_observation_1._cause = read_action_1._id
event_stream.add_event(read_observation_1, EventSource.USER)
event_stream.add_event(read_observation_1, EventSource.ENVIRONMENT)
cmd_action_2 = CmdRunAction(command='pwd')
event_stream.add_event(cmd_action_2, EventSource.AGENT)
@ -499,7 +499,7 @@ class TestStuckDetector:
command_id=2, command='pwd', content='/home/user'
)
cmd_observation_2._cause = cmd_action_2._id
event_stream.add_event(cmd_observation_2, EventSource.USER)
event_stream.add_event(cmd_observation_2, EventSource.ENVIRONMENT)
read_action_2 = FileReadAction(path='file2.txt')
event_stream.add_event(read_action_2, EventSource.AGENT)
@ -507,11 +507,11 @@ class TestStuckDetector:
content='Another file content', path='file2.txt'
)
read_observation_2._cause = read_action_2._id
event_stream.add_event(read_observation_2, EventSource.USER)
event_stream.add_event(read_observation_2, EventSource.ENVIRONMENT)
message_null_observation = NullObservation(content='')
event_stream.add_event(message_action, EventSource.USER)
event_stream.add_event(message_null_observation, EventSource.USER)
event_stream.add_event(message_null_observation, EventSource.ENVIRONMENT)
cmd_action_3 = CmdRunAction(command='pwd')
event_stream.add_event(cmd_action_3, EventSource.AGENT)
@ -519,7 +519,7 @@ class TestStuckDetector:
command_id=cmd_action_3.id, command='pwd', content='/home/user'
)
cmd_observation_3._cause = cmd_action_3._id
event_stream.add_event(cmd_observation_3, EventSource.USER)
event_stream.add_event(cmd_observation_3, EventSource.ENVIRONMENT)
read_action_3 = FileReadAction(path='file2.txt')
event_stream.add_event(read_action_3, EventSource.AGENT)
@ -527,7 +527,7 @@ class TestStuckDetector:
content='Another file content', path='file2.txt'
)
read_observation_3._cause = read_action_3._id
event_stream.add_event(read_observation_3, EventSource.USER)
event_stream.add_event(read_observation_3, EventSource.ENVIRONMENT)
assert stuck_detector.is_stuck() is False
@ -572,7 +572,7 @@ class TestStuckDetector:
exit_code=0,
)
cmd_output_observation._cause = cmd_kill_action._id
event_stream.add_event(cmd_output_observation, EventSource.USER)
event_stream.add_event(cmd_output_observation, EventSource.ENVIRONMENT)
message_action_7 = MessageAction(content="I'm doing well, thanks for asking.")
event_stream.add_event(message_action_7, EventSource.AGENT)

View File

@ -88,7 +88,7 @@ def _create_observation_event(observation: str) -> Event:
event = Event()
event._id = -1
event._timestamp = datetime.now(timezone.utc).isoformat()
event._source = EventSource.USER
event._source = EventSource.ENVIRONMENT
event.observation = observation
return event

View File

@ -155,7 +155,7 @@ def test_get_messages_with_cmd_action(codeact_agent, mock_event_stream):
command='ls -l',
exit_code=0,
)
mock_event_stream.add_event(cmd_observation_1, EventSource.USER)
mock_event_stream.add_event(cmd_observation_1, EventSource.ENVIRONMENT)
message_action_2 = MessageAction("Now, let's create a new directory.")
mock_event_stream.add_event(message_action_2, EventSource.AGENT)
@ -169,7 +169,7 @@ def test_get_messages_with_cmd_action(codeact_agent, mock_event_stream):
command='mkdir new_directory',
exit_code=0,
)
mock_event_stream.add_event(cmd_observation_2, EventSource.USER)
mock_event_stream.add_event(cmd_observation_2, EventSource.ENVIRONMENT)
codeact_agent.reset()
messages = codeact_agent._get_messages(