mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
(refactor) Make Runtime class synchronous (#3661)
* change runtime to be synchronous * fix test runtime with the new interface * fix arg * fix eval * fix missing config attribute * fix plugins * fix on_event by revert it back to async * update upload_file endpoint * fix argument to upload file * remove unncessary async for eval; fix evaluation run in parallel * use asyncio to run controller for eval * revert file upload * truncate eval test result output
This commit is contained in:
parent
b0e52f121c
commit
090c911a50
@ -126,7 +126,7 @@ To create an evaluation workflow for your benchmark, follow these steps:
|
||||
|
||||
3. Initialize the runtime and set up the evaluation environment:
|
||||
```python
|
||||
async def initialize_runtime(runtime: Runtime, instance: pd.Series):
|
||||
def initialize_runtime(runtime: Runtime, instance: pd.Series):
|
||||
# Set up your evaluation environment here
|
||||
# For example, setting environment variables, preparing files, etc.
|
||||
pass
|
||||
@ -134,14 +134,14 @@ To create an evaluation workflow for your benchmark, follow these steps:
|
||||
|
||||
4. Create a function to process each instance:
|
||||
```python
|
||||
async def process_instance(instance: pd.Series, metadata: EvalMetadata) -> EvalOutput:
|
||||
def process_instance(instance: pd.Series, metadata: EvalMetadata) -> EvalOutput:
|
||||
config = get_config(instance, metadata)
|
||||
runtime = await create_runtime(config, sid=instance.instance_id)
|
||||
await initialize_runtime(runtime, instance)
|
||||
runtime = create_runtime(config, sid=instance.instance_id)
|
||||
initialize_runtime(runtime, instance)
|
||||
|
||||
instruction = get_instruction(instance, metadata)
|
||||
|
||||
state = await run_controller(
|
||||
state = run_controller(
|
||||
config=config,
|
||||
task_str=instruction,
|
||||
runtime=runtime,
|
||||
|
||||
@ -74,7 +74,7 @@ def get_config(
|
||||
return config
|
||||
|
||||
|
||||
async def process_instance(
|
||||
def process_instance(
|
||||
instance: pd.Series,
|
||||
metadata: EvalMetadata,
|
||||
reset_logger: bool = True,
|
||||
@ -117,13 +117,17 @@ async def process_instance(
|
||||
instruction += AGENT_CLS_TO_INST_SUFFIX[metadata.agent_class]
|
||||
|
||||
# Here's how you can run the agent (similar to the `main` function) and get the final task state
|
||||
runtime = await create_runtime(config, sid=instance['text'].strip())
|
||||
runtime = create_runtime(config, sid=instance['text'].strip())
|
||||
|
||||
state: State | None = await run_controller(
|
||||
config=config,
|
||||
task_str=instruction,
|
||||
runtime=runtime,
|
||||
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[metadata.agent_class],
|
||||
state: State | None = asyncio.run(
|
||||
run_controller(
|
||||
config=config,
|
||||
task_str=instruction,
|
||||
runtime=runtime,
|
||||
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
|
||||
metadata.agent_class
|
||||
],
|
||||
)
|
||||
)
|
||||
# ======= Attempt to evaluate the agent's edits =======
|
||||
# If you are working on simpler benchmark that only evaluates the final model output (e.g., in a MessageAction)
|
||||
@ -214,12 +218,10 @@ if __name__ == '__main__':
|
||||
eda_dataset.to_pandas(), output_file, args.eval_n_limit
|
||||
)
|
||||
|
||||
asyncio.run(
|
||||
run_evaluation(
|
||||
prepared_dataset,
|
||||
metadata,
|
||||
output_file,
|
||||
args.eval_num_workers,
|
||||
process_instance,
|
||||
)
|
||||
run_evaluation(
|
||||
prepared_dataset,
|
||||
metadata,
|
||||
output_file,
|
||||
args.eval_num_workers,
|
||||
process_instance,
|
||||
)
|
||||
|
||||
@ -56,7 +56,7 @@ def get_config(
|
||||
return config
|
||||
|
||||
|
||||
async def initialize_runtime(
|
||||
def initialize_runtime(
|
||||
runtime: Runtime,
|
||||
instance: pd.Series, # this argument is not required
|
||||
):
|
||||
@ -70,12 +70,12 @@ async def initialize_runtime(
|
||||
# Set instance id
|
||||
action = CmdRunAction(command='mkdir -p /workspace')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
assert obs.exit_code == 0
|
||||
|
||||
action = CmdRunAction(command='cd /workspace')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
assert obs.exit_code == 0
|
||||
|
||||
init_cmd = instance.init
|
||||
@ -85,7 +85,7 @@ async def initialize_runtime(
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
host_script_path = os.path.join(tmpdir, script_name)
|
||||
create_sh_file(host_script_path, init_cmd)
|
||||
await runtime.copy_to(
|
||||
runtime.copy_to(
|
||||
host_script_path,
|
||||
'/workspace',
|
||||
)
|
||||
@ -93,14 +93,14 @@ async def initialize_runtime(
|
||||
logger.info(f'Running init script: {script_name}')
|
||||
action = CmdRunAction(command=f'chmod +x ./{script_name} && ./{script_name}')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
|
||||
logger.info(f"{'-' * 50} END Runtime Initialization Fn {'-' * 50}")
|
||||
|
||||
|
||||
async def complete_runtime(
|
||||
def complete_runtime(
|
||||
runtime: Runtime,
|
||||
instance: pd.Series, # this argument is not required, but it is used to get the workspace_dir_name
|
||||
) -> dict[str, Any]:
|
||||
@ -121,7 +121,7 @@ async def complete_runtime(
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
host_script_path = os.path.join(tmpdir, script_name)
|
||||
create_sh_file(host_script_path, get_agent_result_cmd)
|
||||
await runtime.copy_to(
|
||||
runtime.copy_to(
|
||||
host_script_path,
|
||||
'/workspace',
|
||||
)
|
||||
@ -132,7 +132,7 @@ async def complete_runtime(
|
||||
keep_prompt=False,
|
||||
)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
agent_answer = obs.content
|
||||
@ -149,7 +149,7 @@ async def complete_runtime(
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
host_script_path = os.path.join(tmpdir, script_name)
|
||||
create_sh_file(host_script_path, get_ground_truth_cmd)
|
||||
await runtime.copy_to(
|
||||
runtime.copy_to(
|
||||
host_script_path,
|
||||
'/workspace',
|
||||
)
|
||||
@ -160,7 +160,7 @@ async def complete_runtime(
|
||||
keep_prompt=False,
|
||||
)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
final_ans = obs.content
|
||||
|
||||
@ -171,7 +171,7 @@ async def complete_runtime(
|
||||
}
|
||||
|
||||
|
||||
async def process_instance(
|
||||
def process_instance(
|
||||
instance: pd.Series,
|
||||
metadata: EvalMetadata,
|
||||
reset_logger: bool = True,
|
||||
@ -209,16 +209,18 @@ async def process_instance(
|
||||
# create sandbox and run the agent
|
||||
# =============================================
|
||||
|
||||
runtime: Runtime = await create_runtime(config, sid=instance.instance_id)
|
||||
runtime: Runtime = create_runtime(config, sid=instance.instance_id)
|
||||
|
||||
await initialize_runtime(runtime, instance=instance)
|
||||
initialize_runtime(runtime, instance=instance)
|
||||
|
||||
# Here's how you can run the agent (similar to the `main` function) and get the final task state
|
||||
state: State | None = await run_controller(
|
||||
config=config,
|
||||
task_str=instruction,
|
||||
runtime=runtime,
|
||||
fake_user_response_fn=FAKE_RESPONSES[metadata.agent_class],
|
||||
state: State | None = asyncio.run(
|
||||
run_controller(
|
||||
config=config,
|
||||
task_str=instruction,
|
||||
runtime=runtime,
|
||||
fake_user_response_fn=FAKE_RESPONSES[metadata.agent_class],
|
||||
)
|
||||
)
|
||||
if state is None:
|
||||
raise ValueError('State should not be None.')
|
||||
@ -227,7 +229,7 @@ async def process_instance(
|
||||
# result evaluation
|
||||
# =============================================
|
||||
|
||||
return_val = await complete_runtime(runtime, instance)
|
||||
return_val = complete_runtime(runtime, instance)
|
||||
agent_answer = return_val['agent_answer']
|
||||
final_ans = return_val['final_ans']
|
||||
|
||||
@ -313,8 +315,6 @@ if __name__ == '__main__':
|
||||
output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
|
||||
instances = prepare_dataset(agent_bench_tests, output_file, args.eval_n_limit)
|
||||
|
||||
asyncio.run(
|
||||
run_evaluation(
|
||||
instances, metadata, output_file, args.eval_num_workers, process_instance
|
||||
)
|
||||
run_evaluation(
|
||||
instances, metadata, output_file, args.eval_num_workers, process_instance
|
||||
)
|
||||
|
||||
@ -62,7 +62,7 @@ def get_config(
|
||||
return config
|
||||
|
||||
|
||||
async def initialize_runtime(
|
||||
def initialize_runtime(
|
||||
runtime: Runtime,
|
||||
instance: pd.Series,
|
||||
):
|
||||
@ -76,19 +76,19 @@ async def initialize_runtime(
|
||||
# Set instance id
|
||||
action = CmdRunAction(command='mkdir -p /workspace')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
assert obs.exit_code == 0
|
||||
|
||||
action = CmdRunAction(command='cd /workspace')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
assert obs.exit_code == 0
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
file_path = os.path.join(tmpdir, f'{instance.instance_name}.py')
|
||||
with open(file_path, 'w') as f:
|
||||
f.write(instance.signature)
|
||||
await runtime.copy_to(
|
||||
runtime.copy_to(
|
||||
file_path,
|
||||
'/workspace',
|
||||
)
|
||||
@ -96,14 +96,14 @@ async def initialize_runtime(
|
||||
file_path = os.path.join(tmpdir, f'{instance.instance_name}_test.py')
|
||||
with open(file_path, 'w') as f:
|
||||
f.write(instance.test)
|
||||
await runtime.copy_to(
|
||||
runtime.copy_to(
|
||||
file_path,
|
||||
'/workspace',
|
||||
)
|
||||
logger.info(f"\n{'-' * 50} END Runtime Initialization Fn {'-' * 50}\n")
|
||||
|
||||
|
||||
async def complete_runtime(
|
||||
def complete_runtime(
|
||||
runtime: Runtime,
|
||||
instance: pd.Series,
|
||||
) -> dict[str, Any]:
|
||||
@ -122,7 +122,7 @@ async def complete_runtime(
|
||||
file_path = os.path.join(tmpdir, script_name)
|
||||
with open(file_path, 'w') as f:
|
||||
f.write(instance.test)
|
||||
await runtime.copy_to(
|
||||
runtime.copy_to(
|
||||
file_path,
|
||||
'/workspace',
|
||||
)
|
||||
@ -133,7 +133,7 @@ async def complete_runtime(
|
||||
keep_prompt=False,
|
||||
)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
|
||||
exit_code = 1
|
||||
@ -142,7 +142,7 @@ async def complete_runtime(
|
||||
|
||||
logger.info(f"\n{'-' * 50} END Runtime Completion Fn {'-' * 50}\n")
|
||||
|
||||
await runtime.close()
|
||||
runtime.close()
|
||||
|
||||
return {
|
||||
'test_output': obs.content,
|
||||
@ -150,7 +150,7 @@ async def complete_runtime(
|
||||
}
|
||||
|
||||
|
||||
async def process_instance(
|
||||
def process_instance(
|
||||
instance: pd.Series,
|
||||
metadata: EvalMetadata,
|
||||
reset_logger: bool = True,
|
||||
@ -193,16 +193,18 @@ async def process_instance(
|
||||
# create sandbox and run the agent
|
||||
# =============================================
|
||||
|
||||
runtime: Runtime = await create_runtime(config, sid=str(instance.instance_id))
|
||||
runtime: Runtime = create_runtime(config, sid=str(instance.instance_id))
|
||||
|
||||
await initialize_runtime(runtime, instance=instance)
|
||||
initialize_runtime(runtime, instance=instance)
|
||||
|
||||
# Here's how you can run the agent (similar to the `main` function) and get the final task state
|
||||
state: State | None = await run_controller(
|
||||
config=config,
|
||||
task_str=instruction,
|
||||
runtime=runtime,
|
||||
fake_user_response_fn=FAKE_RESPONSES[metadata.agent_class],
|
||||
state: State | None = asyncio.run(
|
||||
run_controller(
|
||||
config=config,
|
||||
task_str=instruction,
|
||||
runtime=runtime,
|
||||
fake_user_response_fn=FAKE_RESPONSES[metadata.agent_class],
|
||||
)
|
||||
)
|
||||
if state is None:
|
||||
raise ValueError('State should not be None.')
|
||||
@ -211,7 +213,7 @@ async def process_instance(
|
||||
# # result evaluation
|
||||
# # =============================================
|
||||
|
||||
return_val = await complete_runtime(runtime, instance)
|
||||
return_val = complete_runtime(runtime, instance)
|
||||
exit_code = return_val['exit_code']
|
||||
test_output = return_val['test_output']
|
||||
|
||||
@ -286,12 +288,10 @@ if __name__ == '__main__':
|
||||
skip_num=SKIP_NUM,
|
||||
)
|
||||
|
||||
asyncio.run(
|
||||
run_evaluation(
|
||||
instances,
|
||||
metadata,
|
||||
output_file,
|
||||
args.eval_num_workers,
|
||||
process_instance,
|
||||
)
|
||||
run_evaluation(
|
||||
instances,
|
||||
metadata,
|
||||
output_file,
|
||||
args.eval_num_workers,
|
||||
process_instance,
|
||||
)
|
||||
|
||||
@ -74,7 +74,7 @@ def get_config(
|
||||
return config
|
||||
|
||||
|
||||
async def initialize_runtime(
|
||||
def initialize_runtime(
|
||||
runtime: Runtime,
|
||||
instance: BiocoderData, # this argument is not required
|
||||
):
|
||||
@ -89,19 +89,19 @@ async def initialize_runtime(
|
||||
|
||||
action = CmdRunAction(command='mkdir -p /workspace && mkdir -p /testing_files')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
assert obs.exit_code == 0
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
context_path = os.path.join(tmpdir, 'context.' + file_ext)
|
||||
with open(context_path, 'w') as f:
|
||||
f.write(instance.contextCode)
|
||||
await runtime.copy_to(context_path, '/testing_files')
|
||||
runtime.copy_to(context_path, '/testing_files')
|
||||
|
||||
golden_path = os.path.join(tmpdir, 'golden.' + file_ext)
|
||||
with open(golden_path, 'w') as f:
|
||||
f.write(instance.goldenCode)
|
||||
await runtime.copy_to(golden_path, '/testing_files')
|
||||
runtime.copy_to(golden_path, '/testing_files')
|
||||
|
||||
testcase_json = {
|
||||
'test_case_id': instance.test_case_id,
|
||||
@ -112,36 +112,36 @@ async def initialize_runtime(
|
||||
with open(testcase_path, 'w') as f:
|
||||
f.write(json.dumps(testcase_json, indent=4))
|
||||
|
||||
await runtime.copy_to(testcase_path, '/testing_files')
|
||||
runtime.copy_to(testcase_path, '/testing_files')
|
||||
|
||||
# setup paths
|
||||
remove_code_script = os.path.join(
|
||||
os.path.dirname(__file__), 'scripts', 'setup', 'remove_code.py'
|
||||
)
|
||||
await runtime.copy_to(remove_code_script, '/testing_files')
|
||||
runtime.copy_to(remove_code_script, '/testing_files')
|
||||
|
||||
action = CmdRunAction(command='cd /workspace')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
assert obs.exit_code == 0
|
||||
|
||||
# download repository archive
|
||||
repository_url = f"https://biocoder.lilbillbiscuit.com/repos/{instance.repository.split('/')[1]}.zip"
|
||||
action = CmdRunAction(command='wget -O repo.zip ' + repository_url)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
assert obs.exit_code == 0, f'Failed to download the repository: {obs.content}'
|
||||
|
||||
# unzip the repository
|
||||
action = CmdRunAction(command='unzip -o -q repo.zip && rm repo.zip')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
assert obs.exit_code == 0, f'Failed to unzip the repository: {obs.content}'
|
||||
|
||||
# chmod 777
|
||||
action = CmdRunAction(command='chmod -R 777 /workspace')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
assert obs.exit_code == 0, f'Failed to chmod the files: {obs.content}'
|
||||
|
||||
# remove code for evaluation instance
|
||||
@ -155,13 +155,13 @@ async def initialize_runtime(
|
||||
command=f'python3 /testing_files/remove_code.py --target_filepath {target_filepath} --line_start {line_start} --line_end {line_end} --language {language}'
|
||||
)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
assert obs.exit_code == 0, f'Failed to remove the code: {obs.content}'
|
||||
|
||||
logger.info(f"{'-' * 50} END Runtime Initialization Fn {'-' * 50}")
|
||||
|
||||
|
||||
async def complete_runtime(
|
||||
def complete_runtime(
|
||||
runtime: Runtime,
|
||||
instance: pd.Series, # this argument is not required, but it is used to get the workspace_dir_name
|
||||
) -> dict[str, Any]:
|
||||
@ -179,7 +179,7 @@ async def complete_runtime(
|
||||
copy_changed_code_script = os.path.join(
|
||||
os.path.dirname(__file__), 'scripts', 'setup', 'copy_changed_code.py'
|
||||
)
|
||||
await runtime.copy_to(copy_changed_code_script, '/testing_files')
|
||||
runtime.copy_to(copy_changed_code_script, '/testing_files')
|
||||
|
||||
file_ext = FILE_EXT_MAP[instance.language.lower()]
|
||||
target_filepath = os.path.join(
|
||||
@ -191,13 +191,13 @@ async def complete_runtime(
|
||||
command=f'python3 /testing_files/copy_changed_code.py --target_filepath {target_filepath} --generated_code_filepath {generated_path} --line_start {instance.lineStart} --include_signature'
|
||||
)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
if obs.exit_code == 0:
|
||||
test_result['metadata']['1_copy_change_success'] = True
|
||||
|
||||
action = CmdRunAction(command=f'cat {generated_path}', keep_prompt=False)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
assert obs.exit_code == 0
|
||||
|
||||
code = obs.content
|
||||
@ -208,14 +208,14 @@ async def complete_runtime(
|
||||
|
||||
action = CmdRunAction(command='cd /testing_files')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
assert obs.exit_code == 0
|
||||
|
||||
action = CmdRunAction(
|
||||
command='/home/openhands/mambaforge/bin/mamba run -n test python3 /testing/start_test_openhands.py'
|
||||
)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
|
||||
@ -223,7 +223,7 @@ async def complete_runtime(
|
||||
command='cat /testing_files/results_biocoder.json', keep_prompt=False
|
||||
)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
if obs.exit_code == 0:
|
||||
test_result['metadata']['2_run_test_success'] = True
|
||||
test_result['metadata']['2_run_test_result'] = str(obs.content)
|
||||
@ -237,7 +237,7 @@ async def complete_runtime(
|
||||
return test_result
|
||||
|
||||
|
||||
async def process_instance(
|
||||
def process_instance(
|
||||
instance: pd.Series,
|
||||
metadata: EvalMetadata,
|
||||
reset_logger: bool = True,
|
||||
@ -277,22 +277,26 @@ async def process_instance(
|
||||
# use a session id for concurrent evaluation
|
||||
sid = instance.instance_id.replace('/', '__')
|
||||
|
||||
runtime = await create_runtime(config, sid=sid)
|
||||
runtime = create_runtime(config, sid=sid)
|
||||
|
||||
await initialize_runtime(runtime, instance)
|
||||
initialize_runtime(runtime, instance)
|
||||
|
||||
# Here's how you can run the agent (similar to the `main` function) and get the final task state
|
||||
state: State | None = await run_controller(
|
||||
config=config,
|
||||
task_str=instruction,
|
||||
runtime=runtime,
|
||||
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[metadata.agent_class],
|
||||
state: State | None = asyncio.run(
|
||||
run_controller(
|
||||
config=config,
|
||||
task_str=instruction,
|
||||
runtime=runtime,
|
||||
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
|
||||
metadata.agent_class
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
if state is None:
|
||||
raise ValueError('State should not be None.')
|
||||
|
||||
test_result = await complete_runtime(runtime, instance)
|
||||
test_result = complete_runtime(runtime, instance)
|
||||
metrics = state.metrics.get() if state.metrics else None
|
||||
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
|
||||
# for compatibility with the existing output format, we can remake the pairs here
|
||||
@ -340,8 +344,6 @@ if __name__ == '__main__':
|
||||
output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
|
||||
instances = prepare_dataset(biocoder_tests, output_file, args.eval_n_limit)
|
||||
|
||||
asyncio.run(
|
||||
run_evaluation(
|
||||
instances, metadata, output_file, args.eval_num_workers, process_instance
|
||||
)
|
||||
run_evaluation(
|
||||
instances, metadata, output_file, args.eval_num_workers, process_instance
|
||||
)
|
||||
|
||||
@ -242,7 +242,7 @@ def load_bird():
|
||||
return bird_dataset
|
||||
|
||||
|
||||
async def initialize_runtime(
|
||||
def initialize_runtime(
|
||||
runtime: Runtime,
|
||||
instance: pd.Series, # this argument is not required
|
||||
):
|
||||
@ -261,14 +261,14 @@ async def initialize_runtime(
|
||||
instance.db_id,
|
||||
f'{instance.db_id}.sqlite',
|
||||
)
|
||||
await runtime.copy_to(db_file, '/workspace')
|
||||
runtime.copy_to(db_file, '/workspace')
|
||||
|
||||
# Check the database is copied
|
||||
action = CmdRunAction(
|
||||
command='cd /workspace && ls -l',
|
||||
keep_prompt=False,
|
||||
)
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
assert f'{instance.db_id}.sqlite' in obs.content
|
||||
@ -276,7 +276,7 @@ async def initialize_runtime(
|
||||
logger.info(f"{'-' * 50} END Runtime Initialization Fn {'-' * 50}")
|
||||
|
||||
|
||||
async def complete_runtime(
|
||||
def complete_runtime(
|
||||
runtime: Runtime,
|
||||
instance: pd.Series, # this argument is not required, but it is used to get the workspace_dir_name
|
||||
) -> dict[str, Any]:
|
||||
@ -300,7 +300,7 @@ async def complete_runtime(
|
||||
command=f'cat {path}',
|
||||
keep_prompt=False,
|
||||
)
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
|
||||
if obs.exit_code != 0:
|
||||
@ -350,7 +350,7 @@ async def complete_runtime(
|
||||
return test_result
|
||||
|
||||
|
||||
async def process_instance(
|
||||
def process_instance(
|
||||
instance: pd.Series,
|
||||
metadata: EvalMetadata,
|
||||
reset_logger: bool = True,
|
||||
@ -402,19 +402,23 @@ async def process_instance(
|
||||
# NOTE: You can actually set slightly different instruction for different agents
|
||||
instruction += AGENT_CLS_TO_INST_SUFFIX[metadata.agent_class]
|
||||
|
||||
runtime = await create_runtime(config, sid=instance_id)
|
||||
await initialize_runtime(runtime, instance)
|
||||
runtime = create_runtime(config, sid=instance_id)
|
||||
initialize_runtime(runtime, instance)
|
||||
|
||||
# Here's how you can run the agent (similar to the `main` function) and get the final task state
|
||||
state: State | None = await run_controller(
|
||||
config=config,
|
||||
task_str=instruction,
|
||||
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[metadata.agent_class],
|
||||
runtime=runtime,
|
||||
state: State | None = asyncio.run(
|
||||
run_controller(
|
||||
config=config,
|
||||
task_str=instruction,
|
||||
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
|
||||
metadata.agent_class
|
||||
],
|
||||
runtime=runtime,
|
||||
)
|
||||
)
|
||||
|
||||
# ======= Attempt to evaluate the agent's edits =======
|
||||
test_result = await complete_runtime(runtime, instance)
|
||||
test_result = complete_runtime(runtime, instance)
|
||||
|
||||
# If you are working on some simpler benchmark that only evaluates the final model output (e.g., in a MessageAction)
|
||||
# You can simply get the LAST `MessageAction` from the returned `state.history` and parse it for evaluation.
|
||||
@ -463,8 +467,6 @@ if __name__ == '__main__':
|
||||
output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
|
||||
instances = prepare_dataset(dataset, output_file, args.eval_n_limit)
|
||||
|
||||
asyncio.run(
|
||||
run_evaluation(
|
||||
instances, metadata, output_file, args.eval_num_workers, process_instance
|
||||
)
|
||||
run_evaluation(
|
||||
instances, metadata, output_file, args.eval_num_workers, process_instance
|
||||
)
|
||||
|
||||
@ -51,7 +51,7 @@ def get_config(
|
||||
return config
|
||||
|
||||
|
||||
async def process_instance(
|
||||
def process_instance(
|
||||
instance: pd.Series,
|
||||
metadata: EvalMetadata,
|
||||
reset_logger: bool = True,
|
||||
@ -71,12 +71,14 @@ async def process_instance(
|
||||
f'NOTE: You should copy the "query" as is into the <execute_browse> tag. DO NOT change ANYTHING in the query.'
|
||||
)
|
||||
|
||||
runtime = await create_runtime(config, sid=instance.instance_id)
|
||||
runtime = create_runtime(config, sid=instance.instance_id)
|
||||
|
||||
state: State | None = await run_controller(
|
||||
config=config,
|
||||
task_str=instruction,
|
||||
runtime=runtime,
|
||||
state: State | None = asyncio.run(
|
||||
run_controller(
|
||||
config=config,
|
||||
task_str=instruction,
|
||||
runtime=runtime,
|
||||
)
|
||||
)
|
||||
|
||||
if state is None:
|
||||
@ -158,12 +160,10 @@ if __name__ == '__main__':
|
||||
|
||||
output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
|
||||
instances = prepare_dataset(dataset, output_file, args.eval_n_limit)
|
||||
asyncio.run(
|
||||
run_evaluation(
|
||||
instances,
|
||||
metadata,
|
||||
output_file,
|
||||
args.eval_num_workers,
|
||||
process_instance,
|
||||
)
|
||||
run_evaluation(
|
||||
instances,
|
||||
metadata,
|
||||
output_file,
|
||||
args.eval_num_workers,
|
||||
process_instance,
|
||||
)
|
||||
|
||||
@ -63,7 +63,7 @@ def get_config(
|
||||
return config
|
||||
|
||||
|
||||
async def initialize_runtime(
|
||||
def initialize_runtime(
|
||||
runtime: Runtime,
|
||||
instance: pd.Series, # this argument is not required
|
||||
):
|
||||
@ -76,7 +76,7 @@ async def initialize_runtime(
|
||||
|
||||
action = CmdRunAction(command='mkdir -p /workspace')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
assert obs.exit_code == 0
|
||||
|
||||
if instance['file_name'] != '':
|
||||
@ -87,7 +87,7 @@ async def initialize_runtime(
|
||||
)
|
||||
assert os.path.exists(src_file)
|
||||
dest_file = os.path.join('/workspace', instance['file_name'])
|
||||
await runtime.copy_to(src_file, dest_file)
|
||||
runtime.copy_to(src_file, dest_file)
|
||||
|
||||
# rename to file.extension_name
|
||||
extension_name = instance['file_name'].split('.')[-1]
|
||||
@ -95,18 +95,18 @@ async def initialize_runtime(
|
||||
command=f'mv /workspace/{instance["file_name"]} /workspace/file.{extension_name}'
|
||||
)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
assert obs.exit_code == 0
|
||||
|
||||
action = CmdRunAction(command='cd /workspace')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
assert obs.exit_code == 0
|
||||
|
||||
logger.info(f"{'-' * 50} END Runtime Initialization Fn {'-' * 50}")
|
||||
|
||||
|
||||
async def process_instance(
|
||||
def process_instance(
|
||||
instance: pd.Series,
|
||||
metadata: EvalMetadata,
|
||||
reset_logger: bool = True,
|
||||
@ -141,15 +141,19 @@ async def process_instance(
|
||||
instruction += AGENT_CLS_TO_INST_SUFFIX.get(metadata.agent_class, '')
|
||||
logger.info(f'Instruction:\n{instruction}', extra={'msg_type': 'OBSERVATION'})
|
||||
|
||||
runtime = await create_runtime(config, sid=instance['instance_id'])
|
||||
await initialize_runtime(runtime, instance)
|
||||
runtime = create_runtime(config, sid=instance['instance_id'])
|
||||
initialize_runtime(runtime, instance)
|
||||
|
||||
# Here's how you can run the agent (similar to the `main` function) and get the final task state
|
||||
state: State | None = await run_controller(
|
||||
config=config,
|
||||
task_str=instruction,
|
||||
runtime=runtime,
|
||||
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[metadata.agent_class],
|
||||
state: State | None = asyncio.run(
|
||||
run_controller(
|
||||
config=config,
|
||||
task_str=instruction,
|
||||
runtime=runtime,
|
||||
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
|
||||
metadata.agent_class
|
||||
],
|
||||
)
|
||||
)
|
||||
# ======= Attempt to evaluate the agent's edits =======
|
||||
# If you are working on simpler benchmark that only evaluates the final model output (e.g., in a MessageAction)
|
||||
@ -257,12 +261,10 @@ if __name__ == '__main__':
|
||||
output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
|
||||
prepared_dataset = prepare_dataset(gaia_tests, output_file, args.eval_n_limit)
|
||||
|
||||
asyncio.run(
|
||||
run_evaluation(
|
||||
dataset=prepared_dataset,
|
||||
metadata=metadata,
|
||||
output_file=output_file,
|
||||
num_workers=args.eval_num_workers,
|
||||
process_instance_func=process_instance,
|
||||
)
|
||||
run_evaluation(
|
||||
dataset=prepared_dataset,
|
||||
metadata=metadata,
|
||||
output_file=output_file,
|
||||
num_workers=args.eval_num_workers,
|
||||
process_instance_func=process_instance,
|
||||
)
|
||||
|
||||
@ -55,7 +55,7 @@ def get_config(
|
||||
return config
|
||||
|
||||
|
||||
async def process_instance(
|
||||
def process_instance(
|
||||
instance: pd.Series,
|
||||
metadata: EvalMetadata,
|
||||
reset_logger: bool = True,
|
||||
@ -79,14 +79,16 @@ async def process_instance(
|
||||
# logger.info(f'Instruction:\n{instruction}', extra={'msg_type': 'OBSERVATION'})
|
||||
|
||||
# Here's how you can run the agent (similar to the `main` function) and get the final task state
|
||||
runtime = await create_runtime(config, sid=instance_id)
|
||||
state: State | None = await run_controller(
|
||||
config=config,
|
||||
task_str=instruction,
|
||||
runtime=runtime,
|
||||
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
|
||||
metadata.agent_class
|
||||
),
|
||||
runtime = create_runtime(config, sid=instance_id)
|
||||
state: State | None = asyncio.run(
|
||||
run_controller(
|
||||
config=config,
|
||||
task_str=instruction,
|
||||
runtime=runtime,
|
||||
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
|
||||
metadata.agent_class
|
||||
),
|
||||
)
|
||||
)
|
||||
# ======= Attempt to evaluate the agent's edits =======
|
||||
# If you are working on simpler benchmark that only evaluates the final model output (e.g., in a MessageAction)
|
||||
@ -179,14 +181,12 @@ if __name__ == '__main__':
|
||||
else:
|
||||
print('File already exists, skipping download.')
|
||||
|
||||
asyncio.run(
|
||||
run_evaluation(
|
||||
dataset=dataset,
|
||||
metadata=metadata,
|
||||
output_file=output_file,
|
||||
num_workers=args.eval_num_workers,
|
||||
process_instance_func=process_instance,
|
||||
)
|
||||
run_evaluation(
|
||||
dataset=dataset,
|
||||
metadata=metadata,
|
||||
output_file=output_file,
|
||||
num_workers=args.eval_num_workers,
|
||||
process_instance_func=process_instance,
|
||||
)
|
||||
|
||||
# Read the output file and calculate the accuracy
|
||||
|
||||
@ -169,7 +169,7 @@ def convert_instance_dict(instance):
|
||||
return out_instance_dict
|
||||
|
||||
|
||||
async def process_instance(
|
||||
def process_instance(
|
||||
instance: pd.Series,
|
||||
metadata: EvalMetadata,
|
||||
reset_logger: bool = True,
|
||||
@ -214,15 +214,17 @@ Again do not quit without reporting the answer first.
|
||||
Ok now its time to start solving the question. Good luck!
|
||||
"""
|
||||
|
||||
runtime = await create_runtime(config, sid=f'gptq_{str(instance.instance_id)}')
|
||||
runtime = create_runtime(config, sid=f'gptq_{str(instance.instance_id)}')
|
||||
|
||||
state: State | None = await run_controller(
|
||||
config=config,
|
||||
task_str=instruction,
|
||||
runtime=runtime,
|
||||
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
|
||||
metadata.agent_class
|
||||
),
|
||||
state: State | None = asyncio.run(
|
||||
run_controller(
|
||||
config=config,
|
||||
task_str=instruction,
|
||||
runtime=runtime,
|
||||
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
|
||||
metadata.agent_class
|
||||
),
|
||||
)
|
||||
)
|
||||
assert state is not None, 'State should not be None.'
|
||||
|
||||
@ -355,12 +357,10 @@ if __name__ == '__main__':
|
||||
output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
|
||||
prepared_dataset = prepare_dataset(gpqa_dataset, output_file, args.eval_n_limit)
|
||||
|
||||
asyncio.run(
|
||||
run_evaluation(
|
||||
dataset=prepared_dataset,
|
||||
metadata=metadata,
|
||||
output_file=output_file,
|
||||
num_workers=args.eval_num_workers,
|
||||
process_instance_func=process_instance,
|
||||
)
|
||||
run_evaluation(
|
||||
dataset=prepared_dataset,
|
||||
metadata=metadata,
|
||||
output_file=output_file,
|
||||
num_workers=args.eval_num_workers,
|
||||
process_instance_func=process_instance,
|
||||
)
|
||||
|
||||
@ -102,7 +102,7 @@ def _get_instance_id(instance: pd.Series) -> str:
|
||||
return instance.task_id.replace('/', '__')
|
||||
|
||||
|
||||
async def initialize_runtime(
|
||||
def initialize_runtime(
|
||||
runtime: Runtime,
|
||||
instance: pd.Series, # this argument is not required
|
||||
):
|
||||
@ -115,12 +115,12 @@ async def initialize_runtime(
|
||||
|
||||
action = CmdRunAction(command='mkdir -p /workspace')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
assert obs.exit_code == 0
|
||||
|
||||
action = CmdRunAction(command='cd /workspace')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
assert obs.exit_code == 0
|
||||
|
||||
problem_statement = (
|
||||
@ -131,20 +131,20 @@ async def initialize_runtime(
|
||||
host_script_path = os.path.join(tmpdir, filename)
|
||||
with open(host_script_path, 'w') as f:
|
||||
f.write(problem_statement)
|
||||
await runtime.copy_to(
|
||||
runtime.copy_to(
|
||||
host_script_path,
|
||||
'/workspace',
|
||||
)
|
||||
|
||||
# check file exists
|
||||
action = CmdRunAction(command=f'ls /workspace/{_get_instance_id(instance)}.py')
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
assert obs.exit_code == 0
|
||||
|
||||
logger.info(f"{'-' * 50} END Runtime Initialization Fn {'-' * 50}")
|
||||
|
||||
|
||||
async def complete_runtime(
|
||||
def complete_runtime(
|
||||
runtime: Runtime,
|
||||
instance: pd.Series, # this argument is not required, but it is used to get the workspace_dir_name
|
||||
) -> dict[str, Any]:
|
||||
@ -170,7 +170,7 @@ async def complete_runtime(
|
||||
action = CmdRunAction(
|
||||
command=f'cat /workspace/{_get_instance_id(instance)}.py', keep_prompt=False
|
||||
)
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
assert obs.exit_code == 0
|
||||
|
||||
function = obs.content.replace('\r\n', '\n')
|
||||
@ -194,7 +194,7 @@ async def complete_runtime(
|
||||
return test_result
|
||||
|
||||
|
||||
async def process_instance(
|
||||
def process_instance(
|
||||
instance: pd.Series,
|
||||
metadata: EvalMetadata,
|
||||
reset_logger: bool = True,
|
||||
@ -232,21 +232,23 @@ async def process_instance(
|
||||
instruction += AGENT_CLS_TO_INST_SUFFIX[metadata.agent_class]
|
||||
|
||||
# Here's how you can run the agent (similar to the `main` function) and get the final task state
|
||||
runtime = await create_runtime(config, sid=sid)
|
||||
await initialize_runtime(runtime, instance)
|
||||
state: State | None = await run_controller(
|
||||
config=config,
|
||||
task_str=instruction,
|
||||
runtime=runtime,
|
||||
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
|
||||
metadata.agent_class
|
||||
),
|
||||
runtime = create_runtime(config, sid=sid)
|
||||
initialize_runtime(runtime, instance)
|
||||
state: State | None = asyncio.run(
|
||||
run_controller(
|
||||
config=config,
|
||||
task_str=instruction,
|
||||
runtime=runtime,
|
||||
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
|
||||
metadata.agent_class
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if state is None:
|
||||
raise ValueError('State should not be None.')
|
||||
metrics = state.metrics.get() if state.metrics else None
|
||||
test_result = await complete_runtime(runtime, instance)
|
||||
test_result = complete_runtime(runtime, instance)
|
||||
|
||||
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
|
||||
# for compatibility with the existing output format, we can remake the pairs here
|
||||
@ -294,12 +296,10 @@ if __name__ == '__main__':
|
||||
output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
|
||||
instances = prepare_dataset(hefix_tests, output_file, args.eval_n_limit)
|
||||
|
||||
asyncio.run(
|
||||
run_evaluation(
|
||||
instances,
|
||||
metadata,
|
||||
output_file,
|
||||
args.eval_num_workers,
|
||||
process_instance,
|
||||
)
|
||||
run_evaluation(
|
||||
instances,
|
||||
metadata,
|
||||
output_file,
|
||||
args.eval_num_workers,
|
||||
process_instance,
|
||||
)
|
||||
|
||||
@ -128,7 +128,7 @@ def get_test_result(
|
||||
CUR_EVAL_DIR = os.path.dirname(__file__)
|
||||
|
||||
|
||||
async def initialize_runtime(
|
||||
def initialize_runtime(
|
||||
runtime: Runtime,
|
||||
instance: pd.Series, # this argument is not required
|
||||
):
|
||||
@ -142,33 +142,31 @@ async def initialize_runtime(
|
||||
# Set instance id
|
||||
action = CmdRunAction(command='mkdir -p /workspace')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
assert obs.exit_code == 0
|
||||
|
||||
action = CmdRunAction(command='cd /workspace')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
assert obs.exit_code == 0
|
||||
|
||||
# copy logic_inference.py to /workspace
|
||||
await runtime.copy_to(
|
||||
os.path.join(CUR_EVAL_DIR, 'logic_inference.py'), '/workspace'
|
||||
)
|
||||
runtime.copy_to(os.path.join(CUR_EVAL_DIR, 'logic_inference.py'), '/workspace')
|
||||
# check if the file exists
|
||||
obs = await runtime.run_action(CmdRunAction(command='ls /workspace'))
|
||||
obs = runtime.run_action(CmdRunAction(command='ls /workspace'))
|
||||
assert obs.exit_code == 0
|
||||
assert 'logic_inference.py' in obs.content
|
||||
|
||||
await runtime.add_env_vars({'DATASET_NAME': metadata.dataset})
|
||||
runtime.add_env_vars({'DATASET_NAME': metadata.dataset})
|
||||
|
||||
action = CmdRunAction(command='mkdir -p /workspace/.cache_program')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
assert obs.exit_code == 0
|
||||
|
||||
action = IPythonRunCellAction(code='%pip install scitools-pyke')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
ipynb_obs = await runtime.run_action(action)
|
||||
ipynb_obs = runtime.run_action(action)
|
||||
logger.info(ipynb_obs, extra={'msg_type': 'OBSERVATION'})
|
||||
|
||||
logger.info(f"{'-' * 50} END Runtime Initialization Fn {'-' * 50}")
|
||||
@ -179,7 +177,7 @@ with open(os.path.join(CUR_EVAL_DIR, 'instruction.txt'), 'r') as f:
|
||||
INSTRUCTION_TEMPLATE = f.read()
|
||||
|
||||
|
||||
async def process_instance(
|
||||
def process_instance(
|
||||
instance: pd.Series,
|
||||
metadata: EvalMetadata,
|
||||
reset_logger: bool = True,
|
||||
@ -206,8 +204,8 @@ async def process_instance(
|
||||
# use a session id for concurrent evaluation
|
||||
sid = instance['instance_id']
|
||||
|
||||
runtime = await create_runtime(config, sid=sid)
|
||||
await initialize_runtime(runtime, instance)
|
||||
runtime = create_runtime(config, sid=sid)
|
||||
initialize_runtime(runtime, instance)
|
||||
|
||||
# Here's how you can run the agent (similar to the `main` function) and get the final task state
|
||||
state: State | None = asyncio.run(
|
||||
@ -303,8 +301,6 @@ if __name__ == '__main__':
|
||||
)
|
||||
output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
|
||||
instances = prepare_dataset(dataset_df, output_file, args.eval_n_limit)
|
||||
asyncio.run(
|
||||
run_evaluation(
|
||||
instances, metadata, output_file, args.eval_num_workers, process_instance
|
||||
)
|
||||
run_evaluation(
|
||||
instances, metadata, output_file, args.eval_num_workers, process_instance
|
||||
)
|
||||
|
||||
@ -62,7 +62,7 @@ def get_config(
|
||||
return config
|
||||
|
||||
|
||||
async def initialize_runtime(
|
||||
def initialize_runtime(
|
||||
runtime: Runtime,
|
||||
) -> str:
|
||||
"""Initialize the runtime for the agent.
|
||||
@ -75,12 +75,12 @@ async def initialize_runtime(
|
||||
# Set instance id
|
||||
action = CmdRunAction(command='mkdir -p /workspace')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
assert obs.exit_code == 0
|
||||
|
||||
action = BrowseInteractiveAction(browser_actions=BROWSER_EVAL_GET_GOAL_ACTION)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
goal = obs.content
|
||||
|
||||
@ -88,7 +88,7 @@ async def initialize_runtime(
|
||||
return goal
|
||||
|
||||
|
||||
async def complete_runtime(
|
||||
def complete_runtime(
|
||||
runtime: Runtime,
|
||||
) -> dict[str, Any]:
|
||||
"""Complete the runtime for the agent.
|
||||
@ -102,7 +102,7 @@ async def complete_runtime(
|
||||
|
||||
action = BrowseInteractiveAction(browser_actions=BROWSER_EVAL_GET_REWARDS_ACTION)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
|
||||
logger.info(f"{'-' * 50} END Runtime Completion Fn {'-' * 50}")
|
||||
@ -111,7 +111,7 @@ async def complete_runtime(
|
||||
}
|
||||
|
||||
|
||||
async def process_instance(
|
||||
def process_instance(
|
||||
instance: pd.Series,
|
||||
metadata: EvalMetadata,
|
||||
reset_logger: bool = True,
|
||||
@ -126,8 +126,8 @@ async def process_instance(
|
||||
else:
|
||||
logger.info(f'Starting evaluation for instance {env_id}.')
|
||||
|
||||
runtime = await create_runtime(config, sid=env_id)
|
||||
task_str = await initialize_runtime(runtime)
|
||||
runtime = create_runtime(config, sid=env_id)
|
||||
task_str = initialize_runtime(runtime)
|
||||
|
||||
state: State | None = asyncio.run(
|
||||
run_controller(
|
||||
@ -154,7 +154,7 @@ async def process_instance(
|
||||
instruction = event.content
|
||||
break
|
||||
|
||||
return_val = await complete_runtime(runtime)
|
||||
return_val = complete_runtime(runtime)
|
||||
logger.info(f'Return value from complete_runtime: {return_val}')
|
||||
reward = max(return_val['rewards'])
|
||||
|
||||
@ -208,8 +208,6 @@ if __name__ == '__main__':
|
||||
output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
|
||||
instances = prepare_dataset(dataset, output_file, args.eval_n_limit)
|
||||
|
||||
asyncio.run(
|
||||
run_evaluation(
|
||||
instances, metadata, output_file, args.eval_num_workers, process_instance
|
||||
)
|
||||
run_evaluation(
|
||||
instances, metadata, output_file, args.eval_num_workers, process_instance
|
||||
)
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import functools
|
||||
import os
|
||||
from typing import Any
|
||||
@ -114,7 +115,7 @@ def get_config(
|
||||
return config
|
||||
|
||||
|
||||
async def initialize_runtime(runtime: Runtime):
|
||||
def initialize_runtime(runtime: Runtime):
|
||||
"""Initialize the runtime for the agent.
|
||||
|
||||
This function is called before the runtime is used to run the agent.
|
||||
@ -125,18 +126,18 @@ async def initialize_runtime(runtime: Runtime):
|
||||
# Set instance id
|
||||
action = CmdRunAction(command='mkdir -p /workspace')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
assert obs.exit_code == 0
|
||||
|
||||
action = CmdRunAction(command='cd /workspace')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
assert obs.exit_code == 0
|
||||
|
||||
logger.info(f"{'-' * 50} END Runtime Initialization Fn {'-' * 50}")
|
||||
|
||||
|
||||
async def process_instance(
|
||||
def process_instance(
|
||||
instance: Any,
|
||||
metadata: EvalMetadata,
|
||||
reset_logger: bool = True,
|
||||
@ -173,14 +174,16 @@ async def process_instance(
|
||||
},
|
||||
)
|
||||
|
||||
runtime = await create_runtime(config, sid=instance.instance_id)
|
||||
await initialize_runtime(runtime)
|
||||
runtime = create_runtime(config, sid=instance.instance_id)
|
||||
initialize_runtime(runtime)
|
||||
|
||||
state: State | None = await run_controller(
|
||||
config=config,
|
||||
task_str=instruction,
|
||||
runtime=runtime,
|
||||
fake_user_response_fn=fake_user_response_fn,
|
||||
state: State | None = asyncio.run(
|
||||
run_controller(
|
||||
config=config,
|
||||
task_str=instruction,
|
||||
runtime=runtime,
|
||||
fake_user_response_fn=fake_user_response_fn,
|
||||
)
|
||||
)
|
||||
|
||||
if state is None:
|
||||
|
||||
@ -13,6 +13,7 @@ TODOs:
|
||||
- Clean up the code and docker image used for evaluation.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
@ -92,7 +93,7 @@ def get_config(
|
||||
return config
|
||||
|
||||
|
||||
async def initialize_runtime(
|
||||
def initialize_runtime(
|
||||
runtime: Runtime,
|
||||
instance: pd.Series, # this argument is not required
|
||||
):
|
||||
@ -106,38 +107,38 @@ async def initialize_runtime(
|
||||
# Set instance id
|
||||
action = CmdRunAction(command='mkdir -p /workspace')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
assert obs.exit_code == 0
|
||||
|
||||
# Set up the task environment
|
||||
action = CmdRunAction(command=f'conda activate {ID2CONDA[instance["github_id"]]}')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
assert obs.exit_code == 0
|
||||
|
||||
repo_url = instance['github']
|
||||
repo_name = repo_url.split('/')[-1]
|
||||
action = CmdRunAction(command=f'git clone {repo_url} /workspace/{repo_name}')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
assert obs.exit_code == 0
|
||||
|
||||
action = CmdRunAction(command=f'chmod -R 777 /workspace/{repo_name}')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
assert obs.exit_code == 0
|
||||
|
||||
# Navigate to the task's code path
|
||||
task_path = os.path.join('/workspace', repo_name, instance['path'][2:])
|
||||
action = CmdRunAction(command=f'cd {task_path}')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
assert obs.exit_code == 0
|
||||
|
||||
logger.info(f"{'-' * 50} END Runtime Initialization Fn {'-' * 50}")
|
||||
|
||||
|
||||
async def complete_runtime(
|
||||
def complete_runtime(
|
||||
runtime: Runtime,
|
||||
instance: pd.Series, # this argument is not required, but it is used to get the workspace_dir_name
|
||||
) -> dict[str, Any]:
|
||||
@ -160,7 +161,7 @@ async def complete_runtime(
|
||||
|
||||
action = CmdRunAction(command=f'cat {eval_script}', keep_prompt=False)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
if obs.exit_code == 0:
|
||||
eval_script_content = obs.content
|
||||
else:
|
||||
@ -172,7 +173,7 @@ async def complete_runtime(
|
||||
timeout=600,
|
||||
)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
if obs.exit_code == 0:
|
||||
eval_output = obs.content
|
||||
else:
|
||||
@ -200,9 +201,7 @@ async def complete_runtime(
|
||||
return outputs
|
||||
|
||||
|
||||
async def process_instance(
|
||||
instance: Any, metadata: EvalMetadata, reset_logger: bool = True
|
||||
):
|
||||
def process_instance(instance: Any, metadata: EvalMetadata, reset_logger: bool = True):
|
||||
config = get_config(metadata)
|
||||
|
||||
# Setup the logger properly, so you can run multi-processing to parallelize the evaluation
|
||||
@ -236,22 +235,24 @@ async def process_instance(
|
||||
)
|
||||
instruction += AGENT_CLS_TO_INST_SUFFIX[metadata.agent_class]
|
||||
|
||||
runtime = await create_runtime(config, sid=sid)
|
||||
await initialize_runtime(runtime, instance)
|
||||
runtime = create_runtime(config, sid=sid)
|
||||
initialize_runtime(runtime, instance)
|
||||
|
||||
# Run the agent
|
||||
state: State | None = await run_controller(
|
||||
config=config,
|
||||
task_str=instruction,
|
||||
runtime=runtime,
|
||||
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
|
||||
metadata.agent_class
|
||||
),
|
||||
state: State | None = asyncio.run(
|
||||
run_controller(
|
||||
config=config,
|
||||
task_str=instruction,
|
||||
runtime=runtime,
|
||||
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
|
||||
metadata.agent_class
|
||||
),
|
||||
)
|
||||
)
|
||||
assert state is not None
|
||||
metrics = state.metrics.get() if state.metrics else {}
|
||||
|
||||
test_result = await complete_runtime(runtime)
|
||||
test_result = complete_runtime(runtime)
|
||||
|
||||
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
|
||||
# for compatibility with the existing output format, we can remake the pairs here
|
||||
|
||||
@ -2,6 +2,7 @@ import asyncio
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import pandas as pd
|
||||
@ -141,7 +142,7 @@ def get_config(
|
||||
return config
|
||||
|
||||
|
||||
async def initialize_runtime(
|
||||
def initialize_runtime(
|
||||
runtime: Runtime,
|
||||
instance: pd.Series, # this argument is not required
|
||||
):
|
||||
@ -160,13 +161,13 @@ async def initialize_runtime(
|
||||
command=f"""echo 'export SWE_INSTANCE_ID={instance['instance_id']}' >> ~/.bashrc && echo 'export PIP_CACHE_DIR=~/.cache/pip' >> ~/.bashrc && echo "alias git='git --no-pager'" >> ~/.bashrc"""
|
||||
)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
|
||||
action = CmdRunAction(command="""export USER=$(whoami); echo USER=${USER} """)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
|
||||
@ -177,7 +178,7 @@ async def initialize_runtime(
|
||||
# inject the instance info
|
||||
action = CmdRunAction(command='mkdir -p /swe_util/eval_data/instances')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert (
|
||||
obs.exit_code == 0
|
||||
@ -195,35 +196,35 @@ async def initialize_runtime(
|
||||
json.dump([instance], f)
|
||||
|
||||
# Copy the file to the desired location
|
||||
await runtime.copy_to(temp_file_path, '/swe_util/eval_data/instances/')
|
||||
runtime.copy_to(temp_file_path, '/swe_util/eval_data/instances/')
|
||||
|
||||
# inject the instance swe entry
|
||||
await runtime.copy_to(
|
||||
runtime.copy_to(
|
||||
str(os.path.join(script_dir, 'scripts/setup/instance_swe_entry.sh')),
|
||||
'/swe_util/',
|
||||
)
|
||||
action = CmdRunAction(command='cat ~/.bashrc')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
|
||||
action = CmdRunAction(command='source ~/.bashrc')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
|
||||
action = CmdRunAction(command='source /swe_util/instance_swe_entry.sh')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
else:
|
||||
action = CmdRunAction(command='source /swe_util/swe_entry.sh')
|
||||
action.timeout = 1800
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert (
|
||||
obs.exit_code == 0
|
||||
@ -231,13 +232,13 @@ async def initialize_runtime(
|
||||
|
||||
action = CmdRunAction(command=f'cd /workspace/{workspace_dir_name}')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
|
||||
action = CmdRunAction(command='git reset --hard')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
|
||||
@ -245,7 +246,7 @@ async def initialize_runtime(
|
||||
command='for remote_name in $(git remote); do git remote remove "${remote_name}"; done'
|
||||
)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
|
||||
@ -254,7 +255,7 @@ async def initialize_runtime(
|
||||
logger.info('-' * 30)
|
||||
|
||||
|
||||
async def complete_runtime(
|
||||
def complete_runtime(
|
||||
runtime: Runtime,
|
||||
instance: pd.Series, # this argument is not required, but it is used to get the workspace_dir_name
|
||||
) -> dict[str, Any]:
|
||||
@ -272,19 +273,19 @@ async def complete_runtime(
|
||||
|
||||
action = CmdRunAction(command=f'cd /workspace/{workspace_dir_name}')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
|
||||
action = CmdRunAction(command='git config --global core.pager ""')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
|
||||
action = CmdRunAction(command='git add -A')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
|
||||
@ -297,7 +298,7 @@ async def complete_runtime(
|
||||
)
|
||||
action.timeout = 600 + 100 * n_retries
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
n_retries += 1
|
||||
if isinstance(obs, CmdOutputObservation):
|
||||
@ -306,10 +307,10 @@ async def complete_runtime(
|
||||
break
|
||||
else:
|
||||
logger.info('Failed to get git diff, retrying...')
|
||||
await asyncio.sleep(10)
|
||||
time.sleep(10)
|
||||
elif isinstance(obs, ErrorObservation):
|
||||
logger.error(f'Error occurred: {obs.content}. Retrying...')
|
||||
await asyncio.sleep(10)
|
||||
time.sleep(10)
|
||||
else:
|
||||
raise ValueError(f'Unexpected observation type: {type(obs)}')
|
||||
|
||||
@ -319,7 +320,7 @@ async def complete_runtime(
|
||||
return {'git_patch': git_patch}
|
||||
|
||||
|
||||
async def process_instance(
|
||||
def process_instance(
|
||||
instance: pd.Series,
|
||||
metadata: EvalMetadata,
|
||||
reset_logger: bool = True,
|
||||
@ -333,28 +334,32 @@ async def process_instance(
|
||||
else:
|
||||
logger.info(f'Starting evaluation for instance {instance.instance_id}.')
|
||||
|
||||
runtime = await create_runtime(config, sid=instance.instance_id)
|
||||
await initialize_runtime(runtime, instance)
|
||||
runtime = create_runtime(config, sid=instance.instance_id)
|
||||
initialize_runtime(runtime, instance)
|
||||
|
||||
instruction = get_instruction(instance, metadata)
|
||||
|
||||
# Here's how you can run the agent (similar to the `main` function) and get the final task state
|
||||
state: State | None = await run_controller(
|
||||
config=config,
|
||||
task_str=instruction,
|
||||
runtime=runtime,
|
||||
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[metadata.agent_class],
|
||||
state: State | None = asyncio.run(
|
||||
run_controller(
|
||||
config=config,
|
||||
task_str=instruction,
|
||||
runtime=runtime,
|
||||
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
|
||||
metadata.agent_class
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
# ======= THIS IS SWE-Bench specific =======
|
||||
# Get git patch
|
||||
return_val = await complete_runtime(runtime, instance)
|
||||
return_val = complete_runtime(runtime, instance)
|
||||
git_patch = return_val['git_patch']
|
||||
logger.info(
|
||||
f'Got git diff for instance {instance.instance_id}:\n--------\n{git_patch}\n--------'
|
||||
)
|
||||
|
||||
await runtime.close()
|
||||
runtime.close()
|
||||
# ==========================================
|
||||
|
||||
# ======= Attempt to evaluate the agent's edits =======
|
||||
@ -440,8 +445,6 @@ if __name__ == '__main__':
|
||||
output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
|
||||
instances = prepare_dataset(swe_bench_tests, output_file, args.eval_n_limit)
|
||||
|
||||
asyncio.run(
|
||||
run_evaluation(
|
||||
instances, metadata, output_file, args.eval_num_workers, process_instance
|
||||
)
|
||||
run_evaluation(
|
||||
instances, metadata, output_file, args.eval_num_workers, process_instance
|
||||
)
|
||||
|
||||
@ -57,7 +57,7 @@ def get_config(
|
||||
return config
|
||||
|
||||
|
||||
async def initialize_runtime(runtime: Runtime):
|
||||
def initialize_runtime(runtime: Runtime):
|
||||
"""Initialize the runtime for the agent.
|
||||
|
||||
This function is called before the runtime is used to run the agent.
|
||||
@ -68,22 +68,20 @@ async def initialize_runtime(runtime: Runtime):
|
||||
# Set instance id
|
||||
action = CmdRunAction(command='mkdir -p /workspace')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
assert obs.exit_code == 0
|
||||
|
||||
action = CmdRunAction(command='cd /workspace')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
assert obs.exit_code == 0
|
||||
|
||||
await runtime.add_env_vars({'WOLFRAM_ALPHA_APPID': args.wolfram_alpha_appid})
|
||||
runtime.add_env_vars({'WOLFRAM_ALPHA_APPID': args.wolfram_alpha_appid})
|
||||
|
||||
logger.info(f"{'-' * 50} END Runtime Initialization Fn {'-' * 50}")
|
||||
|
||||
|
||||
async def process_instance(
|
||||
instance: Any, metadata: EvalMetadata, reset_logger: bool = True
|
||||
):
|
||||
def process_instance(instance: Any, metadata: EvalMetadata, reset_logger: bool = True):
|
||||
config = get_config(metadata)
|
||||
|
||||
qid = instance.qid
|
||||
@ -104,15 +102,19 @@ async def process_instance(
|
||||
instruction += AGENT_CLS_TO_INST_SUFFIX[metadata.agent_class]
|
||||
logger.info(f'Instruction:\n{instruction}', extra={'msg_type': 'OBSERVATION'})
|
||||
|
||||
runtime = await create_runtime(config, sid=qid)
|
||||
await initialize_runtime(runtime)
|
||||
runtime = create_runtime(config, sid=qid)
|
||||
initialize_runtime(runtime)
|
||||
|
||||
# Here's how you can run the agent (similar to the `main` function) and get the final task state
|
||||
state: State | None = await run_controller(
|
||||
config=config,
|
||||
task_str=instruction,
|
||||
runtime=runtime,
|
||||
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[metadata.agent_class],
|
||||
state: State | None = asyncio.run(
|
||||
run_controller(
|
||||
config=config,
|
||||
task_str=instruction,
|
||||
runtime=runtime,
|
||||
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
|
||||
metadata.agent_class
|
||||
],
|
||||
)
|
||||
)
|
||||
# ======= Attempt to evaluate the agent's edits =======
|
||||
# If you are working on simpler benchmark that only evaluates the final model output (e.g., in a MessageAction)
|
||||
@ -210,8 +212,6 @@ if __name__ == '__main__':
|
||||
)
|
||||
output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
|
||||
instances = prepare_dataset(toolqa_test, output_file, args.eval_n_limit)
|
||||
asyncio.run(
|
||||
run_evaluation(
|
||||
instances, metadata, output_file, args.eval_num_workers, process_instance
|
||||
)
|
||||
run_evaluation(
|
||||
instances, metadata, output_file, args.eval_num_workers, process_instance
|
||||
)
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
@ -227,7 +226,7 @@ def prepare_dataset(
|
||||
return pd.DataFrame(new_dataset)
|
||||
|
||||
|
||||
async def run_evaluation(
|
||||
def run_evaluation(
|
||||
dataset: pd.DataFrame,
|
||||
metadata: EvalMetadata,
|
||||
output_file: str,
|
||||
@ -244,14 +243,14 @@ async def run_evaluation(
|
||||
pbar = tqdm(total=len(dataset))
|
||||
output_fp = open(output_file, 'a')
|
||||
|
||||
async def update_progress(future):
|
||||
def update_progress(future):
|
||||
pbar.update(1)
|
||||
output: EvalOutput = await future if use_multiprocessing else future
|
||||
output: EvalOutput = future.result() if use_multiprocessing else future
|
||||
|
||||
pbar.set_description(f'Instance {output.instance_id}')
|
||||
pbar.set_postfix_str(f'Test Result: {output.test_result}')
|
||||
logger.info(
|
||||
f'Finished evaluation for instance {output.instance_id}: {output.test_result}\n'
|
||||
f'Finished evaluation for instance {output.instance_id}: {str(output.test_result)[:300]}...\n'
|
||||
)
|
||||
output_fp.write(json.dumps(output.model_dump()) + '\n')
|
||||
output_fp.flush()
|
||||
@ -259,25 +258,24 @@ async def run_evaluation(
|
||||
try:
|
||||
if use_multiprocessing:
|
||||
with ProcessPoolExecutor(num_workers) as executor:
|
||||
loop = asyncio.get_event_loop()
|
||||
futures = []
|
||||
for _, instance in dataset.iterrows():
|
||||
future = loop.run_in_executor(
|
||||
executor,
|
||||
future = executor.submit(
|
||||
process_instance_func,
|
||||
instance,
|
||||
metadata,
|
||||
bool(num_workers > 1),
|
||||
)
|
||||
futures.append(update_progress(future))
|
||||
|
||||
await asyncio.gather(*futures)
|
||||
future.add_done_callback(update_progress)
|
||||
futures.append(future)
|
||||
for future in futures:
|
||||
future.result()
|
||||
# Use plain for loop for single process for easier debugging
|
||||
else:
|
||||
assert num_workers == 1
|
||||
for _, instance in dataset.iterrows():
|
||||
output = await process_instance_func(instance, metadata, False)
|
||||
await update_progress(output)
|
||||
output = process_instance_func(instance, metadata, False)
|
||||
update_progress(output)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print('\nKeyboardInterrupt received. Cleaning up...\n')
|
||||
|
||||
@ -78,7 +78,7 @@ def get_config(
|
||||
return config
|
||||
|
||||
|
||||
async def initialize_runtime(
|
||||
def initialize_runtime(
|
||||
runtime: Runtime,
|
||||
) -> dict:
|
||||
"""Initialize the runtime for the agent.
|
||||
@ -91,12 +91,12 @@ async def initialize_runtime(
|
||||
# Set instance id
|
||||
action = CmdRunAction(command='mkdir -p /workspace')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
assert obs.exit_code == 0
|
||||
|
||||
action = BrowseInteractiveAction(browser_actions=BROWSER_EVAL_GET_GOAL_ACTION)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
goal = obs.content
|
||||
|
||||
@ -104,7 +104,7 @@ async def initialize_runtime(
|
||||
return goal
|
||||
|
||||
|
||||
async def complete_runtime(
|
||||
def complete_runtime(
|
||||
runtime: Runtime,
|
||||
) -> dict[str, Any]:
|
||||
"""Complete the runtime for the agent.
|
||||
@ -118,7 +118,7 @@ async def complete_runtime(
|
||||
|
||||
action = BrowseInteractiveAction(browser_actions=BROWSER_EVAL_GET_REWARDS_ACTION)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
|
||||
logger.info(f"{'-' * 50} END Runtime Completion Fn {'-' * 50}")
|
||||
@ -127,7 +127,7 @@ async def complete_runtime(
|
||||
}
|
||||
|
||||
|
||||
async def process_instance(
|
||||
def process_instance(
|
||||
instance: pd.Series,
|
||||
metadata: EvalMetadata,
|
||||
reset_logger: bool = True,
|
||||
@ -142,15 +142,16 @@ async def process_instance(
|
||||
else:
|
||||
logger.info(f'Starting evaluation for instance {env_id}.')
|
||||
|
||||
runtime = await create_runtime(config, sid=env_id)
|
||||
task_str = await initialize_runtime(runtime)
|
||||
runtime = create_runtime(config, sid=env_id)
|
||||
task_str = initialize_runtime(runtime)
|
||||
|
||||
state: State | None = await run_controller(
|
||||
config=config,
|
||||
task_str=task_str,
|
||||
runtime=runtime,
|
||||
state: State | None = asyncio.run(
|
||||
run_controller(
|
||||
config=config,
|
||||
task_str=task_str,
|
||||
runtime=runtime,
|
||||
)
|
||||
)
|
||||
|
||||
# ======= Attempt to evaluate the agent's environment impact =======
|
||||
|
||||
# If you are working on some simpler benchmark that only evaluates the final model output (e.g., in a MessageAction)
|
||||
@ -168,7 +169,7 @@ async def process_instance(
|
||||
instruction = event.content
|
||||
break
|
||||
|
||||
return_val = await complete_runtime(runtime)
|
||||
return_val = complete_runtime(runtime)
|
||||
logger.info(f'Return value from complete_runtime: {return_val}')
|
||||
reward = max(return_val['rewards'])
|
||||
|
||||
@ -222,12 +223,10 @@ if __name__ == '__main__':
|
||||
output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
|
||||
instances = prepare_dataset(dataset, output_file, args.eval_n_limit)
|
||||
|
||||
asyncio.run(
|
||||
run_evaluation(
|
||||
instances,
|
||||
metadata,
|
||||
output_file,
|
||||
args.eval_num_workers,
|
||||
process_instance,
|
||||
)
|
||||
run_evaluation(
|
||||
instances,
|
||||
metadata,
|
||||
output_file,
|
||||
args.eval_num_workers,
|
||||
process_instance,
|
||||
)
|
||||
|
||||
@ -79,13 +79,12 @@ async def main():
|
||||
event_stream = EventStream(sid, file_store)
|
||||
|
||||
runtime_cls = get_runtime_cls(config.runtime)
|
||||
runtime: Runtime = runtime_cls(
|
||||
runtime: Runtime = runtime_cls( # noqa: F841
|
||||
config=config,
|
||||
event_stream=event_stream,
|
||||
sid=sid,
|
||||
plugins=agent_cls.sandbox_plugins,
|
||||
)
|
||||
await runtime.ainit()
|
||||
|
||||
controller = AgentController(
|
||||
agent=agent,
|
||||
|
||||
@ -47,10 +47,9 @@ def read_task_from_stdin() -> str:
|
||||
return sys.stdin.read()
|
||||
|
||||
|
||||
async def create_runtime(
|
||||
def create_runtime(
|
||||
config: AppConfig,
|
||||
sid: str | None = None,
|
||||
runtime_tools_config: dict | None = None,
|
||||
) -> Runtime:
|
||||
"""Create a runtime for the agent to run on.
|
||||
|
||||
@ -79,7 +78,6 @@ async def create_runtime(
|
||||
sid=session_id,
|
||||
plugins=agent_cls.sandbox_plugins,
|
||||
)
|
||||
await runtime.ainit()
|
||||
|
||||
return runtime
|
||||
|
||||
@ -121,7 +119,7 @@ async def run_controller(
|
||||
sid = sid or generate_sid(config)
|
||||
|
||||
if runtime is None:
|
||||
runtime = await create_runtime(config, sid=sid)
|
||||
runtime = create_runtime(config, sid=sid)
|
||||
|
||||
event_stream = runtime.event_stream
|
||||
# restore cli session if enabled
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
import asyncio
|
||||
import os
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from zipfile import ZipFile
|
||||
|
||||
import aiohttp
|
||||
import docker
|
||||
import requests
|
||||
import tenacity
|
||||
|
||||
from openhands.core.config import AppConfig
|
||||
@ -104,25 +104,23 @@ class EventStreamRuntime(Runtime):
|
||||
event_stream: EventStream,
|
||||
sid: str = 'default',
|
||||
plugins: list[PluginRequirement] | None = None,
|
||||
env_vars: dict[str, str] | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
config, event_stream, sid, plugins
|
||||
) # will initialize the event stream
|
||||
self.config = config
|
||||
self._port = find_available_tcp_port()
|
||||
self.api_url = f'http://{self.config.sandbox.api_hostname}:{self._port}'
|
||||
self.session: aiohttp.ClientSession | None = None
|
||||
self.session = requests.Session()
|
||||
|
||||
self.instance_id = (
|
||||
sid + '_' + str(uuid.uuid4()) if sid is not None else str(uuid.uuid4())
|
||||
)
|
||||
# TODO: We can switch to aiodocker when `build_sandbox_image` is updated to use aiodocker
|
||||
self.docker_client: docker.DockerClient = self._init_docker_client()
|
||||
self.base_container_image = self.config.sandbox.base_container_image
|
||||
self.runtime_container_image = self.config.sandbox.runtime_container_image
|
||||
self.container_name = self.container_name_prefix + self.instance_id
|
||||
|
||||
self.container = None
|
||||
self.action_semaphore = asyncio.Semaphore(1) # Ensure one action at a time
|
||||
self.action_semaphore = threading.Semaphore(1) # Ensure one action at a time
|
||||
|
||||
self.runtime_builder = DockerRuntimeBuilder(self.docker_client)
|
||||
logger.debug(f'EventStreamRuntime `{sid}` config:\n{self.config}')
|
||||
@ -131,7 +129,6 @@ class EventStreamRuntime(Runtime):
|
||||
self.log_buffer: LogBuffer | None = None
|
||||
self.startup_done = False
|
||||
|
||||
async def ainit(self, env_vars: dict[str, str] | None = None):
|
||||
if self.config.sandbox.runtime_extra_deps:
|
||||
logger.info(
|
||||
f'Installing extra user-provided dependencies in the runtime image: {self.config.sandbox.runtime_extra_deps}'
|
||||
@ -147,14 +144,13 @@ class EventStreamRuntime(Runtime):
|
||||
self.runtime_builder,
|
||||
extra_deps=self.config.sandbox.runtime_extra_deps,
|
||||
)
|
||||
self.container = await self._init_container(
|
||||
self.container = self._init_container(
|
||||
self.sandbox_workspace_dir,
|
||||
mount_dir=self.config.workspace_mount_path,
|
||||
plugins=self.plugins,
|
||||
plugins=plugins,
|
||||
)
|
||||
# MUST call super().ainit() to initialize both default env vars
|
||||
# AND the ones in env vars!
|
||||
await super().ainit(env_vars)
|
||||
# will initialize both the event stream and the env vars
|
||||
super().__init__(config, event_stream, sid, plugins, env_vars)
|
||||
|
||||
logger.info(
|
||||
f'Container initialized with plugins: {[plugin.name for plugin in self.plugins]}'
|
||||
@ -175,7 +171,7 @@ class EventStreamRuntime(Runtime):
|
||||
stop=tenacity.stop_after_attempt(5),
|
||||
wait=tenacity.wait_exponential(multiplier=1, min=4, max=60),
|
||||
)
|
||||
async def _init_container(
|
||||
def _init_container(
|
||||
self,
|
||||
sandbox_workspace_dir: str,
|
||||
mount_dir: str | None = None,
|
||||
@ -242,19 +238,14 @@ class EventStreamRuntime(Runtime):
|
||||
except Exception as e:
|
||||
logger.error('Failed to start container')
|
||||
logger.exception(e)
|
||||
await self.close(close_client=False)
|
||||
self.close(close_client=False)
|
||||
raise e
|
||||
|
||||
async def _ensure_session(self):
|
||||
if self.session is None or self.session.closed:
|
||||
self.session = aiohttp.ClientSession()
|
||||
return self.session
|
||||
|
||||
@tenacity.retry(
|
||||
stop=tenacity.stop_after_attempt(10),
|
||||
wait=tenacity.wait_exponential(multiplier=2, min=10, max=60),
|
||||
)
|
||||
async def _wait_until_alive(self):
|
||||
def _wait_until_alive(self):
|
||||
init_msg = 'Runtime client initialized.'
|
||||
logger.debug('Getting container logs...')
|
||||
|
||||
@ -284,7 +275,7 @@ class EventStreamRuntime(Runtime):
|
||||
attempts = 0
|
||||
while not self.startup_done and attempts < 10:
|
||||
attempts += 1
|
||||
await asyncio.sleep(1)
|
||||
time.sleep(1)
|
||||
logs = self.log_buffer.get_and_clear()
|
||||
if logs:
|
||||
formatted_logs = '\n'.join([f' |{log}' for log in logs])
|
||||
@ -301,25 +292,24 @@ class EventStreamRuntime(Runtime):
|
||||
self.startup_done = True
|
||||
break
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(f'{self.api_url}/alive') as response:
|
||||
if response.status == 200:
|
||||
return
|
||||
else:
|
||||
msg = f'Action execution API is not alive. Response: {response}'
|
||||
logger.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
response = self.session.get(f'{self.api_url}/alive')
|
||||
if response.status_code == 200:
|
||||
return
|
||||
else:
|
||||
msg = f'Action execution API is not alive. Response: {response}'
|
||||
logger.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
@property
|
||||
def sandbox_workspace_dir(self):
|
||||
return self.config.workspace_mount_path_in_sandbox
|
||||
|
||||
async def close(self, close_client: bool = True):
|
||||
def close(self, close_client: bool = True):
|
||||
if self.log_buffer:
|
||||
self.log_buffer.close()
|
||||
|
||||
if self.session is not None and not self.session.closed:
|
||||
await self.session.close()
|
||||
if self.session:
|
||||
self.session.close()
|
||||
|
||||
containers = self.docker_client.containers.list(all=True)
|
||||
for container in containers:
|
||||
@ -335,12 +325,12 @@ class EventStreamRuntime(Runtime):
|
||||
if close_client:
|
||||
self.docker_client.close()
|
||||
|
||||
async def run_action(self, action: Action) -> Observation:
|
||||
def run_action(self, action: Action) -> Observation:
|
||||
# set timeout to default if not set
|
||||
if action.timeout is None:
|
||||
action.timeout = self.config.sandbox.timeout
|
||||
|
||||
async with self.action_semaphore:
|
||||
with self.action_semaphore:
|
||||
if not action.runnable:
|
||||
return NullObservation('')
|
||||
action_type = action.action # type: ignore[attr-defined]
|
||||
@ -352,30 +342,26 @@ class EventStreamRuntime(Runtime):
|
||||
)
|
||||
|
||||
logger.info('Awaiting session')
|
||||
session = await self._ensure_session()
|
||||
await self._wait_until_alive()
|
||||
self._wait_until_alive()
|
||||
|
||||
assert action.timeout is not None
|
||||
|
||||
try:
|
||||
logger.info('Executing command')
|
||||
async with session.post(
|
||||
response = self.session.post(
|
||||
f'{self.api_url}/execute_action',
|
||||
json={'action': event_to_dict(action)},
|
||||
timeout=action.timeout,
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
output = await response.json()
|
||||
obs = observation_from_dict(output)
|
||||
obs._cause = action.id # type: ignore[attr-defined]
|
||||
return obs
|
||||
else:
|
||||
error_message = await response.text()
|
||||
logger.error(f'Error from server: {error_message}')
|
||||
obs = ErrorObservation(
|
||||
f'Command execution failed: {error_message}'
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
)
|
||||
if response.status_code == 200:
|
||||
output = response.json()
|
||||
obs = observation_from_dict(output)
|
||||
obs._cause = action.id # type: ignore[attr-defined]
|
||||
return obs
|
||||
else:
|
||||
error_message = response.text
|
||||
logger.error(f'Error from server: {error_message}')
|
||||
obs = ErrorObservation(f'Command execution failed: {error_message}')
|
||||
except requests.Timeout:
|
||||
logger.error('No response received within the timeout period.')
|
||||
obs = ErrorObservation('Command execution timed out')
|
||||
except Exception as e:
|
||||
@ -383,36 +369,35 @@ class EventStreamRuntime(Runtime):
|
||||
obs = ErrorObservation(f'Command execution failed: {str(e)}')
|
||||
return obs
|
||||
|
||||
async def run(self, action: CmdRunAction) -> Observation:
|
||||
return await self.run_action(action)
|
||||
def run(self, action: CmdRunAction) -> Observation:
|
||||
return self.run_action(action)
|
||||
|
||||
async def run_ipython(self, action: IPythonRunCellAction) -> Observation:
|
||||
return await self.run_action(action)
|
||||
def run_ipython(self, action: IPythonRunCellAction) -> Observation:
|
||||
return self.run_action(action)
|
||||
|
||||
async def read(self, action: FileReadAction) -> Observation:
|
||||
return await self.run_action(action)
|
||||
def read(self, action: FileReadAction) -> Observation:
|
||||
return self.run_action(action)
|
||||
|
||||
async def write(self, action: FileWriteAction) -> Observation:
|
||||
return await self.run_action(action)
|
||||
def write(self, action: FileWriteAction) -> Observation:
|
||||
return self.run_action(action)
|
||||
|
||||
async def browse(self, action: BrowseURLAction) -> Observation:
|
||||
return await self.run_action(action)
|
||||
def browse(self, action: BrowseURLAction) -> Observation:
|
||||
return self.run_action(action)
|
||||
|
||||
async def browse_interactive(self, action: BrowseInteractiveAction) -> Observation:
|
||||
return await self.run_action(action)
|
||||
def browse_interactive(self, action: BrowseInteractiveAction) -> Observation:
|
||||
return self.run_action(action)
|
||||
|
||||
# ====================================================================
|
||||
# Implement these methods (for file operations) in the subclass
|
||||
# ====================================================================
|
||||
|
||||
async def copy_to(
|
||||
def copy_to(
|
||||
self, host_src: str, sandbox_dest: str, recursive: bool = False
|
||||
) -> None:
|
||||
if not os.path.exists(host_src):
|
||||
raise FileNotFoundError(f'Source file {host_src} does not exist')
|
||||
|
||||
session = await self._ensure_session()
|
||||
await self._wait_until_alive()
|
||||
self._wait_until_alive()
|
||||
try:
|
||||
if recursive:
|
||||
# For recursive copy, create a zip file
|
||||
@ -437,16 +422,16 @@ class EventStreamRuntime(Runtime):
|
||||
|
||||
params = {'destination': sandbox_dest, 'recursive': str(recursive).lower()}
|
||||
|
||||
async with session.post(
|
||||
f'{self.api_url}/upload_file', data=upload_data, params=params
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return
|
||||
else:
|
||||
error_message = await response.text()
|
||||
raise Exception(f'Copy operation failed: {error_message}')
|
||||
response = self.session.post(
|
||||
f'{self.api_url}/upload_file', files=upload_data, params=params
|
||||
)
|
||||
if response.status_code == 200:
|
||||
return
|
||||
else:
|
||||
error_message = response.text
|
||||
raise Exception(f'Copy operation failed: {error_message}')
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
except requests.Timeout:
|
||||
raise TimeoutError('Copy operation timed out')
|
||||
except Exception as e:
|
||||
raise RuntimeError(f'Copy operation failed: {str(e)}')
|
||||
@ -455,29 +440,26 @@ class EventStreamRuntime(Runtime):
|
||||
os.unlink(temp_zip_path)
|
||||
logger.info(f'Copy completed: host:{host_src} -> runtime:{sandbox_dest}')
|
||||
|
||||
async def list_files(self, path: str | None = None) -> list[str]:
|
||||
def list_files(self, path: str | None = None) -> list[str]:
|
||||
"""List files in the sandbox.
|
||||
|
||||
If path is None, list files in the sandbox's initial working directory (e.g., /workspace).
|
||||
"""
|
||||
session = await self._ensure_session()
|
||||
await self._wait_until_alive()
|
||||
self._wait_until_alive()
|
||||
try:
|
||||
data = {}
|
||||
if path is not None:
|
||||
data['path'] = path
|
||||
|
||||
async with session.post(
|
||||
f'{self.api_url}/list_files', json=data
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
response_json = await response.json()
|
||||
assert isinstance(response_json, list)
|
||||
return response_json
|
||||
else:
|
||||
error_message = await response.text()
|
||||
raise Exception(f'List files operation failed: {error_message}')
|
||||
except asyncio.TimeoutError:
|
||||
response = self.session.post(f'{self.api_url}/list_files', json=data)
|
||||
if response.status_code == 200:
|
||||
response_json = response.json()
|
||||
assert isinstance(response_json, list)
|
||||
return response_json
|
||||
else:
|
||||
error_message = response.text
|
||||
raise Exception(f'List files operation failed: {error_message}')
|
||||
except requests.Timeout:
|
||||
raise TimeoutError('List files operation timed out')
|
||||
except Exception as e:
|
||||
raise RuntimeError(f'List files operation failed: {str(e)}')
|
||||
|
||||
@ -33,13 +33,13 @@ class E2BRuntime(Runtime):
|
||||
raise ValueError('E2BRuntime requires an E2BSandbox')
|
||||
self.file_store = E2BFileStore(self.sandbox.filesystem)
|
||||
|
||||
async def read(self, action: FileReadAction) -> Observation:
|
||||
def read(self, action: FileReadAction) -> Observation:
|
||||
content = self.file_store.read(action.path)
|
||||
lines = read_lines(content.split('\n'), action.start, action.end)
|
||||
code_view = ''.join(lines)
|
||||
return FileReadObservation(code_view, path=action.path)
|
||||
|
||||
async def write(self, action: FileWriteAction) -> Observation:
|
||||
def write(self, action: FileWriteAction) -> Observation:
|
||||
if action.start == 0 and action.end == -1:
|
||||
self.file_store.write(action.path, action.content)
|
||||
return FileWriteObservation(content='', path=action.path)
|
||||
|
||||
@ -1,14 +1,19 @@
|
||||
import asyncio
|
||||
import os
|
||||
import ssl
|
||||
import tempfile
|
||||
import threading
|
||||
import uuid
|
||||
from typing import Any, Optional, Type
|
||||
from typing import Any, Type
|
||||
from zipfile import ZipFile
|
||||
|
||||
import aiohttp
|
||||
import aiohttp.client_exceptions
|
||||
import tenacity
|
||||
import requests
|
||||
from requests.exceptions import HTTPError, RequestException, Timeout
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from openhands.core.config import AppConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@ -36,11 +41,9 @@ from openhands.runtime.utils.runtime_build import build_runtime_image
|
||||
|
||||
DEFAULT_RETRY_EXCEPTIONS = [
|
||||
ssl.SSLCertVerificationError,
|
||||
aiohttp.ClientError,
|
||||
aiohttp.client_exceptions.ContentTypeError,
|
||||
aiohttp.client_exceptions.ClientConnectorCertificateError,
|
||||
ssl.SSLCertVerificationError,
|
||||
asyncio.TimeoutError,
|
||||
RequestException,
|
||||
HTTPError,
|
||||
Timeout,
|
||||
]
|
||||
|
||||
|
||||
@ -55,8 +58,9 @@ class RemoteRuntime(Runtime):
|
||||
event_stream: EventStream,
|
||||
sid: str = 'default',
|
||||
plugins: list[PluginRequirement] | None = None,
|
||||
env_vars: dict[str, str] | None = None,
|
||||
):
|
||||
super().__init__(config, event_stream, sid, plugins)
|
||||
self.config = config
|
||||
if self.config.sandbox.api_hostname == 'localhost':
|
||||
self.config.sandbox.api_hostname = 'api.all-hands.dev/v0/runtime'
|
||||
logger.warning(
|
||||
@ -65,20 +69,20 @@ class RemoteRuntime(Runtime):
|
||||
)
|
||||
self.api_url = f'https://{self.config.sandbox.api_hostname.rstrip("/")}'
|
||||
|
||||
self.session: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
self.action_semaphore = asyncio.Semaphore(1) # Ensure one action at a time
|
||||
if self.config.sandbox.api_key is None:
|
||||
raise ValueError(
|
||||
'API key is required to use the remote runtime. '
|
||||
'Please set the API key in the config (config.toml) or as an environment variable (SANDBOX_API_KEY).'
|
||||
)
|
||||
self.session = requests.Session()
|
||||
self.session.headers.update({'X-API-Key': self.config.sandbox.api_key})
|
||||
self.action_semaphore = threading.Semaphore(1)
|
||||
|
||||
if self.config.workspace_base is not None:
|
||||
logger.warning(
|
||||
'Setting workspace_base is not supported in the remote runtime.'
|
||||
)
|
||||
|
||||
if self.config.sandbox.api_key is None:
|
||||
raise ValueError(
|
||||
'API key is required to use the remote runtime. '
|
||||
'Please set the API key in the config (config.toml) or as an environment variable (SANDBOX_API_KEY).'
|
||||
)
|
||||
self.runtime_builder = RemoteRuntimeBuilder(
|
||||
self.api_url, self.config.sandbox.api_key
|
||||
)
|
||||
@ -95,48 +99,8 @@ class RemoteRuntime(Runtime):
|
||||
self.container_image: str = self.config.sandbox.base_container_image
|
||||
self.container_name = 'od-remote-runtime-' + self.instance_id
|
||||
logger.debug(f'RemoteRuntime `{sid}` config:\n{self.config}')
|
||||
|
||||
async def _send_request(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
retry_exceptions: list[Type[Exception]] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> aiohttp.ClientResponse:
|
||||
if retry_exceptions is None:
|
||||
retry_exceptions = DEFAULT_RETRY_EXCEPTIONS
|
||||
|
||||
session = await self._ensure_session()
|
||||
|
||||
def log_retry(retry_state):
|
||||
exception = retry_state.outcome.exception()
|
||||
logger.warning(
|
||||
f'Retry attempt {retry_state.attempt_number} failed with exception: {exception}'
|
||||
)
|
||||
|
||||
@tenacity.retry(
|
||||
stop=tenacity.stop_after_attempt(10),
|
||||
wait=tenacity.wait_exponential(multiplier=1, min=4, max=60),
|
||||
retry=tenacity.retry_if_exception_type(tuple(retry_exceptions)),
|
||||
reraise=True,
|
||||
after=log_retry,
|
||||
)
|
||||
async def _send_request_with_retry():
|
||||
async with session.request(method, url, **kwargs) as response:
|
||||
await response.read()
|
||||
return response
|
||||
|
||||
return await _send_request_with_retry()
|
||||
|
||||
async def ainit(self, env_vars: dict[str, str] | None = None):
|
||||
# Check if the container image exists
|
||||
# Use the /registry_prefix endpoint to get the registry prefix
|
||||
response = await self._send_request('GET', f'{self.api_url}/registry_prefix')
|
||||
if response.status != 200:
|
||||
raise RuntimeError(
|
||||
f'Failed to get registry prefix: {await response.text()}'
|
||||
)
|
||||
response_json = await response.json()
|
||||
response = self._send_request('GET', f'{self.api_url}/registry_prefix')
|
||||
response_json = response.json()
|
||||
registry_prefix = response_json['registry_prefix']
|
||||
os.environ['OD_RUNTIME_RUNTIME_IMAGE_REPO'] = (
|
||||
registry_prefix.rstrip('/') + '/runtime'
|
||||
@ -158,26 +122,23 @@ class RemoteRuntime(Runtime):
|
||||
)
|
||||
|
||||
# Use the /image_exists endpoint to check if the image exists
|
||||
response = await self._send_request(
|
||||
response = self._send_request(
|
||||
'GET',
|
||||
f'{self.api_url}/image_exists',
|
||||
params={'image': self.container_image},
|
||||
)
|
||||
if response.status != 200 or not (await response.json())['exists']:
|
||||
if response.status_code != 200 or not response.json()['exists']:
|
||||
raise RuntimeError(f'Container image {self.container_image} does not exist')
|
||||
|
||||
# Prepare the request body for the /start endpoint
|
||||
plugin_arg = ''
|
||||
if self.plugins is not None and len(self.plugins) > 0:
|
||||
plugin_arg = (
|
||||
f'--plugins {" ".join([plugin.name for plugin in self.plugins])} '
|
||||
)
|
||||
if self.config.sandbox.browsergym_eval_env is not None:
|
||||
browsergym_arg = (
|
||||
f'--browsergym-eval-env {self.config.sandbox.browsergym_eval_env}'
|
||||
)
|
||||
else:
|
||||
browsergym_arg = ''
|
||||
if plugins is not None and len(plugins) > 0:
|
||||
plugin_arg = f'--plugins {" ".join([plugin.name for plugin in plugins])} '
|
||||
browsergym_arg = (
|
||||
f'--browsergym-eval-env {self.config.sandbox.browsergym_eval_env}'
|
||||
if self.config.sandbox.browsergym_eval_env is not None
|
||||
else ''
|
||||
)
|
||||
start_request = {
|
||||
'image': self.container_image,
|
||||
'command': (
|
||||
@ -196,12 +157,12 @@ class RemoteRuntime(Runtime):
|
||||
}
|
||||
|
||||
# Start the sandbox using the /start endpoint
|
||||
response = await self._send_request(
|
||||
response = self._send_request(
|
||||
'POST', f'{self.api_url}/start', json=start_request
|
||||
)
|
||||
if response.status != 201:
|
||||
raise RuntimeError(f'Failed to start sandbox: {await response.text()}')
|
||||
start_response = await response.json()
|
||||
if response.status_code != 201:
|
||||
raise RuntimeError(f'Failed to start sandbox: {response.text}')
|
||||
start_response = response.json()
|
||||
self.runtime_id = start_response['runtime_id']
|
||||
self.runtime_url = start_response['url']
|
||||
|
||||
@ -209,8 +170,8 @@ class RemoteRuntime(Runtime):
|
||||
f'Sandbox started. Runtime ID: {self.runtime_id}, URL: {self.runtime_url}'
|
||||
)
|
||||
|
||||
# Initialize environment variables
|
||||
await super().ainit(env_vars)
|
||||
# Initialize the eventstream and env vars
|
||||
super().__init__(config, event_stream, sid, plugins, env_vars)
|
||||
|
||||
logger.info(
|
||||
f'Runtime initialized with plugins: {[plugin.name for plugin in self.plugins]}'
|
||||
@ -223,26 +184,40 @@ class RemoteRuntime(Runtime):
|
||||
self.runtime_url is not None
|
||||
), 'Runtime URL is not set. This should never happen.'
|
||||
|
||||
async def _ensure_session(self):
|
||||
if self.session is None or self.session.closed:
|
||||
self.session = aiohttp.ClientSession(
|
||||
headers={'X-API-Key': self.config.sandbox.api_key}
|
||||
)
|
||||
return self.session
|
||||
def _send_request(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
retry_exceptions: list[Type[Exception]] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> requests.Response:
|
||||
if retry_exceptions is None:
|
||||
retry_exceptions = DEFAULT_RETRY_EXCEPTIONS
|
||||
|
||||
@tenacity.retry(
|
||||
stop=tenacity.stop_after_attempt(10),
|
||||
wait=tenacity.wait_exponential(multiplier=1, min=4, max=60),
|
||||
retry=tenacity.retry_if_exception_type(RuntimeError),
|
||||
@retry(
|
||||
stop=stop_after_attempt(10),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
retry=retry_if_exception_type(tuple(retry_exceptions)),
|
||||
reraise=True,
|
||||
)
|
||||
def _send_request_with_retry():
|
||||
response = self.session.request(method, url, **kwargs)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
return _send_request_with_retry()
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(10),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
retry=retry_if_exception_type(RuntimeError),
|
||||
reraise=True,
|
||||
)
|
||||
async def _wait_until_alive(self):
|
||||
def _wait_until_alive(self):
|
||||
logger.info('Waiting for sandbox to be alive...')
|
||||
response = await self._send_request('GET', f'{self.runtime_url}/alive')
|
||||
if response.status == 200:
|
||||
return
|
||||
else:
|
||||
msg = f'Runtime is not alive (id={self.runtime_id}). Status: {response.status}.'
|
||||
response = self._send_request('GET', f'{self.runtime_url}/alive')
|
||||
if response.status_code != 200:
|
||||
msg = f'Runtime is not alive yet (id={self.runtime_id}). Status: {response.status_code}.'
|
||||
logger.warning(msg)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
@ -250,28 +225,25 @@ class RemoteRuntime(Runtime):
|
||||
def sandbox_workspace_dir(self):
|
||||
return self.config.workspace_mount_path_in_sandbox
|
||||
|
||||
async def close(self):
|
||||
def close(self):
|
||||
if self.runtime_id:
|
||||
try:
|
||||
response = await self._send_request(
|
||||
response = self._send_request(
|
||||
'POST', f'{self.api_url}/stop', json={'runtime_id': self.runtime_id}
|
||||
)
|
||||
if response.status != 200:
|
||||
logger.error(f'Failed to stop sandbox: {await response.text()}')
|
||||
if response.status_code != 200:
|
||||
logger.error(f'Failed to stop sandbox: {response.text}')
|
||||
else:
|
||||
logger.info(f'Sandbox stopped. Runtime ID: {self.runtime_id}')
|
||||
except Exception as e:
|
||||
raise e
|
||||
finally:
|
||||
if self.session is not None:
|
||||
await self.session.close()
|
||||
self.session = None
|
||||
self.session.close()
|
||||
|
||||
async def run_action(self, action: Action) -> Observation:
|
||||
def run_action(self, action: Action) -> Observation:
|
||||
if action.timeout is None:
|
||||
action.timeout = self.config.sandbox.timeout
|
||||
|
||||
async with self.action_semaphore:
|
||||
with self.action_semaphore:
|
||||
if not action.runnable:
|
||||
return NullObservation('')
|
||||
action_type = action.action # type: ignore[attr-defined]
|
||||
@ -282,7 +254,7 @@ class RemoteRuntime(Runtime):
|
||||
f'Action {action_type} is not supported in the current runtime.'
|
||||
)
|
||||
|
||||
await self._wait_until_alive()
|
||||
self._wait_until_alive()
|
||||
|
||||
assert action.timeout is not None
|
||||
|
||||
@ -290,28 +262,25 @@ class RemoteRuntime(Runtime):
|
||||
logger.info('Executing action')
|
||||
request_body = {'action': event_to_dict(action)}
|
||||
logger.debug(f'Request body: {request_body}')
|
||||
response = await self._send_request(
|
||||
response = self._send_request(
|
||||
'POST',
|
||||
f'{self.runtime_url}/execute_action',
|
||||
json=request_body,
|
||||
timeout=action.timeout,
|
||||
retry_exceptions=list(
|
||||
filter(
|
||||
lambda e: e != asyncio.TimeoutError,
|
||||
DEFAULT_RETRY_EXCEPTIONS,
|
||||
)
|
||||
filter(lambda e: e != TimeoutError, DEFAULT_RETRY_EXCEPTIONS)
|
||||
),
|
||||
)
|
||||
if response.status == 200:
|
||||
output = await response.json()
|
||||
if response.status_code == 200:
|
||||
output = response.json()
|
||||
obs = observation_from_dict(output)
|
||||
obs._cause = action.id # type: ignore[attr-defined]
|
||||
return obs
|
||||
else:
|
||||
error_message = await response.text()
|
||||
error_message = response.text
|
||||
logger.error(f'Error from server: {error_message}')
|
||||
obs = ErrorObservation(f'Action execution failed: {error_message}')
|
||||
except asyncio.TimeoutError:
|
||||
except Timeout:
|
||||
logger.error('No response received within the timeout period.')
|
||||
obs = ErrorObservation('Action execution timed out')
|
||||
except Exception as e:
|
||||
@ -319,31 +288,31 @@ class RemoteRuntime(Runtime):
|
||||
obs = ErrorObservation(f'Action execution failed: {str(e)}')
|
||||
return obs
|
||||
|
||||
async def run(self, action: CmdRunAction) -> Observation:
|
||||
return await self.run_action(action)
|
||||
def run(self, action: CmdRunAction) -> Observation:
|
||||
return self.run_action(action)
|
||||
|
||||
async def run_ipython(self, action: IPythonRunCellAction) -> Observation:
|
||||
return await self.run_action(action)
|
||||
def run_ipython(self, action: IPythonRunCellAction) -> Observation:
|
||||
return self.run_action(action)
|
||||
|
||||
async def read(self, action: FileReadAction) -> Observation:
|
||||
return await self.run_action(action)
|
||||
def read(self, action: FileReadAction) -> Observation:
|
||||
return self.run_action(action)
|
||||
|
||||
async def write(self, action: FileWriteAction) -> Observation:
|
||||
return await self.run_action(action)
|
||||
def write(self, action: FileWriteAction) -> Observation:
|
||||
return self.run_action(action)
|
||||
|
||||
async def browse(self, action: BrowseURLAction) -> Observation:
|
||||
return await self.run_action(action)
|
||||
def browse(self, action: BrowseURLAction) -> Observation:
|
||||
return self.run_action(action)
|
||||
|
||||
async def browse_interactive(self, action: BrowseInteractiveAction) -> Observation:
|
||||
return await self.run_action(action)
|
||||
def browse_interactive(self, action: BrowseInteractiveAction) -> Observation:
|
||||
return self.run_action(action)
|
||||
|
||||
async def copy_to(
|
||||
def copy_to(
|
||||
self, host_src: str, sandbox_dest: str, recursive: bool = False
|
||||
) -> None:
|
||||
if not os.path.exists(host_src):
|
||||
raise FileNotFoundError(f'Source file {host_src} does not exist')
|
||||
|
||||
await self._wait_until_alive()
|
||||
self._wait_until_alive()
|
||||
try:
|
||||
if recursive:
|
||||
with tempfile.NamedTemporaryFile(
|
||||
@ -366,26 +335,24 @@ class RemoteRuntime(Runtime):
|
||||
|
||||
params = {'destination': sandbox_dest, 'recursive': str(recursive).lower()}
|
||||
|
||||
response = await self._send_request(
|
||||
response = self._send_request(
|
||||
'POST',
|
||||
f'{self.runtime_url}/upload_file',
|
||||
data=upload_data,
|
||||
files=upload_data,
|
||||
params=params,
|
||||
retry_exceptions=list(
|
||||
filter(
|
||||
lambda e: e != asyncio.TimeoutError, DEFAULT_RETRY_EXCEPTIONS
|
||||
)
|
||||
filter(lambda e: e != TimeoutError, DEFAULT_RETRY_EXCEPTIONS)
|
||||
),
|
||||
)
|
||||
if response.status == 200:
|
||||
if response.status_code == 200:
|
||||
logger.info(
|
||||
f'Copy completed: host:{host_src} -> runtime:{sandbox_dest}. Response: {await response.text()}'
|
||||
f'Copy completed: host:{host_src} -> runtime:{sandbox_dest}. Response: {response.text}'
|
||||
)
|
||||
return
|
||||
else:
|
||||
error_message = await response.text()
|
||||
error_message = response.text
|
||||
raise Exception(f'Copy operation failed: {error_message}')
|
||||
except asyncio.TimeoutError:
|
||||
except TimeoutError:
|
||||
raise TimeoutError('Copy operation timed out')
|
||||
except Exception as e:
|
||||
raise RuntimeError(f'Copy operation failed: {str(e)}')
|
||||
@ -394,31 +361,29 @@ class RemoteRuntime(Runtime):
|
||||
os.unlink(temp_zip_path)
|
||||
logger.info(f'Copy completed: host:{host_src} -> runtime:{sandbox_dest}')
|
||||
|
||||
async def list_files(self, path: str | None = None) -> list[str]:
|
||||
await self._wait_until_alive()
|
||||
def list_files(self, path: str | None = None) -> list[str]:
|
||||
self._wait_until_alive()
|
||||
try:
|
||||
data = {}
|
||||
if path is not None:
|
||||
data['path'] = path
|
||||
|
||||
response = await self._send_request(
|
||||
response = self._send_request(
|
||||
'POST',
|
||||
f'{self.runtime_url}/list_files',
|
||||
json=data,
|
||||
retry_exceptions=list(
|
||||
filter(
|
||||
lambda e: e != asyncio.TimeoutError, DEFAULT_RETRY_EXCEPTIONS
|
||||
)
|
||||
filter(lambda e: e != TimeoutError, DEFAULT_RETRY_EXCEPTIONS)
|
||||
),
|
||||
)
|
||||
if response.status == 200:
|
||||
response_json = await response.json()
|
||||
if response.status_code == 200:
|
||||
response_json = response.json()
|
||||
assert isinstance(response_json, list)
|
||||
return response_json
|
||||
else:
|
||||
error_message = await response.text()
|
||||
error_message = response.text
|
||||
raise Exception(f'List files operation failed: {error_message}')
|
||||
except asyncio.TimeoutError:
|
||||
except TimeoutError:
|
||||
raise TimeoutError('List files operation timed out')
|
||||
except Exception as e:
|
||||
raise RuntimeError(f'List files operation failed: {str(e)}')
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import atexit
|
||||
import copy
|
||||
import json
|
||||
@ -58,6 +57,7 @@ class Runtime:
|
||||
event_stream: EventStream,
|
||||
sid: str = 'default',
|
||||
plugins: list[PluginRequirement] | None = None,
|
||||
env_vars: dict[str, str] | None = None,
|
||||
):
|
||||
self.sid = sid
|
||||
self.event_stream = event_stream
|
||||
@ -66,41 +66,22 @@ class Runtime:
|
||||
|
||||
self.config = copy.deepcopy(config)
|
||||
self.DEFAULT_ENV_VARS = _default_env_vars(config.sandbox)
|
||||
atexit.register(self.close_sync)
|
||||
atexit.register(self.close)
|
||||
logger.debug(f'Runtime `{sid}` config:\n{self.config}')
|
||||
|
||||
async def ainit(self, env_vars: dict[str, str] | None = None) -> None:
|
||||
"""
|
||||
Initialize the runtime (asynchronously).
|
||||
|
||||
This method should be called after the runtime's constructor.
|
||||
"""
|
||||
if self.DEFAULT_ENV_VARS:
|
||||
logger.debug(f'Adding default env vars: {self.DEFAULT_ENV_VARS}')
|
||||
await self.add_env_vars(self.DEFAULT_ENV_VARS)
|
||||
self.add_env_vars(self.DEFAULT_ENV_VARS)
|
||||
if env_vars is not None:
|
||||
logger.debug(f'Adding provided env vars: {env_vars}')
|
||||
await self.add_env_vars(env_vars)
|
||||
self.add_env_vars(env_vars)
|
||||
|
||||
async def close(self) -> None:
|
||||
def close(self) -> None:
|
||||
pass
|
||||
|
||||
def close_sync(self) -> None:
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
return
|
||||
if loop.is_running():
|
||||
loop.create_task(self.close())
|
||||
else:
|
||||
loop.run_until_complete(self.close())
|
||||
except RuntimeError:
|
||||
# Event loop is already closed, nothing to do
|
||||
pass
|
||||
|
||||
# ====================================================================
|
||||
|
||||
async def add_env_vars(self, env_vars: dict[str, str]) -> None:
|
||||
def add_env_vars(self, env_vars: dict[str, str]) -> None:
|
||||
# Add env vars to the IPython shell (if Jupyter is used)
|
||||
if any(isinstance(plugin, JupyterRequirement) for plugin in self.plugins):
|
||||
code = 'import os\n'
|
||||
@ -108,7 +89,7 @@ class Runtime:
|
||||
# Note: json.dumps gives us nice escaping for free
|
||||
code += f'os.environ["{key}"] = {json.dumps(value)}\n'
|
||||
code += '\n'
|
||||
obs = await self.run_ipython(IPythonRunCellAction(code))
|
||||
obs = self.run_ipython(IPythonRunCellAction(code))
|
||||
logger.info(f'Added env vars to IPython: code={code}, obs={obs}')
|
||||
|
||||
# Add env vars to the Bash shell
|
||||
@ -120,7 +101,7 @@ class Runtime:
|
||||
return
|
||||
cmd = cmd.strip()
|
||||
logger.debug(f'Adding env var: {cmd}')
|
||||
obs = await self.run(CmdRunAction(cmd))
|
||||
obs = self.run(CmdRunAction(cmd))
|
||||
if not isinstance(obs, CmdOutputObservation) or obs.exit_code != 0:
|
||||
raise RuntimeError(
|
||||
f'Failed to add env vars [{env_vars}] to environment: {obs.content}'
|
||||
@ -132,12 +113,12 @@ class Runtime:
|
||||
if event.timeout is None:
|
||||
event.timeout = self.config.sandbox.timeout
|
||||
assert event.timeout is not None
|
||||
observation = await self.run_action(event)
|
||||
observation = self.run_action(event)
|
||||
observation._cause = event.id # type: ignore[attr-defined]
|
||||
source = event.source if event.source else EventSource.AGENT
|
||||
self.event_stream.add_event(observation, source) # type: ignore[arg-type]
|
||||
|
||||
async def run_action(self, action: Action) -> Observation:
|
||||
def run_action(self, action: Action) -> Observation:
|
||||
"""Run an action and return the resulting observation.
|
||||
If the action is not runnable in any runtime, a NullObservation is returned.
|
||||
If the action is not supported by the current runtime, an ErrorObservation is returned.
|
||||
@ -163,35 +144,45 @@ class Runtime:
|
||||
return UserRejectObservation(
|
||||
'Action has been rejected by the user! Waiting for further user input.'
|
||||
)
|
||||
observation = await getattr(self, action_type)(action)
|
||||
observation = getattr(self, action_type)(action)
|
||||
return observation
|
||||
|
||||
# ====================================================================
|
||||
# Context manager
|
||||
# ====================================================================
|
||||
|
||||
def __enter__(self) -> 'Runtime':
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
||||
self.close()
|
||||
|
||||
# ====================================================================
|
||||
# Action execution
|
||||
# ====================================================================
|
||||
|
||||
@abstractmethod
|
||||
async def run(self, action: CmdRunAction) -> Observation:
|
||||
def run(self, action: CmdRunAction) -> Observation:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def run_ipython(self, action: IPythonRunCellAction) -> Observation:
|
||||
def run_ipython(self, action: IPythonRunCellAction) -> Observation:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def read(self, action: FileReadAction) -> Observation:
|
||||
def read(self, action: FileReadAction) -> Observation:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def write(self, action: FileWriteAction) -> Observation:
|
||||
def write(self, action: FileWriteAction) -> Observation:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def browse(self, action: BrowseURLAction) -> Observation:
|
||||
def browse(self, action: BrowseURLAction) -> Observation:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def browse_interactive(self, action: BrowseInteractiveAction) -> Observation:
|
||||
def browse_interactive(self, action: BrowseInteractiveAction) -> Observation:
|
||||
pass
|
||||
|
||||
# ====================================================================
|
||||
@ -199,11 +190,11 @@ class Runtime:
|
||||
# ====================================================================
|
||||
|
||||
@abstractmethod
|
||||
async def copy_to(self, host_src: str, sandbox_dest: str, recursive: bool = False):
|
||||
def copy_to(self, host_src: str, sandbox_dest: str, recursive: bool = False):
|
||||
raise NotImplementedError('This method is not implemented in the base class.')
|
||||
|
||||
@abstractmethod
|
||||
async def list_files(self, path: str | None = None) -> list[str]:
|
||||
def list_files(self, path: str | None = None) -> list[str]:
|
||||
"""List files in the sandbox.
|
||||
|
||||
If path is None, list files in the sandbox's initial working directory (e.g., /workspace).
|
||||
|
||||
@ -406,7 +406,7 @@ async def list_files(request: Request, path: str | None = None):
|
||||
content={'error': 'Runtime not yet initialized'},
|
||||
)
|
||||
runtime: Runtime = request.state.session.agent_session.runtime
|
||||
file_list = await runtime.list_files(path)
|
||||
file_list = runtime.list_files(path)
|
||||
return file_list
|
||||
|
||||
|
||||
@ -440,7 +440,7 @@ async def select_file(file: str, request: Request):
|
||||
)
|
||||
|
||||
read_action = FileReadAction(file)
|
||||
observation = await runtime.run_action(read_action)
|
||||
observation = runtime.run_action(read_action)
|
||||
|
||||
if isinstance(observation, FileReadObservation):
|
||||
content = observation.content
|
||||
@ -519,7 +519,7 @@ async def upload_file(request: Request, files: list[UploadFile]):
|
||||
tmp_file.flush()
|
||||
|
||||
runtime: Runtime = request.state.session.agent_session.runtime
|
||||
await runtime.copy_to(
|
||||
runtime.copy_to(
|
||||
tmp_file_path, runtime.config.workspace_mount_path_in_sandbox
|
||||
)
|
||||
uploaded_files.append(safe_filename)
|
||||
@ -686,7 +686,7 @@ async def save_file(request: Request):
|
||||
# Save the file to the agent's runtime file store
|
||||
runtime: Runtime = request.state.session.agent_session.runtime
|
||||
write_action = FileWriteAction(file_path, content)
|
||||
observation = await runtime.run_action(write_action)
|
||||
observation = runtime.run_action(write_action)
|
||||
|
||||
if isinstance(observation, FileWriteObservation):
|
||||
return JSONResponse(
|
||||
|
||||
@ -69,7 +69,7 @@ class AgentSession:
|
||||
end_state.save_to_session(self.sid, self.file_store)
|
||||
await self.controller.close()
|
||||
if self.runtime is not None:
|
||||
await self.runtime.close()
|
||||
self.runtime.close()
|
||||
if self.security_analyzer is not None:
|
||||
await self.security_analyzer.close()
|
||||
self._closed = True
|
||||
@ -95,7 +95,6 @@ class AgentSession:
|
||||
sid=self.sid,
|
||||
plugins=agent.sandbox_plugins,
|
||||
)
|
||||
await self.runtime.ainit()
|
||||
|
||||
async def _create_controller(
|
||||
self,
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
|
||||
@ -91,14 +90,14 @@ def base_container_image(request):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def runtime(temp_dir, box_class, run_as_openhands):
|
||||
runtime = await _load_runtime(temp_dir, box_class, run_as_openhands)
|
||||
def runtime(temp_dir, box_class, run_as_openhands):
|
||||
runtime = _load_runtime(temp_dir, box_class, run_as_openhands)
|
||||
yield runtime
|
||||
await runtime.close()
|
||||
await asyncio.sleep(1)
|
||||
runtime.close()
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
async def _load_runtime(
|
||||
def _load_runtime(
|
||||
temp_dir,
|
||||
box_class,
|
||||
run_as_openhands: bool = True,
|
||||
@ -135,8 +134,7 @@ async def _load_runtime(
|
||||
sid=sid,
|
||||
plugins=plugins,
|
||||
)
|
||||
await runtime.ainit()
|
||||
await asyncio.sleep(1)
|
||||
time.sleep(1)
|
||||
return runtime
|
||||
|
||||
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
"""Bash-related tests for the EventStreamRuntime, which connects to the RuntimeClient running in the sandbox."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from conftest import _load_runtime
|
||||
@ -16,15 +16,14 @@ from openhands.events.observation import CmdOutputObservation
|
||||
# ============================================================================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bash_command_pexcept(temp_dir, box_class, run_as_openhands):
|
||||
runtime = await _load_runtime(temp_dir, box_class, run_as_openhands)
|
||||
def test_bash_command_pexcept(temp_dir, box_class, run_as_openhands):
|
||||
runtime = _load_runtime(temp_dir, box_class, run_as_openhands)
|
||||
|
||||
# We set env var PS1="\u@\h:\w $"
|
||||
# and construct the PEXCEPT prompt base on it.
|
||||
# When run `env`, bad implementation of CmdRunAction will be pexcepted by this
|
||||
# and failed to pexcept the right content, causing it fail to get error code.
|
||||
obs = await runtime.run_action(CmdRunAction(command='env'))
|
||||
obs = runtime.run_action(CmdRunAction(command='env'))
|
||||
|
||||
# For example:
|
||||
# 02:16:13 - openhands:DEBUG: client.py:78 - Executing command: env
|
||||
@ -43,58 +42,54 @@ async def test_bash_command_pexcept(temp_dir, box_class, run_as_openhands):
|
||||
), 'The observation should be a CmdOutputObservation.'
|
||||
assert obs.exit_code == 0, 'The exit code should be 0.'
|
||||
|
||||
await runtime.close()
|
||||
await asyncio.sleep(1)
|
||||
runtime.close()
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_multiline_command(temp_dir, box_class):
|
||||
runtime = await _load_runtime(temp_dir, box_class)
|
||||
def test_single_multiline_command(temp_dir, box_class):
|
||||
runtime = _load_runtime(temp_dir, box_class)
|
||||
|
||||
action = CmdRunAction(command='echo \\\n -e "foo"')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0, 'The exit code should be 0.'
|
||||
assert 'foo' in obs.content
|
||||
|
||||
await runtime.close()
|
||||
await asyncio.sleep(1)
|
||||
runtime.close()
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiline_echo(temp_dir, box_class):
|
||||
runtime = await _load_runtime(temp_dir, box_class)
|
||||
def test_multiline_echo(temp_dir, box_class):
|
||||
runtime = _load_runtime(temp_dir, box_class)
|
||||
|
||||
action = CmdRunAction(command='echo -e "hello\nworld"')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0, 'The exit code should be 0.'
|
||||
assert 'hello\r\nworld' in obs.content
|
||||
|
||||
await runtime.close()
|
||||
await asyncio.sleep(1)
|
||||
runtime.close()
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_whitespace(temp_dir, box_class):
|
||||
runtime = await _load_runtime(temp_dir, box_class)
|
||||
def test_runtime_whitespace(temp_dir, box_class):
|
||||
runtime = _load_runtime(temp_dir, box_class)
|
||||
|
||||
action = CmdRunAction(command='echo -e "\\n\\n\\n"')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
|
||||
assert obs.exit_code == 0, 'The exit code should be 0.'
|
||||
assert '\r\n\r\n\r\n' in obs.content
|
||||
|
||||
await runtime.close()
|
||||
await asyncio.sleep(1)
|
||||
runtime.close()
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_multiline_commands(temp_dir, box_class, run_as_openhands):
|
||||
def test_multiple_multiline_commands(temp_dir, box_class, run_as_openhands):
|
||||
cmds = [
|
||||
'ls -l',
|
||||
'echo -e "hello\nworld"',
|
||||
@ -124,11 +119,11 @@ world "
|
||||
]
|
||||
joined_cmds = '\n'.join(cmds)
|
||||
|
||||
runtime = await _load_runtime(temp_dir, box_class, run_as_openhands)
|
||||
runtime = _load_runtime(temp_dir, box_class, run_as_openhands)
|
||||
|
||||
action = CmdRunAction(command=joined_cmds)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
|
||||
assert isinstance(obs, CmdOutputObservation)
|
||||
@ -142,32 +137,30 @@ world "
|
||||
assert 'hello\r\nworld\r\nare\r\nyou\r\n\r\nthere?' in obs.content
|
||||
assert 'hello\r\nworld "\r\n' in obs.content
|
||||
|
||||
await runtime.close()
|
||||
await asyncio.sleep(1)
|
||||
runtime.close()
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_ps2_in_output(temp_dir, box_class, run_as_openhands):
|
||||
def test_no_ps2_in_output(temp_dir, box_class, run_as_openhands):
|
||||
"""Test that the PS2 sign is not added to the output of a multiline command."""
|
||||
runtime = await _load_runtime(temp_dir, box_class, run_as_openhands)
|
||||
runtime = _load_runtime(temp_dir, box_class, run_as_openhands)
|
||||
|
||||
action = CmdRunAction(command='echo -e "hello\nworld"')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
|
||||
assert 'hello\r\nworld' in obs.content
|
||||
assert '>' not in obs.content
|
||||
|
||||
await runtime.close()
|
||||
await asyncio.sleep(1)
|
||||
runtime.close()
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiline_command_loop(temp_dir, box_class):
|
||||
def test_multiline_command_loop(temp_dir, box_class):
|
||||
# https://github.com/All-Hands-AI/OpenHands/issues/3143
|
||||
|
||||
runtime = await _load_runtime(temp_dir, box_class)
|
||||
runtime = _load_runtime(temp_dir, box_class)
|
||||
|
||||
init_cmd = """
|
||||
mkdir -p _modules && \
|
||||
@ -180,7 +173,7 @@ echo "created files"
|
||||
"""
|
||||
action = CmdRunAction(command=init_cmd)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
|
||||
assert isinstance(obs, CmdOutputObservation)
|
||||
@ -196,24 +189,23 @@ echo "success"
|
||||
"""
|
||||
action = CmdRunAction(command=follow_up_cmd)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
|
||||
assert isinstance(obs, CmdOutputObservation)
|
||||
assert obs.exit_code == 0, 'The exit code should be 0.'
|
||||
assert 'success' in obs.content
|
||||
|
||||
await runtime.close()
|
||||
await asyncio.sleep(1)
|
||||
runtime.close()
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cmd_run(temp_dir, box_class, run_as_openhands):
|
||||
runtime = await _load_runtime(temp_dir, box_class, run_as_openhands)
|
||||
def test_cmd_run(temp_dir, box_class, run_as_openhands):
|
||||
runtime = _load_runtime(temp_dir, box_class, run_as_openhands)
|
||||
|
||||
action = CmdRunAction(command='ls -l')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, CmdOutputObservation)
|
||||
assert obs.exit_code == 0
|
||||
@ -221,14 +213,14 @@ async def test_cmd_run(temp_dir, box_class, run_as_openhands):
|
||||
|
||||
action = CmdRunAction(command='mkdir test')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, CmdOutputObservation)
|
||||
assert obs.exit_code == 0
|
||||
|
||||
action = CmdRunAction(command='ls -l')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, CmdOutputObservation)
|
||||
assert obs.exit_code == 0
|
||||
@ -240,14 +232,14 @@ async def test_cmd_run(temp_dir, box_class, run_as_openhands):
|
||||
|
||||
action = CmdRunAction(command='touch test/foo.txt')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, CmdOutputObservation)
|
||||
assert obs.exit_code == 0
|
||||
|
||||
action = CmdRunAction(command='ls -l test')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, CmdOutputObservation)
|
||||
assert obs.exit_code == 0
|
||||
@ -258,22 +250,21 @@ async def test_cmd_run(temp_dir, box_class, run_as_openhands):
|
||||
# owned by root
|
||||
action = CmdRunAction(command='rm -rf test')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, CmdOutputObservation)
|
||||
assert obs.exit_code == 0
|
||||
|
||||
await runtime.close()
|
||||
await asyncio.sleep(1)
|
||||
runtime.close()
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_as_user_correct_home_dir(temp_dir, box_class, run_as_openhands):
|
||||
runtime = await _load_runtime(temp_dir, box_class, run_as_openhands)
|
||||
def test_run_as_user_correct_home_dir(temp_dir, box_class, run_as_openhands):
|
||||
runtime = _load_runtime(temp_dir, box_class, run_as_openhands)
|
||||
|
||||
action = CmdRunAction(command='cd ~ && pwd')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, CmdOutputObservation)
|
||||
assert obs.exit_code == 0
|
||||
@ -282,70 +273,67 @@ async def test_run_as_user_correct_home_dir(temp_dir, box_class, run_as_openhand
|
||||
else:
|
||||
assert '/root' in obs.content
|
||||
|
||||
await runtime.close()
|
||||
await asyncio.sleep(1)
|
||||
runtime.close()
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_cmd_run_in_single_line(temp_dir, box_class):
|
||||
runtime = await _load_runtime(temp_dir, box_class)
|
||||
def test_multi_cmd_run_in_single_line(temp_dir, box_class):
|
||||
runtime = _load_runtime(temp_dir, box_class)
|
||||
|
||||
action = CmdRunAction(command='pwd && ls -l')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, CmdOutputObservation)
|
||||
assert obs.exit_code == 0
|
||||
assert '/workspace' in obs.content
|
||||
assert 'total 0' in obs.content
|
||||
|
||||
await runtime.close()
|
||||
await asyncio.sleep(1)
|
||||
runtime.close()
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stateful_cmd(temp_dir, box_class):
|
||||
runtime = await _load_runtime(temp_dir, box_class)
|
||||
def test_stateful_cmd(temp_dir, box_class):
|
||||
runtime = _load_runtime(temp_dir, box_class)
|
||||
|
||||
action = CmdRunAction(command='mkdir test')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, CmdOutputObservation)
|
||||
assert obs.exit_code == 0, 'The exit code should be 0.'
|
||||
|
||||
action = CmdRunAction(command='cd test')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, CmdOutputObservation)
|
||||
assert obs.exit_code == 0, 'The exit code should be 0.'
|
||||
|
||||
action = CmdRunAction(command='pwd')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, CmdOutputObservation)
|
||||
assert obs.exit_code == 0, 'The exit code should be 0.'
|
||||
assert '/workspace/test' in obs.content
|
||||
|
||||
await runtime.close()
|
||||
await asyncio.sleep(1)
|
||||
runtime.close()
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failed_cmd(temp_dir, box_class):
|
||||
runtime = await _load_runtime(temp_dir, box_class)
|
||||
def test_failed_cmd(temp_dir, box_class):
|
||||
runtime = _load_runtime(temp_dir, box_class)
|
||||
|
||||
action = CmdRunAction(command='non_existing_command')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, CmdOutputObservation)
|
||||
assert obs.exit_code != 0, 'The exit code should not be 0 for a failed command.'
|
||||
|
||||
await runtime.close()
|
||||
await asyncio.sleep(1)
|
||||
runtime.close()
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
def _create_test_file(host_temp_dir):
|
||||
@ -354,19 +342,16 @@ def _create_test_file(host_temp_dir):
|
||||
f.write('Hello, World!')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_copy_single_file(temp_dir, box_class):
|
||||
runtime = await _load_runtime(temp_dir, box_class)
|
||||
def test_copy_single_file(temp_dir, box_class):
|
||||
runtime = _load_runtime(temp_dir, box_class)
|
||||
|
||||
with tempfile.TemporaryDirectory() as host_temp_dir:
|
||||
_create_test_file(host_temp_dir)
|
||||
await runtime.copy_to(
|
||||
os.path.join(host_temp_dir, 'test_file.txt'), '/workspace'
|
||||
)
|
||||
runtime.copy_to(os.path.join(host_temp_dir, 'test_file.txt'), '/workspace')
|
||||
|
||||
action = CmdRunAction(command='ls -alh /workspace')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, CmdOutputObservation)
|
||||
assert obs.exit_code == 0
|
||||
@ -374,14 +359,14 @@ async def test_copy_single_file(temp_dir, box_class):
|
||||
|
||||
action = CmdRunAction(command='cat /workspace/test_file.txt')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, CmdOutputObservation)
|
||||
assert obs.exit_code == 0
|
||||
assert 'Hello, World!' in obs.content
|
||||
|
||||
await runtime.close()
|
||||
await asyncio.sleep(1)
|
||||
runtime.close()
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
def _create_test_dir_with_files(host_temp_dir):
|
||||
@ -392,20 +377,19 @@ def _create_test_dir_with_files(host_temp_dir):
|
||||
f.write('File 2 content')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_copy_directory_recursively(temp_dir, box_class):
|
||||
runtime = await _load_runtime(temp_dir, box_class)
|
||||
def test_copy_directory_recursively(temp_dir, box_class):
|
||||
runtime = _load_runtime(temp_dir, box_class)
|
||||
|
||||
with tempfile.TemporaryDirectory() as host_temp_dir:
|
||||
# We need a separate directory, since temp_dir is mounted to /workspace
|
||||
_create_test_dir_with_files(host_temp_dir)
|
||||
await runtime.copy_to(
|
||||
runtime.copy_to(
|
||||
os.path.join(host_temp_dir, 'test_dir'), '/workspace', recursive=True
|
||||
)
|
||||
|
||||
action = CmdRunAction(command='ls -alh /workspace')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, CmdOutputObservation)
|
||||
assert obs.exit_code == 0
|
||||
@ -415,7 +399,7 @@ async def test_copy_directory_recursively(temp_dir, box_class):
|
||||
|
||||
action = CmdRunAction(command='ls -alh /workspace/test_dir')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, CmdOutputObservation)
|
||||
assert obs.exit_code == 0
|
||||
@ -424,53 +408,51 @@ async def test_copy_directory_recursively(temp_dir, box_class):
|
||||
|
||||
action = CmdRunAction(command='cat /workspace/test_dir/file1.txt')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, CmdOutputObservation)
|
||||
assert obs.exit_code == 0
|
||||
assert 'File 1 content' in obs.content
|
||||
|
||||
await runtime.close()
|
||||
await asyncio.sleep(1)
|
||||
runtime.close()
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_copy_to_non_existent_directory(temp_dir, box_class):
|
||||
runtime = await _load_runtime(temp_dir, box_class)
|
||||
def test_copy_to_non_existent_directory(temp_dir, box_class):
|
||||
runtime = _load_runtime(temp_dir, box_class)
|
||||
|
||||
with tempfile.TemporaryDirectory() as host_temp_dir:
|
||||
_create_test_file(host_temp_dir)
|
||||
await runtime.copy_to(
|
||||
runtime.copy_to(
|
||||
os.path.join(host_temp_dir, 'test_file.txt'), '/workspace/new_dir'
|
||||
)
|
||||
|
||||
action = CmdRunAction(command='cat /workspace/new_dir/test_file.txt')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, CmdOutputObservation)
|
||||
assert obs.exit_code == 0
|
||||
assert 'Hello, World!' in obs.content
|
||||
|
||||
await runtime.close()
|
||||
await asyncio.sleep(1)
|
||||
runtime.close()
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_overwrite_existing_file(temp_dir, box_class):
|
||||
runtime = await _load_runtime(temp_dir, box_class)
|
||||
def test_overwrite_existing_file(temp_dir, box_class):
|
||||
runtime = _load_runtime(temp_dir, box_class)
|
||||
|
||||
# touch a file in /workspace
|
||||
action = CmdRunAction(command='touch /workspace/test_file.txt')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, CmdOutputObservation)
|
||||
assert obs.exit_code == 0
|
||||
|
||||
action = CmdRunAction(command='cat /workspace/test_file.txt')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, CmdOutputObservation)
|
||||
assert obs.exit_code == 0
|
||||
@ -478,46 +460,42 @@ async def test_overwrite_existing_file(temp_dir, box_class):
|
||||
|
||||
with tempfile.TemporaryDirectory() as host_temp_dir:
|
||||
_create_test_file(host_temp_dir)
|
||||
await runtime.copy_to(
|
||||
os.path.join(host_temp_dir, 'test_file.txt'), '/workspace'
|
||||
)
|
||||
runtime.copy_to(os.path.join(host_temp_dir, 'test_file.txt'), '/workspace')
|
||||
|
||||
action = CmdRunAction(command='cat /workspace/test_file.txt')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, CmdOutputObservation)
|
||||
assert obs.exit_code == 0
|
||||
assert 'Hello, World!' in obs.content
|
||||
|
||||
await runtime.close()
|
||||
await asyncio.sleep(1)
|
||||
runtime.close()
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_copy_non_existent_file(temp_dir, box_class):
|
||||
runtime = await _load_runtime(temp_dir, box_class)
|
||||
def test_copy_non_existent_file(temp_dir, box_class):
|
||||
runtime = _load_runtime(temp_dir, box_class)
|
||||
|
||||
with pytest.raises(FileNotFoundError):
|
||||
await runtime.copy_to(
|
||||
runtime.copy_to(
|
||||
os.path.join(temp_dir, 'non_existent_file.txt'),
|
||||
'/workspace/should_not_exist.txt',
|
||||
)
|
||||
|
||||
action = CmdRunAction(command='ls /workspace/should_not_exist.txt')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, CmdOutputObservation)
|
||||
assert obs.exit_code != 0 # File should not exist
|
||||
|
||||
await runtime.close()
|
||||
await asyncio.sleep(1)
|
||||
runtime.close()
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keep_prompt(box_class, temp_dir):
|
||||
runtime = await _load_runtime(
|
||||
def test_keep_prompt(box_class, temp_dir):
|
||||
runtime = _load_runtime(
|
||||
temp_dir,
|
||||
box_class=box_class,
|
||||
run_as_openhands=False,
|
||||
@ -525,7 +503,7 @@ async def test_keep_prompt(box_class, temp_dir):
|
||||
|
||||
action = CmdRunAction(command='touch /workspace/test_file.txt')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, CmdOutputObservation)
|
||||
assert obs.exit_code == 0
|
||||
@ -533,22 +511,21 @@ async def test_keep_prompt(box_class, temp_dir):
|
||||
|
||||
action = CmdRunAction(command='cat /workspace/test_file.txt', keep_prompt=False)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, CmdOutputObservation)
|
||||
assert obs.exit_code == 0
|
||||
assert 'root@' not in obs.content
|
||||
|
||||
await runtime.close()
|
||||
await asyncio.sleep(1)
|
||||
runtime.close()
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_git_operation(box_class):
|
||||
def test_git_operation(box_class):
|
||||
# do not mount workspace, since workspace mount by tests will be owned by root
|
||||
# while the user_id we get via os.getuid() is different from root
|
||||
# which causes permission issues
|
||||
runtime = await _load_runtime(
|
||||
runtime = _load_runtime(
|
||||
temp_dir=None,
|
||||
box_class=box_class,
|
||||
# Need to use non-root user to expose issues
|
||||
@ -561,7 +538,7 @@ async def test_git_operation(box_class):
|
||||
# check the ownership of the current directory
|
||||
action = CmdRunAction(command='ls -alh .')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, CmdOutputObservation)
|
||||
assert obs.exit_code == 0
|
||||
@ -581,7 +558,7 @@ async def test_git_operation(box_class):
|
||||
# make sure all git operations are allowed
|
||||
action = CmdRunAction(command='git init')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, CmdOutputObservation)
|
||||
assert obs.exit_code == 0
|
||||
@ -589,7 +566,7 @@ async def test_git_operation(box_class):
|
||||
# create a file
|
||||
action = CmdRunAction(command='echo "hello" > test_file.txt')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, CmdOutputObservation)
|
||||
assert obs.exit_code == 0
|
||||
@ -597,7 +574,7 @@ async def test_git_operation(box_class):
|
||||
# git add
|
||||
action = CmdRunAction(command='git add test_file.txt')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, CmdOutputObservation)
|
||||
assert obs.exit_code == 0
|
||||
@ -605,7 +582,7 @@ async def test_git_operation(box_class):
|
||||
# git diff
|
||||
action = CmdRunAction(command='git diff')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, CmdOutputObservation)
|
||||
assert obs.exit_code == 0
|
||||
@ -613,12 +590,12 @@ async def test_git_operation(box_class):
|
||||
# git commit
|
||||
action = CmdRunAction(command='git commit -m "test commit"')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, CmdOutputObservation)
|
||||
assert obs.exit_code == 0
|
||||
|
||||
await runtime.close()
|
||||
runtime.close()
|
||||
|
||||
await runtime.close()
|
||||
await asyncio.sleep(1)
|
||||
runtime.close()
|
||||
time.sleep(1)
|
||||
|
||||
@ -1,9 +1,8 @@
|
||||
"""Browsing-related tests for the EventStreamRuntime, which connects to the RuntimeClient running in the sandbox."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from conftest import _load_runtime
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@ -24,16 +23,15 @@ from openhands.events.observation import (
|
||||
PY3_FOR_TESTING = '/openhands/miniforge3/bin/mamba run -n base python3'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simple_browse(temp_dir, box_class, run_as_openhands):
|
||||
runtime = await _load_runtime(temp_dir, box_class, run_as_openhands)
|
||||
def test_simple_browse(temp_dir, box_class, run_as_openhands):
|
||||
runtime = _load_runtime(temp_dir, box_class, run_as_openhands)
|
||||
|
||||
# Test browse
|
||||
action_cmd = CmdRunAction(
|
||||
command=f'{PY3_FOR_TESTING} -m http.server 8000 > server.log 2>&1 &'
|
||||
)
|
||||
logger.info(action_cmd, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action_cmd)
|
||||
obs = runtime.run_action(action_cmd)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
|
||||
assert isinstance(obs, CmdOutputObservation)
|
||||
@ -42,13 +40,13 @@ async def test_simple_browse(temp_dir, box_class, run_as_openhands):
|
||||
|
||||
action_cmd = CmdRunAction(command='sleep 5 && cat server.log')
|
||||
logger.info(action_cmd, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action_cmd)
|
||||
obs = runtime.run_action(action_cmd)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
|
||||
action_browse = BrowseURLAction(url='http://localhost:8000')
|
||||
logger.info(action_browse, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action_browse)
|
||||
obs = runtime.run_action(action_browse)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
|
||||
assert isinstance(obs, BrowserOutputObservation)
|
||||
@ -64,17 +62,16 @@ async def test_simple_browse(temp_dir, box_class, run_as_openhands):
|
||||
# clean up
|
||||
action = CmdRunAction(command='rm -rf server.log')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
|
||||
await runtime.close()
|
||||
await asyncio.sleep(1)
|
||||
runtime.close()
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_browsergym_eval_env(box_class, temp_dir):
|
||||
runtime = await _load_runtime(
|
||||
def test_browsergym_eval_env(box_class, temp_dir):
|
||||
runtime = _load_runtime(
|
||||
temp_dir,
|
||||
box_class=box_class,
|
||||
run_as_openhands=False, # need root permission to access file
|
||||
@ -89,7 +86,7 @@ async def test_browsergym_eval_env(box_class, temp_dir):
|
||||
# Test browse
|
||||
action = BrowseInteractiveAction(browser_actions=BROWSER_EVAL_GET_GOAL_ACTION)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
|
||||
assert isinstance(obs, BrowserOutputObservation)
|
||||
@ -100,7 +97,7 @@ async def test_browsergym_eval_env(box_class, temp_dir):
|
||||
# Make sure the browser can produce observation in eva[l
|
||||
action = BrowseInteractiveAction(browser_actions='noop()')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert (
|
||||
obs.url.strip()
|
||||
@ -110,9 +107,9 @@ async def test_browsergym_eval_env(box_class, temp_dir):
|
||||
# Make sure the rewards are working
|
||||
action = BrowseInteractiveAction(browser_actions=BROWSER_EVAL_GET_REWARDS_ACTION)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert json.loads(obs.content) == [0.0]
|
||||
|
||||
await runtime.close()
|
||||
await asyncio.sleep(1)
|
||||
runtime.close()
|
||||
time.sleep(1)
|
||||
|
||||
@ -1,10 +1,9 @@
|
||||
"""Env vars related tests for the EventStreamRuntime, which connects to the RuntimeClient running in the sandbox."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from conftest import _load_runtime
|
||||
|
||||
from openhands.events.action import CmdRunAction
|
||||
@ -15,17 +14,14 @@ from openhands.events.observation import CmdOutputObservation
|
||||
# ============================================================================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_env_vars_os_environ(temp_dir, box_class, run_as_openhands):
|
||||
def test_env_vars_os_environ(temp_dir, box_class, run_as_openhands):
|
||||
with patch.dict(os.environ, {'SANDBOX_ENV_FOOBAR': 'BAZ'}):
|
||||
runtime = await _load_runtime(temp_dir, box_class, run_as_openhands)
|
||||
runtime = _load_runtime(temp_dir, box_class, run_as_openhands)
|
||||
|
||||
obs: CmdOutputObservation = await runtime.run_action(
|
||||
CmdRunAction(command='env')
|
||||
)
|
||||
obs: CmdOutputObservation = runtime.run_action(CmdRunAction(command='env'))
|
||||
print(obs)
|
||||
|
||||
obs: CmdOutputObservation = await runtime.run_action(
|
||||
obs: CmdOutputObservation = runtime.run_action(
|
||||
CmdRunAction(command='echo $FOOBAR')
|
||||
)
|
||||
print(obs)
|
||||
@ -34,55 +30,50 @@ async def test_env_vars_os_environ(temp_dir, box_class, run_as_openhands):
|
||||
obs.content.strip().split('\n\r')[0].strip() == 'BAZ'
|
||||
), f'Output: [{obs.content}] for {box_class}'
|
||||
|
||||
await runtime.close()
|
||||
await asyncio.sleep(1)
|
||||
runtime.close()
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_env_vars_runtime_add_env_vars(temp_dir, box_class):
|
||||
runtime = await _load_runtime(temp_dir, box_class)
|
||||
await runtime.add_env_vars({'QUUX': 'abc"def'})
|
||||
def test_env_vars_runtime_add_env_vars(temp_dir, box_class):
|
||||
runtime = _load_runtime(temp_dir, box_class)
|
||||
runtime.add_env_vars({'QUUX': 'abc"def'})
|
||||
|
||||
obs: CmdOutputObservation = await runtime.run_action(
|
||||
CmdRunAction(command='echo $QUUX')
|
||||
)
|
||||
obs: CmdOutputObservation = runtime.run_action(CmdRunAction(command='echo $QUUX'))
|
||||
print(obs)
|
||||
assert obs.exit_code == 0, 'The exit code should be 0.'
|
||||
assert (
|
||||
obs.content.strip().split('\r\n')[0].strip() == 'abc"def'
|
||||
), f'Output: [{obs.content}] for {box_class}'
|
||||
|
||||
await runtime.close()
|
||||
await asyncio.sleep(1)
|
||||
runtime.close()
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_env_vars_runtime_add_empty_dict(temp_dir, box_class):
|
||||
runtime = await _load_runtime(temp_dir, box_class)
|
||||
def test_env_vars_runtime_add_empty_dict(temp_dir, box_class):
|
||||
runtime = _load_runtime(temp_dir, box_class)
|
||||
|
||||
prev_obs = await runtime.run_action(CmdRunAction(command='env'))
|
||||
prev_obs = runtime.run_action(CmdRunAction(command='env'))
|
||||
assert prev_obs.exit_code == 0, 'The exit code should be 0.'
|
||||
print(prev_obs)
|
||||
|
||||
await runtime.add_env_vars({})
|
||||
runtime.add_env_vars({})
|
||||
|
||||
obs = await runtime.run_action(CmdRunAction(command='env'))
|
||||
obs = runtime.run_action(CmdRunAction(command='env'))
|
||||
assert obs.exit_code == 0, 'The exit code should be 0.'
|
||||
print(obs)
|
||||
assert (
|
||||
obs.content == prev_obs.content
|
||||
), 'The env var content should be the same after adding an empty dict.'
|
||||
|
||||
await runtime.close()
|
||||
await asyncio.sleep(1)
|
||||
runtime.close()
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_env_vars_runtime_add_multiple_env_vars(temp_dir, box_class):
|
||||
runtime = await _load_runtime(temp_dir, box_class)
|
||||
await runtime.add_env_vars({'QUUX': 'abc"def', 'FOOBAR': 'xyz'})
|
||||
def test_env_vars_runtime_add_multiple_env_vars(temp_dir, box_class):
|
||||
runtime = _load_runtime(temp_dir, box_class)
|
||||
runtime.add_env_vars({'QUUX': 'abc"def', 'FOOBAR': 'xyz'})
|
||||
|
||||
obs: CmdOutputObservation = await runtime.run_action(
|
||||
obs: CmdOutputObservation = runtime.run_action(
|
||||
CmdRunAction(command='echo $QUUX $FOOBAR')
|
||||
)
|
||||
print(obs)
|
||||
@ -91,17 +82,16 @@ async def test_env_vars_runtime_add_multiple_env_vars(temp_dir, box_class):
|
||||
obs.content.strip().split('\r\n')[0].strip() == 'abc"def xyz'
|
||||
), f'Output: [{obs.content}] for {box_class}'
|
||||
|
||||
await runtime.close()
|
||||
await asyncio.sleep(1)
|
||||
runtime.close()
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_env_vars_runtime_add_env_vars_overwrite(temp_dir, box_class):
|
||||
def test_env_vars_runtime_add_env_vars_overwrite(temp_dir, box_class):
|
||||
with patch.dict(os.environ, {'SANDBOX_ENV_FOOBAR': 'BAZ'}):
|
||||
runtime = await _load_runtime(temp_dir, box_class)
|
||||
await runtime.add_env_vars({'FOOBAR': 'xyz'})
|
||||
runtime = _load_runtime(temp_dir, box_class)
|
||||
runtime.add_env_vars({'FOOBAR': 'xyz'})
|
||||
|
||||
obs: CmdOutputObservation = await runtime.run_action(
|
||||
obs: CmdOutputObservation = runtime.run_action(
|
||||
CmdRunAction(command='echo $FOOBAR')
|
||||
)
|
||||
print(obs)
|
||||
@ -110,5 +100,5 @@ async def test_env_vars_runtime_add_env_vars_overwrite(temp_dir, box_class):
|
||||
obs.content.strip().split('\r\n')[0].strip() == 'xyz'
|
||||
), f'Output: [{obs.content}] for {box_class}'
|
||||
|
||||
await runtime.close()
|
||||
await asyncio.sleep(1)
|
||||
runtime.close()
|
||||
time.sleep(1)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
"""Image-related tests for the EventStreamRuntime, which connects to the RuntimeClient running in the sandbox."""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from conftest import _load_runtime
|
||||
@ -13,8 +13,7 @@ from openhands.events.action import CmdRunAction
|
||||
# ============================================================================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bash_python_version(temp_dir, box_class, base_container_image):
|
||||
def test_bash_python_version(temp_dir, box_class, base_container_image):
|
||||
"""Make sure Python is available in bash."""
|
||||
if base_container_image not in [
|
||||
'python:3.11-bookworm',
|
||||
@ -22,36 +21,35 @@ async def test_bash_python_version(temp_dir, box_class, base_container_image):
|
||||
]:
|
||||
pytest.skip('This test is only for python-related images')
|
||||
|
||||
runtime = await _load_runtime(
|
||||
runtime = _load_runtime(
|
||||
temp_dir, box_class, base_container_image=base_container_image
|
||||
)
|
||||
|
||||
action = CmdRunAction(command='which python')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
|
||||
action = CmdRunAction(command='python --version')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
assert 'Python 3.11' in obs.content # Check for specific version
|
||||
|
||||
action = CmdRunAction(command='pip --version')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
assert 'pip' in obs.content # Check that pip is available
|
||||
|
||||
await runtime.close()
|
||||
await asyncio.sleep(1)
|
||||
runtime.close()
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nodejs_22_version(temp_dir, box_class, base_container_image):
|
||||
def test_nodejs_22_version(temp_dir, box_class, base_container_image):
|
||||
"""Make sure Node.js is available in bash."""
|
||||
if base_container_image not in [
|
||||
'node:22-bookworm',
|
||||
@ -59,39 +57,38 @@ async def test_nodejs_22_version(temp_dir, box_class, base_container_image):
|
||||
]:
|
||||
pytest.skip('This test is only for nodejs-related images')
|
||||
|
||||
runtime = await _load_runtime(
|
||||
runtime = _load_runtime(
|
||||
temp_dir, box_class, base_container_image=base_container_image
|
||||
)
|
||||
|
||||
action = CmdRunAction(command='node --version')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
assert 'v22' in obs.content # Check for specific version
|
||||
|
||||
await runtime.close()
|
||||
await asyncio.sleep(1)
|
||||
runtime.close()
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_go_version(temp_dir, box_class, base_container_image):
|
||||
def test_go_version(temp_dir, box_class, base_container_image):
|
||||
"""Make sure Go is available in bash."""
|
||||
if base_container_image not in [
|
||||
'golang:1.23-bookworm',
|
||||
]:
|
||||
pytest.skip('This test is only for go-related images')
|
||||
|
||||
runtime = await _load_runtime(
|
||||
runtime = _load_runtime(
|
||||
temp_dir, box_class, base_container_image=base_container_image
|
||||
)
|
||||
|
||||
action = CmdRunAction(command='go version')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
assert 'go1.23' in obs.content # Check for specific version
|
||||
|
||||
await runtime.close()
|
||||
await asyncio.sleep(1)
|
||||
runtime.close()
|
||||
time.sleep(1)
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
"""Test the EventStreamRuntime, which connects to the RuntimeClient running in the sandbox."""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from conftest import _load_runtime
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@ -26,14 +25,13 @@ from openhands.runtime.client.runtime import EventStreamRuntime
|
||||
# ============================================================================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simple_cmd_ipython_and_fileop(temp_dir, box_class, run_as_openhands):
|
||||
runtime = await _load_runtime(temp_dir, box_class, run_as_openhands)
|
||||
def test_simple_cmd_ipython_and_fileop(temp_dir, box_class, run_as_openhands):
|
||||
runtime = _load_runtime(temp_dir, box_class, run_as_openhands)
|
||||
|
||||
# Test run command
|
||||
action_cmd = CmdRunAction(command='ls -l')
|
||||
logger.info(action_cmd, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action_cmd)
|
||||
obs = runtime.run_action(action_cmd)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
|
||||
assert isinstance(obs, CmdOutputObservation)
|
||||
@ -44,7 +42,7 @@ async def test_simple_cmd_ipython_and_fileop(temp_dir, box_class, run_as_openhan
|
||||
test_code = "print('Hello, `World`!\\n')"
|
||||
action_ipython = IPythonRunCellAction(code=test_code)
|
||||
logger.info(action_ipython, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action_ipython)
|
||||
obs = runtime.run_action(action_ipython)
|
||||
assert isinstance(obs, IPythonRunCellObservation)
|
||||
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
@ -57,7 +55,7 @@ async def test_simple_cmd_ipython_and_fileop(temp_dir, box_class, run_as_openhan
|
||||
# Test read file (file should not exist)
|
||||
action_read = FileReadAction(path='hello.sh')
|
||||
logger.info(action_read, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action_read)
|
||||
obs = runtime.run_action(action_read)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, ErrorObservation)
|
||||
assert 'File not found' in obs.content
|
||||
@ -65,7 +63,7 @@ async def test_simple_cmd_ipython_and_fileop(temp_dir, box_class, run_as_openhan
|
||||
# Test write file
|
||||
action_write = FileWriteAction(content='echo "Hello, World!"', path='hello.sh')
|
||||
logger.info(action_write, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action_write)
|
||||
obs = runtime.run_action(action_write)
|
||||
assert isinstance(obs, FileWriteObservation)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
|
||||
@ -76,7 +74,7 @@ async def test_simple_cmd_ipython_and_fileop(temp_dir, box_class, run_as_openhan
|
||||
# Test read file (file should exist)
|
||||
action_read = FileReadAction(path='hello.sh')
|
||||
logger.info(action_read, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action_read)
|
||||
obs = runtime.run_action(action_read)
|
||||
assert isinstance(
|
||||
obs, FileReadObservation
|
||||
), 'The observation should be a FileReadObservation.'
|
||||
@ -88,24 +86,23 @@ async def test_simple_cmd_ipython_and_fileop(temp_dir, box_class, run_as_openhan
|
||||
# clean up
|
||||
action = CmdRunAction(command='rm -rf hello.sh')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
|
||||
await runtime.close()
|
||||
await asyncio.sleep(1)
|
||||
runtime.close()
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ipython_multi_user(temp_dir, box_class, run_as_openhands):
|
||||
runtime = await _load_runtime(temp_dir, box_class, run_as_openhands)
|
||||
def test_ipython_multi_user(temp_dir, box_class, run_as_openhands):
|
||||
runtime = _load_runtime(temp_dir, box_class, run_as_openhands)
|
||||
|
||||
# Test run ipython
|
||||
# get username
|
||||
test_code = "import os; print(os.environ['USER'])"
|
||||
action_ipython = IPythonRunCellAction(code=test_code)
|
||||
logger.info(action_ipython, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action_ipython)
|
||||
obs = runtime.run_action(action_ipython)
|
||||
assert isinstance(obs, IPythonRunCellObservation)
|
||||
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
@ -118,7 +115,7 @@ async def test_ipython_multi_user(temp_dir, box_class, run_as_openhands):
|
||||
test_code = 'import os; print(os.getcwd())'
|
||||
action_ipython = IPythonRunCellAction(code=test_code)
|
||||
logger.info(action_ipython, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action_ipython)
|
||||
obs = runtime.run_action(action_ipython)
|
||||
assert isinstance(obs, IPythonRunCellObservation)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert (
|
||||
@ -134,7 +131,7 @@ async def test_ipython_multi_user(temp_dir, box_class, run_as_openhands):
|
||||
test_code = "with open('test.txt', 'w') as f: f.write('Hello, world!')"
|
||||
action_ipython = IPythonRunCellAction(code=test_code)
|
||||
logger.info(action_ipython, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action_ipython)
|
||||
obs = runtime.run_action(action_ipython)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, IPythonRunCellObservation)
|
||||
assert (
|
||||
@ -149,7 +146,7 @@ async def test_ipython_multi_user(temp_dir, box_class, run_as_openhands):
|
||||
# check file owner via bash
|
||||
action = CmdRunAction(command='ls -alh test.txt')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
if run_as_openhands:
|
||||
@ -163,24 +160,23 @@ async def test_ipython_multi_user(temp_dir, box_class, run_as_openhands):
|
||||
# clean up
|
||||
action = CmdRunAction(command='rm -rf test')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
|
||||
await runtime.close()
|
||||
await asyncio.sleep(1)
|
||||
runtime.close()
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ipython_simple(temp_dir, box_class):
|
||||
runtime = await _load_runtime(temp_dir, box_class)
|
||||
def test_ipython_simple(temp_dir, box_class):
|
||||
runtime = _load_runtime(temp_dir, box_class)
|
||||
|
||||
# Test run ipython
|
||||
# get username
|
||||
test_code = 'print(1)'
|
||||
action_ipython = IPythonRunCellAction(code=test_code)
|
||||
logger.info(action_ipython, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action_ipython)
|
||||
obs = runtime.run_action(action_ipython)
|
||||
assert isinstance(obs, IPythonRunCellObservation)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert (
|
||||
@ -192,30 +188,30 @@ async def test_ipython_simple(temp_dir, box_class):
|
||||
).strip()
|
||||
)
|
||||
|
||||
await runtime.close()
|
||||
await asyncio.sleep(1)
|
||||
runtime.close()
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
async def _test_ipython_agentskills_fileop_pwd_impl(
|
||||
def _test_ipython_agentskills_fileop_pwd_impl(
|
||||
runtime: EventStreamRuntime, enable_auto_lint: bool
|
||||
):
|
||||
# remove everything in /workspace
|
||||
action = CmdRunAction(command='rm -rf /workspace/*')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
|
||||
action = CmdRunAction(command='mkdir test')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, CmdOutputObservation)
|
||||
assert obs.exit_code == 0
|
||||
|
||||
action = IPythonRunCellAction(code="create_file('hello.py')")
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, IPythonRunCellObservation)
|
||||
assert obs.content.replace('\r\n', '\n').strip().split('\n') == (
|
||||
@ -230,7 +226,7 @@ async def _test_ipython_agentskills_fileop_pwd_impl(
|
||||
|
||||
action = CmdRunAction(command='cd test')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, CmdOutputObservation)
|
||||
assert obs.exit_code == 0
|
||||
@ -239,7 +235,7 @@ async def _test_ipython_agentskills_fileop_pwd_impl(
|
||||
# i.e., /workspace/test/hello.py instead of /workspace/hello.py
|
||||
action = IPythonRunCellAction(code="create_file('hello.py')")
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, IPythonRunCellObservation)
|
||||
assert obs.content.replace('\r\n', '\n').strip().split('\n') == (
|
||||
@ -258,7 +254,7 @@ async def _test_ipython_agentskills_fileop_pwd_impl(
|
||||
code="insert_content_at_line('hello.py', 1, ' print(\"hello world\")')"
|
||||
)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, IPythonRunCellObservation)
|
||||
assert obs.content.replace('\r\n', '\n').strip().split('\n') == (
|
||||
@ -292,7 +288,7 @@ DO NOT re-run the same failed edit command. Running it again will lead to the sa
|
||||
code="insert_content_at_line('hello.py', 1, 'print(\"hello world\")')"
|
||||
)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, IPythonRunCellObservation)
|
||||
assert obs.content.replace('\r\n', '\n').strip().split('\n') == (
|
||||
@ -309,38 +305,36 @@ DO NOT re-run the same failed edit command. Running it again will lead to the sa
|
||||
|
||||
action = CmdRunAction(command='rm -rf /workspace/*')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
|
||||
await runtime.close()
|
||||
await asyncio.sleep(1)
|
||||
runtime.close()
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ipython_agentskills_fileop_pwd(
|
||||
def test_ipython_agentskills_fileop_pwd(
|
||||
temp_dir, box_class, run_as_openhands, enable_auto_lint
|
||||
):
|
||||
"""Make sure that cd in bash also update the current working directory in ipython."""
|
||||
|
||||
runtime = await _load_runtime(
|
||||
runtime = _load_runtime(
|
||||
temp_dir, box_class, run_as_openhands, enable_auto_lint=enable_auto_lint
|
||||
)
|
||||
await _test_ipython_agentskills_fileop_pwd_impl(runtime, enable_auto_lint)
|
||||
_test_ipython_agentskills_fileop_pwd_impl(runtime, enable_auto_lint)
|
||||
|
||||
await runtime.close()
|
||||
await asyncio.sleep(1)
|
||||
runtime.close()
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ipython_agentskills_fileop_pwd_with_userdir(temp_dir, box_class):
|
||||
def test_ipython_agentskills_fileop_pwd_with_userdir(temp_dir, box_class):
|
||||
"""Make sure that cd in bash also update the current working directory in ipython.
|
||||
|
||||
Handle special case where the pwd is provided as "~", which should be expanded using os.path.expanduser
|
||||
on the client side.
|
||||
"""
|
||||
|
||||
runtime = await _load_runtime(
|
||||
runtime = _load_runtime(
|
||||
temp_dir,
|
||||
box_class,
|
||||
run_as_openhands=False,
|
||||
@ -348,20 +342,20 @@ async def test_ipython_agentskills_fileop_pwd_with_userdir(temp_dir, box_class):
|
||||
|
||||
action = CmdRunAction(command='cd ~')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
|
||||
action = CmdRunAction(command='mkdir test && ls -la')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, CmdOutputObservation)
|
||||
assert obs.exit_code == 0
|
||||
|
||||
action = IPythonRunCellAction(code="create_file('hello.py')")
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, IPythonRunCellObservation)
|
||||
assert obs.content.replace('\r\n', '\n').strip().split('\n') == (
|
||||
@ -376,7 +370,7 @@ async def test_ipython_agentskills_fileop_pwd_with_userdir(temp_dir, box_class):
|
||||
|
||||
action = CmdRunAction(command='cd test')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, CmdOutputObservation)
|
||||
assert obs.exit_code == 0
|
||||
@ -385,7 +379,7 @@ async def test_ipython_agentskills_fileop_pwd_with_userdir(temp_dir, box_class):
|
||||
# i.e., /workspace/test/hello.py instead of /workspace/hello.py
|
||||
action = IPythonRunCellAction(code="create_file('hello.py')")
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, IPythonRunCellObservation)
|
||||
assert obs.content.replace('\r\n', '\n').strip().split('\n') == (
|
||||
@ -398,26 +392,25 @@ async def test_ipython_agentskills_fileop_pwd_with_userdir(temp_dir, box_class):
|
||||
'[Jupyter Python interpreter: /openhands/poetry/openhands-ai-5O4_aCHf-py3.11/bin/python]'
|
||||
).strip().split('\n')
|
||||
|
||||
await runtime.close()
|
||||
await asyncio.sleep(1)
|
||||
runtime.close()
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ipython_package_install(temp_dir, box_class, run_as_openhands):
|
||||
def test_ipython_package_install(temp_dir, box_class, run_as_openhands):
|
||||
"""Make sure that cd in bash also update the current working directory in ipython."""
|
||||
runtime = await _load_runtime(temp_dir, box_class, run_as_openhands)
|
||||
runtime = _load_runtime(temp_dir, box_class, run_as_openhands)
|
||||
|
||||
# It should error out since pymsgbox is not installed
|
||||
action = IPythonRunCellAction(code='import pymsgbox')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert "ModuleNotFoundError: No module named 'pymsgbox'" in obs.content
|
||||
|
||||
# Install pymsgbox in Jupyter
|
||||
action = IPythonRunCellAction(code='%pip install pymsgbox==1.0.9')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert (
|
||||
'Successfully installed pymsgbox-1.0.9' in obs.content
|
||||
@ -426,7 +419,7 @@ async def test_ipython_package_install(temp_dir, box_class, run_as_openhands):
|
||||
|
||||
action = IPythonRunCellAction(code='import pymsgbox')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = await runtime.run_action(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
# import should not error out
|
||||
assert obs.content.strip() == (
|
||||
@ -435,5 +428,5 @@ async def test_ipython_package_install(temp_dir, box_class, run_as_openhands):
|
||||
'[Jupyter Python interpreter: /openhands/poetry/openhands-ai-5O4_aCHf-py3.11/bin/python]'
|
||||
)
|
||||
|
||||
await runtime.close()
|
||||
await asyncio.sleep(1)
|
||||
runtime.close()
|
||||
time.sleep(1)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user