Fix CLI and headless after changes to eventstream (#5949)

Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
This commit is contained in:
Robert Brennan 2025-01-01 00:05:35 -05:00 committed by GitHub
parent 2ec2f2538f
commit f3885cadc1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 28 additions and 13 deletions

View File

@ -91,7 +91,7 @@ def display_event(event: Event, config: AppConfig):
display_confirmation(event.confirmation_state)
async def main():
async def main(loop):
"""Runs the agent in CLI mode"""
parser = get_parser()
@ -112,7 +112,7 @@ async def main():
logger.setLevel(logging.WARNING)
config = load_app_config(config_file=args.config_file)
sid = 'cli'
sid = str(uuid4())
agent_cls: Type[Agent] = Agent.get_cls(config.default_agent)
agent_config = config.get_agent_config(config.default_agent)
@ -150,7 +150,6 @@ async def main():
async def prompt_for_next_task():
# Run input() in a thread pool to avoid blocking the event loop
loop = asyncio.get_event_loop()
next_message = await loop.run_in_executor(
None, lambda: input('How can I help? >> ')
)
@ -165,13 +164,12 @@ async def main():
event_stream.add_event(action, EventSource.USER)
async def prompt_for_user_confirmation():
loop = asyncio.get_event_loop()
user_confirmation = await loop.run_in_executor(
None, lambda: input('Confirm action (possible security risk)? (y/n) >> ')
)
return user_confirmation.lower() == 'y'
async def on_event(event: Event):
async def on_event_async(event: Event):
display_event(event, config)
if isinstance(event, AgentStateChangedObservation):
if event.agent_state in [
@ -193,6 +191,9 @@ async def main():
ChangeAgentStateAction(AgentState.USER_REJECTED), EventSource.USER
)
def on_event(event: Event) -> None:
loop.create_task(on_event_async(event))
event_stream.subscribe(EventStreamSubscriber.MAIN, on_event, str(uuid4()))
await runtime.connect()
@ -208,7 +209,7 @@ if __name__ == '__main__':
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(main())
loop.run_until_complete(main(loop))
except KeyboardInterrupt:
print('Received keyboard interrupt, shutting down...')
except ConnectionRefusedError as e:

View File

@ -385,7 +385,7 @@ def get_parser() -> argparse.ArgumentParser:
parser.add_argument(
'-n',
'--name',
default='default',
default='',
type=str,
help='Name for the session',
)

View File

@ -182,7 +182,7 @@ async def run_controller(
# init with the provided actions
event_stream.add_event(initial_user_action, EventSource.USER)
async def on_event(event: Event):
def on_event(event: Event):
if isinstance(event, AgentStateChangedObservation):
if event.agent_state == AgentState.AWAITING_USER_INPUT:
if exit_on_message:

View File

@ -229,10 +229,24 @@ class EventStream:
for callback_id in callbacks:
callback = callbacks[callback_id]
pool = self._thread_pools[key][callback_id]
pool.submit(callback, event)
future = pool.submit(callback, event)
future.add_done_callback(self._make_error_handler(callback_id, key))
def _callback(self, callback: Callable, event: Event):
asyncio.run(callback(event))
def _make_error_handler(self, callback_id: str, subscriber_id: str):
def _handle_callback_error(fut):
try:
# This will raise any exception that occurred during callback execution
fut.result()
except Exception as e:
logger.error(
f'Error in event callback {callback_id} for subscriber {subscriber_id}: {str(e)}',
exc_info=True,
stack_info=True,
)
# Re-raise in the main thread so the error is not swallowed
raise e
return _handle_callback_error
def filtered_events_by_source(self, source: EventSource):
for event in self.get_events():

View File

@ -202,7 +202,7 @@ async def process_issue(
runtime = create_runtime(config)
await runtime.connect()
async def on_event(evt):
def on_event(evt):
logger.info(evt)
runtime.event_stream.subscribe(EventStreamSubscriber.MAIN, on_event, str(uuid4()))

View File

@ -18,7 +18,7 @@ def test_parser_default_values():
assert args.eval_num_workers == 4
assert args.eval_note is None
assert args.llm_config is None
assert args.name == 'default'
assert args.name == ''
assert not args.no_auto_continue