mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Logger fixes for openhands-resolver (#4710)
Co-authored-by: Graham Neubig <neubig@gmail.com>
This commit is contained in:
parent
df9e9fca5a
commit
436ecb80a3
@ -108,7 +108,7 @@ class AgentController:
|
||||
# subscribe to the event stream
|
||||
self.event_stream = event_stream
|
||||
self.event_stream.subscribe(
|
||||
EventStreamSubscriber.AGENT_CONTROLLER, self.on_event, append=is_delegate
|
||||
EventStreamSubscriber.AGENT_CONTROLLER, self.on_event, self.id
|
||||
)
|
||||
|
||||
# state from the previous session, state from a parent agent, or a fresh state
|
||||
@ -156,7 +156,7 @@ class AgentController:
|
||||
)
|
||||
|
||||
# unsubscribe from the event stream
|
||||
self.event_stream.unsubscribe(EventStreamSubscriber.AGENT_CONTROLLER)
|
||||
self.event_stream.unsubscribe(EventStreamSubscriber.AGENT_CONTROLLER, self.id)
|
||||
|
||||
def log(self, level: str, message: str, extra: dict | None = None):
|
||||
"""Logs a message to the agent controller's logger.
|
||||
@ -403,6 +403,8 @@ class AgentController:
|
||||
'debug',
|
||||
f'start delegate, creating agent {delegate_agent.name} using LLM {llm}',
|
||||
)
|
||||
|
||||
self.event_stream.unsubscribe(EventStreamSubscriber.AGENT_CONTROLLER, self.id)
|
||||
self.delegate = AgentController(
|
||||
sid=self.id + '-delegate',
|
||||
agent=delegate_agent,
|
||||
@ -519,6 +521,11 @@ class AgentController:
|
||||
|
||||
# close the delegate upon error
|
||||
await self.delegate.close()
|
||||
|
||||
# resubscribe parent when delegate is finished
|
||||
self.event_stream.subscribe(
|
||||
EventStreamSubscriber.AGENT_CONTROLLER, self.on_event, self.id
|
||||
)
|
||||
self.delegate = None
|
||||
self.delegateAction = None
|
||||
|
||||
@ -533,6 +540,11 @@ class AgentController:
|
||||
# close delegate controller: we must close the delegate controller before adding new events
|
||||
await self.delegate.close()
|
||||
|
||||
# resubscribe parent when delegate is finished
|
||||
self.event_stream.subscribe(
|
||||
EventStreamSubscriber.AGENT_CONTROLLER, self.on_event, self.id
|
||||
)
|
||||
|
||||
# update delegate result observation
|
||||
# TODO: replace this with AI-generated summary (#2395)
|
||||
formatted_output = ', '.join(
|
||||
|
||||
@ -2,6 +2,7 @@ import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from typing import Type
|
||||
from uuid import uuid4
|
||||
|
||||
from termcolor import colored
|
||||
|
||||
@ -150,7 +151,7 @@ async def main():
|
||||
]:
|
||||
await prompt_for_next_task()
|
||||
|
||||
event_stream.subscribe(EventStreamSubscriber.MAIN, on_event)
|
||||
event_stream.subscribe(EventStreamSubscriber.MAIN, on_event, str(uuid4()))
|
||||
|
||||
await runtime.connect()
|
||||
|
||||
|
||||
@ -186,7 +186,7 @@ async def run_controller(
|
||||
action = MessageAction(content=message)
|
||||
event_stream.add_event(action, EventSource.USER)
|
||||
|
||||
event_stream.subscribe(EventStreamSubscriber.MAIN, on_event)
|
||||
event_stream.subscribe(EventStreamSubscriber.MAIN, on_event, sid)
|
||||
|
||||
await runtime.connect()
|
||||
|
||||
|
||||
@ -17,6 +17,7 @@ from openhands.utils.async_utils import call_sync_from_async
|
||||
class EventStreamSubscriber(str, Enum):
|
||||
AGENT_CONTROLLER = 'agent_controller'
|
||||
SECURITY_ANALYZER = 'security_analyzer'
|
||||
RESOLVER = 'openhands_resolver'
|
||||
SERVER = 'server'
|
||||
RUNTIME = 'runtime'
|
||||
MAIN = 'main'
|
||||
@ -50,9 +51,9 @@ class AsyncEventStreamWrapper:
|
||||
class EventStream:
|
||||
sid: str
|
||||
file_store: FileStore
|
||||
# For each subscriber ID, there is a stack of callback functions - useful
|
||||
# when there are agent delegates
|
||||
_subscribers: dict[str, list[Callable]] = field(default_factory=dict)
|
||||
# For each subscriber ID, there is a map of callback functions - useful
|
||||
# when there are multiple listeners
|
||||
_subscribers: dict[str, dict[str, Callable]] = field(default_factory=dict)
|
||||
_cur_id: int = 0
|
||||
_lock: threading.Lock = field(default_factory=threading.Lock)
|
||||
|
||||
@ -148,22 +149,29 @@ class EventStream:
|
||||
def get_latest_event_id(self) -> int:
|
||||
return self._cur_id - 1
|
||||
|
||||
def subscribe(self, id: EventStreamSubscriber, callback: Callable, append=False):
|
||||
if id in self._subscribers:
|
||||
if append:
|
||||
self._subscribers[id].append(callback)
|
||||
else:
|
||||
raise ValueError('Subscriber already exists: ' + id)
|
||||
else:
|
||||
self._subscribers[id] = [callback]
|
||||
def subscribe(
|
||||
self, subscriber_id: EventStreamSubscriber, callback: Callable, callback_id: str
|
||||
):
|
||||
if subscriber_id not in self._subscribers:
|
||||
self._subscribers[subscriber_id] = {}
|
||||
|
||||
def unsubscribe(self, id: EventStreamSubscriber):
|
||||
if id not in self._subscribers:
|
||||
logger.warning('Subscriber not found during unsubscribe: ' + id)
|
||||
else:
|
||||
self._subscribers[id].pop()
|
||||
if len(self._subscribers[id]) == 0:
|
||||
del self._subscribers[id]
|
||||
if callback_id in self._subscribers[subscriber_id]:
|
||||
raise ValueError(
|
||||
f'Callback ID on subscriber {subscriber_id} already exists: {callback_id}'
|
||||
)
|
||||
|
||||
self._subscribers[subscriber_id][callback_id] = callback
|
||||
|
||||
def unsubscribe(self, subscriber_id: EventStreamSubscriber, callback_id: str):
|
||||
if subscriber_id not in self._subscribers:
|
||||
logger.warning(f'Subscriber not found during unsubscribe: {subscriber_id}')
|
||||
return
|
||||
|
||||
if callback_id not in self._subscribers[subscriber_id]:
|
||||
logger.warning(f'Callback not found during unsubscribe: {callback_id}')
|
||||
return
|
||||
|
||||
del self._subscribers[subscriber_id][callback_id]
|
||||
|
||||
def add_event(self, event: Event, source: EventSource):
|
||||
try:
|
||||
@ -188,9 +196,10 @@ class EventStream:
|
||||
self.file_store.write(self._get_filename_for_id(event.id), json.dumps(data))
|
||||
tasks = []
|
||||
for key in sorted(self._subscribers.keys()):
|
||||
stack = self._subscribers[key]
|
||||
callback = stack[-1]
|
||||
tasks.append(asyncio.create_task(callback(event)))
|
||||
callbacks = self._subscribers[key]
|
||||
for callback_id in callbacks:
|
||||
callback = callbacks[callback_id]
|
||||
tasks.append(asyncio.create_task(callback(event)))
|
||||
if tasks:
|
||||
await asyncio.wait(tasks)
|
||||
|
||||
|
||||
@ -86,7 +86,9 @@ class Runtime(FileEditRuntimeMixin):
|
||||
):
|
||||
self.sid = sid
|
||||
self.event_stream = event_stream
|
||||
self.event_stream.subscribe(EventStreamSubscriber.RUNTIME, self.on_event)
|
||||
self.event_stream.subscribe(
|
||||
EventStreamSubscriber.RUNTIME, self.on_event, self.sid
|
||||
)
|
||||
self.plugins = plugins if plugins is not None and len(plugins) > 0 else []
|
||||
self.status_callback = status_callback
|
||||
self.attach_to_existing = attach_to_existing
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
@ -19,7 +20,7 @@ class SecurityAnalyzer:
|
||||
"""
|
||||
self.event_stream = event_stream
|
||||
self.event_stream.subscribe(
|
||||
EventStreamSubscriber.SECURITY_ANALYZER, self.on_event
|
||||
EventStreamSubscriber.SECURITY_ANALYZER, self.on_event, str(uuid4())
|
||||
)
|
||||
|
||||
async def on_event(self, event: Event) -> None:
|
||||
|
||||
@ -44,7 +44,7 @@ class Session:
|
||||
sid, file_store, status_callback=self.queue_status_message
|
||||
)
|
||||
self.agent_session.event_stream.subscribe(
|
||||
EventStreamSubscriber.SERVER, self.on_event
|
||||
EventStreamSubscriber.SERVER, self.on_event, self.sid
|
||||
)
|
||||
self.config = config
|
||||
self.loop = asyncio.get_event_loop()
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
@ -143,7 +144,7 @@ async def test_run_controller_with_fatal_error(mock_agent, mock_event_stream):
|
||||
error_obs._cause = event.id
|
||||
event_stream.add_event(error_obs, EventSource.USER)
|
||||
|
||||
event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event)
|
||||
event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4()))
|
||||
runtime.event_stream = event_stream
|
||||
|
||||
state = await run_controller(
|
||||
@ -188,7 +189,7 @@ async def test_run_controller_stop_with_stuck():
|
||||
non_fatal_error_obs._cause = event.id
|
||||
event_stream.add_event(non_fatal_error_obs, EventSource.ENVIRONMENT)
|
||||
|
||||
event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event)
|
||||
event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4()))
|
||||
runtime.event_stream = event_stream
|
||||
|
||||
state = await run_controller(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user