Support multiline and default user messages (#5400)

This commit is contained in:
Engel Nyst 2024-12-05 21:03:18 +01:00 committed by GitHub
parent c3ddb26e43
commit 1146b6248b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 27 additions and 2 deletions

View File

@ -376,6 +376,11 @@ def get_parser() -> argparse.ArgumentParser:
type=str,
help='The comma-separated list (in quotes) of IDs of the instances to evaluate',
)
parser.add_argument(
'--no-auto-continue',
action='store_true',
help='Disable automatic "continue" responses. Will read from stdin instead.',
)
return parser

View File

@ -188,7 +188,9 @@ async def run_controller(
if exit_on_message:
message = '/exit'
elif fake_user_response_fn is None:
message = input('Request user input >> ')
# read until EOF (Ctrl+D on Unix, Ctrl+Z on Windows)
print('Request user input (press Ctrl+D/Z when done) >> ')
message = sys.stdin.read().rstrip()
else:
message = fake_user_response_fn(controller.get_state())
action = MessageAction(content=message)
@ -241,6 +243,17 @@ def generate_sid(config: AppConfig, session_name: str | None = None) -> str:
return f'{session_name}-{hash_str[:16]}'
def auto_continue_response(
state: State,
encapsulate_solution: bool = False,
try_parse: Callable[[Action | None], str] | None = None,
) -> str:
"""Default function to generate user responses.
Returns 'continue' to tell the agent to proceed without asking for more input.
"""
return 'continue'
if __name__ == '__main__':
args = parse_arguments()
@ -284,5 +297,8 @@ if __name__ == '__main__':
config=config,
initial_user_action=initial_user_action,
sid=sid,
fake_user_response_fn=None
if args.no_auto_continue
else auto_continue_response,
)
)

View File

@ -19,6 +19,7 @@ def test_parser_default_values():
assert args.eval_note is None
assert args.llm_config is None
assert args.name == 'default'
assert not args.no_auto_continue
def test_parser_custom_values():
@ -49,6 +50,7 @@ def test_parser_custom_values():
'gpt4',
'-n',
'test_session',
'--no-auto-continue',
]
)
@ -64,6 +66,7 @@ def test_parser_custom_values():
assert args.eval_note == 'Test run'
assert args.llm_config == 'gpt4'
assert args.name == 'test_session'
assert args.no_auto_continue
def test_parser_file_overrides_task():
@ -124,10 +127,11 @@ def test_help_message(capsys):
'-l LLM_CONFIG, --llm-config LLM_CONFIG',
'-n NAME, --name NAME',
'--config-file CONFIG_FILE',
'--no-auto-continue',
]
for element in expected_elements:
assert element in help_output, f"Expected '{element}' to be in the help message"
option_count = help_output.count(' -')
assert option_count == 15, f'Expected 15 options, found {option_count}'
assert option_count == 16, f'Expected 16 options, found {option_count}'