reduce the duplication in run_controller (#3217)

This commit is contained in:
Xingyao Wang 2024-08-02 10:12:34 +08:00 committed by GitHub
parent 8b4ad35cda
commit 001195a3ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 23 additions and 46 deletions

View File

@ -118,13 +118,12 @@ def process_instance(
instruction += AGENT_CLS_TO_INST_SUFFIX[agent.__class__.__name__]
# Here's how you can run the agent (similar to the `main` function) and get the final task state
config.max_iterations = metadata.max_iterations
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=instruction,
max_iterations=metadata.max_iterations,
max_budget_per_task=config.max_budget_per_task,
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
agent.__class__.__name__
],

View File

@ -120,12 +120,11 @@ def process_instance(
logger.info(f'Init script result: {init_res}')
# Here's how you can run the agent (similar to the `main` function) and get the final task state
config.max_iterations = metadata.max_iterations
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=instruction,
max_iterations=metadata.max_iterations,
max_budget_per_task=config.max_budget_per_task,
fake_user_response_fn=FAKE_RESPONSES[agent.__class__.__name__],
agent=agent,
sandbox=sandbox,

View File

@ -166,12 +166,11 @@ def process_instance(
sid = instance.test_case_id.replace('/', '__')
# Here's how you can run the agent (similar to the `main` function) and get the final task state
config.max_iterations = metadata.max_iterations
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=instruction,
max_iterations=metadata.max_iterations,
max_budget_per_task=config.max_budget_per_task,
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
agent.__class__.__name__
],

View File

@ -209,13 +209,13 @@ def process_instance(
)
# NOTE: You can actually set slightly different instruction for different agents
instruction += AGENT_CLS_TO_INST_SUFFIX[agent.__class__.__name__]
# Here's how you can run the agent (similar to the `main` function) and get the final task state
config.max_iterations = metadata.max_iterations
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=instruction,
max_iterations=metadata.max_iterations,
max_budget_per_task=config.max_budget_per_task,
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
agent.__class__.__name__
],

View File

@ -67,12 +67,11 @@ def process_instance(
f'NOTE: You should copy the "query" as is into the <execute_browse> tag. DO NOT change ANYTHING in the query.'
)
config.max_iterations = metadata.max_iterations
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=instruction,
max_iterations=metadata.max_iterations,
max_budget_per_task=config.max_budget_per_task,
agent=agent,
sid=env_id,
)

View File

@ -116,12 +116,11 @@ def process_instance(
logger.info(f'Instruction:\n{instruction}', extra={'msg_type': 'OBSERVATION'})
# Here's how you can run the agent (similar to the `main` function) and get the final task state
config.max_iterations = metadata.max_iterations
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=instruction,
max_iterations=metadata.max_iterations,
max_budget_per_task=config.max_budget_per_task,
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
agent.__class__.__name__
],

View File

@ -110,12 +110,11 @@ def process_instance(agent, question_id, question, metadata, reset_logger: bool
# logger.info(f'Instruction:\n{instruction}', extra={'msg_type': 'OBSERVATION'})
# Here's how you can run the agent (similar to the `main` function) and get the final task state
config.max_iterations = metadata.max_iterations
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=instruction,
max_iterations=metadata.max_iterations,
max_budget_per_task=config.max_budget_per_task,
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
agent.__class__.__name__
),

View File

