mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Fix CLI and headless after changes to eventstream (#5949)
Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
This commit is contained in:
parent
2ec2f2538f
commit
f3885cadc1
@ -91,7 +91,7 @@ def display_event(event: Event, config: AppConfig):
|
||||
display_confirmation(event.confirmation_state)
|
||||
|
||||
|
||||
async def main():
|
||||
async def main(loop):
|
||||
"""Runs the agent in CLI mode"""
|
||||
|
||||
parser = get_parser()
|
||||
@ -112,7 +112,7 @@ async def main():
|
||||
|
||||
logger.setLevel(logging.WARNING)
|
||||
config = load_app_config(config_file=args.config_file)
|
||||
sid = 'cli'
|
||||
sid = str(uuid4())
|
||||
|
||||
agent_cls: Type[Agent] = Agent.get_cls(config.default_agent)
|
||||
agent_config = config.get_agent_config(config.default_agent)
|
||||
@ -150,7 +150,6 @@ async def main():
|
||||
|
||||
async def prompt_for_next_task():
|
||||
# Run input() in a thread pool to avoid blocking the event loop
|
||||
loop = asyncio.get_event_loop()
|
||||
next_message = await loop.run_in_executor(
|
||||
None, lambda: input('How can I help? >> ')
|
||||
)
|
||||
@ -165,13 +164,12 @@ async def main():
|
||||
event_stream.add_event(action, EventSource.USER)
|
||||
|
||||
async def prompt_for_user_confirmation():
|
||||
loop = asyncio.get_event_loop()
|
||||
user_confirmation = await loop.run_in_executor(
|
||||
None, lambda: input('Confirm action (possible security risk)? (y/n) >> ')
|
||||
)
|
||||
return user_confirmation.lower() == 'y'
|
||||
|
||||
async def on_event(event: Event):
|
||||
async def on_event_async(event: Event):
|
||||
display_event(event, config)
|
||||
if isinstance(event, AgentStateChangedObservation):
|
||||
if event.agent_state in [
|
||||
@ -193,6 +191,9 @@ async def main():
|
||||
ChangeAgentStateAction(AgentState.USER_REJECTED), EventSource.USER
|
||||
)
|
||||
|
||||
def on_event(event: Event) -> None:
|
||||
loop.create_task(on_event_async(event))
|
||||
|
||||
event_stream.subscribe(EventStreamSubscriber.MAIN, on_event, str(uuid4()))
|
||||
|
||||
await runtime.connect()
|
||||
@ -208,7 +209,7 @@ if __name__ == '__main__':
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(main())
|
||||
loop.run_until_complete(main(loop))
|
||||
except KeyboardInterrupt:
|
||||
print('Received keyboard interrupt, shutting down...')
|
||||
except ConnectionRefusedError as e:
|
||||
|
||||
@ -385,7 +385,7 @@ def get_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument(
|
||||
'-n',
|
||||
'--name',
|
||||
default='default',
|
||||
default='',
|
||||
type=str,
|
||||
help='Name for the session',
|
||||
)
|
||||
|
||||
@ -182,7 +182,7 @@ async def run_controller(
|
||||
# init with the provided actions
|
||||
event_stream.add_event(initial_user_action, EventSource.USER)
|
||||
|
||||
async def on_event(event: Event):
|
||||
def on_event(event: Event):
|
||||
if isinstance(event, AgentStateChangedObservation):
|
||||
if event.agent_state == AgentState.AWAITING_USER_INPUT:
|
||||
if exit_on_message:
|
||||
|
||||
@ -229,10 +229,24 @@ class EventStream:
|
||||
for callback_id in callbacks:
|
||||
callback = callbacks[callback_id]
|
||||
pool = self._thread_pools[key][callback_id]
|
||||
pool.submit(callback, event)
|
||||
future = pool.submit(callback, event)
|
||||
future.add_done_callback(self._make_error_handler(callback_id, key))
|
||||
|
||||
def _callback(self, callback: Callable, event: Event):
|
||||
asyncio.run(callback(event))
|
||||
def _make_error_handler(self, callback_id: str, subscriber_id: str):
|
||||
def _handle_callback_error(fut):
|
||||
try:
|
||||
# This will raise any exception that occurred during callback execution
|
||||
fut.result()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'Error in event callback {callback_id} for subscriber {subscriber_id}: {str(e)}',
|
||||
exc_info=True,
|
||||
stack_info=True,
|
||||
)
|
||||
# Re-raise in the main thread so the error is not swallowed
|
||||
raise e
|
||||
|
||||
return _handle_callback_error
|
||||
|
||||
def filtered_events_by_source(self, source: EventSource):
|
||||
for event in self.get_events():
|
||||
|
||||
@ -202,7 +202,7 @@ async def process_issue(
|
||||
runtime = create_runtime(config)
|
||||
await runtime.connect()
|
||||
|
||||
async def on_event(evt):
|
||||
def on_event(evt):
|
||||
logger.info(evt)
|
||||
|
||||
runtime.event_stream.subscribe(EventStreamSubscriber.MAIN, on_event, str(uuid4()))
|
||||
|
||||
@ -18,7 +18,7 @@ def test_parser_default_values():
|
||||
assert args.eval_num_workers == 4
|
||||
assert args.eval_note is None
|
||||
assert args.llm_config is None
|
||||
assert args.name == 'default'
|
||||
assert args.name == ''
|
||||
assert not args.no_auto_continue
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user