Logger fixes for openhands-resolver (#4710)

Co-authored-by: Graham Neubig <neubig@gmail.com>
This commit is contained in:
Rohit Malhotra 2024-11-05 11:49:32 -05:00 committed by GitHub
parent df9e9fca5a
commit 436ecb80a3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 56 additions and 30 deletions

View File

@ -108,7 +108,7 @@ class AgentController:
# subscribe to the event stream
self.event_stream = event_stream
self.event_stream.subscribe(
EventStreamSubscriber.AGENT_CONTROLLER, self.on_event, append=is_delegate
EventStreamSubscriber.AGENT_CONTROLLER, self.on_event, self.id
)
# state from the previous session, state from a parent agent, or a fresh state
@ -156,7 +156,7 @@ class AgentController:
)
# unsubscribe from the event stream
self.event_stream.unsubscribe(EventStreamSubscriber.AGENT_CONTROLLER)
self.event_stream.unsubscribe(EventStreamSubscriber.AGENT_CONTROLLER, self.id)
def log(self, level: str, message: str, extra: dict | None = None):
"""Logs a message to the agent controller's logger.
@ -403,6 +403,8 @@ class AgentController:
'debug',
f'start delegate, creating agent {delegate_agent.name} using LLM {llm}',
)
self.event_stream.unsubscribe(EventStreamSubscriber.AGENT_CONTROLLER, self.id)
self.delegate = AgentController(
sid=self.id + '-delegate',
agent=delegate_agent,
@ -519,6 +521,11 @@ class AgentController:
# close the delegate upon error
await self.delegate.close()
# resubscribe parent when delegate is finished
self.event_stream.subscribe(
EventStreamSubscriber.AGENT_CONTROLLER, self.on_event, self.id
)
self.delegate = None
self.delegateAction = None
@ -533,6 +540,11 @@ class AgentController:
# close delegate controller: we must close the delegate controller before adding new events
await self.delegate.close()
# resubscribe parent when delegate is finished
self.event_stream.subscribe(
EventStreamSubscriber.AGENT_CONTROLLER, self.on_event, self.id
)
# update delegate result observation
# TODO: replace this with AI-generated summary (#2395)
formatted_output = ', '.join(

View File

@ -2,6 +2,7 @@ import asyncio
import logging
import sys
from typing import Type
from uuid import uuid4
from termcolor import colored
@ -150,7 +151,7 @@ async def main():
]:
await prompt_for_next_task()
event_stream.subscribe(EventStreamSubscriber.MAIN, on_event)
event_stream.subscribe(EventStreamSubscriber.MAIN, on_event, str(uuid4()))
await runtime.connect()

View File

@ -186,7 +186,7 @@ async def run_controller(
action = MessageAction(content=message)
event_stream.add_event(action, EventSource.USER)
event_stream.subscribe(EventStreamSubscriber.MAIN, on_event)
event_stream.subscribe(EventStreamSubscriber.MAIN, on_event, sid)
await runtime.connect()

View File

@ -17,6 +17,7 @@ from openhands.utils.async_utils import call_sync_from_async
class EventStreamSubscriber(str, Enum):
AGENT_CONTROLLER = 'agent_controller'
SECURITY_ANALYZER = 'security_analyzer'
RESOLVER = 'openhands_resolver'
SERVER = 'server'
RUNTIME = 'runtime'
MAIN = 'main'
@ -50,9 +51,9 @@ class AsyncEventStreamWrapper:
class EventStream:
sid: str
file_store: FileStore
# For each subscriber ID, there is a stack of callback functions - useful
# when there are agent delegates
_subscribers: dict[str, list[Callable]] = field(default_factory=dict)
# For each subscriber ID, there is a map of callback functions - useful
# when there are multiple listeners
_subscribers: dict[str, dict[str, Callable]] = field(default_factory=dict)
_cur_id: int = 0
_lock: threading.Lock = field(default_factory=threading.Lock)
@ -148,22 +149,29 @@ class EventStream:
def get_latest_event_id(self) -> int:
return self._cur_id - 1
def subscribe(self, id: EventStreamSubscriber, callback: Callable, append=False):
if id in self._subscribers:
if append:
self._subscribers[id].append(callback)
else:
raise ValueError('Subscriber already exists: ' + id)
else:
self._subscribers[id] = [callback]
def subscribe(
self, subscriber_id: EventStreamSubscriber, callback: Callable, callback_id: str
):
if subscriber_id not in self._subscribers:
self._subscribers[subscriber_id] = {}
def unsubscribe(self, id: EventStreamSubscriber):
if id not in self._subscribers:
logger.warning('Subscriber not found during unsubscribe: ' + id)
else:
self._subscribers[id].pop()
if len(self._subscribers[id]) == 0:
del self._subscribers[id]
if callback_id in self._subscribers[subscriber_id]:
raise ValueError(
f'Callback ID on subscriber {subscriber_id} already exists: {callback_id}'
)
self._subscribers[subscriber_id][callback_id] = callback
def unsubscribe(self, subscriber_id: EventStreamSubscriber, callback_id: str):
if subscriber_id not in self._subscribers:
logger.warning(f'Subscriber not found during unsubscribe: {subscriber_id}')
return
if callback_id not in self._subscribers[subscriber_id]:
logger.warning(f'Callback not found during unsubscribe: {callback_id}')
return
del self._subscribers[subscriber_id][callback_id]
def add_event(self, event: Event, source: EventSource):
try:
@ -188,9 +196,10 @@ class EventStream:
self.file_store.write(self._get_filename_for_id(event.id), json.dumps(data))
tasks = []
for key in sorted(self._subscribers.keys()):
stack = self._subscribers[key]
callback = stack[-1]
tasks.append(asyncio.create_task(callback(event)))
callbacks = self._subscribers[key]
for callback_id in callbacks:
callback = callbacks[callback_id]
tasks.append(asyncio.create_task(callback(event)))
if tasks:
await asyncio.wait(tasks)

View File

@ -86,7 +86,9 @@ class Runtime(FileEditRuntimeMixin):
):
self.sid = sid
self.event_stream = event_stream
self.event_stream.subscribe(EventStreamSubscriber.RUNTIME, self.on_event)
self.event_stream.subscribe(
EventStreamSubscriber.RUNTIME, self.on_event, self.sid
)
self.plugins = plugins if plugins is not None and len(plugins) > 0 else []
self.status_callback = status_callback
self.attach_to_existing = attach_to_existing

View File

@ -1,4 +1,5 @@
from typing import Any
from uuid import uuid4
from fastapi import Request
@ -19,7 +20,7 @@ class SecurityAnalyzer:
"""
self.event_stream = event_stream
self.event_stream.subscribe(
EventStreamSubscriber.SECURITY_ANALYZER, self.on_event
EventStreamSubscriber.SECURITY_ANALYZER, self.on_event, str(uuid4())
)
async def on_event(self, event: Event) -> None:

View File

@ -44,7 +44,7 @@ class Session:
sid, file_store, status_callback=self.queue_status_message
)
self.agent_session.event_stream.subscribe(
EventStreamSubscriber.SERVER, self.on_event
EventStreamSubscriber.SERVER, self.on_event, self.sid
)
self.config = config
self.loop = asyncio.get_event_loop()

View File

@ -1,5 +1,6 @@
import asyncio
from unittest.mock import AsyncMock, MagicMock, Mock
from uuid import uuid4
import pytest
@ -143,7 +144,7 @@ async def test_run_controller_with_fatal_error(mock_agent, mock_event_stream):
error_obs._cause = event.id
event_stream.add_event(error_obs, EventSource.USER)
event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event)
event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4()))
runtime.event_stream = event_stream
state = await run_controller(
@ -188,7 +189,7 @@ async def test_run_controller_stop_with_stuck():
non_fatal_error_obs._cause = event.id
event_stream.add_event(non_fatal_error_obs, EventSource.ENVIRONMENT)
event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event)
event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4()))
runtime.event_stream = event_stream
state = await run_controller(