mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
[agent] LLM-based editing (#3985)
Co-authored-by: Tim O'Farrell <tofarr@gmail.com> Co-authored-by: Engel Nyst <enyst@users.noreply.github.com> Co-authored-by: Robert Brennan <accounts@rbren.io> Co-authored-by: Graham Neubig <neubig@gmail.com>
This commit is contained in:
parent
6fe5482b20
commit
da548d308c
@ -172,11 +172,9 @@ model = "gpt-4o"
|
||||
#disable_vision = true
|
||||
|
||||
[llm.gpt4o-mini]
|
||||
# API key to use
|
||||
api_key = "your-api-key"
|
||||
model = "gpt-4o"
|
||||
|
||||
# Model to use
|
||||
model = "gpt-4o-mini"
|
||||
|
||||
#################################### Agent ###################################
|
||||
# Configuration for agents (group name starts with 'agent')
|
||||
|
||||
@ -48,13 +48,14 @@ def get_config(
|
||||
config = AppConfig(
|
||||
default_agent=metadata.agent_class,
|
||||
run_as_openhands=False,
|
||||
runtime='eventstream',
|
||||
runtime=os.environ.get('RUNTIME', 'eventstream'),
|
||||
max_iterations=metadata.max_iterations,
|
||||
sandbox=SandboxConfig(
|
||||
base_container_image='python:3.11-bookworm',
|
||||
enable_auto_lint=True,
|
||||
use_host_network=False,
|
||||
timeout=100,
|
||||
api_key=os.environ.get('ALLHANDS_API_KEY', None),
|
||||
),
|
||||
# do not mount workspace
|
||||
workspace_base=None,
|
||||
@ -186,7 +187,9 @@ def process_instance(
|
||||
signature_file=f'{instance.instance_name}.py',
|
||||
)
|
||||
if USE_UNIT_TESTS:
|
||||
print(f'\nInstruction to run test_file: {instance.instance_name}_test.py\n')
|
||||
logger.info(
|
||||
f'\nInstruction to run test_file: {instance.instance_name}_test.py\n'
|
||||
)
|
||||
instruction += (
|
||||
f'Use `python -m unittest {instance.instance_name}_test.py` to run the test_file '
|
||||
'and verify the correctness of your solution. DO NOT EDIT the test file.\n\n'
|
||||
|
||||
@ -81,8 +81,10 @@ def get_instruction(instance: pd.Series, metadata: EvalMetadata):
|
||||
instruction += f'# Hints\n{instance.hints_text}\n\n'
|
||||
instruction += (
|
||||
'IMPORTANT: You should ONLY interact with the environment provided to you AND NEVER ASK FOR HUMAN HELP.\n'
|
||||
'You should NOT modify any existing test case files. If needed, you can add new test cases in a NEW file to reproduce the issue.\n'
|
||||
'You SHOULD INCLUDE PROPER INDENTATION in your edit commands.\n'
|
||||
'You should NOT modify any existing test case files. You SHOULD add new test in a NEW file to reproduce the issue.\n'
|
||||
'You should verify that the issue is resolved and any new tests you create pass successfully.\n'
|
||||
'You should NEVER use web browsing or any other web-based tools.\n'
|
||||
'You should ALWAYS use the default Python interpreter available in the <execute_bash> environment to run code related to the provided issue and/or repository.\n'
|
||||
)
|
||||
|
||||
# NOTE: You can actually set slightly different instruction for different agents
|
||||
@ -123,7 +125,6 @@ def get_config(
|
||||
config = AppConfig(
|
||||
default_agent=metadata.agent_class,
|
||||
run_as_openhands=False,
|
||||
max_budget_per_task=4,
|
||||
max_iterations=metadata.max_iterations,
|
||||
runtime=os.environ.get('RUNTIME', 'eventstream'),
|
||||
sandbox=SandboxConfig(
|
||||
@ -169,7 +170,7 @@ def initialize_runtime(
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert_and_raise(
|
||||
obs.exit_code == 0, f'Failed to export SWE_INSTANCE_ID: {obs.content}'
|
||||
obs.exit_code == 0, f'Failed to export SWE_INSTANCE_ID: {str(obs)}'
|
||||
)
|
||||
|
||||
action = CmdRunAction(command="""export USER=$(whoami); echo USER=${USER} """)
|
||||
@ -177,7 +178,7 @@ def initialize_runtime(
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert_and_raise(obs.exit_code == 0, f'Failed to export USER: {obs.content}')
|
||||
assert_and_raise(obs.exit_code == 0, f'Failed to export USER: {str(obs)}')
|
||||
|
||||
if USE_INSTANCE_IMAGE:
|
||||
# inject the init script
|
||||
@ -191,7 +192,7 @@ def initialize_runtime(
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert_and_raise(
|
||||
obs.exit_code == 0,
|
||||
f'Failed to create /swe_util/eval_data/instances: {obs.content}',
|
||||
f'Failed to create /swe_util/eval_data/instances: {str(obs)}',
|
||||
)
|
||||
|
||||
swe_instance_json_name = 'swe-bench-instance.json'
|
||||
@ -218,16 +219,16 @@ def initialize_runtime(
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert_and_raise(obs.exit_code == 0, f'Failed to cat ~/.bashrc: {obs.content}')
|
||||
assert_and_raise(obs.exit_code == 0, f'Failed to cat ~/.bashrc: {str(obs)}')
|
||||
|
||||
action = CmdRunAction(command='source ~/.bashrc')
|
||||
action.timeout = 600
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert_and_raise(
|
||||
obs.exit_code == 0, f'Failed to source ~/.bashrc: {obs.content}'
|
||||
)
|
||||
if isinstance(obs, ErrorObservation):
|
||||
logger.error(f'Failed to source ~/.bashrc: {str(obs)}')
|
||||
assert_and_raise(obs.exit_code == 0, f'Failed to source ~/.bashrc: {str(obs)}')
|
||||
|
||||
action = CmdRunAction(command='source /swe_util/instance_swe_entry.sh')
|
||||
action.timeout = 3600
|
||||
@ -236,7 +237,7 @@ def initialize_runtime(
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert_and_raise(
|
||||
obs.exit_code == 0,
|
||||
f'Failed to source /swe_util/instance_swe_entry.sh: {obs.content}',
|
||||
f'Failed to source /swe_util/instance_swe_entry.sh: {str(obs)}',
|
||||
)
|
||||
else:
|
||||
action = CmdRunAction(command='source /swe_util/swe_entry.sh')
|
||||
@ -246,7 +247,7 @@ def initialize_runtime(
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert_and_raise(
|
||||
obs.exit_code == 0,
|
||||
f'Failed to source /swe_util/swe_entry.sh: {obs.content}',
|
||||
f'Failed to source /swe_util/swe_entry.sh: {str(obs)}',
|
||||
)
|
||||
|
||||
action = CmdRunAction(command=f'cd /workspace/{workspace_dir_name}')
|
||||
@ -256,7 +257,7 @@ def initialize_runtime(
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert_and_raise(
|
||||
obs.exit_code == 0,
|
||||
f'Failed to cd to /workspace/{workspace_dir_name}: {obs.content}',
|
||||
f'Failed to cd to /workspace/{workspace_dir_name}: {str(obs)}',
|
||||
)
|
||||
|
||||
action = CmdRunAction(command='git reset --hard')
|
||||
@ -264,7 +265,7 @@ def initialize_runtime(
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert_and_raise(obs.exit_code == 0, f'Failed to git reset --hard: {obs.content}')
|
||||
assert_and_raise(obs.exit_code == 0, f'Failed to git reset --hard: {str(obs)}')
|
||||
|
||||
action = CmdRunAction(
|
||||
command='for remote_name in $(git remote); do git remote remove "${remote_name}"; done'
|
||||
@ -273,7 +274,7 @@ def initialize_runtime(
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert_and_raise(obs.exit_code == 0, f'Failed to remove git remotes: {obs.content}')
|
||||
assert_and_raise(obs.exit_code == 0, f'Failed to remove git remotes: {str(obs)}')
|
||||
|
||||
logger.info('-' * 30)
|
||||
logger.info('END Runtime Initialization Fn')
|
||||
@ -303,7 +304,7 @@ def complete_runtime(
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert_and_raise(
|
||||
obs.exit_code == 0,
|
||||
f'Failed to cd to /workspace/{workspace_dir_name}: {obs.content}',
|
||||
f'Failed to cd to /workspace/{workspace_dir_name}: {str(obs)}',
|
||||
)
|
||||
|
||||
action = CmdRunAction(command='git config --global core.pager ""')
|
||||
@ -313,7 +314,7 @@ def complete_runtime(
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert_and_raise(
|
||||
obs.exit_code == 0,
|
||||
f'Failed to git config --global core.pager "": {obs.content}',
|
||||
f'Failed to git config --global core.pager "": {str(obs)}',
|
||||
)
|
||||
|
||||
action = CmdRunAction(command='git add -A')
|
||||
@ -321,7 +322,7 @@ def complete_runtime(
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert_and_raise(obs.exit_code == 0, f'Failed to git add -A: {obs.content}')
|
||||
assert_and_raise(obs.exit_code == 0, f'Failed to git add -A: {str(obs)}')
|
||||
|
||||
n_retries = 0
|
||||
git_patch = None
|
||||
@ -346,7 +347,9 @@ def complete_runtime(
|
||||
logger.error(f'Error occurred: {obs.content}. Retrying...')
|
||||
sleep_if_should_continue(10)
|
||||
else:
|
||||
assert_and_raise(False, f'Unexpected observation type: {type(obs)}')
|
||||
assert_and_raise(False, f'Unexpected observation type: {str(obs)}')
|
||||
|
||||
assert_and_raise(git_patch is not None, 'Failed to get git diff (None)')
|
||||
|
||||
logger.info('-' * 30)
|
||||
logger.info('END Runtime Completion Fn')
|
||||
@ -482,10 +485,6 @@ if __name__ == '__main__':
|
||||
|
||||
details = {}
|
||||
_agent_cls = openhands.agenthub.Agent.get_cls(args.agent_cls)
|
||||
if hasattr(_agent_cls, 'system_message'):
|
||||
details['system_message'] = _agent_cls.system_message
|
||||
if hasattr(_agent_cls, 'in_context_example'):
|
||||
details['in_context_example'] = _agent_cls.in_context_example
|
||||
|
||||
dataset_descrption = (
|
||||
args.dataset.replace('/', '__') + '-' + args.split.replace('/', '__')
|
||||
|
||||
67
evaluation/swe_bench/scripts/eval/compare_outputs.py
Executable file
67
evaluation/swe_bench/scripts/eval/compare_outputs.py
Executable file
@ -0,0 +1,67 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
|
||||
import pandas as pd
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Compare two swe_bench output JSONL files and print the resolved diff'
|
||||
)
|
||||
parser.add_argument('input_file_1', type=str)
|
||||
parser.add_argument('input_file_2', type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
df1 = pd.read_json(args.input_file_1, orient='records', lines=True)
|
||||
df2 = pd.read_json(args.input_file_2, orient='records', lines=True)
|
||||
|
||||
|
||||
# Get the intersection of the instance_ids
|
||||
df = pd.merge(df1, df2, on='instance_id', how='inner')
|
||||
|
||||
|
||||
def _get_resolved(report):
|
||||
if report is None:
|
||||
return False
|
||||
if isinstance(report, float):
|
||||
return False
|
||||
else:
|
||||
return report.get('resolved', False)
|
||||
|
||||
|
||||
df['resolved_x'] = df['report_x'].apply(_get_resolved)
|
||||
df['resolved_y'] = df['report_y'].apply(_get_resolved)
|
||||
df['diff'] = df.apply(lambda x: x['resolved_x'] != x['resolved_y'], axis=1)
|
||||
|
||||
df_diff = df[df['diff']].sort_values(
|
||||
by=['resolved_x', 'resolved_y'], ascending=[False, False]
|
||||
)
|
||||
# skip if any of the resolved is nan, which means one of the eval is not finished yet
|
||||
df_diff = df_diff[df_diff['resolved_x'].notna() & df_diff['resolved_y'].notna()]
|
||||
|
||||
print(f'X={args.input_file_1}')
|
||||
print(f'Y={args.input_file_2}')
|
||||
print(f'# diff={df_diff.shape[0]}')
|
||||
df_diff = df_diff[['instance_id', 'resolved_x', 'resolved_y', 'report_x', 'report_y']]
|
||||
|
||||
# x resolved but y not
|
||||
print('-' * 100)
|
||||
df_diff_x_only = df_diff[df_diff['resolved_x'] & ~df_diff['resolved_y']].sort_values(
|
||||
by='instance_id'
|
||||
)
|
||||
print(f'# x resolved but y not={df_diff_x_only.shape[0]}')
|
||||
print(df_diff_x_only[['instance_id', 'report_x', 'report_y']])
|
||||
|
||||
# y resolved but x not
|
||||
print('-' * 100)
|
||||
df_diff_y_only = df_diff[~df_diff['resolved_x'] & df_diff['resolved_y']].sort_values(
|
||||
by='instance_id'
|
||||
)
|
||||
print(f'# y resolved but x not={df_diff_y_only.shape[0]}')
|
||||
print(df_diff_y_only[['instance_id', 'report_x', 'report_y']])
|
||||
# get instance_id from df_diff_y_only
|
||||
print('-' * 100)
|
||||
print('Instances that x resolved but y not:')
|
||||
print(df_diff_x_only['instance_id'].tolist())
|
||||
|
||||
print('-' * 100)
|
||||
print('Instances that y resolved but x not:')
|
||||
print(df_diff_y_only['instance_id'].tolist())
|
||||
@ -3,6 +3,9 @@ import argparse
|
||||
import json
|
||||
from collections import Counter
|
||||
|
||||
from openhands.events.serialization import event_from_dict
|
||||
from openhands.events.utils import get_pairs_from_events
|
||||
|
||||
ERROR_KEYWORDS = [
|
||||
'Agent encountered an error while processing the last action',
|
||||
'APIError',
|
||||
@ -26,8 +29,37 @@ if __name__ == '__main__':
|
||||
|
||||
error_counter = Counter()
|
||||
|
||||
main_agent_cost = []
|
||||
editor_cost = []
|
||||
num_turns = []
|
||||
|
||||
for line in lines:
|
||||
_d = json.loads(line)
|
||||
|
||||
# Cost
|
||||
costs = _d['metrics'].get('costs', [])
|
||||
_cur_main_agent_cost = 0
|
||||
_cur_editor_cost = 0
|
||||
for cost in costs:
|
||||
if isinstance(cost, float):
|
||||
# backward compatible
|
||||
_cur_main_agent_cost += cost
|
||||
else:
|
||||
if 'draft_editor' in cost['model']:
|
||||
_cur_editor_cost += cost['cost']
|
||||
else:
|
||||
_cur_main_agent_cost += cost['cost']
|
||||
|
||||
main_agent_cost.append(_cur_main_agent_cost)
|
||||
editor_cost.append(_cur_editor_cost)
|
||||
|
||||
# Turn status
|
||||
history = _d.get('history', [])
|
||||
events = [event_from_dict(event) for event in history]
|
||||
pairs = get_pairs_from_events(events)
|
||||
num_turns.append(len(pairs))
|
||||
|
||||
# Patch & resolve status
|
||||
patch = _d.get('test_result', {}).get('git_patch', '')
|
||||
if patch == '':
|
||||
num_empty_patch += 1
|
||||
@ -38,6 +70,7 @@ if __name__ == '__main__':
|
||||
if resolved:
|
||||
num_resolved += 1
|
||||
|
||||
# Error
|
||||
error = _d.get('error', None)
|
||||
|
||||
if error is not None and isinstance(error, str):
|
||||
@ -70,7 +103,17 @@ if __name__ == '__main__':
|
||||
print(
|
||||
f'# of loop: {num_agent_stuck_in_loop} / {num_lines} ({num_agent_stuck_in_loop / num_lines * 100:.2f}%)'
|
||||
)
|
||||
assert len(num_turns) == num_lines
|
||||
assert len(main_agent_cost) == num_lines
|
||||
assert len(editor_cost) == num_lines
|
||||
print(f'Avg. num of turns per instance: {sum(num_turns) / num_lines:.2f}')
|
||||
print(f'Avg. agent cost per instance: {sum(main_agent_cost) / num_lines:.2f} USD')
|
||||
print(f'Avg. editor cost per instance: {sum(editor_cost) / num_lines:.2f} USD')
|
||||
print(
|
||||
f'Avg. total cost per instance: {(sum(main_agent_cost) + sum(editor_cost)) / num_lines:.2f} USD'
|
||||
)
|
||||
print('-' * 100)
|
||||
print('Detailed error breakdown:')
|
||||
for error, count in error_counter.items():
|
||||
print(f'{error}: {count} ({count / num_lines * 100:.2f}%)')
|
||||
print('-' * 100)
|
||||
|
||||
@ -25,8 +25,8 @@ if [ -z "$AGENT" ]; then
|
||||
fi
|
||||
|
||||
if [ -z "$MAX_ITER" ]; then
|
||||
echo "MAX_ITER not specified, use default 30"
|
||||
MAX_ITER=30
|
||||
echo "MAX_ITER not specified, use default 100"
|
||||
MAX_ITER=100
|
||||
fi
|
||||
|
||||
if [ -z "$USE_INSTANCE_IMAGE" ]; then
|
||||
|
||||
@ -1,11 +1,17 @@
|
||||
import re
|
||||
|
||||
from openhands.controller.action_parser import ActionParser, ResponseParser
|
||||
from openhands.controller.action_parser import (
|
||||
ActionParser,
|
||||
ResponseParser,
|
||||
)
|
||||
from openhands.core.exceptions import LLMMalformedActionError
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action import (
|
||||
Action,
|
||||
AgentDelegateAction,
|
||||
AgentFinishAction,
|
||||
CmdRunAction,
|
||||
FileEditAction,
|
||||
IPythonRunCellAction,
|
||||
MessageAction,
|
||||
)
|
||||
@ -14,6 +20,7 @@ from openhands.events.action import (
|
||||
class CodeActResponseParser(ResponseParser):
|
||||
"""Parser action:
|
||||
- CmdRunAction(command) - bash command to run
|
||||
- FileEditAction(path, content) - edit a file
|
||||
- IPythonRunCellAction(code) - IPython code to run
|
||||
- AgentDelegateAction(agent, inputs) - delegate action for (sub)task
|
||||
- MessageAction(content) - Message action to run (e.g. ask for clarification)
|
||||
@ -25,6 +32,7 @@ class CodeActResponseParser(ResponseParser):
|
||||
super().__init__()
|
||||
self.action_parsers = [
|
||||
CodeActActionParserFinish(),
|
||||
CodeActActionParserFileEdit(),
|
||||
CodeActActionParserCmdRun(),
|
||||
CodeActActionParserIPythonRunCell(),
|
||||
CodeActActionParserAgentDelegate(),
|
||||
@ -46,6 +54,8 @@ class CodeActResponseParser(ResponseParser):
|
||||
|
||||
if f'<execute_{lang}>' in action and f'</execute_{lang}>' not in action:
|
||||
action += f'</execute_{lang}>'
|
||||
if '<file_edit' in action and '</file_edit>' not in action:
|
||||
action += '</file_edit>'
|
||||
return action
|
||||
|
||||
def parse_action(self, action_str: str) -> Action:
|
||||
@ -186,3 +196,87 @@ class CodeActActionParserMessage(ActionParser):
|
||||
|
||||
def parse(self, action_str: str) -> Action:
|
||||
return MessageAction(content=action_str, wait_for_response=True)
|
||||
|
||||
|
||||
class CodeActActionParserFileEdit(ActionParser):
|
||||
"""Parser action:
|
||||
- FileEditAction(path, content) - edit a file
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.file_edit_match: re.Match | None = None
|
||||
|
||||
def check_condition(self, action_str: str) -> bool:
|
||||
if '<file_edit' not in action_str:
|
||||
return False
|
||||
|
||||
# Updated regex to make start and end optional
|
||||
self.file_edit_match = re.search(
|
||||
r'<file_edit\s+path=(["\']?)(.*?)\1(?:\s+start=(["\']?)(.*?)\3)?(?:\s+end=(["\']?)(.*?)\5)?\s*>(.*?)</file_edit>',
|
||||
action_str,
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
if self.file_edit_match is None:
|
||||
logger.error(
|
||||
f'FileEditAction detected but the format is incorrect. Unable to match for <file_edit> in:\n{"-" * 80}\n{action_str}\n{"-" * 80}'
|
||||
)
|
||||
raise LLMMalformedActionError(
|
||||
'FileEditAction detected but the format is incorrect. Usage:\n'
|
||||
'<file_edit path="[path]" start=[start_line] end=[end_line]>\n'
|
||||
'[content_to_edit]\n'
|
||||
'</file_edit>\n'
|
||||
)
|
||||
|
||||
path = self.file_edit_match.group(2)
|
||||
start = self.file_edit_match.group(4)
|
||||
end = self.file_edit_match.group(6)
|
||||
|
||||
if not path:
|
||||
raise LLMMalformedActionError(
|
||||
'FileEditAction detected but no `path` specified. You should specify the path of the file to edit.'
|
||||
)
|
||||
|
||||
if start:
|
||||
try:
|
||||
int(start)
|
||||
except ValueError:
|
||||
raise LLMMalformedActionError(
|
||||
f'FileEditAction detected but `start` is not a valid integer: {start}'
|
||||
)
|
||||
|
||||
if end:
|
||||
try:
|
||||
int(end)
|
||||
except ValueError:
|
||||
raise LLMMalformedActionError(
|
||||
f'FileEditAction detected but `end` is not a valid integer: {end}'
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
def parse(self, action_str: str) -> Action:
|
||||
assert (
|
||||
self.file_edit_match is not None
|
||||
), 'self.file_edit_match should not be None when parse is called'
|
||||
|
||||
file_path = self.file_edit_match.group(2).strip()
|
||||
start_line = (
|
||||
int(self.file_edit_match.group(4))
|
||||
if self.file_edit_match.group(4)
|
||||
else None
|
||||
)
|
||||
end_line = (
|
||||
int(self.file_edit_match.group(6))
|
||||
if self.file_edit_match.group(6)
|
||||
else None
|
||||
)
|
||||
content = self.file_edit_match.group(7)
|
||||
thought = action_str.replace(self.file_edit_match.group(0), '').strip()
|
||||
|
||||
action = FileEditAction(path=file_path, content=content, thought=thought)
|
||||
if start_line is not None:
|
||||
action.start = start_line
|
||||
if end_line is not None:
|
||||
action.end = end_line
|
||||
return action
|
||||
|
||||
@ -11,12 +11,14 @@ from openhands.events.action import (
|
||||
AgentDelegateAction,
|
||||
AgentFinishAction,
|
||||
CmdRunAction,
|
||||
FileEditAction,
|
||||
IPythonRunCellAction,
|
||||
MessageAction,
|
||||
)
|
||||
from openhands.events.observation import (
|
||||
AgentDelegateObservation,
|
||||
CmdOutputObservation,
|
||||
FileEditObservation,
|
||||
IPythonRunCellObservation,
|
||||
UserRejectObservation,
|
||||
)
|
||||
@ -34,7 +36,7 @@ from openhands.utils.prompt import PromptManager
|
||||
|
||||
|
||||
class CodeActAgent(Agent):
|
||||
VERSION = '1.9'
|
||||
VERSION = '2.0'
|
||||
"""
|
||||
The Code Act Agent is a minimalist agent.
|
||||
The agent works by passing the model a list of action-observation pairs and prompting the model to take the next step.
|
||||
@ -102,6 +104,8 @@ class CodeActAgent(Agent):
|
||||
return f'{action.thought}\n<execute_ipython>\n{action.code}\n</execute_ipython>'
|
||||
elif isinstance(action, AgentDelegateAction):
|
||||
return f'{action.thought}\n<execute_browse>\n{action.inputs["task"]}\n</execute_browse>'
|
||||
elif isinstance(action, FileEditAction):
|
||||
return f'{action.thought}\n<file_edit path={action.path}>\n{action.content}\n</file_edit>'
|
||||
elif isinstance(action, MessageAction):
|
||||
return action.content
|
||||
elif isinstance(action, AgentFinishAction) and action.source == 'agent':
|
||||
@ -114,6 +118,7 @@ class CodeActAgent(Agent):
|
||||
or isinstance(action, CmdRunAction)
|
||||
or isinstance(action, IPythonRunCellAction)
|
||||
or isinstance(action, MessageAction)
|
||||
or isinstance(action, FileEditAction)
|
||||
or (isinstance(action, AgentFinishAction) and action.source == 'agent')
|
||||
):
|
||||
content = [TextContent(text=self.action_to_str(action))]
|
||||
@ -151,6 +156,9 @@ class CodeActAgent(Agent):
|
||||
text = '\n'.join(splitted)
|
||||
text = truncate_content(text, max_message_chars)
|
||||
return Message(role='user', content=[TextContent(text=text)])
|
||||
elif isinstance(obs, FileEditObservation):
|
||||
text = obs_prefix + truncate_content(str(obs), max_message_chars)
|
||||
return Message(role='user', content=[TextContent(text=text)])
|
||||
elif isinstance(obs, AgentDelegateObservation):
|
||||
text = obs_prefix + truncate_content(
|
||||
obs.outputs['content'] if 'content' in obs.outputs else '',
|
||||
@ -201,6 +209,7 @@ class CodeActAgent(Agent):
|
||||
'</execute_ipython>',
|
||||
'</execute_bash>',
|
||||
'</execute_browse>',
|
||||
'</file_edit>',
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
{% set MINIMAL_SYSTEM_PREFIX %}
|
||||
A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed answers to the user's questions.
|
||||
The assistant can use a Python environment with <execute_ipython>, e.g.:
|
||||
|
||||
[1] The assistant can use a Python environment with <execute_ipython>, e.g.:
|
||||
<execute_ipython>
|
||||
print("Hello World!")
|
||||
</execute_ipython>
|
||||
The assistant can execute bash commands wrapped with <execute_bash>, e.g. <execute_bash> ls </execute_bash>.
|
||||
|
||||
[2] The assistant can execute bash commands wrapped with <execute_bash>, e.g. <execute_bash> ls </execute_bash>.
|
||||
If a bash command returns exit code `-1`, this means the process is not yet finished.
|
||||
The assistant must then send a second <execute_bash>. The second <execute_bash> can be empty
|
||||
(which will retrieve any additional logs), or it can contain text to be sent to STDIN of the running process,
|
||||
@ -14,6 +16,126 @@ For commands that may run indefinitely, the output should be redirected to a fil
|
||||
in the background, e.g. <execute_bash> python3 app.py > server.log 2>&1 & </execute_bash>
|
||||
If a command execution result says "Command timed out. Sending SIGINT to the process",
|
||||
the assistant should retry running the command in the background.
|
||||
|
||||
[3] The assistant can edit files using <file_edit> by setting the file path and providing a draft of the new file content. The draft file content does not need to be exactly the same as the existing file content; the assistant may skip some lines and only include the parts that need to be changed.
|
||||
|
||||
IMPORTANT: When editing large file (e.g., > 300 lines), the assistant MUST SPECIFY the range of lines to be edited by setting `start` and `end` (1-indexed, both inclusive). For example, `<file_edit path="/path/to/file.txt" start=1 end=-1>` means the assistant will edit the whole file (from line 1 to the end of the file). `start=1` and `end=-1` are the default values, so the assistant can omit them if they are the same as the default values.
|
||||
BEFORE you start editing, you MUST view the ENTIRE body of the part you want to edit and get the correct begin and end line numbers.
|
||||
|
||||
When editing files, the assistant should include comments indicating where the code will not change. For example, use comments like `# no changes before` or `# no changes here` to clearly mark sections of the code that remain unchanged. This helps to provide context and ensure clarity in the edits being made.
|
||||
|
||||
Possible cases:
|
||||
- File too long: When the file to be edited is too long, the assistant should set `start` and `end` (1-indexed, both inclusive) to specify the range of lines to be edited. For example, `<file_edit path="/path/to/file.txt" start=100 end=200>` means the assistant will only edit lines 100 to 200 of `/path/to/file.txt`.
|
||||
- Append to file: If the assistant wants to append to a file, it should set both `start` and `end` to `-1`.
|
||||
- File does not exist: If `<file_edit>` is pointing to a file that does not exist, a new file with the exact content will be created.
|
||||
|
||||
Important: because line numbers are useful, the assistant should always use the provided functions to search (e.g., `search_dir`) or view the file content (e.g., `open_file`) along with the line numbers. DO NOT use other methods (e.g., `cat`) to view the file content.
|
||||
|
||||
**Example 1 (general edit for short files)**
|
||||
For example, given an existing file `/path/to/file.py` that looks like this:
|
||||
|
||||
(this is the end of the file)
|
||||
1|class MyClass:
|
||||
2| def __init__(self):
|
||||
3| self.x = 1
|
||||
4| self.y = 2
|
||||
5| self.z = 3
|
||||
6|
|
||||
7|print(MyClass().z)
|
||||
8|print(MyClass().x)
|
||||
(this is the end of the file)
|
||||
|
||||
|
||||
The assistant wants to edit the file to look like this:
|
||||
|
||||
(this is the end of the file)
|
||||
1|class MyClass:
|
||||
2| def __init__(self):
|
||||
3| self.x = 1
|
||||
4| self.y = 2
|
||||
5|
|
||||
6|print(MyClass().y)
|
||||
(this is the end of the file)
|
||||
|
||||
|
||||
The assistant may produce an edit action like this:
|
||||
<file_edit path="/path/to/file.txt" start=1 end=-1>
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
# no changes before
|
||||
self.y = 2
|
||||
# self.z is removed
|
||||
|
||||
# MyClass().z is removed
|
||||
print(MyClass().y)
|
||||
</file_edit>
|
||||
|
||||
**Example 2 (append to file for short files)**
|
||||
|
||||
For example, given an existing file `/path/to/file.py` that looks like this:
|
||||
|
||||
(this is the end of the file)
|
||||
1|class MyClass:
|
||||
2| def __init__(self):
|
||||
3| self.x = 1
|
||||
4| self.y = 2
|
||||
5| self.z = 3
|
||||
6|
|
||||
7|print(MyClass().z)
|
||||
8|print(MyClass().x)
|
||||
(this is the end of the file)
|
||||
|
||||
To append the following lines to the file:
|
||||
```python
|
||||
print(MyClass().y)
|
||||
```
|
||||
|
||||
The assistant may produce an edit action like this:
|
||||
<file_edit path="/path/to/file.txt" start=-1 end=-1>
|
||||
print(MyClass().y)
|
||||
</file_edit>
|
||||
|
||||
**Example 3 (edit for long files)**
|
||||
|
||||
Given an existing file `/path/to/file.py` that looks like this:
|
||||
|
||||
(1000 more lines above)
|
||||
1001|class MyClass:
|
||||
1002| def __init__(self):
|
||||
1003| self.x = 1
|
||||
1004| self.y = 2
|
||||
1005| self.z = 3
|
||||
1006|
|
||||
1007|print(MyClass().z)
|
||||
1008|print(MyClass().x)
|
||||
(2000 more lines below)
|
||||
|
||||
|
||||
The assistant wants to edit the file to look like this:
|
||||
|
||||
(1000 more lines above)
|
||||
1001|class MyClass:
|
||||
1002| def __init__(self):
|
||||
1003| self.x = 1
|
||||
1004| self.y = 2
|
||||
1005|
|
||||
1006|print(MyClass().y)
|
||||
(2000 more lines below)
|
||||
|
||||
The assistant may produce an edit action like this:
|
||||
|
||||
<file_edit path="/path/to/file.txt" start=1001 end=1008>
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
# no changes before
|
||||
self.y = 2
|
||||
# self.z is removed
|
||||
|
||||
# MyClass().z is removed
|
||||
print(MyClass().y)
|
||||
</file_edit>
|
||||
|
||||
|
||||
{% endset %}
|
||||
{% set BROWSING_PREFIX %}
|
||||
The assistant can browse the Internet with <execute_browse> and </execute_browse>.
|
||||
@ -29,12 +151,8 @@ Apart from the standard Python library, the assistant can also use the following
|
||||
{{ agent_skills_docs }}
|
||||
IMPORTANT:
|
||||
- `open_file` only returns the first 100 lines of the file by default! The assistant MUST use `scroll_down` repeatedly to read the full file BEFORE making edits!
|
||||
- The assistant shall adhere to THE `edit_file_by_replace`, `append_file` and `insert_content_at_line` FUNCTIONS REQUIRING PROPER INDENTATION. If the assistant would like to add the line ' print(x)', it must fully write the line out, with all leading spaces before the code!
|
||||
- Indentation is important and code that is not indented correctly will fail and require fixing before it can be run.
|
||||
- Any code issued should be less than 50 lines to avoid context being cut off!
|
||||
- After EVERY `create_file` the method `append_file` shall be used to write the FIRST content!
|
||||
- For `edit_file_by_replace` NEVER provide empty parameters!
|
||||
- For `edit_file_by_replace` the file must be read fully before any replacements!
|
||||
{% endset %}
|
||||
{% set SYSTEM_SUFFIX %}
|
||||
Responses should be concise.
|
||||
|
||||
@ -4,23 +4,19 @@
|
||||
USER: Create a list of numbers from 1 to 10, and display them in a web page at port 5000.
|
||||
|
||||
ASSISTANT:
|
||||
Sure! Let me create a Python file `app.py`:
|
||||
<execute_ipython>
|
||||
create_file('app.py')
|
||||
</execute_ipython>
|
||||
Sure! Let me first check the current directory:
|
||||
<execute_bash>
|
||||
ls
|
||||
</execute_bash>
|
||||
|
||||
USER:
|
||||
OBSERVATION:
|
||||
[File: /workspace/app.py (1 lines total)]
|
||||
(this is the beginning of the file)
|
||||
1|
|
||||
(this is the end of the file)
|
||||
[File app.py created.]
|
||||
openhands@runtime:~/workspace$
|
||||
|
||||
ASSISTANT:
|
||||
Now I will write the Python code for starting a web server and save it to the file `app.py`:
|
||||
<execute_ipython>
|
||||
EDITED_CODE="""from flask import Flask
|
||||
There is no `app.py` file in the current directory. Let me create a Python file `app.py`:
|
||||
<file_edit path="/workspace/app.py" start=-1 end=-1>
|
||||
from flask import Flask
|
||||
app = Flask(__name__)
|
||||
|
||||
@app.route('/')
|
||||
@ -29,30 +25,27 @@ def index():
|
||||
return str(numbers)
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(port=5000)"""
|
||||
|
||||
insert_content_at_line(
|
||||
'app.py',
|
||||
1,
|
||||
EDITED_CODE,
|
||||
)
|
||||
</execute_ipython>
|
||||
app.run(port=5000)
|
||||
</file_edit>
|
||||
|
||||
USER:
|
||||
OBSERVATION:
|
||||
(this is the beginning of the file)
|
||||
1|from flask import Flask
|
||||
2|app = Flask(__name__)
|
||||
3|
|
||||
4|@app.route('/')
|
||||
5|def index():
|
||||
6| numbers = list(range(1, 11))
|
||||
7| return str(numbers)
|
||||
8|
|
||||
9|if __name__ == '__main__':
|
||||
10| app.run(port=5000)
|
||||
(this is the end of the file)
|
||||
[File updated (edited at line 1). Please review the changes and make sure they are correct (correct indentation, no duplicate lines, etc). Edit the file again if necessary.]
|
||||
[New file /workspace/app.py is created.]
|
||||
(begin of changes)
|
||||
--- /workspace/app.py
|
||||
+++ /workspace/app.py
|
||||
@@ -1 +1,10 @@
|
||||
+from flask import Flask
|
||||
+app = Flask(__name__)
|
||||
|
||||
+@app.route('/')
|
||||
+def index():
|
||||
+ numbers = list(range(1, 11))
|
||||
+ return str(numbers)
|
||||
+
|
||||
+if __name__ == '__main__':
|
||||
+ app.run(port=5000)
|
||||
(end of changes)
|
||||
|
||||
ASSISTANT:
|
||||
I have created a Python file `app.py` that will display a list of numbers from 1 to 10 when you run it. Let me run the Python file for you:
|
||||
@ -177,16 +170,17 @@ USER:
|
||||
|
||||
ASSISTANT:
|
||||
I should edit the file to display the numbers in a table format. I should include correct indentation. Let me update the file:
|
||||
<execute_ipython>
|
||||
edit_file_by_replace(
|
||||
'app.py',
|
||||
to_replace=" return str(numbers)",
|
||||
new_content=" return '<table>' + ''.join([f'<tr><td>{i}</td></tr>' for i in numbers]) + '</table>'",
|
||||
)
|
||||
</execute_ipython>
|
||||
<file_edit path="/workspace/app.py" start=4 end=7>
|
||||
@app.route('/')
|
||||
def index():
|
||||
numbers = list(range(1, 11))
|
||||
ret = '<table>' + ''.join([f'<tr><td>{i}</td></tr>' for i in numbers]) + '</table>'
|
||||
return ret
|
||||
</file_edit>
|
||||
|
||||
USER:
|
||||
Observation:
|
||||
[Edited existing file /workspace/app.py]
|
||||
[File: /workspace/app.py (10 lines total after edit)]
|
||||
(this is the beginning of the file)
|
||||
1|from flask import Flask
|
||||
@ -195,10 +189,11 @@ Observation:
|
||||
4|@app.route('/')
|
||||
5|def index():
|
||||
6| numbers = list(range(1, 11))
|
||||
7| return '<table>' + ''.join([f'<tr><td>{i}</td></tr>' for i in numbers]) + '</table>'
|
||||
8|
|
||||
9|if __name__ == '__main__':
|
||||
10| app.run(port=5000)
|
||||
7| ret = '<table>' + ''.join([f'<tr><td>{i}</td></tr>' for i in numbers]) + '</table>'
|
||||
8| return ret
|
||||
9|
|
||||
10|if __name__ == '__main__':
|
||||
11| app.run(port=5000)
|
||||
(this is the end of the file)
|
||||
[File updated (edited at line 7). Please review the changes and make sure they are correct (correct indentation, no duplicate lines, etc). Edit the file again if necessary.]
|
||||
|
||||
|
||||
@ -3,6 +3,16 @@ from abc import ABC, abstractmethod
|
||||
from openhands.events.action import Action
|
||||
|
||||
|
||||
class ActionParseError(Exception):
|
||||
"""Exception raised when the response from the LLM cannot be parsed into an action."""
|
||||
|
||||
def __init__(self, error: str):
|
||||
self.error = error
|
||||
|
||||
def __str__(self):
|
||||
return self.error
|
||||
|
||||
|
||||
class ResponseParser(ABC):
|
||||
"""This abstract base class is a general interface for an response parser dedicated to
|
||||
parsing the action from the response from the LLM.
|
||||
|
||||
@ -6,12 +6,12 @@ from typing import Any
|
||||
|
||||
from openhands.controller.state.task import RootTask
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.metrics import Metrics
|
||||
from openhands.core.schema import AgentState
|
||||
from openhands.events.action import (
|
||||
MessageAction,
|
||||
)
|
||||
from openhands.events.action.agent import AgentFinishAction
|
||||
from openhands.llm.metrics import Metrics
|
||||
from openhands.memory.history import ShortTermHistory
|
||||
from openhands.storage.files import FileStore
|
||||
|
||||
|
||||
@ -19,12 +19,14 @@ from openhands.events.action import (
|
||||
Action,
|
||||
ChangeAgentStateAction,
|
||||
CmdRunAction,
|
||||
FileEditAction,
|
||||
MessageAction,
|
||||
)
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.observation import (
|
||||
AgentStateChangedObservation,
|
||||
CmdOutputObservation,
|
||||
FileEditObservation,
|
||||
)
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.runtime import get_runtime_cls
|
||||
@ -50,6 +52,10 @@ def display_command_output(output: str):
|
||||
print('\n')
|
||||
|
||||
|
||||
def display_file_edit(event: FileEditAction | FileEditObservation):
|
||||
print(colored(str(event), 'green'))
|
||||
|
||||
|
||||
def display_event(event: Event):
|
||||
if isinstance(event, Action):
|
||||
if hasattr(event, 'thought'):
|
||||
@ -61,6 +67,10 @@ def display_event(event: Event):
|
||||
display_command(event.command)
|
||||
if isinstance(event, CmdOutputObservation):
|
||||
display_command_output(event.content)
|
||||
if isinstance(event, FileEditAction):
|
||||
display_file_edit(event)
|
||||
if isinstance(event, FileEditObservation):
|
||||
display_file_edit(event)
|
||||
|
||||
|
||||
async def main():
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import os
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Optional
|
||||
|
||||
from openhands.core.config.config_utils import get_field_info
|
||||
|
||||
@ -39,6 +40,7 @@ class LLMConfig:
|
||||
disable_vision: If model is vision capable, this option allows to disable image processing (useful for cost reduction).
|
||||
caching_prompt: Use the prompt caching feature if provided by the LLM and supported by the provider.
|
||||
log_completions: Whether to log LLM completions to the state.
|
||||
draft_editor: A more efficient LLM to use for file editing. Introduced in [PR 3985](https://github.com/All-Hands-AI/OpenHands/pull/3985).
|
||||
"""
|
||||
|
||||
model: str = 'gpt-4o'
|
||||
@ -71,6 +73,7 @@ class LLMConfig:
|
||||
disable_vision: bool | None = None
|
||||
caching_prompt: bool = True
|
||||
log_completions: bool = False
|
||||
draft_editor: Optional['LLMConfig'] = None
|
||||
|
||||
def defaults_to_dict(self) -> dict:
|
||||
"""Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional."""
|
||||
@ -113,4 +116,19 @@ class LLMConfig:
|
||||
for k, v in ret.items():
|
||||
if k in LLM_SENSITIVE_FIELDS:
|
||||
ret[k] = '******' if v else None
|
||||
elif isinstance(v, LLMConfig):
|
||||
ret[k] = v.to_safe_dict()
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, llm_config_dict: dict) -> 'LLMConfig':
|
||||
"""Create an LLMConfig object from a dictionary.
|
||||
|
||||
This function is used to create an LLMConfig object from a dictionary,
|
||||
with the exception of the 'draft_editor' key, which is a nested LLMConfig object.
|
||||
"""
|
||||
args = {k: v for k, v in llm_config_dict.items() if not isinstance(v, dict)}
|
||||
if 'draft_editor' in llm_config_dict:
|
||||
draft_editor_config = LLMConfig(**llm_config_dict['draft_editor'])
|
||||
args['draft_editor'] = draft_editor_config
|
||||
return cls(**args)
|
||||
|
||||
@ -136,17 +136,14 @@ def load_from_toml(cfg: AppConfig, toml_file: str = 'config.toml'):
|
||||
logger.openhands_logger.debug(
|
||||
'Attempt to load default LLM config from config toml'
|
||||
)
|
||||
non_dict_fields = {
|
||||
k: v for k, v in value.items() if not isinstance(v, dict)
|
||||
}
|
||||
llm_config = LLMConfig(**non_dict_fields)
|
||||
llm_config = LLMConfig.from_dict(value)
|
||||
cfg.set_llm_config(llm_config, 'llm')
|
||||
for nested_key, nested_value in value.items():
|
||||
if isinstance(nested_value, dict):
|
||||
logger.openhands_logger.debug(
|
||||
f'Attempt to load group {nested_key} from config toml as llm config'
|
||||
)
|
||||
llm_config = LLMConfig(**nested_value)
|
||||
llm_config = LLMConfig.from_dict(nested_value)
|
||||
cfg.set_llm_config(llm_config, nested_key)
|
||||
elif not key.startswith('sandbox') and key.lower() != 'core':
|
||||
logger.openhands_logger.warning(
|
||||
@ -272,7 +269,7 @@ def get_llm_config_arg(
|
||||
|
||||
# update the llm config with the specified section
|
||||
if 'llm' in toml_config and llm_config_arg in toml_config['llm']:
|
||||
return LLMConfig(**toml_config['llm'][llm_config_arg])
|
||||
return LLMConfig.from_dict(toml_config['llm'][llm_config_arg])
|
||||
logger.openhands_logger.debug(f'Loading from toml failed for {llm_config_arg}')
|
||||
return None
|
||||
|
||||
|
||||
@ -52,8 +52,12 @@ class BrowserUnavailableException(Exception):
|
||||
# It might be malformed JSON
|
||||
class LLMMalformedActionError(Exception):
|
||||
def __init__(self, message='Malformed response'):
|
||||
self.message = message
|
||||
super().__init__(message)
|
||||
|
||||
def __str__(self):
|
||||
return self.message
|
||||
|
||||
|
||||
# This exception gets sent back to the LLM
|
||||
# For some reason, the agent did not return an action
|
||||
|
||||
@ -119,11 +119,14 @@ class SensitiveDataFilter(logging.Filter):
|
||||
return True
|
||||
|
||||
|
||||
def get_console_handler(log_level=logging.INFO):
|
||||
def get_console_handler(log_level=logging.INFO, extra_info: str | None = None):
|
||||
"""Returns a console handler for logging."""
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(log_level)
|
||||
console_handler.setFormatter(console_formatter)
|
||||
formatter_str = '%(asctime)s - %(levelname)s - %(message)s'
|
||||
if extra_info:
|
||||
formatter_str = f'{extra_info} - ' + formatter_str
|
||||
console_handler.setFormatter(logging.Formatter(formatter_str))
|
||||
return console_handler
|
||||
|
||||
|
||||
|
||||
@ -24,6 +24,10 @@ class ActionTypeSchema(BaseModel):
|
||||
"""Writes the content to a file.
|
||||
"""
|
||||
|
||||
EDIT: str = Field(default='edit')
|
||||
"""Edits a file by providing a draft.
|
||||
"""
|
||||
|
||||
RUN: str = Field(default='run')
|
||||
"""Runs a command.
|
||||
"""
|
||||
|
||||
@ -10,6 +10,8 @@ class ObservationTypeSchema(BaseModel):
|
||||
|
||||
WRITE: str = Field(default='write')
|
||||
|
||||
EDIT: str = Field(default='edit')
|
||||
|
||||
BROWSE: str = Field(default='browse')
|
||||
"""The HTML content of a URL
|
||||
"""
|
||||
|
||||
@ -6,6 +6,7 @@ from json_repair import repair_json
|
||||
from openhands.core.exceptions import LLMResponseError
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.serialization import event_to_dict
|
||||
from openhands.llm.metrics import Metrics
|
||||
|
||||
|
||||
def my_default_encoder(obj):
|
||||
@ -14,6 +15,8 @@ def my_default_encoder(obj):
|
||||
return obj.isoformat()
|
||||
if isinstance(obj, Event):
|
||||
return event_to_dict(obj)
|
||||
if isinstance(obj, Metrics):
|
||||
return obj.get()
|
||||
return json.JSONEncoder().default(obj)
|
||||
|
||||
|
||||
|
||||
@ -9,7 +9,11 @@ from openhands.events.action.agent import (
|
||||
from openhands.events.action.browse import BrowseInteractiveAction, BrowseURLAction
|
||||
from openhands.events.action.commands import CmdRunAction, IPythonRunCellAction
|
||||
from openhands.events.action.empty import NullAction
|
||||
from openhands.events.action.files import FileReadAction, FileWriteAction
|
||||
from openhands.events.action.files import (
|
||||
FileEditAction,
|
||||
FileReadAction,
|
||||
FileWriteAction,
|
||||
)
|
||||
from openhands.events.action.message import MessageAction
|
||||
from openhands.events.action.tasks import AddTaskAction, ModifyTaskAction
|
||||
|
||||
@ -21,6 +25,7 @@ __all__ = [
|
||||
'BrowseInteractiveAction',
|
||||
'FileReadAction',
|
||||
'FileWriteAction',
|
||||
'FileEditAction',
|
||||
'AgentFinishAction',
|
||||
'AgentRejectAction',
|
||||
'AgentDelegateAction',
|
||||
|
||||
@ -27,6 +27,11 @@ class FileReadAction(Action):
|
||||
|
||||
@dataclass
|
||||
class FileWriteAction(Action):
|
||||
"""Writes a file to a given path.
|
||||
Can be set to write specific lines using start and end
|
||||
Default lines 0:-1 (whole file)
|
||||
"""
|
||||
|
||||
path: str
|
||||
content: str
|
||||
start: int = 0
|
||||
@ -39,3 +44,31 @@ class FileWriteAction(Action):
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return f'Writing file: {self.path}'
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileEditAction(Action):
|
||||
"""Edits a file by provided a draft at a given path.
|
||||
|
||||
Can be set to edit specific lines using start and end (1-index, inclusive) if the file is too long.
|
||||
Default lines 1:-1 (whole file).
|
||||
|
||||
If start is set to -1, the FileEditAction will simply append the content to the file.
|
||||
"""
|
||||
|
||||
path: str
|
||||
content: str
|
||||
start: int = 1
|
||||
end: int = -1
|
||||
thought: str = ''
|
||||
action: str = ActionType.EDIT
|
||||
runnable: ClassVar[bool] = True
|
||||
security_risk: ActionSecurityRisk | None = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
ret = '**FileEditAction**\n'
|
||||
ret += f'Thought: {self.thought}\n'
|
||||
ret += f'Range: [L{self.start}:L{self.end}]\n'
|
||||
ret += f'Path: [{self.path}]\n'
|
||||
ret += f'Content:\n```\n{self.content}\n```\n'
|
||||
return ret
|
||||
|
||||
@ -2,7 +2,7 @@ from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from openhands.core.metrics import Metrics
|
||||
from openhands.llm.metrics import Metrics
|
||||
|
||||
|
||||
class EventSource(str, Enum):
|
||||
|
||||
@ -8,6 +8,7 @@ from openhands.events.observation.delegate import AgentDelegateObservation
|
||||
from openhands.events.observation.empty import NullObservation
|
||||
from openhands.events.observation.error import ErrorObservation, FatalErrorObservation
|
||||
from openhands.events.observation.files import (
|
||||
FileEditObservation,
|
||||
FileReadObservation,
|
||||
FileWriteObservation,
|
||||
)
|
||||
@ -23,6 +24,7 @@ __all__ = [
|
||||
'BrowserOutputObservation',
|
||||
'FileReadObservation',
|
||||
'FileWriteObservation',
|
||||
'FileEditObservation',
|
||||
'ErrorObservation',
|
||||
'FatalErrorObservation',
|
||||
'AgentStateChangedObservation',
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
from openhands.core.schema import ObservationType
|
||||
from openhands.events.observation.observation import Observation
|
||||
@ -26,3 +27,107 @@ class FileWriteObservation(Observation):
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return f'I wrote to the file {self.path}.'
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileEditObservation(Observation):
|
||||
"""This data class represents a file edit operation"""
|
||||
|
||||
# content: str will be a unified diff patch string include NO context lines
|
||||
path: str
|
||||
prev_exist: bool
|
||||
old_content: str
|
||||
new_content: str
|
||||
observation: str = ObservationType.EDIT
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return f'I edited the file {self.path}.'
|
||||
|
||||
def get_edit_groups(self, n_context_lines: int = 2) -> list[dict[str, list[str]]]:
|
||||
"""Get the edit groups of the file edit."""
|
||||
old_lines = self.old_content.split('\n')
|
||||
new_lines = self.new_content.split('\n')
|
||||
# Borrowed from difflib.unified_diff to directly parse into structured format.
|
||||
edit_groups: list[dict] = []
|
||||
for group in SequenceMatcher(None, old_lines, new_lines).get_grouped_opcodes(
|
||||
n_context_lines
|
||||
):
|
||||
# take the max line number in the group
|
||||
_indent_pad_size = len(str(group[-1][3])) + 1 # +1 for the "*" prefix
|
||||
cur_group: dict[str, list[str]] = {
|
||||
'before_edits': [],
|
||||
'after_edits': [],
|
||||
}
|
||||
for tag, i1, i2, j1, j2 in group:
|
||||
if tag == 'equal':
|
||||
for idx, line in enumerate(old_lines[i1:i2]):
|
||||
cur_group['before_edits'].append(
|
||||
f'{i1+idx+1:>{_indent_pad_size}}|{line}'
|
||||
)
|
||||
for idx, line in enumerate(new_lines[j1:j2]):
|
||||
cur_group['after_edits'].append(
|
||||
f'{j1+idx+1:>{_indent_pad_size}}|{line}'
|
||||
)
|
||||
continue
|
||||
if tag in {'replace', 'delete'}:
|
||||
for idx, line in enumerate(old_lines[i1:i2]):
|
||||
cur_group['before_edits'].append(
|
||||
f'-{i1+idx+1:>{_indent_pad_size-1}}|{line}'
|
||||
)
|
||||
if tag in {'replace', 'insert'}:
|
||||
for idx, line in enumerate(new_lines[j1:j2]):
|
||||
cur_group['after_edits'].append(
|
||||
f'+{j1+idx+1:>{_indent_pad_size-1}}|{line}'
|
||||
)
|
||||
edit_groups.append(cur_group)
|
||||
return edit_groups
|
||||
|
||||
def visualize_diff(
|
||||
self,
|
||||
n_context_lines: int = 2,
|
||||
change_applied: bool = True,
|
||||
) -> str:
|
||||
"""Visualize the diff of the file edit.
|
||||
|
||||
Instead of showing the diff line by line, this function
|
||||
shows each hunk of changes as a separate entity.
|
||||
|
||||
Args:
|
||||
n_context_lines: The number of lines of context to show before and after the changes.
|
||||
change_applied: Whether the changes are applied to the file. If true, the file have been modified. If not, the file is not modified (due to linting errors).
|
||||
"""
|
||||
if change_applied and self.content.strip() == '':
|
||||
# diff patch is empty
|
||||
return '(no changes detected. Please make sure your edits changes the content of the existing file.)\n'
|
||||
|
||||
edit_groups = self.get_edit_groups(n_context_lines=n_context_lines)
|
||||
|
||||
result = [
|
||||
f'[Existing file {self.path} is edited with {len(edit_groups)} changes.]'
|
||||
if change_applied
|
||||
else f"[Changes are NOT applied to {self.path} - Here's how the file looks like if changes are applied.]"
|
||||
]
|
||||
|
||||
op_type = 'edit' if change_applied else 'ATTEMPTED edit'
|
||||
for i, cur_edit_group in enumerate(edit_groups):
|
||||
if i != 0:
|
||||
result.append('-------------------------')
|
||||
result.append(f'[begin of {op_type} {i+1} / {len(edit_groups)}]')
|
||||
result.append(f'(content before {op_type})')
|
||||
result.extend(cur_edit_group['before_edits'])
|
||||
result.append(f'(content after {op_type})')
|
||||
result.extend(cur_edit_group['after_edits'])
|
||||
result.append(f'[end of {op_type} {i+1} / {len(edit_groups)}]')
|
||||
return '\n'.join(result)
|
||||
|
||||
def __str__(self) -> str:
|
||||
ret = ''
|
||||
if not self.prev_exist:
|
||||
assert (
|
||||
self.old_content == ''
|
||||
), 'old_content should be empty if the file is new (prev_exist=False).'
|
||||
ret += f'[New file {self.path} is created with the provided content.]\n'
|
||||
return ret.rstrip() + '\n'
|
||||
ret += self.visualize_diff()
|
||||
return ret.rstrip() + '\n'
|
||||
|
||||
@ -12,7 +12,11 @@ from openhands.events.action.commands import (
|
||||
IPythonRunCellAction,
|
||||
)
|
||||
from openhands.events.action.empty import NullAction
|
||||
from openhands.events.action.files import FileReadAction, FileWriteAction
|
||||
from openhands.events.action.files import (
|
||||
FileEditAction,
|
||||
FileReadAction,
|
||||
FileWriteAction,
|
||||
)
|
||||
from openhands.events.action.message import MessageAction
|
||||
from openhands.events.action.tasks import AddTaskAction, ModifyTaskAction
|
||||
|
||||
@ -24,6 +28,7 @@ actions = (
|
||||
BrowseInteractiveAction,
|
||||
FileReadAction,
|
||||
FileWriteAction,
|
||||
FileEditAction,
|
||||
AgentFinishAction,
|
||||
AgentRejectAction,
|
||||
AgentDelegateAction,
|
||||
|
||||
@ -7,7 +7,11 @@ from openhands.events.observation.commands import (
|
||||
from openhands.events.observation.delegate import AgentDelegateObservation
|
||||
from openhands.events.observation.empty import NullObservation
|
||||
from openhands.events.observation.error import ErrorObservation
|
||||
from openhands.events.observation.files import FileReadObservation, FileWriteObservation
|
||||
from openhands.events.observation.files import (
|
||||
FileEditObservation,
|
||||
FileReadObservation,
|
||||
FileWriteObservation,
|
||||
)
|
||||
from openhands.events.observation.observation import Observation
|
||||
from openhands.events.observation.reject import UserRejectObservation
|
||||
from openhands.events.observation.success import SuccessObservation
|
||||
@ -19,6 +23,7 @@ observations = (
|
||||
BrowserOutputObservation,
|
||||
FileReadObservation,
|
||||
FileWriteObservation,
|
||||
FileEditObservation,
|
||||
AgentDelegateObservation,
|
||||
SuccessObservation,
|
||||
ErrorObservation,
|
||||
|
||||
@ -24,8 +24,8 @@ from litellm.types.utils import CostPerToken, ModelResponse, Usage
|
||||
from openhands.core.exceptions import CloudFlareBlockageError
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.message import Message
|
||||
from openhands.core.metrics import Metrics
|
||||
from openhands.llm.debug_mixin import DebugMixin
|
||||
from openhands.llm.metrics import Metrics
|
||||
from openhands.llm.retry_mixin import RetryMixin
|
||||
|
||||
__all__ = ['LLM']
|
||||
@ -73,7 +73,9 @@ class LLM(RetryMixin, DebugMixin):
|
||||
config: The LLM configuration.
|
||||
metrics: The metrics to use.
|
||||
"""
|
||||
self.metrics: Metrics = metrics if metrics is not None else Metrics()
|
||||
self.metrics: Metrics = (
|
||||
metrics if metrics is not None else Metrics(model_name=config.model)
|
||||
)
|
||||
self.cost_metric_supported: bool = True
|
||||
self.config: LLMConfig = copy.deepcopy(config)
|
||||
|
||||
@ -396,7 +398,7 @@ class LLM(RetryMixin, DebugMixin):
|
||||
return str(self)
|
||||
|
||||
def reset(self):
|
||||
self.metrics = Metrics()
|
||||
self.metrics.reset()
|
||||
self.llm_completions = []
|
||||
|
||||
def format_messages_for_llm(self, messages: Message | list[Message]) -> list[dict]:
|
||||
|
||||
@ -1,12 +1,24 @@
|
||||
import time
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Cost(BaseModel):
|
||||
model: str
|
||||
cost: float
|
||||
timestamp: float = Field(default_factory=time.time)
|
||||
|
||||
|
||||
class Metrics:
|
||||
"""Metrics class can record various metrics during running and evaluation.
|
||||
Currently, we define the following metrics:
|
||||
accumulated_cost: the total cost (USD $) of the current LLM.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, model_name: str = 'default') -> None:
|
||||
self._accumulated_cost: float = 0.0
|
||||
self._costs: list[float] = []
|
||||
self._costs: list[Cost] = []
|
||||
self.model_name = model_name
|
||||
|
||||
@property
|
||||
def accumulated_cost(self) -> float:
|
||||
@ -19,22 +31,29 @@ class Metrics:
|
||||
self._accumulated_cost = value
|
||||
|
||||
@property
|
||||
def costs(self) -> list:
|
||||
def costs(self) -> list[Cost]:
|
||||
return self._costs
|
||||
|
||||
def add_cost(self, value: float) -> None:
|
||||
if value < 0:
|
||||
raise ValueError('Added cost cannot be negative.')
|
||||
self._accumulated_cost += value
|
||||
self._costs.append(value)
|
||||
self._costs.append(Cost(cost=value, model=self.model_name))
|
||||
|
||||
def merge(self, other: 'Metrics') -> None:
|
||||
self._accumulated_cost += other.accumulated_cost
|
||||
self._costs += other._costs
|
||||
|
||||
def get(self):
|
||||
def get(self) -> dict:
|
||||
"""Return the metrics in a dictionary."""
|
||||
return {'accumulated_cost': self._accumulated_cost, 'costs': self._costs}
|
||||
return {
|
||||
'accumulated_cost': self._accumulated_cost,
|
||||
'costs': [cost.model_dump() for cost in self._costs],
|
||||
}
|
||||
|
||||
def reset(self):
|
||||
self._accumulated_cost = 0.0
|
||||
self._costs = []
|
||||
|
||||
def log(self):
|
||||
"""Log the metrics."""
|
||||
@ -98,7 +98,7 @@ class RemoteRuntimeBuilder(RuntimeBuilder):
|
||||
'EXPIRED',
|
||||
]:
|
||||
error_message = status_data.get(
|
||||
'error', f'Build failed with status: {status}'
|
||||
'error', f'Build failed with status: {status}. Build ID: {build_id}'
|
||||
)
|
||||
logger.error(error_message)
|
||||
raise RuntimeError(error_message)
|
||||
|
||||
@ -297,7 +297,8 @@ class RuntimeClient:
|
||||
self.pwd = os.path.expanduser(working_dir)
|
||||
|
||||
# re-assemble the prompt
|
||||
prompt = f'{other_info.strip()}\n{username}@{hostname}:{working_dir} '
|
||||
# ignore the hostname AND use 'openhands-workspace'
|
||||
prompt = f'{other_info.strip()}\n{username}@openhands-workspace:{working_dir} '
|
||||
if username == 'root':
|
||||
prompt += '#'
|
||||
else:
|
||||
@ -504,7 +505,9 @@ class RuntimeClient:
|
||||
# NOTE: this is part of initialization, so we hard code the timeout
|
||||
result, exit_code = self._execute_bash('pwd', timeout=60, keep_prompt=False)
|
||||
if exit_code != 0:
|
||||
raise RuntimeError('Failed to get working directory')
|
||||
raise RuntimeError(
|
||||
f'Failed to get working directory (exit code: {exit_code}): {result}'
|
||||
)
|
||||
return result.strip()
|
||||
|
||||
def _resolve_path(self, path: str, working_dir: str) -> str:
|
||||
|
||||
@ -17,6 +17,7 @@ from openhands.events.action import (
|
||||
BrowseInteractiveAction,
|
||||
BrowseURLAction,
|
||||
CmdRunAction,
|
||||
FileEditAction,
|
||||
FileReadAction,
|
||||
FileWriteAction,
|
||||
IPythonRunCellAction,
|
||||
@ -429,6 +430,9 @@ class EventStreamRuntime(Runtime):
|
||||
self.docker_client.close()
|
||||
|
||||
def run_action(self, action: Action) -> Observation:
|
||||
if isinstance(action, FileEditAction):
|
||||
return self.edit(action)
|
||||
|
||||
# set timeout to default if not set
|
||||
if action.timeout is None:
|
||||
action.timeout = self.config.sandbox.timeout
|
||||
|
||||
@ -7,20 +7,12 @@ Functions:
|
||||
- goto_line(line_number: int): Moves the window to show the specified line number.
|
||||
- scroll_down(): Moves the window down by the number of lines specified in WINDOW.
|
||||
- scroll_up(): Moves the window up by the number of lines specified in WINDOW.
|
||||
- create_file(filename: str): Creates and opens a new file with the given name.
|
||||
- search_dir(search_term: str, dir_path: str = './'): Searches for a term in all files in the specified directory.
|
||||
- search_file(search_term: str, file_path: str | None = None): Searches for a term in the specified file or the currently open file.
|
||||
- find_file(file_name: str, dir_path: str = './'): Finds all files with the given name in the specified directory.
|
||||
- edit_file_by_replace(file_name: str, to_replace: str, new_content: str): Replaces specific content in a file with new content.
|
||||
- insert_content_at_line(file_name: str, line_number: int, content: str): Inserts given content at the specified line number in a file.
|
||||
- append_file(file_name: str, content: str): Appends the given content to the end of the specified file.
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
import uuid
|
||||
|
||||
from openhands.linter import DefaultLinter, LintResult
|
||||
|
||||
@ -205,7 +197,7 @@ def open_file(
|
||||
output += _print_window(
|
||||
CURRENT_FILE,
|
||||
CURRENT_LINE,
|
||||
_clamp(context_lines, 1, 300),
|
||||
_clamp(context_lines, 1, 100),
|
||||
return_str=True,
|
||||
ignore_window=False,
|
||||
)
|
||||
@ -276,531 +268,10 @@ def scroll_up() -> None:
|
||||
print(output)
|
||||
|
||||
|
||||
def create_file(filename: str) -> None:
|
||||
"""Creates and opens a new file with the given name.
|
||||
|
||||
Args:
|
||||
filename: str: The name of the file to create.
|
||||
"""
|
||||
if os.path.exists(filename):
|
||||
_output_error(f"File '{filename}' already exists.")
|
||||
return
|
||||
|
||||
with open(filename, 'w') as file:
|
||||
file.write('\n')
|
||||
|
||||
open_file(filename)
|
||||
print(f'[File {filename} created.]')
|
||||
|
||||
|
||||
class LineNumberError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _append_impl(lines, content):
|
||||
"""Internal method to handle appending to a file.
|
||||
|
||||
Args:
|
||||
lines: list[str]: The lines in the original file.
|
||||
content: str: The content to append to the file.
|
||||
|
||||
Returns:
|
||||
content: str: The new content of the file.
|
||||
n_added_lines: int: The number of lines added to the file.
|
||||
"""
|
||||
content_lines = content.splitlines(keepends=True)
|
||||
n_added_lines = len(content_lines)
|
||||
if lines and not (len(lines) == 1 and lines[0].strip() == ''):
|
||||
# file is not empty
|
||||
if not lines[-1].endswith('\n'):
|
||||
lines[-1] += '\n'
|
||||
new_lines = lines + content_lines
|
||||
content = ''.join(new_lines)
|
||||
else:
|
||||
# file is empty
|
||||
content = ''.join(content_lines)
|
||||
|
||||
return content, n_added_lines
|
||||
|
||||
|
||||
def _insert_impl(lines, start, content):
|
||||
"""Internal method to handle inserting to a file.
|
||||
|
||||
Args:
|
||||
lines: list[str]: The lines in the original file.
|
||||
start: int: The start line number for inserting.
|
||||
content: str: The content to insert to the file.
|
||||
|
||||
Returns:
|
||||
content: str: The new content of the file.
|
||||
n_added_lines: int: The number of lines added to the file.
|
||||
|
||||
Raises:
|
||||
LineNumberError: If the start line number is invalid.
|
||||
"""
|
||||
inserted_lines = [content + '\n' if not content.endswith('\n') else content]
|
||||
if len(lines) == 0:
|
||||
new_lines = inserted_lines
|
||||
elif start is not None:
|
||||
if len(lines) == 1 and lines[0].strip() == '':
|
||||
# if the file with only 1 line and that line is empty
|
||||
lines = []
|
||||
|
||||
if len(lines) == 0:
|
||||
new_lines = inserted_lines
|
||||
else:
|
||||
new_lines = lines[: start - 1] + inserted_lines + lines[start - 1 :]
|
||||
else:
|
||||
raise LineNumberError(
|
||||
f'Invalid line number: {start}. Line numbers must be between 1 and {len(lines)} (inclusive).'
|
||||
)
|
||||
|
||||
content = ''.join(new_lines)
|
||||
n_added_lines = len(inserted_lines)
|
||||
return content, n_added_lines
|
||||
|
||||
|
||||
def _edit_impl(lines, start, end, content):
|
||||
"""Internal method to handle editing a file.
|
||||
|
||||
REQUIRES (should be checked by caller):
|
||||
start <= end
|
||||
start and end are between 1 and len(lines) (inclusive)
|
||||
content ends with a newline
|
||||
|
||||
Args:
|
||||
lines: list[str]: The lines in the original file.
|
||||
start: int: The start line number for editing.
|
||||
end: int: The end line number for editing.
|
||||
content: str: The content to replace the lines with.
|
||||
|
||||
Returns:
|
||||
content: str: The new content of the file.
|
||||
n_added_lines: int: The number of lines added to the file.
|
||||
"""
|
||||
# Handle cases where start or end are None
|
||||
if start is None:
|
||||
start = 1 # Default to the beginning
|
||||
if end is None:
|
||||
end = len(lines) # Default to the end
|
||||
# Check arguments
|
||||
if not (1 <= start <= len(lines)):
|
||||
raise LineNumberError(
|
||||
f'Invalid start line number: {start}. Line numbers must be between 1 and {len(lines)} (inclusive).'
|
||||
)
|
||||
if not (1 <= end <= len(lines)):
|
||||
raise LineNumberError(
|
||||
f'Invalid end line number: {end}. Line numbers must be between 1 and {len(lines)} (inclusive).'
|
||||
)
|
||||
if start > end:
|
||||
raise LineNumberError(
|
||||
f'Invalid line range: {start}-{end}. Start must be less than or equal to end.'
|
||||
)
|
||||
|
||||
if not content.endswith('\n'):
|
||||
content += '\n'
|
||||
content_lines = content.splitlines(True)
|
||||
n_added_lines = len(content_lines)
|
||||
new_lines = lines[: start - 1] + content_lines + lines[end:]
|
||||
content = ''.join(new_lines)
|
||||
return content, n_added_lines
|
||||
|
||||
|
||||
def _edit_file_impl(
|
||||
file_name: str,
|
||||
start: int | None = None,
|
||||
end: int | None = None,
|
||||
content: str = '',
|
||||
is_insert: bool = False,
|
||||
is_append: bool = False,
|
||||
) -> str | None:
|
||||
"""Internal method to handle common logic for edit_/append_file methods.
|
||||
|
||||
Args:
|
||||
file_name: str: The name of the file to edit or append to.
|
||||
start: int | None = None: The start line number for editing. Ignored if is_append is True.
|
||||
end: int | None = None: The end line number for editing. Ignored if is_append is True.
|
||||
content: str: The content to replace the lines with or to append.
|
||||
is_insert: bool = False: Whether to insert content at the given line number instead of editing.
|
||||
is_append: bool = False: Whether to append content to the file instead of editing.
|
||||
"""
|
||||
ret_str = ''
|
||||
global CURRENT_FILE, CURRENT_LINE, WINDOW
|
||||
|
||||
ERROR_MSG = f'[Error editing file {file_name}. Please confirm the file is correct.]'
|
||||
ERROR_MSG_SUFFIX = (
|
||||
'Your changes have NOT been applied. Please fix your edit command and try again.\n'
|
||||
'You either need to 1) Open the correct file and try again or 2) Specify the correct line number arguments.\n'
|
||||
'DO NOT re-run the same failed edit command. Running it again will lead to the same error.'
|
||||
)
|
||||
|
||||
if not _is_valid_filename(file_name):
|
||||
_output_error('Invalid file name.')
|
||||
return None
|
||||
|
||||
if not _is_valid_path(file_name):
|
||||
_output_error('Invalid path or file name.')
|
||||
return None
|
||||
|
||||
if not _create_paths(file_name):
|
||||
_output_error('Could not access or create directories.')
|
||||
return None
|
||||
|
||||
if not os.path.isfile(file_name):
|
||||
_output_error(f'File {file_name} not found.')
|
||||
return None
|
||||
|
||||
if is_insert and is_append:
|
||||
_output_error('Cannot insert and append at the same time.')
|
||||
return None
|
||||
|
||||
# Use a temporary file to write changes
|
||||
content = str(content or '')
|
||||
temp_file_path = ''
|
||||
first_error_line = None
|
||||
|
||||
try:
|
||||
n_added_lines = None
|
||||
|
||||
# lint the original file
|
||||
enable_auto_lint = os.getenv('ENABLE_AUTO_LINT', 'false').lower() == 'true'
|
||||
if enable_auto_lint:
|
||||
# Copy the original file to a temporary file (with the same ext) and lint it
|
||||
suffix = os.path.splitext(file_name)[1]
|
||||
with tempfile.NamedTemporaryFile(suffix=suffix) as orig_file_clone:
|
||||
shutil.copy2(file_name, orig_file_clone.name)
|
||||
original_lint_error, _ = _lint_file(orig_file_clone.name)
|
||||
|
||||
# Create a temporary file in the same directory as the original file
|
||||
original_dir = os.path.dirname(file_name)
|
||||
original_ext = os.path.splitext(file_name)[1]
|
||||
temp_file_name = f'.temp_{uuid.uuid4().hex}{original_ext}'
|
||||
temp_file_path = os.path.join(original_dir, temp_file_name)
|
||||
|
||||
with open(temp_file_path, 'w') as temp_file:
|
||||
# Read the original file and check if empty and for a trailing newline
|
||||
with open(file_name) as original_file:
|
||||
lines = original_file.readlines()
|
||||
|
||||
if is_append:
|
||||
content, n_added_lines = _append_impl(lines, content)
|
||||
elif is_insert:
|
||||
try:
|
||||
content, n_added_lines = _insert_impl(lines, start, content)
|
||||
except LineNumberError as e:
|
||||
ret_str += (f'{ERROR_MSG}\n' f'{e}\n' f'{ERROR_MSG_SUFFIX}') + '\n'
|
||||
return ret_str
|
||||
else:
|
||||
try:
|
||||
content, n_added_lines = _edit_impl(lines, start, end, content)
|
||||
except LineNumberError as e:
|
||||
ret_str += (f'{ERROR_MSG}\n' f'{e}\n' f'{ERROR_MSG_SUFFIX}') + '\n'
|
||||
return ret_str
|
||||
|
||||
if not content.endswith('\n'):
|
||||
content += '\n'
|
||||
|
||||
# Write the new content to the temporary file
|
||||
temp_file.write(content)
|
||||
|
||||
# Replace the original file with the temporary file
|
||||
os.replace(temp_file_path, file_name)
|
||||
|
||||
# Handle linting
|
||||
# NOTE: we need to get env var inside this function
|
||||
# because the env var will be set AFTER the agentskills is imported
|
||||
if enable_auto_lint:
|
||||
# Generate a random temporary file path
|
||||
suffix = os.path.splitext(file_name)[1]
|
||||
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tfile:
|
||||
original_file_backup_path = tfile.name
|
||||
|
||||
with open(original_file_backup_path, 'w') as f:
|
||||
f.writelines(lines)
|
||||
|
||||
file_name_abs = os.path.abspath(file_name)
|
||||
lint_error, first_error_line = _lint_file(file_name_abs)
|
||||
|
||||
# Select the errors caused by the modification
|
||||
def extract_last_part(line):
|
||||
parts = line.split(':')
|
||||
if len(parts) > 1:
|
||||
return parts[-1].strip()
|
||||
return line.strip()
|
||||
|
||||
def subtract_strings(str1, str2) -> str:
|
||||
lines1 = str1.splitlines()
|
||||
lines2 = str2.splitlines()
|
||||
|
||||
last_parts1 = [extract_last_part(line) for line in lines1]
|
||||
|
||||
remaining_lines = [
|
||||
line
|
||||
for line in lines2
|
||||
if extract_last_part(line) not in last_parts1
|
||||
]
|
||||
|
||||
result = '\n'.join(remaining_lines)
|
||||
return result
|
||||
|
||||
if original_lint_error and lint_error:
|
||||
lint_error = subtract_strings(original_lint_error, lint_error)
|
||||
if lint_error == '':
|
||||
lint_error = None
|
||||
first_error_line = None
|
||||
|
||||
if lint_error is not None:
|
||||
if first_error_line is not None:
|
||||
show_line = int(first_error_line)
|
||||
elif is_append:
|
||||
# original end-of-file
|
||||
show_line = len(lines)
|
||||
# insert OR edit WILL provide meaningful line numbers
|
||||
elif start is not None and end is not None:
|
||||
show_line = int((start + end) / 2)
|
||||
else:
|
||||
raise ValueError('Invalid state. This should never happen.')
|
||||
|
||||
ret_str += LINTER_ERROR_MSG
|
||||
ret_str += lint_error + '\n'
|
||||
|
||||
editor_lines = n_added_lines + 20
|
||||
sep = '-' * 49 + '\n'
|
||||
ret_str += (
|
||||
f'[This is how your edit would have looked if applied]\n{sep}'
|
||||
)
|
||||
ret_str += (
|
||||
_print_window(file_name, show_line, editor_lines, return_str=True)
|
||||
+ '\n'
|
||||
)
|
||||
ret_str += f'{sep}\n'
|
||||
|
||||
ret_str += '[This is the original code before your edit]\n'
|
||||
ret_str += sep
|
||||
ret_str += (
|
||||
_print_window(
|
||||
original_file_backup_path,
|
||||
show_line,
|
||||
editor_lines,
|
||||
return_str=True,
|
||||
)
|
||||
+ '\n'
|
||||
)
|
||||
ret_str += sep
|
||||
ret_str += (
|
||||
'Your changes have NOT been applied. Please fix your edit command and try again.\n'
|
||||
'You either need to 1) Specify the correct start/end line arguments or 2) Correct your edit code.\n'
|
||||
'DO NOT re-run the same failed edit command. Running it again will lead to the same error.'
|
||||
)
|
||||
|
||||
# recover the original file
|
||||
with open(original_file_backup_path) as fin, open(
|
||||
file_name, 'w'
|
||||
) as fout:
|
||||
fout.write(fin.read())
|
||||
|
||||
# Don't forget to remove the temporary file after you're done
|
||||
os.unlink(original_file_backup_path)
|
||||
return ret_str
|
||||
|
||||
except FileNotFoundError as e:
|
||||
ret_str += f'File not found: {e}\n'
|
||||
except PermissionError as e:
|
||||
ret_str += f'Permission error during file operation: {str(e)}\n'
|
||||
except IOError as e:
|
||||
ret_str += f'An error occurred while handling the file: {e}\n'
|
||||
except ValueError as e:
|
||||
ret_str += f'Invalid input: {e}\n'
|
||||
except Exception as e:
|
||||
# Clean up the temporary file if an error occurs
|
||||
if temp_file_path and os.path.exists(temp_file_path):
|
||||
os.remove(temp_file_path)
|
||||
print(f'An unexpected error occurred: {e}')
|
||||
raise e
|
||||
|
||||
# Update the file information and print the updated content
|
||||
with open(file_name, 'r', encoding='utf-8') as file:
|
||||
n_total_lines = max(1, len(file.readlines()))
|
||||
if first_error_line is not None and int(first_error_line) > 0:
|
||||
CURRENT_LINE = first_error_line
|
||||
else:
|
||||
if is_append:
|
||||
CURRENT_LINE = max(1, len(lines)) # end of original file
|
||||
else:
|
||||
CURRENT_LINE = start or n_total_lines or 1
|
||||
ret_str += f'[File: {os.path.abspath(file_name)} ({n_total_lines} lines total after edit)]\n'
|
||||
CURRENT_FILE = file_name
|
||||
ret_str += _print_window(CURRENT_FILE, CURRENT_LINE, WINDOW, return_str=True) + '\n'
|
||||
ret_str += MSG_FILE_UPDATED.format(line_number=CURRENT_LINE)
|
||||
return ret_str
|
||||
|
||||
|
||||
def edit_file_by_replace(file_name: str, to_replace: str, new_content: str) -> None:
|
||||
"""Edit an existing file. This will search for non-empty `to_replace` in the given file and replace it with non-empty `new_content`.
|
||||
`to_replace` and `new_content` must be different! Split large edits into multiple smaller edits if necessary!
|
||||
Use `append_file` method for writing after `create_file`!
|
||||
|
||||
Every *to_replace* must *EXACTLY MATCH* the existing source code, character for character, including all comments, docstrings, etc.
|
||||
|
||||
Include enough lines to make code in `to_replace` unique. `to_replace` should NOT be empty.
|
||||
|
||||
For example, given a file "/workspace/example.txt" with the following content:
|
||||
```
|
||||
line 1
|
||||
line 2
|
||||
line 2
|
||||
line 3
|
||||
```
|
||||
|
||||
EDITING: If you want to replace the second occurrence of "line 2", you can make `to_replace` unique:
|
||||
|
||||
edit_file_by_replace(
|
||||
'/workspace/example.txt',
|
||||
to_replace='line 2\nline 3',
|
||||
new_content='new line\nline 3',
|
||||
)
|
||||
|
||||
This will replace only the second "line 2" with "new line". The first "line 2" will remain unchanged.
|
||||
|
||||
The resulting file will be:
|
||||
```
|
||||
line 1
|
||||
line 2
|
||||
new line
|
||||
line 3
|
||||
```
|
||||
|
||||
REMOVAL: If you want to remove "line 2" and "line 3", you can set `new_content` to an empty string:
|
||||
|
||||
edit_file_by_replace(
|
||||
'/workspace/example.txt',
|
||||
to_replace='line 2\nline 3',
|
||||
new_content='',
|
||||
)
|
||||
|
||||
Args:
|
||||
file_name: str: The name of the file to edit.
|
||||
to_replace: str: The content to search for and replace.
|
||||
new_content: str: The new content to replace the old content with.
|
||||
"""
|
||||
# FIXME: support replacing *all* occurrences
|
||||
if to_replace is None or to_replace.strip() == '':
|
||||
_output_error('`to_replace` must not be empty.')
|
||||
return
|
||||
|
||||
if to_replace == new_content:
|
||||
_output_error('`to_replace` and `new_content` must be different.')
|
||||
return
|
||||
|
||||
if not os.path.isfile(file_name):
|
||||
_output_error(f'File {file_name} not found.')
|
||||
return None
|
||||
|
||||
# search for `to_replace` in the file
|
||||
# if found, replace it with `new_content`
|
||||
# if not found, perform a fuzzy search to find the closest match and replace it with `new_content`
|
||||
with open(file_name, 'r') as file:
|
||||
file_content = file.read()
|
||||
|
||||
if file_content.count(to_replace) > 1:
|
||||
_output_error(
|
||||
'`to_replace` appears more than once, please include enough lines to make code in `to_replace` unique.'
|
||||
)
|
||||
return
|
||||
|
||||
start = file_content.find(to_replace)
|
||||
if start != -1:
|
||||
# Convert start from index to line number
|
||||
start_line_number = file_content[:start].count('\n') + 1
|
||||
end_line_number = start_line_number + len(to_replace.splitlines()) - 1
|
||||
else:
|
||||
|
||||
def _fuzzy_transform(s: str) -> str:
|
||||
# remove all space except newline
|
||||
return re.sub(r'[^\S\n]+', '', s)
|
||||
|
||||
# perform a fuzzy search (remove all spaces except newlines)
|
||||
to_replace_fuzzy = _fuzzy_transform(to_replace)
|
||||
file_content_fuzzy = _fuzzy_transform(file_content)
|
||||
# find the closest match
|
||||
start = file_content_fuzzy.find(to_replace_fuzzy)
|
||||
if start == -1:
|
||||
print(
|
||||
f'[No exact match found in {file_name} for\n```\n{to_replace}\n```\n]'
|
||||
)
|
||||
return
|
||||
# Convert start from index to line number for fuzzy match
|
||||
start_line_number = file_content_fuzzy[:start].count('\n') + 1
|
||||
end_line_number = start_line_number + len(to_replace.splitlines()) - 1
|
||||
|
||||
ret_str = _edit_file_impl(
|
||||
file_name,
|
||||
start=start_line_number,
|
||||
end=end_line_number,
|
||||
content=new_content,
|
||||
is_insert=False,
|
||||
)
|
||||
# lint_error = bool(LINTER_ERROR_MSG in ret_str)
|
||||
# TODO: automatically tries to fix linter error (maybe involve some static analysis tools on the location near the edit to figure out indentation)
|
||||
if ret_str is not None:
|
||||
print(ret_str)
|
||||
|
||||
|
||||
def insert_content_at_line(file_name: str, line_number: int, content: str) -> None:
|
||||
"""Insert content at the given line number in a file.
|
||||
This will NOT modify the content of the lines before OR after the given line number.
|
||||
|
||||
For example, if the file has the following content:
|
||||
```
|
||||
line 1
|
||||
line 2
|
||||
line 3
|
||||
```
|
||||
and you call `insert_content_at_line('file.txt', 2, 'new line')`, the file will be updated to:
|
||||
```
|
||||
line 1
|
||||
new line
|
||||
line 2
|
||||
line 3
|
||||
```
|
||||
|
||||
Args:
|
||||
file_name: str: The name of the file to edit.
|
||||
line_number: int: The line number (starting from 1) to insert the content after.
|
||||
content: str: The content to insert.
|
||||
"""
|
||||
ret_str = _edit_file_impl(
|
||||
file_name,
|
||||
start=line_number,
|
||||
end=line_number,
|
||||
content=content,
|
||||
is_insert=True,
|
||||
is_append=False,
|
||||
)
|
||||
if ret_str is not None:
|
||||
print(ret_str)
|
||||
|
||||
|
||||
def append_file(file_name: str, content: str) -> None:
|
||||
"""Append content to the given file.
|
||||
It appends text `content` to the end of the specified file, ideal after a `create_file`!
|
||||
|
||||
Args:
|
||||
file_name: str: The name of the file to edit.
|
||||
content: str: The content to insert.
|
||||
"""
|
||||
ret_str = _edit_file_impl(
|
||||
file_name,
|
||||
start=None,
|
||||
end=None,
|
||||
content=content,
|
||||
is_insert=False,
|
||||
is_append=True,
|
||||
)
|
||||
if ret_str is not None:
|
||||
print(ret_str)
|
||||
|
||||
|
||||
def search_dir(search_term: str, dir_path: str = './') -> None:
|
||||
"""Searches for search_term in all files in dir. If dir is not provided, searches in the current directory.
|
||||
|
||||
@ -904,10 +375,6 @@ __all__ = [
|
||||
'goto_line',
|
||||
'scroll_down',
|
||||
'scroll_up',
|
||||
'create_file',
|
||||
'edit_file_by_replace',
|
||||
'insert_content_at_line',
|
||||
'append_file',
|
||||
'search_dir',
|
||||
'search_file',
|
||||
'find_file',
|
||||
|
||||
@ -15,6 +15,7 @@ from openhands.events.action import (
|
||||
BrowseInteractiveAction,
|
||||
BrowseURLAction,
|
||||
CmdRunAction,
|
||||
FileEditAction,
|
||||
FileReadAction,
|
||||
FileWriteAction,
|
||||
IPythonRunCellAction,
|
||||
@ -350,6 +351,8 @@ class RemoteRuntime(Runtime):
|
||||
def run_action(self, action: Action) -> Observation:
|
||||
if action.timeout is None:
|
||||
action.timeout = self.config.sandbox.timeout
|
||||
if isinstance(action, FileEditAction):
|
||||
return self.edit(action)
|
||||
with self.action_semaphore:
|
||||
if not action.runnable:
|
||||
return NullObservation('')
|
||||
|
||||
@ -28,6 +28,7 @@ from openhands.events.observation import (
|
||||
)
|
||||
from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
|
||||
from openhands.runtime.plugins import JupyterRequirement, PluginRequirement
|
||||
from openhands.runtime.utils.edit import FileEditRuntimeMixin
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
|
||||
@ -42,7 +43,7 @@ def _default_env_vars(sandbox_config: SandboxConfig) -> dict[str, str]:
|
||||
return ret
|
||||
|
||||
|
||||
class Runtime:
|
||||
class Runtime(FileEditRuntimeMixin):
|
||||
"""The runtime is how the agent interacts with the external environment.
|
||||
This includes a bash sandbox, a browser, and filesystem interactions.
|
||||
|
||||
@ -78,6 +79,9 @@ class Runtime:
|
||||
if env_vars is not None:
|
||||
self.initial_env_vars.update(env_vars)
|
||||
|
||||
# Load mixins
|
||||
FileEditRuntimeMixin.__init__(self)
|
||||
|
||||
def setup_initial_env(self) -> None:
|
||||
if self.attach_to_existing:
|
||||
return
|
||||
|
||||
350
openhands/runtime/utils/edit.py
Normal file
350
openhands/runtime/utils/edit.py
Normal file
@ -0,0 +1,350 @@
|
||||
import copy
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from openhands.core.config import AppConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action import (
|
||||
FileEditAction,
|
||||
FileReadAction,
|
||||
FileWriteAction,
|
||||
)
|
||||
from openhands.events.observation import (
|
||||
ErrorObservation,
|
||||
FatalErrorObservation,
|
||||
FileEditObservation,
|
||||
FileReadObservation,
|
||||
FileWriteObservation,
|
||||
Observation,
|
||||
)
|
||||
from openhands.linter import DefaultLinter
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.llm.metrics import Metrics
|
||||
from openhands.utils.chunk_localizer import Chunk, get_top_k_chunk_matches
|
||||
from openhands.utils.diff import get_diff
|
||||
|
||||
SYS_MSG = """Your job is to produce a new version of the file based on the old version and the
|
||||
provided draft of the new version. The provided draft may be incomplete (it may skip lines) and/or incorrectly indented. You should try to apply the changes present in the draft to the old version, and output a new version of the file.
|
||||
NOTE:
|
||||
- The output file should be COMPLETE and CORRECTLY INDENTED. Do not omit any lines, and do not change any lines that are not part of the changes.
|
||||
- You should output the new version of the file by wrapping the new version of the file content in a ``` block.
|
||||
- If there's no explicit comment to remove the existing code, we should keep them and append the new code to the end of the file.
|
||||
- If there's placeholder comments like `# no changes before` or `# no changes here`, we should replace these comments with the original code near the placeholder comments.
|
||||
"""
|
||||
|
||||
USER_MSG = """
|
||||
HERE IS THE OLD VERSION OF THE FILE:
|
||||
```
|
||||
{old_contents}
|
||||
```
|
||||
|
||||
HERE IS THE DRAFT OF THE NEW VERSION OF THE FILE:
|
||||
```
|
||||
{draft_changes}
|
||||
```
|
||||
|
||||
GIVE ME THE NEW VERSION OF THE FILE.
|
||||
IMPORTANT:
|
||||
- There should be NO placeholder comments like `# no changes before` or `# no changes here`. They should be replaced with the original code near the placeholder comments.
|
||||
- The output file should be COMPLETE and CORRECTLY INDENTED. Do not omit any lines, and do not change any lines that are not part of the changes.
|
||||
""".strip()
|
||||
|
||||
|
||||
def _extract_code(string):
|
||||
pattern = r'```(?:\w*\n)?(.*?)```'
|
||||
matches = re.findall(pattern, string, re.DOTALL)
|
||||
if not matches:
|
||||
return None
|
||||
return matches[0]
|
||||
|
||||
|
||||
def get_new_file_contents(
|
||||
llm: LLM, old_contents: str, draft_changes: str, num_retries: int = 3
|
||||
) -> str | None:
|
||||
while num_retries > 0:
|
||||
messages = [
|
||||
{'role': 'system', 'content': SYS_MSG},
|
||||
{
|
||||
'role': 'user',
|
||||
'content': USER_MSG.format(
|
||||
old_contents=old_contents, draft_changes=draft_changes
|
||||
),
|
||||
},
|
||||
]
|
||||
resp = llm.completion(messages=messages)
|
||||
new_contents = _extract_code(resp['choices'][0]['message']['content'])
|
||||
if new_contents is not None:
|
||||
return new_contents
|
||||
num_retries -= 1
|
||||
return None
|
||||
|
||||
|
||||
class FileEditRuntimeInterface(ABC):
|
||||
config: AppConfig
|
||||
|
||||
@abstractmethod
|
||||
def read(self, action: FileReadAction) -> Observation:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def write(self, action: FileWriteAction) -> Observation:
|
||||
pass
|
||||
|
||||
|
||||
class FileEditRuntimeMixin(FileEditRuntimeInterface):
|
||||
# Most LLMs have output token limit of 4k tokens.
|
||||
# This restricts the number of lines we can edit to avoid exceeding the token limit.
|
||||
MAX_LINES_TO_EDIT = 300
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
llm_config = self.config.get_llm_config()
|
||||
|
||||
if llm_config.draft_editor is None:
|
||||
llm_config.draft_editor = copy.deepcopy(llm_config)
|
||||
|
||||
# manually set the model name for the draft editor LLM to distinguish token costs
|
||||
llm_metrics = Metrics(
|
||||
model_name='draft_editor:' + llm_config.draft_editor.model
|
||||
)
|
||||
if llm_config.draft_editor.caching_prompt:
|
||||
logger.info(
|
||||
'It is not recommended to cache draft editor LLM prompts as it may incur high costs for the same prompt. '
|
||||
'Automatically setting caching_prompt=false.'
|
||||
)
|
||||
llm_config.draft_editor.caching_prompt = False
|
||||
|
||||
self.draft_editor_llm = LLM(llm_config.draft_editor, metrics=llm_metrics)
|
||||
logger.info(
|
||||
f'[Draft edit functionality] enabled with LLM: {self.draft_editor_llm}'
|
||||
)
|
||||
|
||||
def _validate_range(
|
||||
self, start: int, end: int, total_lines: int
|
||||
) -> Observation | None:
|
||||
# start and end are 1-indexed and inclusive
|
||||
if (
|
||||
(start < 1 and start != -1)
|
||||
or start > total_lines
|
||||
or (start > end and end != -1 and start != -1)
|
||||
):
|
||||
return ErrorObservation(
|
||||
f'Invalid range for editing: start={start}, end={end}, total lines={total_lines}. start must be >= 1 and <={total_lines} (total lines of the edited file), start <= end, or start == -1 (append to the end of the file).'
|
||||
)
|
||||
if (
|
||||
(end < 1 and end != -1)
|
||||
or end > total_lines
|
||||
or (end < start and start != -1 and end != -1)
|
||||
):
|
||||
return ErrorObservation(
|
||||
f'Invalid range for editing: start={start}, end={end}, total lines={total_lines}. end must be >= 1 and <= {total_lines} (total lines of the edited file), end >= start, or end == -1 (to edit till the end of the file).'
|
||||
)
|
||||
return None
|
||||
|
||||
def _get_lint_error(
|
||||
self,
|
||||
suffix: str,
|
||||
old_content: str,
|
||||
new_content: str,
|
||||
filepath: str,
|
||||
diff: str,
|
||||
) -> ErrorObservation | None:
|
||||
linter = DefaultLinter()
|
||||
# Copy the original file to a temporary file (with the same ext) and lint it
|
||||
with tempfile.NamedTemporaryFile(
|
||||
suffix=suffix, mode='w+', encoding='utf-8'
|
||||
) as original_file_copy, tempfile.NamedTemporaryFile(
|
||||
suffix=suffix, mode='w+', encoding='utf-8'
|
||||
) as updated_file_copy:
|
||||
# Lint the original file
|
||||
original_file_copy.write(old_content)
|
||||
original_file_copy.flush()
|
||||
|
||||
# Lint the updated file
|
||||
updated_file_copy.write(new_content)
|
||||
updated_file_copy.flush()
|
||||
|
||||
updated_lint_error = linter.lint_file_diff(
|
||||
original_file_copy.name, updated_file_copy.name
|
||||
)
|
||||
|
||||
if len(updated_lint_error) > 0:
|
||||
_obs = FileEditObservation(
|
||||
content=diff,
|
||||
path=filepath,
|
||||
prev_exist=True,
|
||||
old_content=old_content,
|
||||
new_content=new_content,
|
||||
)
|
||||
error_message = (
|
||||
(
|
||||
f'\n[Linting failed for edited file {filepath}. {len(updated_lint_error)} lint errors found.]\n'
|
||||
'[begin attempted changes]\n'
|
||||
f'{_obs.visualize_diff(change_applied=False)}\n'
|
||||
'[end attempted changes]\n'
|
||||
)
|
||||
+ '-' * 40
|
||||
+ '\n'
|
||||
)
|
||||
error_message += '-' * 20 + 'First 5 lint errors' + '-' * 20 + '\n'
|
||||
for i, lint_error in enumerate(updated_lint_error[:5]):
|
||||
error_message += f'[begin lint error {i}]\n'
|
||||
error_message += lint_error.visualize().strip() + '\n'
|
||||
error_message += f'[end lint error {i}]\n'
|
||||
error_message += '-' * 40 + '\n'
|
||||
return ErrorObservation(error_message)
|
||||
return None
|
||||
|
||||
def edit(self, action: FileEditAction) -> Observation:
|
||||
obs = self.read(FileReadAction(path=action.path))
|
||||
if (
|
||||
isinstance(obs, ErrorObservation)
|
||||
and 'File not found'.lower() in obs.content.lower()
|
||||
):
|
||||
logger.debug(
|
||||
f'Agent attempted to edit a file that does not exist. Creating the file. Error msg: {obs.content}'
|
||||
)
|
||||
# directly write the new content
|
||||
obs = self.write(
|
||||
FileWriteAction(path=action.path, content=action.content.strip())
|
||||
)
|
||||
if isinstance(obs, ErrorObservation):
|
||||
return obs
|
||||
if not isinstance(obs, FileWriteObservation):
|
||||
return FatalErrorObservation(
|
||||
f'Fatal Runtime in editing: Expected FileWriteObservation, got {type(obs)}: {str(obs)}'
|
||||
)
|
||||
return FileEditObservation(
|
||||
content=get_diff('', action.content, action.path),
|
||||
path=action.path,
|
||||
prev_exist=False,
|
||||
old_content='',
|
||||
new_content=action.content,
|
||||
)
|
||||
if not isinstance(obs, FileReadObservation):
|
||||
return FatalErrorObservation(
|
||||
f'Fatal Runtime in editing: Expected FileReadObservation, got {type(obs)}: {str(obs)}'
|
||||
)
|
||||
|
||||
original_file_content = obs.content
|
||||
old_file_lines = original_file_content.split('\n')
|
||||
# NOTE: start and end are 1-indexed
|
||||
start = action.start
|
||||
end = action.end
|
||||
# validate the range
|
||||
error = self._validate_range(start, end, len(old_file_lines))
|
||||
if error is not None:
|
||||
return error
|
||||
|
||||
# append to the end of the file
|
||||
if start == -1:
|
||||
updated_content = '\n'.join(old_file_lines + action.content.split('\n'))
|
||||
diff = get_diff(original_file_content, updated_content, action.path)
|
||||
# Lint the updated content
|
||||
if self.config.sandbox.enable_auto_lint:
|
||||
suffix = os.path.splitext(action.path)[1]
|
||||
|
||||
error_obs = self._get_lint_error(
|
||||
suffix,
|
||||
original_file_content,
|
||||
updated_content,
|
||||
action.path,
|
||||
diff,
|
||||
)
|
||||
if error_obs is not None:
|
||||
return error_obs
|
||||
|
||||
obs = self.write(FileWriteAction(path=action.path, content=updated_content))
|
||||
return FileEditObservation(
|
||||
content=diff,
|
||||
path=action.path,
|
||||
prev_exist=True,
|
||||
old_content=original_file_content,
|
||||
new_content=updated_content,
|
||||
)
|
||||
|
||||
# Get the 0-indexed start and end
|
||||
start_idx = start - 1
|
||||
if end != -1:
|
||||
# remove 1 to make it 0-indexed
|
||||
# then add 1 since the `end` is inclusive
|
||||
end_idx = end - 1 + 1
|
||||
else:
|
||||
# end == -1 means the user wants to edit till the end of the file
|
||||
end_idx = len(old_file_lines)
|
||||
|
||||
# Get the range of lines to edit - reject if too long
|
||||
length_of_range = end_idx - start_idx
|
||||
if length_of_range > self.MAX_LINES_TO_EDIT + 1:
|
||||
error_msg = (
|
||||
f'[Edit error: The range of lines to edit is too long.]\n'
|
||||
f'[The maximum number of lines allowed to edit at once is {self.MAX_LINES_TO_EDIT}. '
|
||||
f'Got (L{start_idx + 1}-L{end_idx}) {length_of_range} lines.]\n' # [start_idx, end_idx), so no need to + 1
|
||||
)
|
||||
# search for relevant ranges to hint the agent
|
||||
topk_chunks: list[Chunk] = get_top_k_chunk_matches(
|
||||
text=original_file_content,
|
||||
query=action.content, # edit draft as query
|
||||
k=3,
|
||||
max_chunk_size=20, # lines
|
||||
)
|
||||
error_msg += (
|
||||
'Here are some snippets that maybe relevant to the provided edit.\n'
|
||||
)
|
||||
for i, chunk in enumerate(topk_chunks):
|
||||
error_msg += f'[begin relevant snippet {i+1}. Line range: L{chunk.line_range[0]}-L{chunk.line_range[1]}. Similarity: {chunk.normalized_lcs}]\n'
|
||||
error_msg += f'[Browse around it via `open_file("{action.path}", {(chunk.line_range[0] + chunk.line_range[1]) // 2})`]\n'
|
||||
error_msg += chunk.visualize() + '\n'
|
||||
error_msg += f'[end relevant snippet {i+1}]\n'
|
||||
error_msg += '-' * 40 + '\n'
|
||||
|
||||
error_msg += 'Consider using `open_file` to explore around the relevant snippets if needed.\n'
|
||||
error_msg += f'**IMPORTANT**: Please REDUCE the range of edits to less than {self.MAX_LINES_TO_EDIT} lines by setting `start` and `end` in the edit action (e.g. `<file_edit path="{action.path}" start=[PUT LINE NUMBER HERE] end=[PUT LINE NUMBER HERE] />`). '
|
||||
|
||||
return ErrorObservation(error_msg)
|
||||
|
||||
content_to_edit = '\n'.join(old_file_lines[start_idx:end_idx])
|
||||
self.draft_editor_llm.reset()
|
||||
_edited_content = get_new_file_contents(
|
||||
self.draft_editor_llm, content_to_edit, action.content
|
||||
)
|
||||
if _edited_content is None:
|
||||
ret_err = ErrorObservation(
|
||||
'Failed to get new file contents. '
|
||||
'Please try to reduce the number of edits and try again.'
|
||||
)
|
||||
ret_err.llm_metrics = self.draft_editor_llm.metrics
|
||||
return ret_err
|
||||
|
||||
# piece the updated content with the unchanged content
|
||||
updated_lines = (
|
||||
old_file_lines[:start_idx]
|
||||
+ _edited_content.split('\n')
|
||||
+ old_file_lines[end_idx:]
|
||||
)
|
||||
updated_content = '\n'.join(updated_lines)
|
||||
diff = get_diff(original_file_content, updated_content, action.path)
|
||||
|
||||
# Lint the updated content
|
||||
if self.config.sandbox.enable_auto_lint:
|
||||
suffix = os.path.splitext(action.path)[1]
|
||||
error_obs = self._get_lint_error(
|
||||
suffix, original_file_content, updated_content, action.path, diff
|
||||
)
|
||||
if error_obs is not None:
|
||||
error_obs.llm_metrics = self.draft_editor_llm.metrics
|
||||
return error_obs
|
||||
|
||||
obs = self.write(FileWriteAction(path=action.path, content=updated_content))
|
||||
ret_obs = FileEditObservation(
|
||||
content=diff,
|
||||
path=action.path,
|
||||
prev_exist=True,
|
||||
old_content=original_file_content,
|
||||
new_content=updated_content,
|
||||
)
|
||||
ret_obs.llm_metrics = self.draft_editor_llm.metrics
|
||||
return ret_obs
|
||||
12
openhands/runtime/utils/tenacity_stop.py
Normal file
12
openhands/runtime/utils/tenacity_stop.py
Normal file
@ -0,0 +1,12 @@
|
||||
|
||||
|
||||
from tenacity import RetryCallState
|
||||
from tenacity.stop import stop_base
|
||||
from openhands.runtime.utils.shutdown_listener import should_exit
|
||||
|
||||
|
||||
class stop_if_should_exit(stop_base):
|
||||
"""Stop if the should_exit flag is set."""
|
||||
|
||||
def __call__(self, retry_state: "RetryCallState") -> bool:
|
||||
return should_exit()
|
||||
@ -224,13 +224,23 @@ class AgentSession:
|
||||
'Runtime must be initialized before the agent controller'
|
||||
)
|
||||
|
||||
logger.info(
|
||||
msg = (
|
||||
'\n--------------------------------- OpenHands Configuration ---------------------------------\n'
|
||||
f'LLM: {agent.llm.config.model}\n'
|
||||
f'Base URL: {agent.llm.config.base_url}\n'
|
||||
)
|
||||
if agent.llm.config.draft_editor:
|
||||
msg += (
|
||||
f'Draft editor LLM (for file editing): {agent.llm.config.draft_editor.model}\n'
|
||||
f'Draft editor LLM (for file editing) Base URL: {agent.llm.config.draft_editor.base_url}\n'
|
||||
)
|
||||
msg += (
|
||||
f'Agent: {agent.name}\n'
|
||||
f'Runtime: {self.runtime.__class__.__name__}\n'
|
||||
f'Plugins: {agent.sandbox_plugins}\n'
|
||||
'-------------------------------------------------------------------------------------------'
|
||||
)
|
||||
logger.info(msg)
|
||||
|
||||
self.controller = AgentController(
|
||||
sid=self.sid,
|
||||
|
||||
97
openhands/utils/chunk_localizer.py
Normal file
97
openhands/utils/chunk_localizer.py
Normal file
@ -0,0 +1,97 @@
|
||||
"""Chunk localizer to help localize the most relevant chunks in a file.
|
||||
|
||||
This is primarily used to localize the most relevant chunks in a file
|
||||
for a given query (e.g. edit draft produced by the agent).
|
||||
"""
|
||||
|
||||
import pylcs
|
||||
from pydantic import BaseModel
|
||||
from tree_sitter_languages import get_parser
|
||||
|
||||
|
||||
class Chunk(BaseModel):
|
||||
text: str
|
||||
line_range: tuple[int, int] # (start_line, end_line), 1-index, inclusive
|
||||
normalized_lcs: float | None = None
|
||||
|
||||
def visualize(self) -> str:
|
||||
lines = self.text.split('\n')
|
||||
assert len(lines) == self.line_range[1] - self.line_range[0] + 1
|
||||
ret = ''
|
||||
for i, line in enumerate(lines):
|
||||
ret += f'{self.line_range[0] + i}|{line}\n'
|
||||
return ret
|
||||
|
||||
|
||||
def _create_chunks_from_raw_string(content: str, size: int):
|
||||
lines = content.split('\n')
|
||||
ret = []
|
||||
for i in range(0, len(lines), size):
|
||||
_cur_lines = lines[i : i + size]
|
||||
ret.append(
|
||||
Chunk(
|
||||
text='\n'.join(_cur_lines),
|
||||
line_range=(i + 1, i + len(_cur_lines)),
|
||||
)
|
||||
)
|
||||
return ret
|
||||
|
||||
|
||||
def create_chunks(
|
||||
text: str, size: int = 100, language: str | None = None
|
||||
) -> list[Chunk]:
|
||||
try:
|
||||
parser = get_parser(language) if language is not None else None
|
||||
except AttributeError:
|
||||
# print(f"Language {language} not supported. Falling back to raw string.")
|
||||
parser = None
|
||||
|
||||
if parser is None:
|
||||
# fallback to raw string
|
||||
return _create_chunks_from_raw_string(text, size)
|
||||
|
||||
# TODO: implement tree-sitter chunking
|
||||
# return _create_chunks_from_tree_sitter(parser.parse(bytes(text, 'utf-8')), max_chunk_lines=size)
|
||||
raise NotImplementedError('Tree-sitter chunking not implemented yet.')
|
||||
|
||||
|
||||
def normalized_lcs(chunk: str, query: str) -> float:
|
||||
"""Calculate the normalized Longest Common Subsequence (LCS) to compare file chunk with the query (e.g. edit draft).
|
||||
|
||||
We normalize Longest Common Subsequence (LCS) by the length of the chunk
|
||||
to check how **much** of the chunk is covered by the query.
|
||||
"""
|
||||
if len(chunk) == 0:
|
||||
return 0.0
|
||||
_score = pylcs.lcs_sequence_length(chunk, query)
|
||||
return _score / len(chunk)
|
||||
|
||||
|
||||
def get_top_k_chunk_matches(
|
||||
text: str, query: str, k: int = 3, max_chunk_size: int = 100
|
||||
) -> list[Chunk]:
|
||||
"""Get the top k chunks in the text that match the query.
|
||||
|
||||
The query could be a string of draft code edits.
|
||||
|
||||
Args:
|
||||
text: The text to search for the query.
|
||||
query: The query to search for in the text.
|
||||
k: The number of top chunks to return.
|
||||
max_chunk_size: The maximum number of lines in a chunk.
|
||||
"""
|
||||
raw_chunks = create_chunks(text, max_chunk_size)
|
||||
chunks_with_lcs: list[Chunk] = [
|
||||
Chunk(
|
||||
text=chunk.text,
|
||||
line_range=chunk.line_range,
|
||||
normalized_lcs=normalized_lcs(chunk.text, query),
|
||||
)
|
||||
for chunk in raw_chunks
|
||||
]
|
||||
sorted_chunks = sorted(
|
||||
chunks_with_lcs,
|
||||
key=lambda x: x.normalized_lcs, # type: ignore
|
||||
reverse=True,
|
||||
)
|
||||
return sorted_chunks[:k]
|
||||
36
poetry.lock
generated
36
poetry.lock
generated
@ -6497,6 +6497,20 @@ files = [
|
||||
[package.dependencies]
|
||||
pyasn1 = ">=0.4.6,<0.7.0"
|
||||
|
||||
[[package]]
|
||||
name = "pybind11"
|
||||
version = "2.13.6"
|
||||
description = "Seamless operability between C++11 and Python"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "pybind11-2.13.6-py3-none-any.whl", hash = "sha256:237c41e29157b962835d356b370ededd57594a26d5894a795960f0047cb5caf5"},
|
||||
{file = "pybind11-2.13.6.tar.gz", hash = "sha256:ba6af10348c12b24e92fa086b39cfba0eff619b61ac77c406167d813b096d39a"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
global = ["pybind11-global (==2.13.6)"]
|
||||
|
||||
[[package]]
|
||||
name = "pycodestyle"
|
||||
version = "2.12.1"
|
||||
@ -6775,6 +6789,26 @@ files = [
|
||||
{file = "pylatexenc-2.10.tar.gz", hash = "sha256:3dd8fd84eb46dc30bee1e23eaab8d8fb5a7f507347b23e5f38ad9675c84f40d3"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pylcs"
|
||||
version = "0.1.1"
|
||||
description = "super fast cpp implementation of longest common subsequence"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "pylcs-0.1.1-cp310-cp310-win_amd64.whl", hash = "sha256:7b8adea6b41dff27332c967533ec3c42a5e94171be778d6f01f0c5cee82e7604"},
|
||||
{file = "pylcs-0.1.1-cp311-cp311-win_amd64.whl", hash = "sha256:9ff06e037c54056cb67d6ef5ad946c0360afeff7d43be67ce09e55201ecc15cc"},
|
||||
{file = "pylcs-0.1.1-cp35-cp35m-win_amd64.whl", hash = "sha256:d2ebf340aa180d841939d9ec1168dfd072992dda1d48148ceb07b65b1ab62ffa"},
|
||||
{file = "pylcs-0.1.1-cp36-cp36m-win_amd64.whl", hash = "sha256:b6c43b63e20048f8fec7e122fbc08c238940a0ee5302bf84a70db22c7f8cc836"},
|
||||
{file = "pylcs-0.1.1-cp37-cp37m-win_amd64.whl", hash = "sha256:db52d55cfdf813af974bcc164aedbd29274da83086877bf05778aa7fbf777f7f"},
|
||||
{file = "pylcs-0.1.1-cp38-cp38-win_amd64.whl", hash = "sha256:954495f1c164ccb722b835e7028783f8a38d85ed5f6ff7b9d50143896c6cff9b"},
|
||||
{file = "pylcs-0.1.1-cp39-cp39-win_amd64.whl", hash = "sha256:0f4c82fad8c0429abef9e98fb98904459c4f5f9fb9b6ce20e0df0841a6a48a54"},
|
||||
{file = "pylcs-0.1.1.tar.gz", hash = "sha256:632c69235d77cda0ba524d82796878801d2f46131fc59e730c98767fc4ce1307"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
pybind11 = ">=2.2"
|
||||
|
||||
[[package]]
|
||||
name = "pyparsing"
|
||||
version = "3.2.0"
|
||||
@ -10031,4 +10065,4 @@ testing = ["coverage (>=5.0.3)", "zope.event", "zope.testing"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.12"
|
||||
content-hash = "7fc51225767e3a98147f7b0dacdce4486a1afd83dc3273f06fd9f6cdc35d1860"
|
||||
content-hash = "6198aa7d5c9d6e172d9f6cb3aff58006ab0e38fd1ce34be8a2d7696bf2a28fb9"
|
||||
|
||||
@ -52,6 +52,7 @@ python-pptx = "*"
|
||||
pylatexenc = "*"
|
||||
tornado = "*"
|
||||
python-dotenv = "*"
|
||||
pylcs = "^0.1.1"
|
||||
whatthepatch = "^1.0.6"
|
||||
protobuf = "^4.21.6,<5.0.0" # chromadb currently fails on 5.0+
|
||||
opentelemetry-api = "1.25.0"
|
||||
@ -88,6 +89,7 @@ reportlab = "*"
|
||||
[tool.coverage.run]
|
||||
concurrency = ["gevent"]
|
||||
|
||||
|
||||
[tool.poetry.group.runtime.dependencies]
|
||||
jupyterlab = "*"
|
||||
notebook = "*"
|
||||
@ -118,6 +120,7 @@ ignore = ["D1"]
|
||||
[tool.ruff.lint.pydocstyle]
|
||||
convention = "google"
|
||||
|
||||
|
||||
[tool.poetry.group.evaluation.dependencies]
|
||||
streamlit = "*"
|
||||
whatthepatch = "*"
|
||||
|
||||
417
tests/runtime/test_edit.py
Normal file
417
tests/runtime/test_edit.py
Normal file
@ -0,0 +1,417 @@
|
||||
"""Edit-related tests for the EventStreamRuntime."""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from conftest import (
|
||||
TEST_IN_CI,
|
||||
_close_test_runtime,
|
||||
_load_runtime,
|
||||
)
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action import FileEditAction, FileReadAction
|
||||
from openhands.events.observation import FileEditObservation
|
||||
from openhands.utils.diff import get_diff
|
||||
|
||||
ORGINAL = """from flask import Flask
|
||||
app = Flask(__name__)
|
||||
|
||||
@app.route('/')
|
||||
def index():
|
||||
numbers = list(range(1, 11))
|
||||
return str(numbers)
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(port=5000)
|
||||
"""
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
TEST_IN_CI != 'True',
|
||||
reason='This test requires LLM to run.',
|
||||
)
|
||||
def test_edit_from_scratch(temp_dir, box_class, run_as_openhands):
|
||||
runtime = _load_runtime(temp_dir, box_class, run_as_openhands)
|
||||
try:
|
||||
action = FileEditAction(
|
||||
content=ORGINAL,
|
||||
start=-1,
|
||||
path=os.path.join('/workspace', 'app.py'),
|
||||
)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
|
||||
assert isinstance(
|
||||
obs, FileEditObservation
|
||||
), 'The observation should be a FileEditObservation.'
|
||||
|
||||
action = FileReadAction(
|
||||
path=os.path.join('/workspace', 'app.py'),
|
||||
)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.content.strip() == ORGINAL.strip()
|
||||
|
||||
finally:
|
||||
_close_test_runtime(runtime)
|
||||
|
||||
|
||||
EDIT = """# above stays the same
|
||||
@app.route('/')
|
||||
def index():
|
||||
numbers = list(range(1, 11))
|
||||
return '<table>' + ''.join([f'<tr><td>{i}</td></tr>' for i in numbers]) + '</table>'
|
||||
# below stays the same
|
||||
"""
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
TEST_IN_CI != 'True',
|
||||
reason='This test requires LLM to run.',
|
||||
)
|
||||
def test_edit(temp_dir, box_class, run_as_openhands):
|
||||
runtime = _load_runtime(temp_dir, box_class, run_as_openhands)
|
||||
try:
|
||||
action = FileEditAction(
|
||||
content=ORGINAL,
|
||||
path=os.path.join('/workspace', 'app.py'),
|
||||
)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
|
||||
assert isinstance(
|
||||
obs, FileEditObservation
|
||||
), 'The observation should be a FileEditObservation.'
|
||||
|
||||
action = FileReadAction(
|
||||
path=os.path.join('/workspace', 'app.py'),
|
||||
)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.content.strip() == ORGINAL.strip()
|
||||
|
||||
action = FileEditAction(
|
||||
content=EDIT,
|
||||
path=os.path.join('/workspace', 'app.py'),
|
||||
)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert (
|
||||
obs.content.strip()
|
||||
== (
|
||||
'--- /workspace/app.py\n'
|
||||
'+++ /workspace/app.py\n'
|
||||
'@@ -4,7 +4,7 @@\n'
|
||||
" @app.route('/')\n"
|
||||
' def index():\n'
|
||||
' numbers = list(range(1, 11))\n'
|
||||
'- return str(numbers)\n'
|
||||
"+ return '<table>' + ''.join([f'<tr><td>{i}</td></tr>' for i in numbers]) + '</table>'\n"
|
||||
'\n'
|
||||
" if __name__ == '__main__':\n"
|
||||
' app.run(port=5000)\n'
|
||||
).strip()
|
||||
)
|
||||
finally:
|
||||
_close_test_runtime(runtime)
|
||||
|
||||
|
||||
ORIGINAL_LONG = '\n'.join([f'This is line {i}' for i in range(1, 1000)])
|
||||
EDIT_LONG = """
|
||||
This is line 100 + 10
|
||||
This is line 101 + 10
|
||||
"""
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
TEST_IN_CI != 'True',
|
||||
reason='This test requires LLM to run.',
|
||||
)
|
||||
def test_edit_long_file(temp_dir, box_class, run_as_openhands):
|
||||
runtime = _load_runtime(temp_dir, box_class, run_as_openhands)
|
||||
try:
|
||||
action = FileEditAction(
|
||||
content=ORIGINAL_LONG,
|
||||
path=os.path.join('/workspace', 'app.py'),
|
||||
start=-1,
|
||||
)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
|
||||
assert isinstance(
|
||||
obs, FileEditObservation
|
||||
), 'The observation should be a FileEditObservation.'
|
||||
|
||||
action = FileReadAction(
|
||||
path=os.path.join('/workspace', 'app.py'),
|
||||
)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.content.strip() == ORIGINAL_LONG.strip()
|
||||
|
||||
action = FileEditAction(
|
||||
content=EDIT_LONG,
|
||||
path=os.path.join('/workspace', 'app.py'),
|
||||
start=100,
|
||||
end=200,
|
||||
)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert (
|
||||
obs.content.strip()
|
||||
== (
|
||||
'--- /workspace/app.py\n'
|
||||
'+++ /workspace/app.py\n'
|
||||
'@@ -97,8 +97,8 @@\n'
|
||||
' This is line 97\n'
|
||||
' This is line 98\n'
|
||||
' This is line 99\n'
|
||||
'-This is line 100\n'
|
||||
'-This is line 101\n'
|
||||
'+This is line 100 + 10\n'
|
||||
'+This is line 101 + 10\n'
|
||||
' This is line 102\n'
|
||||
' This is line 103\n'
|
||||
' This is line 104\n'
|
||||
).strip()
|
||||
)
|
||||
finally:
|
||||
_close_test_runtime(runtime)
|
||||
|
||||
|
||||
# ======================================================================================
|
||||
# Test FileEditObservation (things that are displayed to the agent)
|
||||
# ======================================================================================
|
||||
|
||||
|
||||
def test_edit_obs_insert_only():
|
||||
EDIT_LONG_INSERT_ONLY = (
|
||||
'\n'.join([f'This is line {i}' for i in range(1, 100)])
|
||||
+ EDIT_LONG
|
||||
+ '\n'.join([f'This is line {i}' for i in range(100, 1000)])
|
||||
)
|
||||
|
||||
diff = get_diff(ORIGINAL_LONG, EDIT_LONG_INSERT_ONLY, '/workspace/app.py')
|
||||
obs = FileEditObservation(
|
||||
content=diff,
|
||||
path='/workspace/app.py',
|
||||
prev_exist=True,
|
||||
old_content=ORIGINAL_LONG,
|
||||
new_content=EDIT_LONG_INSERT_ONLY,
|
||||
)
|
||||
assert (
|
||||
str(obs).strip()
|
||||
== """
|
||||
[Existing file /workspace/app.py is edited with 1 changes.]
|
||||
[begin of edit 1 / 1]
|
||||
(content before edit)
|
||||
98|This is line 98
|
||||
99|This is line 99
|
||||
100|This is line 100
|
||||
101|This is line 101
|
||||
(content after edit)
|
||||
98|This is line 98
|
||||
99|This is line 99
|
||||
+100|This is line 100 + 10
|
||||
+101|This is line 101 + 10
|
||||
102|This is line 100
|
||||
103|This is line 101
|
||||
[end of edit 1 / 1]
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
||||
def test_edit_obs_replace():
|
||||
_new_content = (
|
||||
'\n'.join([f'This is line {i}' for i in range(1, 100)])
|
||||
+ EDIT_LONG
|
||||
+ '\n'.join([f'This is line {i}' for i in range(102, 1000)])
|
||||
)
|
||||
|
||||
diff = get_diff(ORIGINAL_LONG, _new_content, '/workspace/app.py')
|
||||
obs = FileEditObservation(
|
||||
content=diff,
|
||||
path='/workspace/app.py',
|
||||
prev_exist=True,
|
||||
old_content=ORIGINAL_LONG,
|
||||
new_content=_new_content,
|
||||
)
|
||||
print(str(obs))
|
||||
assert (
|
||||
str(obs).strip()
|
||||
== """
|
||||
[Existing file /workspace/app.py is edited with 1 changes.]
|
||||
[begin of edit 1 / 1]
|
||||
(content before edit)
|
||||
98|This is line 98
|
||||
99|This is line 99
|
||||
-100|This is line 100
|
||||
-101|This is line 101
|
||||
102|This is line 102
|
||||
103|This is line 103
|
||||
(content after edit)
|
||||
98|This is line 98
|
||||
99|This is line 99
|
||||
+100|This is line 100 + 10
|
||||
+101|This is line 101 + 10
|
||||
102|This is line 102
|
||||
103|This is line 103
|
||||
[end of edit 1 / 1]
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
||||
def test_edit_obs_replace_with_empty_line():
|
||||
_new_content = (
|
||||
'\n'.join([f'This is line {i}' for i in range(1, 100)])
|
||||
+ '\n'
|
||||
+ EDIT_LONG
|
||||
+ '\n'.join([f'This is line {i}' for i in range(102, 1000)])
|
||||
)
|
||||
|
||||
diff = get_diff(ORIGINAL_LONG, _new_content, '/workspace/app.py')
|
||||
obs = FileEditObservation(
|
||||
content=diff,
|
||||
path='/workspace/app.py',
|
||||
prev_exist=True,
|
||||
old_content=ORIGINAL_LONG,
|
||||
new_content=_new_content,
|
||||
)
|
||||
print(str(obs))
|
||||
assert (
|
||||
str(obs).strip()
|
||||
== """
|
||||
[Existing file /workspace/app.py is edited with 1 changes.]
|
||||
[begin of edit 1 / 1]
|
||||
(content before edit)
|
||||
98|This is line 98
|
||||
99|This is line 99
|
||||
-100|This is line 100
|
||||
-101|This is line 101
|
||||
102|This is line 102
|
||||
103|This is line 103
|
||||
(content after edit)
|
||||
98|This is line 98
|
||||
99|This is line 99
|
||||
+100|
|
||||
+101|This is line 100 + 10
|
||||
+102|This is line 101 + 10
|
||||
103|This is line 102
|
||||
104|This is line 103
|
||||
[end of edit 1 / 1]
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
||||
def test_edit_obs_multiple_edits():
|
||||
_new_content = (
|
||||
'\n'.join([f'This is line {i}' for i in range(1, 50)])
|
||||
+ '\nbalabala\n'
|
||||
+ '\n'.join([f'This is line {i}' for i in range(50, 100)])
|
||||
+ EDIT_LONG
|
||||
+ '\n'.join([f'This is line {i}' for i in range(102, 1000)])
|
||||
)
|
||||
|
||||
diff = get_diff(ORIGINAL_LONG, _new_content, '/workspace/app.py')
|
||||
obs = FileEditObservation(
|
||||
content=diff,
|
||||
path='/workspace/app.py',
|
||||
prev_exist=True,
|
||||
old_content=ORIGINAL_LONG,
|
||||
new_content=_new_content,
|
||||
)
|
||||
assert (
|
||||
str(obs).strip()
|
||||
== """
|
||||
[Existing file /workspace/app.py is edited with 2 changes.]
|
||||
[begin of edit 1 / 2]
|
||||
(content before edit)
|
||||
48|This is line 48
|
||||
49|This is line 49
|
||||
50|This is line 50
|
||||
51|This is line 51
|
||||
(content after edit)
|
||||
48|This is line 48
|
||||
49|This is line 49
|
||||
+50|balabala
|
||||
51|This is line 50
|
||||
52|This is line 51
|
||||
[end of edit 1 / 2]
|
||||
-------------------------
|
||||
[begin of edit 2 / 2]
|
||||
(content before edit)
|
||||
98|This is line 98
|
||||
99|This is line 99
|
||||
-100|This is line 100
|
||||
-101|This is line 101
|
||||
102|This is line 102
|
||||
103|This is line 103
|
||||
(content after edit)
|
||||
99|This is line 98
|
||||
100|This is line 99
|
||||
+101|This is line 100 + 10
|
||||
+102|This is line 101 + 10
|
||||
103|This is line 102
|
||||
104|This is line 103
|
||||
[end of edit 2 / 2]
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
||||
def test_edit_visualize_failed_edit():
|
||||
_new_content = (
|
||||
'\n'.join([f'This is line {i}' for i in range(1, 50)])
|
||||
+ '\nbalabala\n'
|
||||
+ '\n'.join([f'This is line {i}' for i in range(50, 100)])
|
||||
+ EDIT_LONG
|
||||
+ '\n'.join([f'This is line {i}' for i in range(102, 1000)])
|
||||
)
|
||||
|
||||
diff = get_diff(ORIGINAL_LONG, _new_content, '/workspace/app.py')
|
||||
obs = FileEditObservation(
|
||||
content=diff,
|
||||
path='/workspace/app.py',
|
||||
prev_exist=True,
|
||||
old_content=ORIGINAL_LONG,
|
||||
new_content=_new_content,
|
||||
)
|
||||
assert (
|
||||
obs.visualize_diff(change_applied=False).strip()
|
||||
== """
|
||||
[Changes are NOT applied to /workspace/app.py - Here's how the file looks like if changes are applied.]
|
||||
[begin of ATTEMPTED edit 1 / 2]
|
||||
(content before ATTEMPTED edit)
|
||||
48|This is line 48
|
||||
49|This is line 49
|
||||
50|This is line 50
|
||||
51|This is line 51
|
||||
(content after ATTEMPTED edit)
|
||||
48|This is line 48
|
||||
49|This is line 49
|
||||
+50|balabala
|
||||
51|This is line 50
|
||||
52|This is line 51
|
||||
[end of ATTEMPTED edit 1 / 2]
|
||||
-------------------------
|
||||
[begin of ATTEMPTED edit 2 / 2]
|
||||
(content before ATTEMPTED edit)
|
||||
98|This is line 98
|
||||
99|This is line 99
|
||||
-100|This is line 100
|
||||
-101|This is line 101
|
||||
102|This is line 102
|
||||
103|This is line 103
|
||||
(content after ATTEMPTED edit)
|
||||
99|This is line 98
|
||||
100|This is line 99
|
||||
+101|This is line 100 + 10
|
||||
+102|This is line 101 + 10
|
||||
103|This is line 102
|
||||
104|This is line 103
|
||||
[end of ATTEMPTED edit 2 / 2]
|
||||
""".strip()
|
||||
)
|
||||
@ -22,7 +22,6 @@ from openhands.events.observation import (
|
||||
FileWriteObservation,
|
||||
IPythonRunCellObservation,
|
||||
)
|
||||
from openhands.runtime.client.runtime import EventStreamRuntime
|
||||
|
||||
# ============================================================================================================================
|
||||
# ipython-specific tests
|
||||
@ -199,215 +198,6 @@ def test_ipython_simple(temp_dir, box_class):
|
||||
_close_test_runtime(runtime)
|
||||
|
||||
|
||||
def _test_ipython_agentskills_fileop_pwd_impl(
|
||||
runtime: EventStreamRuntime, enable_auto_lint: bool
|
||||
):
|
||||
sandbox_dir = _get_sandbox_folder(runtime)
|
||||
# remove everything in /workspace
|
||||
action = CmdRunAction(command=f'rm -rf {sandbox_dir}/*')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
|
||||
action = CmdRunAction(command='mkdir test')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, CmdOutputObservation)
|
||||
assert obs.exit_code == 0
|
||||
|
||||
action = IPythonRunCellAction(code="create_file('hello.py')")
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, IPythonRunCellObservation)
|
||||
assert obs.content.replace('\r\n', '\n').strip().split('\n') == (
|
||||
f'[File: {sandbox_dir}/hello.py (1 lines total)]\n'
|
||||
'(this is the beginning of the file)\n'
|
||||
'1|\n'
|
||||
'(this is the end of the file)\n'
|
||||
'[File hello.py created.]\n'
|
||||
f'[Jupyter current working directory: {sandbox_dir}]\n'
|
||||
'[Jupyter Python interpreter: /openhands/poetry/openhands-ai-5O4_aCHf-py3.12/bin/python]'
|
||||
).strip().split('\n')
|
||||
|
||||
action = CmdRunAction(command='cd test')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, CmdOutputObservation)
|
||||
assert obs.exit_code == 0
|
||||
|
||||
# This should create a file in the current working directory
|
||||
# i.e., /workspace/test/hello.py instead of /workspace/hello.py
|
||||
action = IPythonRunCellAction(code="create_file('hello.py')")
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, IPythonRunCellObservation)
|
||||
assert obs.content.replace('\r\n', '\n').strip().split('\n') == (
|
||||
f'[File: {sandbox_dir}/test/hello.py (1 lines total)]\n'
|
||||
'(this is the beginning of the file)\n'
|
||||
'1|\n'
|
||||
'(this is the end of the file)\n'
|
||||
'[File hello.py created.]\n'
|
||||
f'[Jupyter current working directory: {sandbox_dir}/test]\n'
|
||||
'[Jupyter Python interpreter: /openhands/poetry/openhands-ai-5O4_aCHf-py3.12/bin/python]'
|
||||
).strip().split('\n')
|
||||
|
||||
if enable_auto_lint:
|
||||
# edit file, but make a mistake in indentation
|
||||
action = IPythonRunCellAction(
|
||||
code="insert_content_at_line('hello.py', 1, ' print(\"hello world\")')"
|
||||
)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, IPythonRunCellObservation)
|
||||
assert obs.content.replace('\r\n', '\n').strip().split('\n') == (
|
||||
f"""
|
||||
[Your proposed edit has introduced new syntax error(s). Please understand the errors and retry your edit command.]
|
||||
ERRORS:
|
||||
{sandbox_dir}/test/hello.py:1:3: E999 IndentationError: unexpected indent
|
||||
[This is how your edit would have looked if applied]
|
||||
-------------------------------------------------
|
||||
(this is the beginning of the file)
|
||||
1| print("hello world")
|
||||
(this is the end of the file)
|
||||
-------------------------------------------------
|
||||
|
||||
[This is the original code before your edit]
|
||||
-------------------------------------------------
|
||||
(this is the beginning of the file)
|
||||
1|
|
||||
(this is the end of the file)
|
||||
-------------------------------------------------
|
||||
Your changes have NOT been applied. Please fix your edit command and try again.
|
||||
You either need to 1) Specify the correct start/end line arguments or 2) Correct your edit code.
|
||||
DO NOT re-run the same failed edit command. Running it again will lead to the same error.
|
||||
[Jupyter current working directory: {sandbox_dir}/test]
|
||||
[Jupyter Python interpreter: /openhands/poetry/openhands-ai-5O4_aCHf-py3.12/bin/python]
|
||||
"""
|
||||
).strip().split('\n')
|
||||
|
||||
# edit file with correct indentation
|
||||
action = IPythonRunCellAction(
|
||||
code="insert_content_at_line('hello.py', 1, 'print(\"hello world\")')"
|
||||
)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, IPythonRunCellObservation)
|
||||
assert obs.content.replace('\r\n', '\n').strip().split('\n') == (
|
||||
f"""
|
||||
[File: {sandbox_dir}/test/hello.py (1 lines total after edit)]
|
||||
(this is the beginning of the file)
|
||||
1|print("hello world")
|
||||
(this is the end of the file)
|
||||
[File updated (edited at line 1). Please review the changes and make sure they are correct (correct indentation, no duplicate lines, etc). Edit the file again if necessary.]
|
||||
[Jupyter current working directory: {sandbox_dir}/test]
|
||||
[Jupyter Python interpreter: /openhands/poetry/openhands-ai-5O4_aCHf-py3.12/bin/python]
|
||||
"""
|
||||
).strip().split('\n')
|
||||
|
||||
action = CmdRunAction(command=f'rm -rf {sandbox_dir}/*')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
|
||||
|
||||
def test_ipython_agentskills_fileop_pwd_with_lint(
|
||||
temp_dir, box_class, run_as_openhands
|
||||
):
|
||||
runtime = _load_runtime(
|
||||
temp_dir, box_class, run_as_openhands, enable_auto_lint=True
|
||||
)
|
||||
_test_ipython_agentskills_fileop_pwd_impl(runtime, True)
|
||||
|
||||
_close_test_runtime(runtime)
|
||||
|
||||
|
||||
def test_ipython_agentskills_fileop_pwd_without_lint(
|
||||
temp_dir, box_class, run_as_openhands
|
||||
):
|
||||
runtime = _load_runtime(
|
||||
temp_dir, box_class, run_as_openhands, enable_auto_lint=False
|
||||
)
|
||||
_test_ipython_agentskills_fileop_pwd_impl(runtime, False)
|
||||
|
||||
_close_test_runtime(runtime)
|
||||
|
||||
|
||||
def test_ipython_agentskills_fileop_pwd_with_userdir(temp_dir, box_class):
|
||||
"""Make sure that cd in bash also update the current working directory in ipython.
|
||||
|
||||
Handle special case where the pwd is provided as "~", which should be expanded using os.path.expanduser
|
||||
on the client side.
|
||||
"""
|
||||
|
||||
runtime = _load_runtime(
|
||||
temp_dir,
|
||||
box_class,
|
||||
run_as_openhands=False,
|
||||
)
|
||||
|
||||
action = CmdRunAction(command='cd ~')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
|
||||
action = CmdRunAction(command='mkdir test && ls -la')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, CmdOutputObservation)
|
||||
assert obs.exit_code == 0
|
||||
|
||||
action = IPythonRunCellAction(code="create_file('hello.py')")
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, IPythonRunCellObservation)
|
||||
assert obs.content.replace('\r\n', '\n').strip().split('\n') == (
|
||||
'[File: /root/hello.py (1 lines total)]\n'
|
||||
'(this is the beginning of the file)\n'
|
||||
'1|\n'
|
||||
'(this is the end of the file)\n'
|
||||
'[File hello.py created.]\n'
|
||||
'[Jupyter current working directory: /root]\n'
|
||||
'[Jupyter Python interpreter: /openhands/poetry/openhands-ai-5O4_aCHf-py3.12/bin/python]'
|
||||
).strip().split('\n')
|
||||
|
||||
action = CmdRunAction(command='cd test')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, CmdOutputObservation)
|
||||
assert obs.exit_code == 0
|
||||
|
||||
# This should create a file in the current working directory
|
||||
# i.e., /workspace/test/hello.py instead of /workspace/hello.py
|
||||
action = IPythonRunCellAction(code="create_file('hello.py')")
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, IPythonRunCellObservation)
|
||||
assert obs.content.replace('\r\n', '\n').strip().split('\n') == (
|
||||
'[File: /root/test/hello.py (1 lines total)]\n'
|
||||
'(this is the beginning of the file)\n'
|
||||
'1|\n'
|
||||
'(this is the end of the file)\n'
|
||||
'[File hello.py created.]\n'
|
||||
'[Jupyter current working directory: /root/test]\n'
|
||||
'[Jupyter Python interpreter: /openhands/poetry/openhands-ai-5O4_aCHf-py3.12/bin/python]'
|
||||
).strip().split('\n')
|
||||
|
||||
_close_test_runtime(runtime)
|
||||
|
||||
|
||||
def test_ipython_package_install(temp_dir, box_class, run_as_openhands):
|
||||
"""Make sure that cd in bash also update the current working directory in ipython."""
|
||||
runtime = _load_runtime(temp_dir, box_class, run_as_openhands)
|
||||
|
||||
@ -1,22 +1,15 @@
|
||||
import contextlib
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import patch
|
||||
|
||||
import docx
|
||||
import pytest
|
||||
|
||||
from openhands.runtime.plugins.agent_skills.file_ops.file_ops import (
|
||||
MSG_FILE_UPDATED,
|
||||
WINDOW,
|
||||
_print_window,
|
||||
append_file,
|
||||
create_file,
|
||||
edit_file_by_replace,
|
||||
find_file,
|
||||
goto_line,
|
||||
insert_content_at_line,
|
||||
open_file,
|
||||
scroll_down,
|
||||
scroll_up,
|
||||
@ -182,28 +175,6 @@ def test_open_file_long_with_lineno(tmp_path):
|
||||
assert result.split('\n') == expected.split('\n')
|
||||
|
||||
|
||||
def test_create_file_unexist_path():
|
||||
with pytest.raises(FileNotFoundError):
|
||||
create_file('/unexist/path/a.txt')
|
||||
|
||||
|
||||
def test_create_file(tmp_path):
|
||||
temp_file_path = tmp_path / 'a.txt'
|
||||
with io.StringIO() as buf:
|
||||
with contextlib.redirect_stdout(buf):
|
||||
create_file(str(temp_file_path))
|
||||
result = buf.getvalue()
|
||||
|
||||
expected = (
|
||||
f'[File: {temp_file_path} (1 lines total)]\n'
|
||||
'(this is the beginning of the file)\n'
|
||||
'1|\n'
|
||||
'(this is the end of the file)\n'
|
||||
f'[File {temp_file_path} created.]\n'
|
||||
)
|
||||
assert result.split('\n') == expected.split('\n')
|
||||
|
||||
|
||||
def test_goto_line(tmp_path):
|
||||
temp_file_path = tmp_path / 'a.txt'
|
||||
total_lines = 1000
|
||||
@ -405,7 +376,7 @@ def test_scroll_down_edge(tmp_path):
|
||||
|
||||
def test_print_window_internal(tmp_path):
|
||||
test_file_path = tmp_path / 'a.txt'
|
||||
create_file(str(test_file_path))
|
||||
test_file_path.write_text('')
|
||||
open_file(str(test_file_path))
|
||||
with open(test_file_path, 'w') as file:
|
||||
for i in range(1, 101):
|
||||
@ -432,7 +403,7 @@ def test_print_window_internal(tmp_path):
|
||||
|
||||
def test_open_file_large_line_number(tmp_path):
|
||||
test_file_path = tmp_path / 'a.txt'
|
||||
create_file(str(test_file_path))
|
||||
test_file_path.write_text('')
|
||||
open_file(str(test_file_path))
|
||||
with open(test_file_path, 'w') as file:
|
||||
for i in range(1, 1000):
|
||||
@ -457,648 +428,6 @@ def test_open_file_large_line_number(tmp_path):
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_edit_file_by_replace_window(tmp_path):
|
||||
with patch.dict(os.environ, {'ENABLE_AUTO_LINT': 'True'}):
|
||||
content = """def any_int(a, b, c):
|
||||
return isinstance(a, int) and isinstance(b, int) and isinstance(c, int)
|
||||
|
||||
def test_any_int():
|
||||
assert any_int(1, 2, 3) == True
|
||||
assert any_int(1.5, 2, 3) == False
|
||||
assert any_int(1, 2.5, 3) == False
|
||||
assert any_int(1, 2, 3.5) == False
|
||||
assert any_int(1.0, 2, 3) == False
|
||||
assert any_int(1, 2.0, 3) == False
|
||||
assert any_int(1, 2, 3.0) == False
|
||||
assert any_int(0, 0, 0) == True
|
||||
assert any_int(-1, -2, -3) == True
|
||||
assert any_int(1, -2, 3) == True
|
||||
assert any_int(1.5, -2, 3) == False
|
||||
assert any_int(1, -2.5, 3) == False
|
||||
|
||||
def check(any_int):
|
||||
# Check some simple cases
|
||||
assert any_int(2, 3, 1)==True, "This prints if this assert fails 1 (good for debugging!)"
|
||||
assert any_int(2.5, 2, 3)==False, "This prints if this assert fails 2 (good for debugging!)"
|
||||
assert any_int(1.5, 5, 3.5)==False, "This prints if this assert fails 3 (good for debugging!)"
|
||||
assert any_int(2, 6, 2)==False, "This prints if this assert fails 4 (good for debugging!)"
|
||||
assert any_int(4, 2, 2)==True, "This prints if this assert fails 5 (good for debugging!)"
|
||||
assert any_int(2.2, 2.2, 2.2)==False, "This prints if this assert fails 6 (good for debugging!)"
|
||||
assert any_int(-4, 6, 2)==True, "This prints if this assert fails 7 (good for debugging!)"
|
||||
|
||||
# Check some edge cases that are easy to work out by hand.
|
||||
assert any_int(2,1,1)==True, "This prints if this assert fails 8 (also good for debugging!)"
|
||||
assert any_int(3,4,7)==True, "This prints if this assert fails 9 (also good for debugging!)"
|
||||
assert any_int(3.0,4,7)==False, "This prints if this assert fails 10 (also good for debugging!)"
|
||||
|
||||
check(any_int)"""
|
||||
|
||||
temp_file_path = tmp_path / 'error-test.py'
|
||||
temp_file_path.write_text(content)
|
||||
|
||||
open_file(str(temp_file_path))
|
||||
|
||||
with io.StringIO() as buf:
|
||||
with contextlib.redirect_stdout(buf):
|
||||
edit_file_by_replace(
|
||||
str(temp_file_path),
|
||||
to_replace=' assert any_int(1.0, 2, 3) == False',
|
||||
new_content=' assert any_int(1.0, 2, 3) == False',
|
||||
)
|
||||
result = buf.getvalue()
|
||||
expected = (
|
||||
'[Your proposed edit has introduced new syntax error(s). Please understand the errors and retry your edit command.]\n'
|
||||
'ERRORS:\n'
|
||||
+ str(temp_file_path)
|
||||
+ ':9:9: '
|
||||
+ 'E999 IndentationError: unexpected indent\n'
|
||||
'[This is how your edit would have looked if applied]\n'
|
||||
+ SEP
|
||||
+ '(this is the beginning of the file)\n'
|
||||
'1|def any_int(a, b, c):\n'
|
||||
'2| return isinstance(a, int) and isinstance(b, int) and isinstance(c, int)\n'
|
||||
'3|\n'
|
||||
'4|def test_any_int():\n'
|
||||
'5| assert any_int(1, 2, 3) == True\n'
|
||||
'6| assert any_int(1.5, 2, 3) == False\n'
|
||||
'7| assert any_int(1, 2.5, 3) == False\n'
|
||||
'8| assert any_int(1, 2, 3.5) == False\n'
|
||||
'9| assert any_int(1.0, 2, 3) == False\n'
|
||||
'10| assert any_int(1, 2.0, 3) == False\n'
|
||||
'11| assert any_int(1, 2, 3.0) == False\n'
|
||||
'12| assert any_int(0, 0, 0) == True\n'
|
||||
'13| assert any_int(-1, -2, -3) == True\n'
|
||||
'14| assert any_int(1, -2, 3) == True\n'
|
||||
'15| assert any_int(1.5, -2, 3) == False\n'
|
||||
'16| assert any_int(1, -2.5, 3) == False\n'
|
||||
'17|\n'
|
||||
'18|def check(any_int):\n'
|
||||
'19| # Check some simple cases\n'
|
||||
'20| assert any_int(2, 3, 1)==True, "This prints if this assert fails 1 (good for debugging!)"\n'
|
||||
'21| assert any_int(2.5, 2, 3)==False, "This prints if this assert fails 2 (good for debugging!)"\n'
|
||||
'(12 more lines below)\n' + SEP + '\n'
|
||||
'[This is the original code before your edit]\n'
|
||||
+ SEP
|
||||
+ '(this is the beginning of the file)\n'
|
||||
'1|def any_int(a, b, c):\n'
|
||||
'2| return isinstance(a, int) and isinstance(b, int) and isinstance(c, int)\n'
|
||||
'3|\n'
|
||||
'4|def test_any_int():\n'
|
||||
'5| assert any_int(1, 2, 3) == True\n'
|
||||
'6| assert any_int(1.5, 2, 3) == False\n'
|
||||
'7| assert any_int(1, 2.5, 3) == False\n'
|
||||
'8| assert any_int(1, 2, 3.5) == False\n'
|
||||
'9| assert any_int(1.0, 2, 3) == False\n'
|
||||
'10| assert any_int(1, 2.0, 3) == False\n'
|
||||
'11| assert any_int(1, 2, 3.0) == False\n'
|
||||
'12| assert any_int(0, 0, 0) == True\n'
|
||||
'13| assert any_int(-1, -2, -3) == True\n'
|
||||
'14| assert any_int(1, -2, 3) == True\n'
|
||||
'15| assert any_int(1.5, -2, 3) == False\n'
|
||||
'16| assert any_int(1, -2.5, 3) == False\n'
|
||||
'17|\n'
|
||||
'18|def check(any_int):\n'
|
||||
'19| # Check some simple cases\n'
|
||||
'20| assert any_int(2, 3, 1)==True, "This prints if this assert fails 1 (good for debugging!)"\n'
|
||||
'21| assert any_int(2.5, 2, 3)==False, "This prints if this assert fails 2 (good for debugging!)"\n'
|
||||
'(12 more lines below)\n'
|
||||
+ SEP
|
||||
+ 'Your changes have NOT been applied. Please fix your edit command and try again.\n'
|
||||
'You either need to 1) Specify the correct start/end line arguments or 2) Correct your edit code.\n'
|
||||
'DO NOT re-run the same failed edit command. Running it again will lead to the same error.\n'
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_edit_file_by_replace_with_multiple_errors(tmp_path):
|
||||
# If the file has multiple errors, but the suggested modification can only fix one error, make sure it is applied.
|
||||
with patch.dict(os.environ, {'ENABLE_AUTO_LINT': 'True'}):
|
||||
content = """def Sum(a,b):
|
||||
try:
|
||||
answer = a + b
|
||||
return answer
|
||||
except Exception:
|
||||
answer = ANOTHER_CONSTANT
|
||||
return answer
|
||||
Sum(1,1)
|
||||
"""
|
||||
|
||||
temp_file_path = tmp_path / 'problematic-file-test.py'
|
||||
temp_file_path.write_text(content)
|
||||
|
||||
open_file(str(temp_file_path))
|
||||
|
||||
with io.StringIO() as buf:
|
||||
with contextlib.redirect_stdout(buf):
|
||||
edit_file_by_replace(
|
||||
str(temp_file_path),
|
||||
to_replace=' answer = a + b',
|
||||
new_content=' answer = a+b',
|
||||
)
|
||||
result = buf.getvalue()
|
||||
expected = (
|
||||
f'[File: {temp_file_path} (8 lines total after edit)]\n'
|
||||
'(this is the beginning of the file)\n'
|
||||
'1|def Sum(a,b):\n'
|
||||
'2| try:\n'
|
||||
'3| answer = a+b\n'
|
||||
'4| return answer\n'
|
||||
'5| except Exception:\n'
|
||||
'6| answer = ANOTHER_CONSTANT\n'
|
||||
'7| return answer\n'
|
||||
'8|Sum(1,1)\n'
|
||||
'(this is the end of the file)\n'
|
||||
+ MSG_FILE_UPDATED.format(line_number=3)
|
||||
+ '\n'
|
||||
)
|
||||
assert result.split('\n') == expected.split('\n')
|
||||
|
||||
|
||||
# ================================
|
||||
|
||||
|
||||
def test_edit_file_by_replace(tmp_path):
|
||||
temp_file_path = tmp_path / 'a.txt'
|
||||
content = 'Line 1\nLine 2\nLine 3\nLine 4\nLine 5'
|
||||
temp_file_path.write_text(content)
|
||||
|
||||
open_file(str(temp_file_path))
|
||||
|
||||
with io.StringIO() as buf:
|
||||
with contextlib.redirect_stdout(buf):
|
||||
edit_file_by_replace(
|
||||
file_name=str(temp_file_path),
|
||||
to_replace='Line 1\nLine 2\nLine 3',
|
||||
new_content='REPLACE TEXT',
|
||||
)
|
||||
result = buf.getvalue()
|
||||
expected = (
|
||||
f'[File: {temp_file_path} (3 lines total after edit)]\n'
|
||||
'(this is the beginning of the file)\n'
|
||||
'1|REPLACE TEXT\n'
|
||||
'2|Line 4\n'
|
||||
'3|Line 5\n'
|
||||
'(this is the end of the file)\n'
|
||||
+ MSG_FILE_UPDATED.format(line_number=1)
|
||||
+ '\n'
|
||||
)
|
||||
assert result.split('\n') == expected.split('\n')
|
||||
|
||||
with open(temp_file_path, 'r') as file:
|
||||
lines = file.readlines()
|
||||
assert len(lines) == 3
|
||||
assert lines[0].rstrip() == 'REPLACE TEXT'
|
||||
assert lines[1].rstrip() == 'Line 4'
|
||||
assert lines[2].rstrip() == 'Line 5'
|
||||
|
||||
|
||||
def test_edit_file_by_replace_sameline(tmp_path):
|
||||
temp_file_path = tmp_path / 'a.txt'
|
||||
content = 'Line 1\nLine 2\nLine 2\nLine 4\nLine 5'
|
||||
temp_file_path.write_text(content)
|
||||
|
||||
open_file(str(temp_file_path))
|
||||
|
||||
with io.StringIO() as buf:
|
||||
with contextlib.redirect_stdout(buf):
|
||||
edit_file_by_replace(
|
||||
file_name=str(temp_file_path),
|
||||
to_replace='Line 2\nLine 2',
|
||||
new_content='Line 2\nREPLACE TEXT',
|
||||
)
|
||||
result = buf.getvalue()
|
||||
expected = (
|
||||
f'[File: {temp_file_path} (5 lines total after edit)]\n'
|
||||
'(this is the beginning of the file)\n'
|
||||
'1|Line 1\n'
|
||||
'2|Line 2\n'
|
||||
'3|REPLACE TEXT\n'
|
||||
'4|Line 4\n'
|
||||
'5|Line 5\n'
|
||||
'(this is the end of the file)\n'
|
||||
+ MSG_FILE_UPDATED.format(line_number=2)
|
||||
+ '\n'
|
||||
)
|
||||
assert result.split('\n') == expected.split('\n')
|
||||
|
||||
with open(temp_file_path, 'r') as file:
|
||||
lines = file.readlines()
|
||||
assert len(lines) == 5
|
||||
assert lines[0].rstrip() == 'Line 1'
|
||||
assert lines[1].rstrip() == 'Line 2'
|
||||
assert lines[2].rstrip() == 'REPLACE TEXT'
|
||||
assert lines[3].rstrip() == 'Line 4'
|
||||
assert lines[4].rstrip() == 'Line 5'
|
||||
|
||||
|
||||
def test_edit_file_by_replace_multiline(tmp_path):
|
||||
temp_file_path = tmp_path / 'a.txt'
|
||||
content = 'Line 1\nLine 2\nLine 2\nLine 4\nLine 5'
|
||||
temp_file_path.write_text(content)
|
||||
|
||||
open_file(str(temp_file_path))
|
||||
|
||||
with io.StringIO() as buf:
|
||||
with contextlib.redirect_stdout(buf):
|
||||
edit_file_by_replace(
|
||||
file_name=str(temp_file_path),
|
||||
to_replace='Line 2',
|
||||
new_content='REPLACE TEXT',
|
||||
)
|
||||
result = buf.getvalue()
|
||||
assert result.strip().startswith(
|
||||
'ERROR: `to_replace` appears more than once, please include enough lines to make code in `to_replace` unique'
|
||||
)
|
||||
|
||||
|
||||
def test_edit_file_by_replace_no_diff(tmp_path):
|
||||
temp_file_path = tmp_path / 'a.txt'
|
||||
content = 'Line 1\nLine 2\nLine 2\nLine 4\nLine 5'
|
||||
temp_file_path.write_text(content)
|
||||
|
||||
open_file(str(temp_file_path))
|
||||
|
||||
with io.StringIO() as buf:
|
||||
with contextlib.redirect_stdout(buf):
|
||||
edit_file_by_replace(
|
||||
file_name=str(temp_file_path),
|
||||
to_replace='Line 1',
|
||||
new_content='Line 1',
|
||||
)
|
||||
result = buf.getvalue()
|
||||
assert result.strip().startswith(
|
||||
'ERROR: `to_replace` and `new_content` must be different'
|
||||
)
|
||||
|
||||
|
||||
def test_edit_file_by_replace_toreplace_empty(tmp_path):
|
||||
temp_file_path = tmp_path / 'a.txt'
|
||||
content = 'Line 1\nLine 2\nLine 2\nLine 4\nLine 5'
|
||||
temp_file_path.write_text(content)
|
||||
|
||||
open_file(str(temp_file_path))
|
||||
|
||||
_capture_file_operation_error(
|
||||
lambda: edit_file_by_replace(
|
||||
file_name=str(temp_file_path),
|
||||
to_replace='',
|
||||
new_content='Line 1',
|
||||
),
|
||||
'ERROR: `to_replace` must not be empty.',
|
||||
)
|
||||
|
||||
|
||||
def test_edit_file_by_replace_unknown_file():
|
||||
_capture_file_operation_error(
|
||||
lambda: edit_file_by_replace(
|
||||
str('unknown file'),
|
||||
'ORIGINAL TEXT',
|
||||
'REPLACE TEXT',
|
||||
),
|
||||
'ERROR: File unknown file not found.',
|
||||
)
|
||||
|
||||
|
||||
def test_insert_content_at_line(tmp_path):
|
||||
temp_file_path = tmp_path / 'b.txt'
|
||||
content = 'Line 1\nLine 2\nLine 3'
|
||||
temp_file_path.write_text(content)
|
||||
open_file(str(temp_file_path))
|
||||
|
||||
with io.StringIO() as buf:
|
||||
with contextlib.redirect_stdout(buf):
|
||||
insert_content_at_line(
|
||||
file_name=str(temp_file_path),
|
||||
line_number=2,
|
||||
content='Inserted Line',
|
||||
)
|
||||
result = buf.getvalue()
|
||||
expected = (
|
||||
f'[File: {temp_file_path} (4 lines total after edit)]\n'
|
||||
'(this is the beginning of the file)\n'
|
||||
'1|Line 1\n'
|
||||
'2|Inserted Line\n'
|
||||
'3|Line 2\n'
|
||||
'4|Line 3\n'
|
||||
'(this is the end of the file)\n'
|
||||
+ MSG_FILE_UPDATED.format(line_number=2)
|
||||
+ '\n'
|
||||
)
|
||||
assert result.split('\n') == expected.split('\n')
|
||||
|
||||
with open(temp_file_path, 'r') as file:
|
||||
lines = file.readlines()
|
||||
assert len(lines) == 4
|
||||
assert lines[0].rstrip() == 'Line 1'
|
||||
assert lines[1].rstrip() == 'Inserted Line'
|
||||
assert lines[2].rstrip() == 'Line 2'
|
||||
assert lines[3].rstrip() == 'Line 3'
|
||||
|
||||
|
||||
def test_insert_content_at_line_from_scratch(tmp_path):
|
||||
temp_file_path = tmp_path / 'a.txt'
|
||||
create_file(str(temp_file_path))
|
||||
open_file(str(temp_file_path))
|
||||
|
||||
with io.StringIO() as buf:
|
||||
with contextlib.redirect_stdout(buf):
|
||||
insert_content_at_line(
|
||||
file_name=str(temp_file_path),
|
||||
line_number=1,
|
||||
content='REPLACE TEXT',
|
||||
)
|
||||
result = buf.getvalue()
|
||||
expected = (
|
||||
f'[File: {temp_file_path} (1 lines total after edit)]\n'
|
||||
'(this is the beginning of the file)\n'
|
||||
'1|REPLACE TEXT\n'
|
||||
'(this is the end of the file)\n'
|
||||
+ MSG_FILE_UPDATED.format(line_number=1)
|
||||
+ '\n'
|
||||
)
|
||||
assert result.split('\n') == expected.split('\n')
|
||||
|
||||
with open(temp_file_path, 'r') as file:
|
||||
lines = file.readlines()
|
||||
assert len(lines) == 1
|
||||
assert lines[0].rstrip() == 'REPLACE TEXT'
|
||||
|
||||
|
||||
def test_insert_content_at_line_from_scratch_emptyfile(tmp_path):
|
||||
temp_file_path = tmp_path / 'a.txt'
|
||||
with open(temp_file_path, 'w') as file:
|
||||
file.write('')
|
||||
open_file(str(temp_file_path))
|
||||
|
||||
with io.StringIO() as buf:
|
||||
with contextlib.redirect_stdout(buf):
|
||||
insert_content_at_line(
|
||||
file_name=str(temp_file_path),
|
||||
line_number=1,
|
||||
content='REPLACE TEXT',
|
||||
)
|
||||
result = buf.getvalue()
|
||||
expected = (
|
||||
f'[File: {temp_file_path} (1 lines total after edit)]\n'
|
||||
'(this is the beginning of the file)\n'
|
||||
'1|REPLACE TEXT\n'
|
||||
'(this is the end of the file)\n'
|
||||
+ MSG_FILE_UPDATED.format(line_number=1)
|
||||
+ '\n'
|
||||
)
|
||||
assert result.split('\n') == expected.split('\n')
|
||||
|
||||
with open(temp_file_path, 'r') as file:
|
||||
lines = file.readlines()
|
||||
assert len(lines) == 1
|
||||
assert lines[0].rstrip() == 'REPLACE TEXT'
|
||||
|
||||
|
||||
def test_insert_content_at_line_emptyline(tmp_path):
|
||||
temp_file_path = tmp_path / 'b.txt'
|
||||
content = 'Line 1\n\n'
|
||||
temp_file_path.write_text(content)
|
||||
open_file(str(temp_file_path))
|
||||
|
||||
with io.StringIO() as buf:
|
||||
with contextlib.redirect_stdout(buf):
|
||||
insert_content_at_line(
|
||||
file_name=str(temp_file_path),
|
||||
line_number=2,
|
||||
content='Inserted Line',
|
||||
)
|
||||
result = buf.getvalue()
|
||||
expected = (
|
||||
f'[File: {temp_file_path} (3 lines total after edit)]\n'
|
||||
'(this is the beginning of the file)\n'
|
||||
'1|Line 1\n'
|
||||
'2|Inserted Line\n'
|
||||
'3|\n'
|
||||
'(this is the end of the file)\n'
|
||||
+ MSG_FILE_UPDATED.format(line_number=2)
|
||||
+ '\n'
|
||||
)
|
||||
assert result.split('\n') == expected.split('\n')
|
||||
|
||||
with open(temp_file_path, 'r') as file:
|
||||
lines = file.readlines()
|
||||
assert len(lines) == 3
|
||||
assert lines[0].rstrip() == 'Line 1'
|
||||
assert lines[1].rstrip() == 'Inserted Line'
|
||||
|
||||
|
||||
def test_insert_content_at_line_from_scratch_multiline_with_backticks_and_second_edit(
|
||||
tmp_path,
|
||||
):
|
||||
temp_file_path = tmp_path / 'a.txt'
|
||||
create_file(str(temp_file_path))
|
||||
open_file(str(temp_file_path))
|
||||
|
||||
with io.StringIO() as buf:
|
||||
with contextlib.redirect_stdout(buf):
|
||||
insert_content_at_line(
|
||||
str(temp_file_path),
|
||||
1,
|
||||
'`REPLACE TEXT1`\n`REPLACE TEXT2`\n`REPLACE TEXT3`',
|
||||
)
|
||||
result = buf.getvalue()
|
||||
expected = (
|
||||
f'[File: {temp_file_path} (3 lines total after edit)]\n'
|
||||
'(this is the beginning of the file)\n'
|
||||
'1|`REPLACE TEXT1`\n'
|
||||
'2|`REPLACE TEXT2`\n'
|
||||
'3|`REPLACE TEXT3`\n'
|
||||
'(this is the end of the file)\n'
|
||||
+ MSG_FILE_UPDATED.format(line_number=1)
|
||||
+ '\n'
|
||||
)
|
||||
assert result.split('\n') == expected.split('\n')
|
||||
|
||||
with open(temp_file_path, 'r') as file:
|
||||
lines = file.readlines()
|
||||
assert len(lines) == 3
|
||||
assert lines[0].rstrip() == '`REPLACE TEXT1`'
|
||||
assert lines[1].rstrip() == '`REPLACE TEXT2`'
|
||||
assert lines[2].rstrip() == '`REPLACE TEXT3`'
|
||||
|
||||
# Check that no backticks are escaped in the edit_file_by_replace call
|
||||
assert '\\`' not in result
|
||||
|
||||
# Perform a second edit
|
||||
with io.StringIO() as buf:
|
||||
with contextlib.redirect_stdout(buf):
|
||||
edit_file_by_replace(
|
||||
str(temp_file_path),
|
||||
'`REPLACE TEXT1`\n`REPLACE TEXT2`\n`REPLACE TEXT3`',
|
||||
'`REPLACED TEXT1`\n`REPLACED TEXT2`\n`REPLACED TEXT3`',
|
||||
)
|
||||
second_result = buf.getvalue()
|
||||
second_expected = (
|
||||
f'[File: {temp_file_path} (3 lines total after edit)]\n'
|
||||
'(this is the beginning of the file)\n'
|
||||
'1|`REPLACED TEXT1`\n'
|
||||
'2|`REPLACED TEXT2`\n'
|
||||
'3|`REPLACED TEXT3`\n'
|
||||
'(this is the end of the file)\n'
|
||||
+ MSG_FILE_UPDATED.format(line_number=1)
|
||||
+ '\n'
|
||||
)
|
||||
assert second_result.split('\n') == second_expected.split('\n')
|
||||
|
||||
with open(temp_file_path, 'r') as file:
|
||||
lines = file.readlines()
|
||||
assert len(lines) == 3
|
||||
assert lines[0].rstrip() == '`REPLACED TEXT1`'
|
||||
assert lines[1].rstrip() == '`REPLACED TEXT2`'
|
||||
assert lines[2].rstrip() == '`REPLACED TEXT3`'
|
||||
|
||||
# Check that no backticks are escaped in the second edit_file_by_replace call
|
||||
assert '\\`' not in second_result
|
||||
|
||||
|
||||
def test_insert_content_at_line_from_scratch_multiline(tmp_path):
|
||||
temp_file_path = tmp_path / 'a.txt'
|
||||
create_file(str(temp_file_path))
|
||||
open_file(temp_file_path)
|
||||
|
||||
with io.StringIO() as buf:
|
||||
with contextlib.redirect_stdout(buf):
|
||||
insert_content_at_line(
|
||||
str(temp_file_path),
|
||||
1,
|
||||
content='REPLACE TEXT1\nREPLACE TEXT2\nREPLACE TEXT3',
|
||||
)
|
||||
result = buf.getvalue()
|
||||
expected = (
|
||||
f'[File: {temp_file_path} (3 lines total after edit)]\n'
|
||||
'(this is the beginning of the file)\n'
|
||||
'1|REPLACE TEXT1\n'
|
||||
'2|REPLACE TEXT2\n'
|
||||
'3|REPLACE TEXT3\n'
|
||||
'(this is the end of the file)\n'
|
||||
+ MSG_FILE_UPDATED.format(line_number=1)
|
||||
+ '\n'
|
||||
)
|
||||
assert result.split('\n') == expected.split('\n')
|
||||
|
||||
with open(temp_file_path, 'r') as file:
|
||||
lines = file.readlines()
|
||||
assert len(lines) == 3
|
||||
assert lines[0].rstrip() == 'REPLACE TEXT1'
|
||||
assert lines[1].rstrip() == 'REPLACE TEXT2'
|
||||
assert lines[2].rstrip() == 'REPLACE TEXT3'
|
||||
|
||||
|
||||
def test_insert_content_at_line_not_opened():
|
||||
_capture_file_operation_error(
|
||||
lambda: insert_content_at_line(
|
||||
str('unknown file'),
|
||||
1,
|
||||
'REPLACE TEXT',
|
||||
),
|
||||
'ERROR: Invalid path or file name.',
|
||||
)
|
||||
|
||||
|
||||
def test_append_file(tmp_path):
|
||||
temp_file_path = tmp_path / 'a.txt'
|
||||
content = 'Line 1\nLine 2'
|
||||
temp_file_path.write_text(content)
|
||||
|
||||
open_file(str(temp_file_path))
|
||||
|
||||
with io.StringIO() as buf:
|
||||
with contextlib.redirect_stdout(buf):
|
||||
append_file(str(temp_file_path), content='APPENDED TEXT')
|
||||
result = buf.getvalue()
|
||||
expected = (
|
||||
f'[File: {temp_file_path} (3 lines total after edit)]\n'
|
||||
'(this is the beginning of the file)\n'
|
||||
'1|Line 1\n'
|
||||
'2|Line 2\n'
|
||||
'3|APPENDED TEXT\n'
|
||||
'(this is the end of the file)\n'
|
||||
+ MSG_FILE_UPDATED.format(line_number=2)
|
||||
+ '\n'
|
||||
)
|
||||
assert result.split('\n') == expected.split('\n')
|
||||
|
||||
with open(temp_file_path, 'r') as file:
|
||||
lines = file.readlines()
|
||||
assert len(lines) == 3
|
||||
assert lines[0].rstrip() == 'Line 1'
|
||||
assert lines[1].rstrip() == 'Line 2'
|
||||
assert lines[2].rstrip() == 'APPENDED TEXT'
|
||||
|
||||
|
||||
def test_append_file_from_scratch(tmp_path):
|
||||
temp_file_path = tmp_path / 'a.txt'
|
||||
create_file(str(temp_file_path))
|
||||
try:
|
||||
open_file(str(temp_file_path))
|
||||
with io.StringIO() as buf:
|
||||
with contextlib.redirect_stdout(buf):
|
||||
append_file(str(temp_file_path), content='APPENDED TEXT')
|
||||
result = buf.getvalue()
|
||||
expected = (
|
||||
f'[File: {temp_file_path} (1 lines total after edit)]\n'
|
||||
'(this is the beginning of the file)\n'
|
||||
'1|APPENDED TEXT\n'
|
||||
'(this is the end of the file)\n'
|
||||
+ MSG_FILE_UPDATED.format(line_number=1)
|
||||
+ '\n'
|
||||
)
|
||||
assert result.split('\n') == expected.split('\n')
|
||||
|
||||
with open(temp_file_path, 'r') as file:
|
||||
lines = file.readlines()
|
||||
assert len(lines) == 1
|
||||
assert lines[0].rstrip() == 'APPENDED TEXT'
|
||||
finally:
|
||||
os.remove(temp_file_path)
|
||||
|
||||
|
||||
def test_append_file_from_scratch_multiline(tmp_path):
|
||||
temp_file_path = tmp_path / 'a3.txt'
|
||||
create_file(str(temp_file_path))
|
||||
try:
|
||||
open_file(temp_file_path)
|
||||
with io.StringIO() as buf:
|
||||
with contextlib.redirect_stdout(buf):
|
||||
append_file(
|
||||
str(temp_file_path),
|
||||
content='APPENDED TEXT1\nAPPENDED TEXT2\nAPPENDED TEXT3',
|
||||
)
|
||||
result = buf.getvalue()
|
||||
expected = (
|
||||
f'[File: {temp_file_path} (3 lines total after edit)]\n'
|
||||
'(this is the beginning of the file)\n'
|
||||
'1|APPENDED TEXT1\n'
|
||||
'2|APPENDED TEXT2\n'
|
||||
'3|APPENDED TEXT3\n'
|
||||
'(this is the end of the file)\n'
|
||||
+ MSG_FILE_UPDATED.format(line_number=1)
|
||||
+ '\n'
|
||||
)
|
||||
assert result.split('\n') == expected.split('\n')
|
||||
|
||||
with open(temp_file_path, 'r') as file:
|
||||
lines = file.readlines()
|
||||
assert len(lines) == 3
|
||||
assert lines[0].rstrip() == 'APPENDED TEXT1'
|
||||
assert lines[1].rstrip() == 'APPENDED TEXT2'
|
||||
assert lines[2].rstrip() == 'APPENDED TEXT3'
|
||||
finally:
|
||||
os.remove(temp_file_path)
|
||||
|
||||
|
||||
def test_append_file_not_opened():
|
||||
_capture_file_operation_error(
|
||||
lambda: append_file('unknown file', content='APPENDED TEXT'),
|
||||
'ERROR: Invalid path or file name.',
|
||||
)
|
||||
|
||||
|
||||
def test_search_dir(tmp_path):
|
||||
# create files with the search term "bingo"
|
||||
for i in range(1, 101):
|
||||
@ -1269,143 +598,6 @@ def test_find_file_not_exist_file_specific_path(tmp_path):
|
||||
assert result.split('\n') == expected.split('\n')
|
||||
|
||||
|
||||
def test_edit_lint_file_pass(tmp_path):
|
||||
# Enable linting
|
||||
with patch.dict(os.environ, {'ENABLE_AUTO_LINT': 'True'}):
|
||||
file_path = _generate_test_file_with_lines(tmp_path, 1)
|
||||
|
||||
# Test linting functionality
|
||||
with io.StringIO() as buf:
|
||||
with contextlib.redirect_stdout(buf):
|
||||
open_file(str(file_path))
|
||||
insert_content_at_line(str(file_path), 1, "print('hello')\n")
|
||||
result = buf.getvalue()
|
||||
assert result is not None
|
||||
expected = (
|
||||
f'[File: {file_path} (1 lines total)]\n'
|
||||
'(this is the beginning of the file)\n'
|
||||
'1|\n'
|
||||
'(this is the end of the file)\n'
|
||||
f'[File: {file_path} (1 lines total after edit)]\n'
|
||||
'(this is the beginning of the file)\n'
|
||||
"1|print('hello')\n"
|
||||
'(this is the end of the file)\n'
|
||||
+ MSG_FILE_UPDATED.format(line_number=1)
|
||||
+ '\n'
|
||||
)
|
||||
assert result.split('\n') == expected.split('\n')
|
||||
|
||||
|
||||
def test_lint_file_fail_undefined_name(tmp_path, capsys):
|
||||
with patch.dict(os.environ, {'ENABLE_AUTO_LINT': 'True'}):
|
||||
current_line = 1
|
||||
|
||||
file_path = _generate_test_file_with_lines(tmp_path, 1)
|
||||
|
||||
open_file(str(file_path), current_line)
|
||||
insert_content_at_line(str(file_path), 1, 'undefined_name()\n')
|
||||
|
||||
result = capsys.readouterr().out
|
||||
assert result is not None
|
||||
|
||||
expected = (
|
||||
f'[File: {file_path} (1 lines total)]\n'
|
||||
'(this is the beginning of the file)\n'
|
||||
'1|\n'
|
||||
'(this is the end of the file)\n'
|
||||
'[Your proposed edit has introduced new syntax error(s). Please understand the errors and retry your edit command.]\n'
|
||||
'ERRORS:\n'
|
||||
f"{file_path}:1:1: F821 undefined name 'undefined_name'\n"
|
||||
'[This is how your edit would have looked if applied]\n'
|
||||
+ SEP
|
||||
+ '(this is the beginning of the file)\n'
|
||||
'1|undefined_name()\n'
|
||||
'(this is the end of the file)\n'
|
||||
+ SEP
|
||||
+ '\n[This is the original code before your edit]\n'
|
||||
+ SEP
|
||||
+ '(this is the beginning of the file)\n'
|
||||
'1|\n'
|
||||
'(this is the end of the file)\n'
|
||||
+ SEP
|
||||
+ 'Your changes have NOT been applied. Please fix your edit command and try again.\n'
|
||||
'You either need to 1) Specify the correct start/end line arguments or 2) Correct your edit code.\n'
|
||||
'DO NOT re-run the same failed edit command. Running it again will lead to the same error.\n'
|
||||
)
|
||||
assert result.split('\n') == expected.split('\n')
|
||||
|
||||
|
||||
def test_lint_file_fail_undefined_name_long(tmp_path, capsys):
|
||||
with patch.dict(os.environ, {'ENABLE_AUTO_LINT': 'True'}):
|
||||
num_lines = 1000
|
||||
error_line = 500
|
||||
|
||||
file_path = _generate_test_file_with_lines(tmp_path, num_lines)
|
||||
|
||||
error_message = (
|
||||
f"{file_path}:{error_line}:1: F821 undefined name 'undefined_name'"
|
||||
)
|
||||
|
||||
open_file(str(file_path))
|
||||
insert_content_at_line(str(file_path), error_line, 'undefined_name()\n')
|
||||
|
||||
result = capsys.readouterr().out
|
||||
assert result is not None
|
||||
|
||||
open_lines = '\n'.join([f'{i}|' for i in range(1, WINDOW + 1)])
|
||||
expected = (
|
||||
f'[File: {file_path} ({num_lines} lines total)]\n'
|
||||
'(this is the beginning of the file)\n'
|
||||
f'{open_lines}\n'
|
||||
f'({num_lines - WINDOW} more lines below)\n'
|
||||
f'[Use `scroll_down` to view the next 100 lines of the file!]\n'
|
||||
'[Your proposed edit has introduced new syntax error(s). Please understand the errors and retry your edit command.]\n'
|
||||
f'ERRORS:\n{error_message}\n'
|
||||
'[This is how your edit would have looked if applied]\n'
|
||||
+ SEP
|
||||
+ '(489 more lines above)\n'
|
||||
+ _numbered_test_lines(error_line - 10, error_line - 1)
|
||||
+ '500|undefined_name()\n'
|
||||
+ _numbered_test_lines(error_line + 1, error_line + 10)
|
||||
+ '(491 more lines below)\n'
|
||||
+ SEP
|
||||
+ '\n[This is the original code before your edit]\n'
|
||||
+ SEP
|
||||
+ '(489 more lines above)\n'
|
||||
+ _numbered_test_lines(error_line - 10, error_line + 10)
|
||||
+ '(490 more lines below)\n'
|
||||
+ SEP
|
||||
+ 'Your changes have NOT been applied. Please fix your edit command and try again.\n'
|
||||
'You either need to 1) Specify the correct start/end line arguments or 2) Correct your edit code.\n'
|
||||
'DO NOT re-run the same failed edit command. Running it again will lead to the same error.\n'
|
||||
)
|
||||
assert result.split('\n') == expected.split('\n')
|
||||
|
||||
|
||||
def test_lint_file_disabled_undefined_name(tmp_path, capsys):
|
||||
with patch.dict(os.environ, {'ENABLE_AUTO_LINT': 'False'}):
|
||||
file_path = _generate_test_file_with_lines(tmp_path, 1)
|
||||
|
||||
open_file(str(file_path))
|
||||
insert_content_at_line(str(file_path), 1, 'undefined_name()\n')
|
||||
|
||||
result = capsys.readouterr().out
|
||||
assert result is not None
|
||||
expected = (
|
||||
f'[File: {file_path} (1 lines total)]\n'
|
||||
'(this is the beginning of the file)\n'
|
||||
'1|\n'
|
||||
'(this is the end of the file)\n'
|
||||
f'[File: {file_path} (1 lines total after edit)]\n'
|
||||
'(this is the beginning of the file)\n'
|
||||
'1|undefined_name()\n'
|
||||
'(this is the end of the file)\n'
|
||||
+ MSG_FILE_UPDATED.format(line_number=1)
|
||||
+ '\n'
|
||||
)
|
||||
assert result.split('\n') == expected.split('\n')
|
||||
|
||||
|
||||
def test_parse_docx(tmp_path):
|
||||
# Create a DOCX file with some content
|
||||
test_docx_path = tmp_path / 'test.docx'
|
||||
@ -1523,42 +715,3 @@ def test_parse_pptx(tmp_path):
|
||||
'Hello, this is the second test PPTX slide.\n\n'
|
||||
)
|
||||
assert output == expected_output, f'Expected output does not match. Got: {output}'
|
||||
|
||||
|
||||
def test_lint_file_fail_non_python(tmp_path, capsys):
|
||||
with patch.dict(os.environ, {'ENABLE_AUTO_LINT': 'True'}):
|
||||
current_line = 1
|
||||
file_path = _generate_ruby_test_file_with_lines(tmp_path, 1)
|
||||
|
||||
open_file(str(file_path), current_line)
|
||||
insert_content_at_line(
|
||||
str(file_path), 1, "def print_hello_world()\n puts 'Hello World'"
|
||||
)
|
||||
result = capsys.readouterr().out
|
||||
assert result is not None
|
||||
expected = (
|
||||
f'[File: {file_path} (1 lines total)]\n'
|
||||
'(this is the beginning of the file)\n'
|
||||
'1|\n'
|
||||
'(this is the end of the file)\n'
|
||||
'[Your proposed edit has introduced new syntax error(s). Please understand the errors and retry your edit command.]\n'
|
||||
'ERRORS:\n'
|
||||
f'{file_path}:1:1: Syntax error\n'
|
||||
'[This is how your edit would have looked if applied]\n'
|
||||
+ SEP
|
||||
+ '(this is the beginning of the file)\n'
|
||||
'1|def print_hello_world()\n'
|
||||
"2| puts 'Hello World'\n"
|
||||
'(this is the end of the file)\n'
|
||||
'-------------------------------------------------\n\n'
|
||||
'[This is the original code before your edit]\n'
|
||||
+ SEP
|
||||
+ '(this is the beginning of the file)\n'
|
||||
'1|\n'
|
||||
'(this is the end of the file)\n'
|
||||
+ SEP
|
||||
+ 'Your changes have NOT been applied. Please fix your edit command and try again.\n'
|
||||
'You either need to 1) Specify the correct start/end line arguments or 2) Correct your edit code.\n'
|
||||
'DO NOT re-run the same failed edit command. Running it again will lead to the same error.\n'
|
||||
)
|
||||
assert result.split('\n') == expected.split('\n')
|
||||
|
||||
136
tests/unit/test_chunk_localizer.py
Normal file
136
tests/unit/test_chunk_localizer.py
Normal file
@ -0,0 +1,136 @@
|
||||
import pytest
|
||||
|
||||
from openhands.utils.chunk_localizer import (
|
||||
Chunk,
|
||||
create_chunks,
|
||||
get_top_k_chunk_matches,
|
||||
normalized_lcs,
|
||||
)
|
||||
|
||||
|
||||
def test_chunk_creation():
|
||||
chunk = Chunk(text='test chunk', line_range=(1, 1))
|
||||
assert chunk.text == 'test chunk'
|
||||
assert chunk.line_range == (1, 1)
|
||||
assert chunk.normalized_lcs is None
|
||||
|
||||
|
||||
def test_chunk_visualization(capsys):
|
||||
chunk = Chunk(text='line1\nline2', line_range=(1, 2))
|
||||
assert chunk.visualize() == '1|line1\n2|line2\n'
|
||||
|
||||
|
||||
def test_create_chunks_raw_string():
|
||||
text = 'line1\nline2\nline3\nline4\nline5'
|
||||
chunks = create_chunks(text, size=2)
|
||||
assert len(chunks) == 3
|
||||
assert chunks[0].text == 'line1\nline2'
|
||||
assert chunks[0].line_range == (1, 2)
|
||||
assert chunks[1].text == 'line3\nline4'
|
||||
assert chunks[1].line_range == (3, 4)
|
||||
assert chunks[2].text == 'line5'
|
||||
assert chunks[2].line_range == (5, 5)
|
||||
|
||||
|
||||
def test_normalized_lcs():
|
||||
chunk = 'abcdef'
|
||||
edit_draft = 'abcxyz'
|
||||
assert normalized_lcs(chunk, edit_draft) == 0.5
|
||||
|
||||
|
||||
def test_get_top_k_chunk_matches():
|
||||
text = 'chunk1\nchunk2\nchunk3\nchunk4'
|
||||
query = 'chunk2'
|
||||
matches = get_top_k_chunk_matches(text, query, k=2, max_chunk_size=1)
|
||||
assert len(matches) == 2
|
||||
assert matches[0].text == 'chunk2'
|
||||
assert matches[0].line_range == (2, 2)
|
||||
assert matches[0].normalized_lcs == 1.0
|
||||
assert matches[1].text == 'chunk1'
|
||||
assert matches[1].line_range == (1, 1)
|
||||
assert matches[1].normalized_lcs == 5 / 6
|
||||
assert matches[0].normalized_lcs > matches[1].normalized_lcs
|
||||
|
||||
|
||||
def test_create_chunks_with_empty_lines():
|
||||
text = 'line1\n\nline3\n\n\nline6'
|
||||
chunks = create_chunks(text, size=2)
|
||||
assert len(chunks) == 3
|
||||
assert chunks[0].text == 'line1\n'
|
||||
assert chunks[0].line_range == (1, 2)
|
||||
assert chunks[1].text == 'line3\n'
|
||||
assert chunks[1].line_range == (3, 4)
|
||||
assert chunks[2].text == '\nline6'
|
||||
assert chunks[2].line_range == (5, 6)
|
||||
|
||||
|
||||
def test_create_chunks_with_large_size():
|
||||
text = 'line1\nline2\nline3'
|
||||
chunks = create_chunks(text, size=10)
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].text == text
|
||||
assert chunks[0].line_range == (1, 3)
|
||||
|
||||
|
||||
def test_create_chunks_with_last_chunk_smaller():
|
||||
text = 'line1\nline2\nline3'
|
||||
chunks = create_chunks(text, size=2)
|
||||
assert len(chunks) == 2
|
||||
assert chunks[0].text == 'line1\nline2'
|
||||
assert chunks[0].line_range == (1, 2)
|
||||
assert chunks[1].text == 'line3'
|
||||
assert chunks[1].line_range == (3, 3)
|
||||
|
||||
|
||||
def test_normalized_lcs_edge_cases():
|
||||
assert normalized_lcs('', '') == 0.0
|
||||
assert normalized_lcs('a', '') == 0.0
|
||||
assert normalized_lcs('', 'a') == 0.0
|
||||
assert normalized_lcs('abcde', 'ace') == 0.6
|
||||
|
||||
|
||||
def test_get_top_k_chunk_matches_with_ties():
|
||||
text = 'chunk1\nchunk2\nchunk3\nchunk1'
|
||||
query = 'chunk'
|
||||
matches = get_top_k_chunk_matches(text, query, k=3, max_chunk_size=1)
|
||||
assert len(matches) == 3
|
||||
assert all(match.normalized_lcs == 5 / 6 for match in matches)
|
||||
assert {match.text for match in matches} == {'chunk1', 'chunk2', 'chunk3'}
|
||||
|
||||
|
||||
def test_get_top_k_chunk_matches_with_large_k():
|
||||
text = 'chunk1\nchunk2\nchunk3'
|
||||
query = 'chunk'
|
||||
matches = get_top_k_chunk_matches(text, query, k=10, max_chunk_size=1)
|
||||
assert len(matches) == 3 # Should return all chunks even if k is larger
|
||||
|
||||
|
||||
@pytest.mark.parametrize('chunk_size', [1, 2, 3, 4])
|
||||
def test_create_chunks_different_sizes(chunk_size):
|
||||
text = 'line1\nline2\nline3\nline4'
|
||||
chunks = create_chunks(text, size=chunk_size)
|
||||
assert len(chunks) == (4 + chunk_size - 1) // chunk_size
|
||||
assert sum(len(chunk.text.split('\n')) for chunk in chunks) == 4
|
||||
|
||||
|
||||
def test_chunk_visualization_with_special_characters():
|
||||
chunk = Chunk(text='line1\nline2\t\nline3\r', line_range=(1, 3))
|
||||
assert chunk.visualize() == '1|line1\n2|line2\t\n3|line3\r\n'
|
||||
|
||||
|
||||
def test_normalized_lcs_with_unicode():
|
||||
chunk = 'Hello, 世界!'
|
||||
edit_draft = 'Hello, world!'
|
||||
assert 0 < normalized_lcs(chunk, edit_draft) < 1
|
||||
|
||||
|
||||
def test_get_top_k_chunk_matches_with_overlapping_chunks():
|
||||
text = 'chunk1\nchunk2\nchunk3\nchunk4'
|
||||
query = 'chunk2\nchunk3'
|
||||
matches = get_top_k_chunk_matches(text, query, k=2, max_chunk_size=2)
|
||||
assert len(matches) == 2
|
||||
assert matches[0].text == 'chunk1\nchunk2'
|
||||
assert matches[0].line_range == (1, 2)
|
||||
assert matches[1].text == 'chunk3\nchunk4'
|
||||
assert matches[1].line_range == (3, 4)
|
||||
assert matches[0].normalized_lcs == matches[1].normalized_lcs
|
||||
@ -1,3 +1,4 @@
|
||||
import copy
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@ -10,8 +11,8 @@ from litellm.exceptions import (
|
||||
|
||||
from openhands.core.config import LLMConfig
|
||||
from openhands.core.exceptions import OperationCancelled
|
||||
from openhands.core.metrics import Metrics
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.llm.metrics import Metrics
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@ -39,6 +40,7 @@ def test_llm_init_with_default_config(default_config):
|
||||
assert llm.config.model == 'gpt-4o'
|
||||
assert llm.config.api_key == 'test_key'
|
||||
assert isinstance(llm.metrics, Metrics)
|
||||
assert llm.metrics.model_name == 'gpt-4o'
|
||||
|
||||
|
||||
@patch('openhands.llm.llm.litellm.get_model_info')
|
||||
@ -83,13 +85,18 @@ def test_llm_init_with_metrics():
|
||||
metrics = Metrics()
|
||||
llm = LLM(config, metrics=metrics)
|
||||
assert llm.metrics is metrics
|
||||
assert (
|
||||
llm.metrics.model_name == 'default'
|
||||
) # because we didn't specify model_name in Metrics init
|
||||
|
||||
|
||||
def test_llm_reset():
|
||||
llm = LLM(LLMConfig(model='gpt-4o-mini', api_key='test_key'))
|
||||
initial_metrics = llm.metrics
|
||||
initial_metrics = copy.deepcopy(llm.metrics)
|
||||
initial_metrics.add_cost(1.0)
|
||||
llm.reset()
|
||||
assert llm.metrics is not initial_metrics
|
||||
assert llm.metrics._accumulated_cost != initial_metrics._accumulated_cost
|
||||
assert llm.metrics._costs != initial_metrics._costs
|
||||
assert isinstance(llm.metrics, Metrics)
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user