@ -226,12 +226,11 @@ Ok now its time to start solving the question. Good luck!
"""
# Here's how you can run the agent (similar to the `main` function) and get the final task state
config.max_iterations = metadata.max_iterations
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=instruction,
max_iterations=metadata.max_iterations,
max_budget_per_task=config.max_budget_per_task,
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
agent.__class__.__name__
),

View File

@ -179,12 +179,11 @@ def process_instance(
instruction += AGENT_CLS_TO_INST_SUFFIX[agent.__class__.__name__]
# Here's how you can run the agent (similar to the `main` function) and get the final task state
config.max_iterations = metadata.max_iterations
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=instruction,
max_iterations=metadata.max_iterations,
max_budget_per_task=config.max_budget_per_task,
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
agent.__class__.__name__
),

View File

@ -185,12 +185,11 @@ def process_instance(
exit_code, command_output = sandbox.execute('pip install scitools-pyke')
# Here's how you can run the agent (similar to the `main` function) and get the final task state
config.max_iterations = metadata.max_iterations
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=instruction,
max_iterations=metadata.max_iterations,
max_budget_per_task=config.max_budget_per_task,
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
agent.__class__.__name__
),

View File

@ -85,12 +85,11 @@ def process_instance(
}
}
config.max_iterations = metadata.max_iterations
state: State | None = asyncio.run(
run_controller(
config=config,
task_str='PLACEHOLDER_GOAL',
max_iterations=metadata.max_iterations,
max_budget_per_task=config.max_budget_per_task,
runtime_tools_config=runtime_tools_config,
agent=agent,
sandbox=get_sandbox(),

View File

@ -148,12 +148,11 @@ def process_instance(
},
)
config.max_iterations = metadata.max_iterations
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=instruction,
max_iterations=metadata.max_iterations,
max_budget_per_task=config.max_budget_per_task,
fake_user_response_fn=fake_user_response_fn,
agent=agent,
sandbox=sandbox,

View File

@ -155,12 +155,11 @@ def process_instance(instance: Any, metadata: EvalMetadata, reset_logger: bool =
instruction += AGENT_CLS_TO_INST_SUFFIX[agent.__class__.__name__]
# Run the agent
config.max_iterations = metadata.max_iterations
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=instruction,
max_iterations=metadata.max_iterations,
max_budget_per_task=config.max_budget_per_task,
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
agent.__class__.__name__
),

View File

@ -280,12 +280,11 @@ IMPORTANT TIPS:
instruction += AGENT_CLS_TO_INST_SUFFIX[agent.__class__.__name__]
# Here's how you can run the agent (similar to the `main` function) and get the final task state
config.max_iterations = metadata.max_iterations
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=instruction,
max_iterations=metadata.max_iterations,
max_budget_per_task=config.max_budget_per_task,
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
agent.__class__.__name__
],

View File

@ -76,12 +76,11 @@ def process_instance(instance: Any, metadata: EvalMetadata, reset_logger: bool =
# logger.info(f'Instruction:\n{instruction}', extra={'msg_type': 'OBSERVATION'})
# Here's how you can run the agent (similar to the `main` function) and get the final task state
config.max_iterations = metadata.max_iterations
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=instruction,
max_iterations=metadata.max_iterations,
max_budget_per_task=config.max_budget_per_task,
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
agent.__class__.__name__
],

View File

@ -86,12 +86,11 @@ def process_instance(
}
}
config.max_iterations = metadata.max_iterations
state: State | None = asyncio.run(
run_controller(
config=config,
task_str='PLACEHOLDER_GOAL',
max_iterations=metadata.max_iterations,
max_budget_per_task=config.max_budget_per_task,
runtime_tools_config=runtime_tools_config,
agent=agent,
sandbox=get_sandbox(),

View File

@ -40,8 +40,6 @@ def read_task_from_stdin() -> str:
async def run_controller(
config: AppConfig,
task_str: str,
max_iterations: int | None = None,
max_budget_per_task: float | None = None,
exit_on_message: bool = False,
fake_user_response_fn: Callable[[State | None], str] | None = None,
sandbox: Sandbox | None = None,
@ -56,8 +54,6 @@ async def run_controller(
Args:
config: The app config.
task_str: The task to run.
max_iterations: The maximum number of iterations to run.
max_budget_per_task: The maximum budget per task.
exit_on_message: quit if agent asks for a message from user (optional)
fake_user_response_fn: An optional function that receives the current state (could be None) and returns a fake user response.
sandbox: (will be deprecated) An optional sandbox to run the agent in.
@ -72,7 +68,6 @@ async def run_controller(
agent = agent_cls(
llm=LLM(config=config.get_llm_config_from_agent(config.default_agent))
)
max_iterations = max_iterations or config.max_iterations
# Logging
logger.info(
@ -96,8 +91,8 @@ async def run_controller(
# init controller with this initial state
controller = AgentController(
agent=agent,
max_iterations=max_iterations,
max_budget_per_task=max_budget_per_task,
max_iterations=config.max_iterations,
max_budget_per_task=config.max_budget_per_task,
agent_to_llm_config=config.get_agent_to_llm_config_map(),
event_stream=event_stream,
initial_state=initial_state,
@ -212,17 +207,14 @@ if __name__ == '__main__':
config.default_agent = args.agent_cls
# if max budget per task is not sent on the command line, use the config value
max_budget_per_task = (
args.max_budget_per_task
if args.max_budget_per_task
else config.max_budget_per_task
)
if args.max_budget_per_task is not None:
config.max_budget_per_task = args.max_budget_per_task
if args.max_iterations is not None:
config.max_iterations = args.max_iterations
asyncio.run(
run_controller(
config=config,
task_str=task_str,
max_iterations=args.max_iterations,
max_budget_per_task=args.max_budget_per_task,
)
)