mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 13:47:19 +08:00
Feat: Support Gorilla APIBench (#2081)
* removed unused files from gorilla * Update run_infer.py, removed unused imports * Update utils.py * Update ast_eval_hf.py * Update ast_eval_tf.py * Update ast_eval_th.py * Create README.md * Update run_infer.py * make lint * Update run_infer.py * fix lint --------- Co-authored-by: yufansong <yufan@risingwave-labs.com>
This commit is contained in:
41
evaluation/gorilla/README.md
Normal file
41
evaluation/gorilla/README.md
Normal file
@@ -0,0 +1,41 @@
|
||||
# Gorilla APIBench Evaluation with OpenDevin
|
||||
|
||||
This folder contains evaluation harness we built on top of the original [Gorilla APIBench](https://github.com/ShishirPatil/gorilla) ([paper](https://arxiv.org/pdf/2305.15334)).
|
||||
|
||||
## Setup Environment
|
||||
|
||||
Please follow [this document](https://github.com/OpenDevin/OpenDevin/blob/main/Development.md) to setup local development environment for OpenDevin.
|
||||
|
||||
## Configure OpenDevin and your LLM
|
||||
|
||||
Run `make setup-config` to set up the `config.toml` file if it does not exist at the root of the workspace.
|
||||
|
||||
## Run Inference on APIBench Instances
|
||||
|
||||
Make sure your Docker daemon is running, then run this bash script:
|
||||
|
||||
```bash
|
||||
bash evaluation/gorilla/scripts/run_infer.sh [model_config] [agent] [eval_limit] [hubs]
|
||||
```
|
||||
|
||||
where `model_config` is mandatory, while all other arguments are optional.
|
||||
|
||||
`model_config`, e.g. `llm`, is the config group name for your
|
||||
LLM settings, as defined in your `config.toml`.
|
||||
|
||||
`agent`, e.g. `CodeActAgent`, is the name of the agent for benchmarks, defaulting
|
||||
to `CodeActAgent`.
|
||||
|
||||
`eval_limit`, e.g. `10`, limits the evaluation to the first `eval_limit` instances.
|
||||
By default, the script evaluates 1 instance.
|
||||
|
||||
`hubs`, the hub from APIBench to evaluate from. You could choose one or more from `torch` or `th` (which is abbreviation of torch), `hf` (which is abbreviation of huggingface), and `tf` (which is abbreviation of tensorflow), for `hubs`. The default is `hf,torch,tf`.
|
||||
|
||||
Note: in order to use `eval_limit`, you must also set `agent`; in order to use `hubs`, you must also set `eval_limit`.
|
||||
|
||||
Let's say you'd like to run 10 instances using `llm` and CodeActAgent on `th` test,
|
||||
then your command would be:
|
||||
|
||||
```bash
|
||||
bash evaluation/gorilla/scripts/run_infer.sh llm CodeActAgent 10 th
|
||||
```
|
||||
127
evaluation/gorilla/ast_eval_hf.py
Normal file
127
evaluation/gorilla/ast_eval_hf.py
Normal file
@@ -0,0 +1,127 @@
|
||||
# Copyright 2023 https://github.com/ShishirPatil/gorilla
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is modifed from https://github.com/ShishirPatil/gorilla/blob/main/eval/eval-scripts/ast_eval_hf.py
|
||||
|
||||
from tree_sitter import Language, Parser
|
||||
|
||||
|
||||
# Get all the subtrees given a root_node
|
||||
def get_all_sub_trees(root_node):
|
||||
node_stack = []
|
||||
sub_tree_sexp_list = []
|
||||
depth = 1
|
||||
# text = root_node.text
|
||||
node_stack.append([root_node, depth])
|
||||
while len(node_stack) != 0:
|
||||
cur_node, cur_depth = node_stack.pop()
|
||||
if cur_node.child_count > 0:
|
||||
sub_tree_sexp_list.append(
|
||||
[cur_node.sexp(), cur_depth, cur_node, cur_node.children[0].text]
|
||||
)
|
||||
else:
|
||||
sub_tree_sexp_list.append([cur_node.sexp(), cur_depth, cur_node, None])
|
||||
for child_node in cur_node.children:
|
||||
if len(child_node.children) != 0:
|
||||
depth = cur_depth + 1
|
||||
node_stack.append([child_node, depth])
|
||||
return sub_tree_sexp_list
|
||||
|
||||
|
||||
# Parse the program into AST trees
|
||||
def ast_parse(candidate, lang='python'):
|
||||
LANGUAGE = Language('evaluation/gorilla/my-languages.so', lang)
|
||||
parser = Parser()
|
||||
parser.set_language(LANGUAGE)
|
||||
|
||||
candidate_tree = parser.parse(bytes(candidate, 'utf8')).root_node
|
||||
return candidate_tree
|
||||
|
||||
|
||||
# Get all the arguments in the ast tree
|
||||
def get_args(node):
|
||||
if node.child_count == 0:
|
||||
return []
|
||||
args_list = []
|
||||
for child in node.children[0].children[0].children[1].children:
|
||||
if '=' in child.text.decode():
|
||||
args_list.append(child.children[2].text)
|
||||
elif (
|
||||
child.text.decode() != '('
|
||||
and child.text.decode() != ')'
|
||||
and child.text.decode() != ','
|
||||
):
|
||||
args_list.append(child.text)
|
||||
return args_list
|
||||
|
||||
|
||||
# Check if there is an api match
|
||||
def ast_check(candidate_subtree_list, base_tree_list):
|
||||
for idx, base_tree in enumerate(base_tree_list):
|
||||
if base_tree.children[0].children[0].child_count == 0:
|
||||
continue
|
||||
api_name = base_tree.children[0].children[0].children[0].text
|
||||
for candidate_tree in candidate_subtree_list:
|
||||
if candidate_tree[3] == api_name:
|
||||
break
|
||||
# Now we have a sub-tree
|
||||
candidate_tree = candidate_tree[2]
|
||||
args_list = get_args(base_tree)
|
||||
if len(args_list) == 0:
|
||||
continue
|
||||
ast_match = True
|
||||
for arg in args_list:
|
||||
if arg.decode().lstrip("'").rstrip("'") not in candidate_tree.text.decode():
|
||||
ast_match = False
|
||||
break
|
||||
if ast_match:
|
||||
return idx
|
||||
return -1
|
||||
|
||||
|
||||
def ast_eval_hf(api_database, qa_pairs, ast_database, question_id, response):
|
||||
# Check correctness
|
||||
correct = False
|
||||
hallucination = False
|
||||
output = response
|
||||
# Index the "api_call" domain
|
||||
output = output.split('api_call')
|
||||
if len(output) == 1:
|
||||
api_call = output[0]
|
||||
else:
|
||||
# Parse the output
|
||||
output = output[1].split('api_provider')[0]
|
||||
if ':' not in output:
|
||||
start = 0
|
||||
else:
|
||||
start = output.index(':')
|
||||
if ')' not in output:
|
||||
end = -2
|
||||
else:
|
||||
end = output.rindex(')')
|
||||
api_call = output[start + 2 : end + 1]
|
||||
# Parse the api_call into AST tree
|
||||
ast_tree = ast_parse(api_call)
|
||||
# Search for a subtree
|
||||
ast_subtree_list = get_all_sub_trees(ast_tree)
|
||||
# Check which ast tree is matching
|
||||
database_index = ast_check(ast_subtree_list, ast_database)
|
||||
# We cannot index this ast in our database
|
||||
if database_index == -1:
|
||||
hallucination = True
|
||||
# We index our reference api_call
|
||||
ref_api_call = api_database[database_index]
|
||||
# Check for functionality
|
||||
if ref_api_call['domain'] == qa_pairs[question_id - 1]['domain']:
|
||||
correct = True
|
||||
return correct, hallucination
|
||||
127
evaluation/gorilla/ast_eval_tf.py
Normal file
127
evaluation/gorilla/ast_eval_tf.py
Normal file
@@ -0,0 +1,127 @@
|
||||
# Copyright 2023 https://github.com/ShishirPatil/gorilla
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is modifed from https://github.com/ShishirPatil/gorilla/blob/main/eval/eval-scripts/ast_eval_tf.py
|
||||
|
||||
from tree_sitter import Language, Parser
|
||||
|
||||
|
||||
# Get all the subtrees given a root_node
|
||||
def get_all_sub_trees(root_node):
|
||||
node_stack = []
|
||||
sub_tree_sexp_list = []
|
||||
depth = 1
|
||||
# text = root_node.text
|
||||
node_stack.append([root_node, depth])
|
||||
while len(node_stack) != 0:
|
||||
cur_node, cur_depth = node_stack.pop()
|
||||
if cur_node.child_count > 0:
|
||||
sub_tree_sexp_list.append(
|
||||
[cur_node.sexp(), cur_depth, cur_node, cur_node.children[0].text]
|
||||
)
|
||||
else:
|
||||
sub_tree_sexp_list.append([cur_node.sexp(), cur_depth, cur_node, None])
|
||||
for child_node in cur_node.children:
|
||||
if len(child_node.children) != 0:
|
||||
depth = cur_depth + 1
|
||||
node_stack.append([child_node, depth])
|
||||
return sub_tree_sexp_list
|
||||
|
||||
|
||||
# Parse the program into AST trees
|
||||
def ast_parse(candidate, lang='python'):
|
||||
LANGUAGE = Language('evaluation/gorilla/my-languages.so', lang)
|
||||
parser = Parser()
|
||||
parser.set_language(LANGUAGE)
|
||||
|
||||
candidate_tree = parser.parse(bytes(candidate, 'utf8')).root_node
|
||||
return candidate_tree
|
||||
|
||||
|
||||
# Get all the arguments in the ast tree
|
||||
def get_args(node):
|
||||
if node.child_count == 0:
|
||||
return []
|
||||
args_list = []
|
||||
for child in node.children[0].children[0].children[1].children:
|
||||
if 'model=' in child.text.decode() or 'model =' in child.text.decode():
|
||||
args_list.append(child.children[2].text)
|
||||
elif (
|
||||
child.text.decode() != '('
|
||||
and child.text.decode() != ')'
|
||||
and child.text.decode() != ','
|
||||
):
|
||||
args_list.append(child.text)
|
||||
return args_list
|
||||
|
||||
|
||||
# Check if there is an api match
|
||||
def ast_check(candidate_subtree_list, base_tree_list):
|
||||
for idx, base_tree in enumerate(base_tree_list):
|
||||
if base_tree.children[0].children[0].child_count == 0:
|
||||
continue
|
||||
api_name = base_tree.children[0].children[0].children[0].text
|
||||
for candidate_tree in candidate_subtree_list:
|
||||
if candidate_tree[3] == api_name:
|
||||
break
|
||||
# Now we have a sub-tree
|
||||
candidate_tree = candidate_tree[2]
|
||||
args_list = get_args(base_tree)
|
||||
if len(args_list) == 0:
|
||||
continue
|
||||
ast_match = True
|
||||
for arg in args_list:
|
||||
if arg.decode().lstrip("'").rstrip("'") not in candidate_tree.text.decode():
|
||||
ast_match = False
|
||||
break
|
||||
if ast_match:
|
||||
return idx
|
||||
return -1
|
||||
|
||||
|
||||
def ast_eval_tf(api_database, qa_pairs, ast_database, question_id, response):
|
||||
# Check correctness
|
||||
correct = False
|
||||
hallucination = False
|
||||
output = response
|
||||
# Index the "api_call" domain
|
||||
output = output.split('api_call')
|
||||
if len(output) == 1:
|
||||
api_call = output[0]
|
||||
else:
|
||||
# Parse the output
|
||||
output = output[1].split('api_provider')[0]
|
||||
if ':' not in output:
|
||||
start = 0
|
||||
else:
|
||||
start = output.index(':')
|
||||
if ')' not in output:
|
||||
end = -2
|
||||
else:
|
||||
end = output.rindex(')')
|
||||
api_call = output[start + 2 : end + 1]
|
||||
# Parse the api_call into AST tree
|
||||
ast_tree = ast_parse(api_call)
|
||||
# Search for a subtree
|
||||
ast_subtree_list = get_all_sub_trees(ast_tree)
|
||||
# Check which ast tree is matching
|
||||
database_index = ast_check(ast_subtree_list, ast_database)
|
||||
# We cannot index this ast in our database
|
||||
if database_index == -1:
|
||||
hallucination = True
|
||||
# We index our reference api_call
|
||||
ref_api_call = api_database[database_index]
|
||||
# Check for functionality
|
||||
if ref_api_call['domain'] == qa_pairs[question_id - 1]['domain']:
|
||||
correct = True
|
||||
return correct, hallucination
|
||||
123
evaluation/gorilla/ast_eval_th.py
Normal file
123
evaluation/gorilla/ast_eval_th.py
Normal file
@@ -0,0 +1,123 @@
|
||||
# Copyright 2023 https://github.com/ShishirPatil/gorilla
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is modifed from https://github.com/ShishirPatil/gorilla/blob/main/eval/eval-scripts/ast_eval_th.py
|
||||
|
||||
from tree_sitter import Language, Parser
|
||||
|
||||
|
||||
# Get all the subtrees given a root_node
|
||||
def get_all_sub_trees(root_node):
|
||||
node_stack = []
|
||||
sub_tree_sexp_list = []
|
||||
depth = 1
|
||||
# text = root_node.text
|
||||
node_stack.append([root_node, depth])
|
||||
while len(node_stack) != 0:
|
||||
cur_node, cur_depth = node_stack.pop()
|
||||
if cur_node.child_count > 0:
|
||||
sub_tree_sexp_list.append(
|
||||
[cur_node.sexp(), cur_depth, cur_node, cur_node.children[0].text]
|
||||
)
|
||||
else:
|
||||
sub_tree_sexp_list.append([cur_node.sexp(), cur_depth, cur_node, None])
|
||||
for child_node in cur_node.children:
|
||||
if len(child_node.children) != 0:
|
||||
depth = cur_depth + 1
|
||||
node_stack.append([child_node, depth])
|
||||
return sub_tree_sexp_list
|
||||
|
||||
|
||||
# Parse the program into AST trees
|
||||
def ast_parse(candidate, lang='python'):
|
||||
LANGUAGE = Language('evaluation/gorilla/my-languages.so', lang)
|
||||
parser = Parser()
|
||||
parser.set_language(LANGUAGE)
|
||||
|
||||
candidate_tree = parser.parse(bytes(candidate, 'utf8')).root_node
|
||||
return candidate_tree
|
||||
|
||||
|
||||
# Get all the arguments in the ast tree
|
||||
def get_args(node):
|
||||
if node.child_count == 0:
|
||||
return []
|
||||
args_list = []
|
||||
for child in node.children[0].children[0].children[1].children:
|
||||
if 'repo_or_dir' in child.text.decode() or 'model' in child.text.decode():
|
||||
args_list.append(child.children[2].text)
|
||||
return args_list
|
||||
|
||||
|
||||
# Check if there is an api match
|
||||
def ast_check(candidate_subtree_list, base_tree_list):
|
||||
for idx, base_tree in enumerate(base_tree_list):
|
||||
if base_tree.children[0].children[0].child_count == 0:
|
||||
continue
|
||||
api_name = base_tree.children[0].children[0].children[0].text
|
||||
for candidate_tree in candidate_subtree_list:
|
||||
if candidate_tree[3] == api_name:
|
||||
break
|
||||
# Now we have a sub-tree
|
||||
candidate_tree = candidate_tree[2]
|
||||
args_list = get_args(base_tree)
|
||||
if len(args_list) == 0:
|
||||
continue
|
||||
ast_match = True
|
||||
for arg in args_list:
|
||||
if arg.decode().lstrip("'").rstrip("'") not in candidate_tree.text.decode():
|
||||
ast_match = False
|
||||
break
|
||||
if ast_match:
|
||||
return idx
|
||||
return -1
|
||||
|
||||
|
||||
def process_response(question_id, output, api_database, qa_pairs, ast_database):
|
||||
# Index the "api_call" domain
|
||||
output = output.split('api_call')
|
||||
if len(output) == 1:
|
||||
return False, False
|
||||
else:
|
||||
output = output[1].split('api_provider')[0]
|
||||
if ':' not in output:
|
||||
start = 0
|
||||
else:
|
||||
start = output.index(':')
|
||||
if ')' not in output:
|
||||
end = -2
|
||||
else:
|
||||
end = output.rindex(')')
|
||||
api_call = output[start + 2 : end + 1]
|
||||
|
||||
# Parse the api_call into AST tree
|
||||
ast_tree = ast_parse(api_call)
|
||||
# Search for a subtree
|
||||
ast_subtree_list = get_all_sub_trees(ast_tree)
|
||||
# Check which ast tree is matching
|
||||
database_index = ast_check(ast_subtree_list, ast_database)
|
||||
# We cannot index this ast in our database
|
||||
if database_index == -1:
|
||||
return False, True
|
||||
# We index our reference api_call
|
||||
ref_api_call = api_database[database_index]
|
||||
# Check for functionality
|
||||
if ref_api_call['domain'] == qa_pairs[question_id - 1]['domain']:
|
||||
return True, False
|
||||
else:
|
||||
return False, False
|
||||
|
||||
|
||||
def ast_eval_th(api_database, qa_pairs, ast_database, question_id, response):
|
||||
# Check correctness
|
||||
return process_response(question_id, response, api_database, qa_pairs, ast_database)
|
||||
355
evaluation/gorilla/run_infer.py
Normal file
355
evaluation/gorilla/run_infer.py
Normal file
@@ -0,0 +1,355 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import pathlib
|
||||
import subprocess
|
||||
import time
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
|
||||
from tqdm import tqdm
|
||||
from utils import encode_question, get_data
|
||||
|
||||
from opendevin.controller.state.state import State
|
||||
from opendevin.core.config import config, get_llm_config_arg, get_parser
|
||||
from opendevin.core.logger import get_console_handler
|
||||
from opendevin.core.logger import opendevin_logger as logger
|
||||
from opendevin.core.main import main
|
||||
from opendevin.events.action import MessageAction
|
||||
from opendevin.events.serialization.event import event_to_dict
|
||||
|
||||
|
||||
def cleanup():
|
||||
print('Cleaning up child processes...')
|
||||
for process in mp.active_children():
|
||||
print(f'Terminating child process: {process.name}')
|
||||
process.terminate()
|
||||
process.join()
|
||||
|
||||
|
||||
def codeact_user_response(state: State) -> str:
|
||||
msg = (
|
||||
#'Please continue working on the task on whatever approach you think is suitable.\n'
|
||||
'Please run the following command: <execute_bash> exit </execute_bash>.\n'
|
||||
#'IMPORTANT: YOU SHOULD NEVER ASK FOR HUMAN HELP OR USE THE INTERNET TO SOLVE THIS TASK.\n'
|
||||
)
|
||||
if state.history:
|
||||
user_msgs = [
|
||||
action
|
||||
for action, _ in state.history
|
||||
if isinstance(action, MessageAction) and action.source == 'user'
|
||||
]
|
||||
if len(user_msgs) >= 2:
|
||||
# let the agent know that it can give up when it has tried 3 times
|
||||
return (
|
||||
msg
|
||||
+ 'If you want to give up, run: <execute_bash> exit </execute_bash>.\n'
|
||||
)
|
||||
return msg
|
||||
|
||||
|
||||
def monologue_user_response(state: State) -> str:
|
||||
raise NotImplementedError('MonologueAgent should never ask for user responses.')
|
||||
|
||||
|
||||
AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
|
||||
'CodeActAgent': codeact_user_response,
|
||||
'MonologueAgent': monologue_user_response,
|
||||
}
|
||||
|
||||
AGENT_CLS_TO_INST_SUFFIX = {
|
||||
'CodeActAgent': 'When you think you have completed the request, please run the following command: <execute_bash> exit </execute_bash>.\n'
|
||||
}
|
||||
|
||||
|
||||
def process_instance(
|
||||
question_id, question, agent_class, metadata, reset_logger: bool = True
|
||||
):
|
||||
# create process-specific workspace dir
|
||||
# we will create a workspace directory for EACH process
|
||||
# so that different agent don't interfere with each other.
|
||||
old_workspace_mount_path = config.workspace_mount_path
|
||||
try:
|
||||
workspace_mount_path = os.path.join(
|
||||
config.workspace_mount_path, '_eval_workspace'
|
||||
)
|
||||
workspace_mount_path = os.path.join(workspace_mount_path, str(os.getpid()))
|
||||
pathlib.Path(workspace_mount_path).mkdir(parents=True, exist_ok=True)
|
||||
config.workspace_mount_path = workspace_mount_path
|
||||
|
||||
# Setup the logger properly, so you can run multi-processing to parallize the evaluation
|
||||
eval_output_dir = metadata['eval_output_dir']
|
||||
if reset_logger:
|
||||
# Set up logger
|
||||
log_file = os.path.join(
|
||||
eval_output_dir, 'logs', f'instance_{question_id}.log'
|
||||
)
|
||||
# Remove all existing handlers from logger
|
||||
for handler in logger.handlers[:]:
|
||||
logger.removeHandler(handler)
|
||||
# add back the console handler to print ONE line
|
||||
logger.addHandler(get_console_handler())
|
||||
logger.info(
|
||||
f'Starting evaluation for instance {question_id}.\nLOG: tail -f {log_file}'
|
||||
)
|
||||
# Remove all existing handlers from logger
|
||||
for handler in logger.handlers[:]:
|
||||
logger.removeHandler(handler)
|
||||
file_handler = logging.FileHandler(log_file)
|
||||
file_handler.setFormatter(
|
||||
logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
||||
)
|
||||
logger.addHandler(file_handler)
|
||||
logger.info(f'Process-specific workspace mounted at {workspace_mount_path}')
|
||||
|
||||
# Prepare instruction
|
||||
instruction = encode_question(question, metadata['hub'])
|
||||
instruction += 'IMPORTANT: You should ONLY interact with the environment provided to you AND NEVER ASK FOR HUMAN HELP.\n'
|
||||
# NOTE: You can actually set slightly different instruction for different agents
|
||||
instruction += AGENT_CLS_TO_INST_SUFFIX.get(agent_class, '')
|
||||
# logger.info(f'Instruction:\n{instruction}', extra={'msg_type': 'OBSERVATION'})
|
||||
|
||||
# Here's how you can run the agent (similar to the `main` function) and get the final task state
|
||||
state: State = asyncio.run(
|
||||
main(
|
||||
instruction,
|
||||
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
|
||||
agent_class
|
||||
),
|
||||
)
|
||||
)
|
||||
# ======= Attempt to evaluate the agent's edits =======
|
||||
# If you are working on simplier benchmark that only evaluates the final model output (e.g., in a MessageAction)
|
||||
# You can simply get the LAST `MessageAction` from the returned `state.history` and parse it for evaluation.
|
||||
|
||||
if state is None:
|
||||
raise ValueError('State should not be None.')
|
||||
|
||||
model_answer_raw = ''
|
||||
for act, _ in reversed(state.history):
|
||||
if isinstance(act, MessageAction) and act.source == 'agent':
|
||||
model_answer_raw = act.content
|
||||
break
|
||||
# attempt to parse model_answer
|
||||
_, _, ast_eval = get_data(metadata['hub'])
|
||||
correct, hallucination = ast_eval(question_id, model_answer_raw)
|
||||
metrics = state.metrics.get() if state.metrics else None
|
||||
logger.info(
|
||||
f'Final message: {model_answer_raw} | Correctness: {correct} | Hallucination: {hallucination}'
|
||||
)
|
||||
# Save the output
|
||||
output = {
|
||||
'question_id': question_id,
|
||||
'text': model_answer_raw,
|
||||
'correct': correct,
|
||||
'hallucination': hallucination,
|
||||
'answer_id': 'None',
|
||||
'model_id': metadata['model_name'],
|
||||
'metadata': metadata,
|
||||
'history': [
|
||||
(event_to_dict(action), event_to_dict(obs))
|
||||
for action, obs in state.history
|
||||
],
|
||||
'metrics': metrics,
|
||||
'error': state.error if state and state.error else None,
|
||||
}
|
||||
except Exception:
|
||||
logger.error('Process instance failed')
|
||||
raise
|
||||
finally:
|
||||
config.workspace_mount_path = old_workspace_mount_path
|
||||
return output
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = get_parser()
|
||||
parser.add_argument(
|
||||
'--hubs',
|
||||
type=str,
|
||||
help='Which hubs to evaluate from APIBench. APIBench contains 3 hubs, namely huggingface, torch, and tensorflow. You could choose one or more from hf, torch, or tf, seperated by commas. For example, the default is --hub hf,torch,tf.',
|
||||
default='hf,torch,tf',
|
||||
)
|
||||
args, _ = parser.parse_known_args()
|
||||
if args.directory:
|
||||
config.workspace_base = os.path.abspath(args.directory)
|
||||
print(f'Setting workspace base to {config.workspace_base}')
|
||||
|
||||
# Check https://github.com/OpenDevin/OpenDevin/blob/main/evaluation/swe_bench/README.md#configure-opendevin-and-your-llm
|
||||
# for details of how to set `llm_config`
|
||||
if args.llm_config:
|
||||
specified_llm_config = get_llm_config_arg(args.llm_config)
|
||||
if specified_llm_config:
|
||||
config.llm = specified_llm_config
|
||||
logger.info(f'Config for evaluation: {config}')
|
||||
agent_class = args.agent_cls
|
||||
assert (
|
||||
agent_class in AGENT_CLS_TO_FAKE_USER_RESPONSE_FN
|
||||
), f'Unsupported agent class: {agent_class}'
|
||||
model_name = config.llm.model.split('/')[-1]
|
||||
max_iterations = args.max_iterations
|
||||
eval_note = ''
|
||||
if args.eval_note is not None:
|
||||
eval_note += '_N_' + args.eval_note
|
||||
eval_output_dir = os.path.join(
|
||||
args.eval_output_dir,
|
||||
'gorilla',
|
||||
agent_class,
|
||||
model_name + '_maxiter_' + str(max_iterations) + eval_note,
|
||||
)
|
||||
pathlib.Path(eval_output_dir).mkdir(parents=True, exist_ok=True)
|
||||
pathlib.Path(os.path.join(eval_output_dir, 'logs')).mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
logger.info(f'Using evaluation output directory: {eval_output_dir}')
|
||||
|
||||
hubs = []
|
||||
if 'hf' in args.hubs:
|
||||
hubs.append('hf')
|
||||
if 'torch' in args.hubs or 'th' in args.hubs:
|
||||
hubs.append('torch')
|
||||
if 'tf' in args.hubs:
|
||||
hubs.append('tf')
|
||||
if hubs == []:
|
||||
raise ValueError('Please choose at least one from hf, torch, and tf for hubs.')
|
||||
|
||||
for hub in hubs:
|
||||
logger.info(f'Evaluating APIBench {hub} test')
|
||||
questions, question_ids, ast_eval = get_data(hub)
|
||||
|
||||
# TEST METADATA
|
||||
metadata = {
|
||||
'hub': hub,
|
||||
'agent_class': agent_class,
|
||||
'model_name': model_name,
|
||||
'max_iterations': max_iterations,
|
||||
'eval_output_dir': eval_output_dir,
|
||||
'start_time': time.strftime('%Y-%m-%d %H:%M:%S'),
|
||||
# get the commit id of current repo for reproduciblity
|
||||
'git_commit': subprocess.check_output(['git', 'rev-parse', 'HEAD'])
|
||||
.decode('utf-8')
|
||||
.strip(),
|
||||
}
|
||||
logger.info(f'Metadata: {metadata}')
|
||||
with open(os.path.join(eval_output_dir, f'metadata_{hub}.json'), 'w') as f:
|
||||
json.dump(metadata, f)
|
||||
|
||||
# LIMIT EVALUATION
|
||||
eval_n_limit = args.eval_n_limit
|
||||
if eval_n_limit:
|
||||
questions = questions[: (eval_n_limit // len(hubs))]
|
||||
question_ids = question_ids[: (eval_n_limit // len(hubs))]
|
||||
logger.info(
|
||||
f'Limiting evaluation to a total of first {eval_n_limit} instances -> first {eval_n_limit//len(hubs)} instances per hub.'
|
||||
)
|
||||
output_file = os.path.join(eval_output_dir, f'output_{model_name}_{hub}.jsonl')
|
||||
logger.info(f'Writing evaluation output to {output_file}')
|
||||
finished_task_ids = set()
|
||||
if os.path.exists(output_file):
|
||||
with open(output_file, 'r') as f:
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
for i in range(len(question_ids)):
|
||||
if question_ids[i] == int(data['question_id']):
|
||||
finished_task_ids.add(data['question_id'])
|
||||
logger.warning(
|
||||
f'Output file {output_file} already exists. Loaded {len(finished_task_ids)} finished instances.'
|
||||
)
|
||||
output_fp = open(output_file, 'a')
|
||||
logger.info(
|
||||
f'Evaluation started with Agent {agent_class}, model {model_name}, max iterations {max_iterations}.'
|
||||
)
|
||||
# =============================================
|
||||
# filter out finished instances
|
||||
new_questions = []
|
||||
new_question_ids = []
|
||||
for i in range(len(question_ids)):
|
||||
if question_ids[i] in finished_task_ids:
|
||||
logger.info(
|
||||
f'Skipping instance {question_ids[i]} as it is already finished.'
|
||||
)
|
||||
continue
|
||||
new_questions.append(questions[i])
|
||||
new_question_ids.append(question_ids[i])
|
||||
|
||||
finished_task_number = len(finished_task_ids)
|
||||
questions = new_questions
|
||||
question_ids = new_question_ids
|
||||
logger.info(
|
||||
f'Finished instances: {finished_task_number}, Remaining instances: {len(question_ids)}'
|
||||
)
|
||||
# =============================================
|
||||
pbar = tqdm(total=len(question_ids))
|
||||
|
||||
# This function tracks the progress AND write the output to a JSONL file
|
||||
def update_progress(future, pbar, output_fp, finished_task_ids):
|
||||
pbar.update(1)
|
||||
output = future.result()
|
||||
pbar.set_description(f'Instance {output["question_id"]}')
|
||||
pbar.set_postfix_str(f'Test Result: {output["correct"]}')
|
||||
logger.info(
|
||||
f'Finished evaluation for instance {output["question_id"]}: {output["correct"]}'
|
||||
)
|
||||
output_fp.write(json.dumps(output) + '\n')
|
||||
output_fp.flush()
|
||||
finished_task_ids.add(output['question_id'])
|
||||
|
||||
# This sets the multi-processing
|
||||
num_workers = args.eval_num_workers
|
||||
logger.info(f'Using {num_workers} workers for evaluation.')
|
||||
try:
|
||||
with ProcessPoolExecutor(num_workers) as executor:
|
||||
futures = []
|
||||
# This is how we perform multi-processing
|
||||
for i in range(len(question_ids)):
|
||||
try:
|
||||
question_id = question_ids[i]
|
||||
question = questions[i]
|
||||
future = executor.submit(
|
||||
process_instance,
|
||||
question_id,
|
||||
question,
|
||||
agent_class,
|
||||
metadata,
|
||||
reset_logger=bool(num_workers > 1),
|
||||
)
|
||||
future.add_done_callback(
|
||||
update_progress, pbar, output_fp, finished_task_ids
|
||||
)
|
||||
futures.append(future)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# Wait for all futures to complete
|
||||
for future in futures:
|
||||
try:
|
||||
future.result()
|
||||
except Exception:
|
||||
continue
|
||||
except KeyboardInterrupt:
|
||||
logger.info('KeyboardInterrupt received. Cleaning up...')
|
||||
cleanup()
|
||||
|
||||
output_fp.close()
|
||||
total_correct = 0
|
||||
total_hallucination = 0
|
||||
output = []
|
||||
with open(output_file, 'r') as f:
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
output.append(data)
|
||||
if int(data['question_id']) in finished_task_ids:
|
||||
if str(data['correct']).lower() == 'true':
|
||||
total_correct += 1
|
||||
if str(data['hallucination']).lower() == 'true':
|
||||
total_hallucination += 1
|
||||
# sort all output by question_id
|
||||
output = sorted(output, key=lambda x: x['question_id'])
|
||||
with open(output_file, 'w') as f:
|
||||
for dat in output:
|
||||
f.write(json.dumps(dat) + '\n')
|
||||
f.flush()
|
||||
|
||||
logger.info(
|
||||
f'Evaluation finished for {hub}. Total: {len(question_ids)+finished_task_number}; Correct: {total_correct}; Hallucination: {total_hallucination}. Accuracy: {total_correct / (len(question_ids)+finished_task_number)}'
|
||||
)
|
||||
42
evaluation/gorilla/scripts/run_infer.sh
Normal file
42
evaluation/gorilla/scripts/run_infer.sh
Normal file
@@ -0,0 +1,42 @@
|
||||
#!/bin/bash
|
||||
MODEL_CONFIG=$1
|
||||
AGENT=$2
|
||||
EVAL_LIMIT=$3
|
||||
HUBS=$4
|
||||
|
||||
if [ -z "$AGENT" ]; then
|
||||
echo "Agent not specified, use default CodeActAgent"
|
||||
AGENT="CodeActAgent"
|
||||
fi
|
||||
|
||||
if [ -z "$HUBS" ]; then
|
||||
HUBS="hf,torch,tf"
|
||||
echo "Hubs not specified, use default $HUBS"
|
||||
fi
|
||||
|
||||
# IMPORTANT: Because Agent's prompt changes fairly often in the rapidly evolving codebase of OpenDevin
|
||||
# We need to track the version of Agent in the evaluation to make sure results are comparable
|
||||
AGENT_VERSION=v$(poetry run python -c "import agenthub; from opendevin.controller.agent import Agent; print(Agent.get_cls('$AGENT').VERSION)")
|
||||
|
||||
echo "AGENT: $AGENT"
|
||||
echo "AGENT_VERSION: $AGENT_VERSION"
|
||||
echo "MODEL_CONFIG: $MODEL_CONFIG"
|
||||
echo "HUBS: $HUBS"
|
||||
|
||||
COMMAND="poetry run python evaluation/gorilla/run_infer.py \
|
||||
--agent-cls $AGENT \
|
||||
--llm-config $MODEL_CONFIG \
|
||||
--max-iterations 30 \
|
||||
--hubs $HUBS \
|
||||
--data-split validation \
|
||||
--max-chars 10000000 \
|
||||
--eval-num-workers 1 \
|
||||
--eval-note ${AGENT_VERSION}_${LEVELS}"
|
||||
|
||||
if [ -n "$EVAL_LIMIT" ]; then
|
||||
echo "EVAL_LIMIT: $EVAL_LIMIT"
|
||||
COMMAND="$COMMAND --eval-n-limit $EVAL_LIMIT"
|
||||
fi
|
||||
|
||||
# Run the command
|
||||
eval $COMMAND
|
||||
101
evaluation/gorilla/utils.py
Normal file
101
evaluation/gorilla/utils.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import json
|
||||
from functools import partial
|
||||
|
||||
import requests
|
||||
from ast_eval_hf import ast_eval_hf, ast_parse
|
||||
from ast_eval_tf import ast_eval_tf
|
||||
from ast_eval_th import ast_eval_th
|
||||
|
||||
|
||||
# This function is modified from Gorilla's APIBench implementations (https://github.com/ShishirPatil/gorilla/blob/main/eval/get_llm_responses.py).
|
||||
def encode_question(question, api_name):
|
||||
"""Encode multiple prompt instructions into a single string."""
|
||||
|
||||
prompts = []
|
||||
if api_name == 'torch':
|
||||
api_name = 'torchhub'
|
||||
domains = '1. $DOMAIN is inferred from the task description and should include one of {Classification, Semantic Segmentation, Object Detection, Audio Separation, Video Classification, Text-to-Speech}.'
|
||||
elif api_name == 'hf':
|
||||
api_name = 'huggingface'
|
||||
domains = '1. $DOMAIN should include one of {Multimodal Feature Extraction, Multimodal Text-to-Image, Multimodal Image-to-Text, Multimodal Text-to-Video, \
|
||||
Multimodal Visual Question Answering, Multimodal Document Question Answer, Multimodal Graph Machine Learning, Computer Vision Depth Estimation,\
|
||||
Computer Vision Image Classification, Computer Vision Object Detection, Computer Vision Image Segmentation, Computer Vision Image-to-Image, \
|
||||
Computer Vision Unconditional Image Generation, Computer Vision Video Classification, Computer Vision Zero-Shor Image Classification, \
|
||||
Natural Language Processing Text Classification, Natural Language Processing Token Classification, Natural Language Processing Table Question Answering, \
|
||||
Natural Language Processing Question Answering, Natural Language Processing Zero-Shot Classification, Natural Language Processing Translation, \
|
||||
Natural Language Processing Summarization, Natural Language Processing Conversational, Natural Language Processing Text Generation, Natural Language Processing Fill-Mask,\
|
||||
Natural Language Processing Text2Text Generation, Natural Language Processing Sentence Similarity, Audio Text-to-Speech, Audio Automatic Speech Recognition, \
|
||||
Audio Audio-to-Audio, Audio Audio Classification, Audio Voice Activity Detection, Tabular Tabular Classification, Tabular Tabular Regression, \
|
||||
Reinforcement Learning Reinforcement Learning, Reinforcement Learning Robotics }'
|
||||
elif api_name == 'tf':
|
||||
api_name = 'tensorhub'
|
||||
domains = '1. $DOMAIN is inferred from the task description and should include one of {text-sequence-alignment, text-embedding, text-language-model, text-preprocessing, text-classification, text-generation, text-question-answering, text-retrieval-question-answering, text-segmentation, text-to-mel, image-classification, image-feature-vector, image-object-detection, image-segmentation, image-generator, image-pose-detection, image-rnn-agent, image-augmentation, image-classifier, image-style-transfer, image-aesthetic-quality, image-depth-estimation, image-super-resolution, image-deblurring, image-extrapolation, image-text-recognition, image-dehazing, image-deraining, image-enhancemenmt, image-classification-logits, image-frame-interpolation, image-text-detection, image-denoising, image-others, video-classification, video-feature-extraction, video-generation, video-audio-text, video-text, audio-embedding, audio-event-classification, audio-command-detection, audio-paralinguists-classification, audio-speech-to-text, audio-speech-synthesis, audio-synthesis, audio-pitch-extraction}'
|
||||
else:
|
||||
print('Error: API name is not supported.')
|
||||
|
||||
prompt = (
|
||||
question
|
||||
+ '\nWrite a python program in 1 to 2 lines to call API in '
|
||||
+ api_name
|
||||
+ '.\n\nThe answer should follow the format: <<<domain>>> $DOMAIN, <<<api_call>>>: $API_CALL, <<<api_provider>>>: $API_PROVIDER, <<<explanation>>>: $EXPLANATION, <<<code>>>: $CODE}. Here are the requirements:\n'
|
||||
+ domains
|
||||
+ '\n2. The $API_CALL should have only 1 line of code that calls api.\n3. The $API_PROVIDER should be the programming framework used.\n4. $EXPLANATION should be a step-by-step explanation.\n5. The $CODE is the python code.\n6. Do not repeat the format in your answer.'
|
||||
)
|
||||
# prompts.append({"role": "system", "content": ""})
|
||||
prompts = (
|
||||
'You are a helpful API writer who can write APIs based on requirements.\n'
|
||||
+ prompt
|
||||
)
|
||||
return prompts
|
||||
|
||||
|
||||
def get_data(hub):
|
||||
if hub == 'hf':
|
||||
question_data = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/eval/eval-data/questions/huggingface/questions_huggingface_0_shot.jsonl'
|
||||
api_dataset = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/api/huggingface_api.jsonl'
|
||||
apibench = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/apibench/huggingface_eval.json'
|
||||
ast_eval = ast_eval_hf
|
||||
if hub == 'torch':
|
||||
question_data = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/eval/eval-data/questions/torchhub/questions_torchhub_0_shot.jsonl'
|
||||
api_dataset = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/api/torchhub_api.jsonl'
|
||||
apibench = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/apibench/torchhub_eval.json'
|
||||
ast_eval = ast_eval_th
|
||||
if hub == 'tf':
|
||||
question_data = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/eval/eval-data/questions/tensorflowhub/questions_tensorflowhub_0_shot.jsonl'
|
||||
api_dataset = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/api/tensorflowhub_api.jsonl'
|
||||
apibench = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/apibench/tensorflow_eval.json'
|
||||
ast_eval = ast_eval_tf
|
||||
|
||||
# get questions and question_ids
|
||||
questions = []
|
||||
question_ids = []
|
||||
question_data = requests.get(question_data)
|
||||
if question_data.status_code == 200:
|
||||
lines = question_data.text.splitlines()
|
||||
for line in lines:
|
||||
questions.append(json.loads(line)['text'])
|
||||
question_ids.append(json.loads(line)['question_id'])
|
||||
|
||||
# get the api datasest
|
||||
api_database = []
|
||||
api_dataset = requests.get(api_dataset)
|
||||
if api_dataset.status_code == 200:
|
||||
lines = api_dataset.text.splitlines()
|
||||
for line in lines:
|
||||
api_database.append(json.loads(line))
|
||||
|
||||
# get the question answer pair datasest
|
||||
qa_pairs = []
|
||||
apibench = requests.get(apibench)
|
||||
if apibench.status_code == 200:
|
||||
lines = apibench.text.splitlines()
|
||||
for line in lines:
|
||||
qa_pairs.append(json.loads(line)['api_data'])
|
||||
|
||||
# Parse all apis to ast trees
|
||||
ast_database = []
|
||||
for data in api_database:
|
||||
ast_tree = ast_parse(data['api_call'])
|
||||
ast_database.append(ast_tree)
|
||||
ast_eval = partial(ast_eval, api_database, qa_pairs, ast_database)
|
||||
return questions, question_ids, ast_eval
|
||||
Reference in New Issue
Block a user