Feat: Support Gorilla APIBench (#2081)

* removed unused files from gorilla

* Update run_infer.py, removed unused imports

* Update utils.py

* Update ast_eval_hf.py

* Update ast_eval_tf.py

* Update ast_eval_th.py

* Create README.md

* Update run_infer.py

* make lint

* Update run_infer.py

* fix lint

---------

Co-authored-by: yufansong <yufan@risingwave-labs.com>
This commit is contained in:
yueqis
2024-06-08 12:54:54 -04:00
committed by GitHub
parent b5a17efc45
commit 68d9ad61cf
7 changed files with 916 additions and 0 deletions

View File

@@ -0,0 +1,41 @@
# Gorilla APIBench Evaluation with OpenDevin
This folder contains evaluation harness we built on top of the original [Gorilla APIBench](https://github.com/ShishirPatil/gorilla) ([paper](https://arxiv.org/pdf/2305.15334)).
## Setup Environment
Please follow [this document](https://github.com/OpenDevin/OpenDevin/blob/main/Development.md) to setup local development environment for OpenDevin.
## Configure OpenDevin and your LLM
Run `make setup-config` to set up the `config.toml` file if it does not exist at the root of the workspace.
## Run Inference on APIBench Instances
Make sure your Docker daemon is running, then run this bash script:
```bash
bash evaluation/gorilla/scripts/run_infer.sh [model_config] [agent] [eval_limit] [hubs]
```
where `model_config` is mandatory, while all other arguments are optional.
`model_config`, e.g. `llm`, is the config group name for your
LLM settings, as defined in your `config.toml`.
`agent`, e.g. `CodeActAgent`, is the name of the agent for benchmarks, defaulting
to `CodeActAgent`.
`eval_limit`, e.g. `10`, limits the evaluation to the first `eval_limit` instances.
By default, the script evaluates 1 instance.
`hubs`, the hub from APIBench to evaluate from. You could choose one or more from `torch` or `th` (which is abbreviation of torch), `hf` (which is abbreviation of huggingface), and `tf` (which is abbreviation of tensorflow), for `hubs`. The default is `hf,torch,tf`.
Note: in order to use `eval_limit`, you must also set `agent`; in order to use `hubs`, you must also set `eval_limit`.
Let's say you'd like to run 10 instances using `llm` and CodeActAgent on `th` test,
then your command would be:
```bash
bash evaluation/gorilla/scripts/run_infer.sh llm CodeActAgent 10 th
```

View File

@@ -0,0 +1,127 @@
# Copyright 2023 https://github.com/ShishirPatil/gorilla
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is modifed from https://github.com/ShishirPatil/gorilla/blob/main/eval/eval-scripts/ast_eval_hf.py
from tree_sitter import Language, Parser
# Get all the subtrees given a root_node
def get_all_sub_trees(root_node):
node_stack = []
sub_tree_sexp_list = []
depth = 1
# text = root_node.text
node_stack.append([root_node, depth])
while len(node_stack) != 0:
cur_node, cur_depth = node_stack.pop()
if cur_node.child_count > 0:
sub_tree_sexp_list.append(
[cur_node.sexp(), cur_depth, cur_node, cur_node.children[0].text]
)
else:
sub_tree_sexp_list.append([cur_node.sexp(), cur_depth, cur_node, None])
for child_node in cur_node.children:
if len(child_node.children) != 0:
depth = cur_depth + 1
node_stack.append([child_node, depth])
return sub_tree_sexp_list
# Parse the program into AST trees
def ast_parse(candidate, lang='python'):
LANGUAGE = Language('evaluation/gorilla/my-languages.so', lang)
parser = Parser()
parser.set_language(LANGUAGE)
candidate_tree = parser.parse(bytes(candidate, 'utf8')).root_node
return candidate_tree
# Get all the arguments in the ast tree
def get_args(node):
if node.child_count == 0:
return []
args_list = []
for child in node.children[0].children[0].children[1].children:
if '=' in child.text.decode():
args_list.append(child.children[2].text)
elif (
child.text.decode() != '('
and child.text.decode() != ')'
and child.text.decode() != ','
):
args_list.append(child.text)
return args_list
# Check if there is an api match
def ast_check(candidate_subtree_list, base_tree_list):
for idx, base_tree in enumerate(base_tree_list):
if base_tree.children[0].children[0].child_count == 0:
continue
api_name = base_tree.children[0].children[0].children[0].text
for candidate_tree in candidate_subtree_list:
if candidate_tree[3] == api_name:
break
# Now we have a sub-tree
candidate_tree = candidate_tree[2]
args_list = get_args(base_tree)
if len(args_list) == 0:
continue
ast_match = True
for arg in args_list:
if arg.decode().lstrip("'").rstrip("'") not in candidate_tree.text.decode():
ast_match = False
break
if ast_match:
return idx
return -1
def ast_eval_hf(api_database, qa_pairs, ast_database, question_id, response):
# Check correctness
correct = False
hallucination = False
output = response
# Index the "api_call" domain
output = output.split('api_call')
if len(output) == 1:
api_call = output[0]
else:
# Parse the output
output = output[1].split('api_provider')[0]
if ':' not in output:
start = 0
else:
start = output.index(':')
if ')' not in output:
end = -2
else:
end = output.rindex(')')
api_call = output[start + 2 : end + 1]
# Parse the api_call into AST tree
ast_tree = ast_parse(api_call)
# Search for a subtree
ast_subtree_list = get_all_sub_trees(ast_tree)
# Check which ast tree is matching
database_index = ast_check(ast_subtree_list, ast_database)
# We cannot index this ast in our database
if database_index == -1:
hallucination = True
# We index our reference api_call
ref_api_call = api_database[database_index]
# Check for functionality
if ref_api_call['domain'] == qa_pairs[question_id - 1]['domain']:
correct = True
return correct, hallucination

View File

@@ -0,0 +1,127 @@
# Copyright 2023 https://github.com/ShishirPatil/gorilla
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is modifed from https://github.com/ShishirPatil/gorilla/blob/main/eval/eval-scripts/ast_eval_tf.py
from tree_sitter import Language, Parser
# Get all the subtrees given a root_node
def get_all_sub_trees(root_node):
node_stack = []
sub_tree_sexp_list = []
depth = 1
# text = root_node.text
node_stack.append([root_node, depth])
while len(node_stack) != 0:
cur_node, cur_depth = node_stack.pop()
if cur_node.child_count > 0:
sub_tree_sexp_list.append(
[cur_node.sexp(), cur_depth, cur_node, cur_node.children[0].text]
)
else:
sub_tree_sexp_list.append([cur_node.sexp(), cur_depth, cur_node, None])
for child_node in cur_node.children:
if len(child_node.children) != 0:
depth = cur_depth + 1
node_stack.append([child_node, depth])
return sub_tree_sexp_list
# Parse the program into AST trees
def ast_parse(candidate, lang='python'):
LANGUAGE = Language('evaluation/gorilla/my-languages.so', lang)
parser = Parser()
parser.set_language(LANGUAGE)
candidate_tree = parser.parse(bytes(candidate, 'utf8')).root_node
return candidate_tree
# Get all the arguments in the ast tree
def get_args(node):
if node.child_count == 0:
return []
args_list = []
for child in node.children[0].children[0].children[1].children:
if 'model=' in child.text.decode() or 'model =' in child.text.decode():
args_list.append(child.children[2].text)
elif (
child.text.decode() != '('
and child.text.decode() != ')'
and child.text.decode() != ','
):
args_list.append(child.text)
return args_list
# Check if there is an api match
def ast_check(candidate_subtree_list, base_tree_list):
for idx, base_tree in enumerate(base_tree_list):
if base_tree.children[0].children[0].child_count == 0:
continue
api_name = base_tree.children[0].children[0].children[0].text
for candidate_tree in candidate_subtree_list:
if candidate_tree[3] == api_name:
break
# Now we have a sub-tree
candidate_tree = candidate_tree[2]
args_list = get_args(base_tree)
if len(args_list) == 0:
continue
ast_match = True
for arg in args_list:
if arg.decode().lstrip("'").rstrip("'") not in candidate_tree.text.decode():
ast_match = False
break
if ast_match:
return idx
return -1
def ast_eval_tf(api_database, qa_pairs, ast_database, question_id, response):
# Check correctness
correct = False
hallucination = False
output = response
# Index the "api_call" domain
output = output.split('api_call')
if len(output) == 1:
api_call = output[0]
else:
# Parse the output
output = output[1].split('api_provider')[0]
if ':' not in output:
start = 0
else:
start = output.index(':')
if ')' not in output:
end = -2
else:
end = output.rindex(')')
api_call = output[start + 2 : end + 1]
# Parse the api_call into AST tree
ast_tree = ast_parse(api_call)
# Search for a subtree
ast_subtree_list = get_all_sub_trees(ast_tree)
# Check which ast tree is matching
database_index = ast_check(ast_subtree_list, ast_database)
# We cannot index this ast in our database
if database_index == -1:
hallucination = True
# We index our reference api_call
ref_api_call = api_database[database_index]
# Check for functionality
if ref_api_call['domain'] == qa_pairs[question_id - 1]['domain']:
correct = True
return correct, hallucination

View File

@@ -0,0 +1,123 @@
# Copyright 2023 https://github.com/ShishirPatil/gorilla
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is modifed from https://github.com/ShishirPatil/gorilla/blob/main/eval/eval-scripts/ast_eval_th.py
from tree_sitter import Language, Parser
# Get all the subtrees given a root_node
def get_all_sub_trees(root_node):
node_stack = []
sub_tree_sexp_list = []
depth = 1
# text = root_node.text
node_stack.append([root_node, depth])
while len(node_stack) != 0:
cur_node, cur_depth = node_stack.pop()
if cur_node.child_count > 0:
sub_tree_sexp_list.append(
[cur_node.sexp(), cur_depth, cur_node, cur_node.children[0].text]
)
else:
sub_tree_sexp_list.append([cur_node.sexp(), cur_depth, cur_node, None])
for child_node in cur_node.children:
if len(child_node.children) != 0:
depth = cur_depth + 1
node_stack.append([child_node, depth])
return sub_tree_sexp_list
# Parse the program into AST trees
def ast_parse(candidate, lang='python'):
LANGUAGE = Language('evaluation/gorilla/my-languages.so', lang)
parser = Parser()
parser.set_language(LANGUAGE)
candidate_tree = parser.parse(bytes(candidate, 'utf8')).root_node
return candidate_tree
# Get all the arguments in the ast tree
def get_args(node):
if node.child_count == 0:
return []
args_list = []
for child in node.children[0].children[0].children[1].children:
if 'repo_or_dir' in child.text.decode() or 'model' in child.text.decode():
args_list.append(child.children[2].text)
return args_list
# Check if there is an api match
def ast_check(candidate_subtree_list, base_tree_list):
for idx, base_tree in enumerate(base_tree_list):
if base_tree.children[0].children[0].child_count == 0:
continue
api_name = base_tree.children[0].children[0].children[0].text
for candidate_tree in candidate_subtree_list:
if candidate_tree[3] == api_name:
break
# Now we have a sub-tree
candidate_tree = candidate_tree[2]
args_list = get_args(base_tree)
if len(args_list) == 0:
continue
ast_match = True
for arg in args_list:
if arg.decode().lstrip("'").rstrip("'") not in candidate_tree.text.decode():
ast_match = False
break
if ast_match:
return idx
return -1
def process_response(question_id, output, api_database, qa_pairs, ast_database):
# Index the "api_call" domain
output = output.split('api_call')
if len(output) == 1:
return False, False
else:
output = output[1].split('api_provider')[0]
if ':' not in output:
start = 0
else:
start = output.index(':')
if ')' not in output:
end = -2
else:
end = output.rindex(')')
api_call = output[start + 2 : end + 1]
# Parse the api_call into AST tree
ast_tree = ast_parse(api_call)
# Search for a subtree
ast_subtree_list = get_all_sub_trees(ast_tree)
# Check which ast tree is matching
database_index = ast_check(ast_subtree_list, ast_database)
# We cannot index this ast in our database
if database_index == -1:
return False, True
# We index our reference api_call
ref_api_call = api_database[database_index]
# Check for functionality
if ref_api_call['domain'] == qa_pairs[question_id - 1]['domain']:
return True, False
else:
return False, False
def ast_eval_th(api_database, qa_pairs, ast_database, question_id, response):
# Check correctness
return process_response(question_id, response, api_database, qa_pairs, ast_database)

View File

@@ -0,0 +1,355 @@
import asyncio
import json
import logging
import multiprocessing as mp
import os
import pathlib
import subprocess
import time
from concurrent.futures import ProcessPoolExecutor
from tqdm import tqdm
from utils import encode_question, get_data
from opendevin.controller.state.state import State
from opendevin.core.config import config, get_llm_config_arg, get_parser
from opendevin.core.logger import get_console_handler
from opendevin.core.logger import opendevin_logger as logger
from opendevin.core.main import main
from opendevin.events.action import MessageAction
from opendevin.events.serialization.event import event_to_dict
def cleanup():
print('Cleaning up child processes...')
for process in mp.active_children():
print(f'Terminating child process: {process.name}')
process.terminate()
process.join()
def codeact_user_response(state: State) -> str:
msg = (
#'Please continue working on the task on whatever approach you think is suitable.\n'
'Please run the following command: <execute_bash> exit </execute_bash>.\n'
#'IMPORTANT: YOU SHOULD NEVER ASK FOR HUMAN HELP OR USE THE INTERNET TO SOLVE THIS TASK.\n'
)
if state.history:
user_msgs = [
action
for action, _ in state.history
if isinstance(action, MessageAction) and action.source == 'user'
]
if len(user_msgs) >= 2:
# let the agent know that it can give up when it has tried 3 times
return (
msg
+ 'If you want to give up, run: <execute_bash> exit </execute_bash>.\n'
)
return msg
def monologue_user_response(state: State) -> str:
raise NotImplementedError('MonologueAgent should never ask for user responses.')
AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
'CodeActAgent': codeact_user_response,
'MonologueAgent': monologue_user_response,
}
AGENT_CLS_TO_INST_SUFFIX = {
'CodeActAgent': 'When you think you have completed the request, please run the following command: <execute_bash> exit </execute_bash>.\n'
}
def process_instance(
question_id, question, agent_class, metadata, reset_logger: bool = True
):
# create process-specific workspace dir
# we will create a workspace directory for EACH process
# so that different agent don't interfere with each other.
old_workspace_mount_path = config.workspace_mount_path
try:
workspace_mount_path = os.path.join(
config.workspace_mount_path, '_eval_workspace'
)
workspace_mount_path = os.path.join(workspace_mount_path, str(os.getpid()))
pathlib.Path(workspace_mount_path).mkdir(parents=True, exist_ok=True)
config.workspace_mount_path = workspace_mount_path
# Setup the logger properly, so you can run multi-processing to parallize the evaluation
eval_output_dir = metadata['eval_output_dir']
if reset_logger:
# Set up logger
log_file = os.path.join(
eval_output_dir, 'logs', f'instance_{question_id}.log'
)
# Remove all existing handlers from logger
for handler in logger.handlers[:]:
logger.removeHandler(handler)
# add back the console handler to print ONE line
logger.addHandler(get_console_handler())
logger.info(
f'Starting evaluation for instance {question_id}.\nLOG: tail -f {log_file}'
)
# Remove all existing handlers from logger
for handler in logger.handlers[:]:
logger.removeHandler(handler)
file_handler = logging.FileHandler(log_file)
file_handler.setFormatter(
logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
)
logger.addHandler(file_handler)
logger.info(f'Process-specific workspace mounted at {workspace_mount_path}')
# Prepare instruction
instruction = encode_question(question, metadata['hub'])
instruction += 'IMPORTANT: You should ONLY interact with the environment provided to you AND NEVER ASK FOR HUMAN HELP.\n'
# NOTE: You can actually set slightly different instruction for different agents
instruction += AGENT_CLS_TO_INST_SUFFIX.get(agent_class, '')
# logger.info(f'Instruction:\n{instruction}', extra={'msg_type': 'OBSERVATION'})
# Here's how you can run the agent (similar to the `main` function) and get the final task state
state: State = asyncio.run(
main(
instruction,
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
agent_class
),
)
)
# ======= Attempt to evaluate the agent's edits =======
# If you are working on simplier benchmark that only evaluates the final model output (e.g., in a MessageAction)
# You can simply get the LAST `MessageAction` from the returned `state.history` and parse it for evaluation.
if state is None:
raise ValueError('State should not be None.')
model_answer_raw = ''
for act, _ in reversed(state.history):
if isinstance(act, MessageAction) and act.source == 'agent':
model_answer_raw = act.content
break
# attempt to parse model_answer
_, _, ast_eval = get_data(metadata['hub'])
correct, hallucination = ast_eval(question_id, model_answer_raw)
metrics = state.metrics.get() if state.metrics else None
logger.info(
f'Final message: {model_answer_raw} | Correctness: {correct} | Hallucination: {hallucination}'
)
# Save the output
output = {
'question_id': question_id,
'text': model_answer_raw,
'correct': correct,
'hallucination': hallucination,
'answer_id': 'None',
'model_id': metadata['model_name'],
'metadata': metadata,
'history': [
(event_to_dict(action), event_to_dict(obs))
for action, obs in state.history
],
'metrics': metrics,
'error': state.error if state and state.error else None,
}
except Exception:
logger.error('Process instance failed')
raise
finally:
config.workspace_mount_path = old_workspace_mount_path
return output
if __name__ == '__main__':
parser = get_parser()
parser.add_argument(
'--hubs',
type=str,
help='Which hubs to evaluate from APIBench. APIBench contains 3 hubs, namely huggingface, torch, and tensorflow. You could choose one or more from hf, torch, or tf, seperated by commas. For example, the default is --hub hf,torch,tf.',
default='hf,torch,tf',
)
args, _ = parser.parse_known_args()
if args.directory:
config.workspace_base = os.path.abspath(args.directory)
print(f'Setting workspace base to {config.workspace_base}')
# Check https://github.com/OpenDevin/OpenDevin/blob/main/evaluation/swe_bench/README.md#configure-opendevin-and-your-llm
# for details of how to set `llm_config`
if args.llm_config:
specified_llm_config = get_llm_config_arg(args.llm_config)
if specified_llm_config:
config.llm = specified_llm_config
logger.info(f'Config for evaluation: {config}')
agent_class = args.agent_cls
assert (
agent_class in AGENT_CLS_TO_FAKE_USER_RESPONSE_FN
), f'Unsupported agent class: {agent_class}'
model_name = config.llm.model.split('/')[-1]
max_iterations = args.max_iterations
eval_note = ''
if args.eval_note is not None:
eval_note += '_N_' + args.eval_note
eval_output_dir = os.path.join(
args.eval_output_dir,
'gorilla',
agent_class,
model_name + '_maxiter_' + str(max_iterations) + eval_note,
)
pathlib.Path(eval_output_dir).mkdir(parents=True, exist_ok=True)
pathlib.Path(os.path.join(eval_output_dir, 'logs')).mkdir(
parents=True, exist_ok=True
)
logger.info(f'Using evaluation output directory: {eval_output_dir}')
hubs = []
if 'hf' in args.hubs:
hubs.append('hf')
if 'torch' in args.hubs or 'th' in args.hubs:
hubs.append('torch')
if 'tf' in args.hubs:
hubs.append('tf')
if hubs == []:
raise ValueError('Please choose at least one from hf, torch, and tf for hubs.')
for hub in hubs:
logger.info(f'Evaluating APIBench {hub} test')
questions, question_ids, ast_eval = get_data(hub)
# TEST METADATA
metadata = {
'hub': hub,
'agent_class': agent_class,
'model_name': model_name,
'max_iterations': max_iterations,
'eval_output_dir': eval_output_dir,
'start_time': time.strftime('%Y-%m-%d %H:%M:%S'),
# get the commit id of current repo for reproduciblity
'git_commit': subprocess.check_output(['git', 'rev-parse', 'HEAD'])
.decode('utf-8')
.strip(),
}
logger.info(f'Metadata: {metadata}')
with open(os.path.join(eval_output_dir, f'metadata_{hub}.json'), 'w') as f:
json.dump(metadata, f)
# LIMIT EVALUATION
eval_n_limit = args.eval_n_limit
if eval_n_limit:
questions = questions[: (eval_n_limit // len(hubs))]
question_ids = question_ids[: (eval_n_limit // len(hubs))]
logger.info(
f'Limiting evaluation to a total of first {eval_n_limit} instances -> first {eval_n_limit//len(hubs)} instances per hub.'
)
output_file = os.path.join(eval_output_dir, f'output_{model_name}_{hub}.jsonl')
logger.info(f'Writing evaluation output to {output_file}')
finished_task_ids = set()
if os.path.exists(output_file):
with open(output_file, 'r') as f:
for line in f:
data = json.loads(line)
for i in range(len(question_ids)):
if question_ids[i] == int(data['question_id']):
finished_task_ids.add(data['question_id'])
logger.warning(
f'Output file {output_file} already exists. Loaded {len(finished_task_ids)} finished instances.'
)
output_fp = open(output_file, 'a')
logger.info(
f'Evaluation started with Agent {agent_class}, model {model_name}, max iterations {max_iterations}.'
)
# =============================================
# filter out finished instances
new_questions = []
new_question_ids = []
for i in range(len(question_ids)):
if question_ids[i] in finished_task_ids:
logger.info(
f'Skipping instance {question_ids[i]} as it is already finished.'
)
continue
new_questions.append(questions[i])
new_question_ids.append(question_ids[i])
finished_task_number = len(finished_task_ids)
questions = new_questions
question_ids = new_question_ids
logger.info(
f'Finished instances: {finished_task_number}, Remaining instances: {len(question_ids)}'
)
# =============================================
pbar = tqdm(total=len(question_ids))
# This function tracks the progress AND write the output to a JSONL file
def update_progress(future, pbar, output_fp, finished_task_ids):
pbar.update(1)
output = future.result()
pbar.set_description(f'Instance {output["question_id"]}')
pbar.set_postfix_str(f'Test Result: {output["correct"]}')
logger.info(
f'Finished evaluation for instance {output["question_id"]}: {output["correct"]}'
)
output_fp.write(json.dumps(output) + '\n')
output_fp.flush()
finished_task_ids.add(output['question_id'])
# This sets the multi-processing
num_workers = args.eval_num_workers
logger.info(f'Using {num_workers} workers for evaluation.')
try:
with ProcessPoolExecutor(num_workers) as executor:
futures = []
# This is how we perform multi-processing
for i in range(len(question_ids)):
try:
question_id = question_ids[i]
question = questions[i]
future = executor.submit(
process_instance,
question_id,
question,
agent_class,
metadata,
reset_logger=bool(num_workers > 1),
)
future.add_done_callback(
update_progress, pbar, output_fp, finished_task_ids
)
futures.append(future)
except Exception:
continue
# Wait for all futures to complete
for future in futures:
try:
future.result()
except Exception:
continue
except KeyboardInterrupt:
logger.info('KeyboardInterrupt received. Cleaning up...')
cleanup()
output_fp.close()
total_correct = 0
total_hallucination = 0
output = []
with open(output_file, 'r') as f:
for line in f:
data = json.loads(line)
output.append(data)
if int(data['question_id']) in finished_task_ids:
if str(data['correct']).lower() == 'true':
total_correct += 1
if str(data['hallucination']).lower() == 'true':
total_hallucination += 1
# sort all output by question_id
output = sorted(output, key=lambda x: x['question_id'])
with open(output_file, 'w') as f:
for dat in output:
f.write(json.dumps(dat) + '\n')
f.flush()
logger.info(
f'Evaluation finished for {hub}. Total: {len(question_ids)+finished_task_number}; Correct: {total_correct}; Hallucination: {total_hallucination}. Accuracy: {total_correct / (len(question_ids)+finished_task_number)}'
)

View File

@@ -0,0 +1,42 @@
#!/bin/bash
MODEL_CONFIG=$1
AGENT=$2
EVAL_LIMIT=$3
HUBS=$4
if [ -z "$AGENT" ]; then
echo "Agent not specified, use default CodeActAgent"
AGENT="CodeActAgent"
fi
if [ -z "$HUBS" ]; then
HUBS="hf,torch,tf"
echo "Hubs not specified, use default $HUBS"
fi
# IMPORTANT: Because Agent's prompt changes fairly often in the rapidly evolving codebase of OpenDevin
# We need to track the version of Agent in the evaluation to make sure results are comparable
AGENT_VERSION=v$(poetry run python -c "import agenthub; from opendevin.controller.agent import Agent; print(Agent.get_cls('$AGENT').VERSION)")
echo "AGENT: $AGENT"
echo "AGENT_VERSION: $AGENT_VERSION"
echo "MODEL_CONFIG: $MODEL_CONFIG"
echo "HUBS: $HUBS"
COMMAND="poetry run python evaluation/gorilla/run_infer.py \
--agent-cls $AGENT \
--llm-config $MODEL_CONFIG \
--max-iterations 30 \
--hubs $HUBS \
--data-split validation \
--max-chars 10000000 \
--eval-num-workers 1 \
--eval-note ${AGENT_VERSION}_${LEVELS}"
if [ -n "$EVAL_LIMIT" ]; then
echo "EVAL_LIMIT: $EVAL_LIMIT"
COMMAND="$COMMAND --eval-n-limit $EVAL_LIMIT"
fi
# Run the command
eval $COMMAND

101
evaluation/gorilla/utils.py Normal file
View File

@@ -0,0 +1,101 @@
import json
from functools import partial
import requests
from ast_eval_hf import ast_eval_hf, ast_parse
from ast_eval_tf import ast_eval_tf
from ast_eval_th import ast_eval_th
# This function is modified from Gorilla's APIBench implementations (https://github.com/ShishirPatil/gorilla/blob/main/eval/get_llm_responses.py).
def encode_question(question, api_name):
"""Encode multiple prompt instructions into a single string."""
prompts = []
if api_name == 'torch':
api_name = 'torchhub'
domains = '1. $DOMAIN is inferred from the task description and should include one of {Classification, Semantic Segmentation, Object Detection, Audio Separation, Video Classification, Text-to-Speech}.'
elif api_name == 'hf':
api_name = 'huggingface'
domains = '1. $DOMAIN should include one of {Multimodal Feature Extraction, Multimodal Text-to-Image, Multimodal Image-to-Text, Multimodal Text-to-Video, \
Multimodal Visual Question Answering, Multimodal Document Question Answer, Multimodal Graph Machine Learning, Computer Vision Depth Estimation,\
Computer Vision Image Classification, Computer Vision Object Detection, Computer Vision Image Segmentation, Computer Vision Image-to-Image, \
Computer Vision Unconditional Image Generation, Computer Vision Video Classification, Computer Vision Zero-Shor Image Classification, \
Natural Language Processing Text Classification, Natural Language Processing Token Classification, Natural Language Processing Table Question Answering, \
Natural Language Processing Question Answering, Natural Language Processing Zero-Shot Classification, Natural Language Processing Translation, \
Natural Language Processing Summarization, Natural Language Processing Conversational, Natural Language Processing Text Generation, Natural Language Processing Fill-Mask,\
Natural Language Processing Text2Text Generation, Natural Language Processing Sentence Similarity, Audio Text-to-Speech, Audio Automatic Speech Recognition, \
Audio Audio-to-Audio, Audio Audio Classification, Audio Voice Activity Detection, Tabular Tabular Classification, Tabular Tabular Regression, \
Reinforcement Learning Reinforcement Learning, Reinforcement Learning Robotics }'
elif api_name == 'tf':
api_name = 'tensorhub'
domains = '1. $DOMAIN is inferred from the task description and should include one of {text-sequence-alignment, text-embedding, text-language-model, text-preprocessing, text-classification, text-generation, text-question-answering, text-retrieval-question-answering, text-segmentation, text-to-mel, image-classification, image-feature-vector, image-object-detection, image-segmentation, image-generator, image-pose-detection, image-rnn-agent, image-augmentation, image-classifier, image-style-transfer, image-aesthetic-quality, image-depth-estimation, image-super-resolution, image-deblurring, image-extrapolation, image-text-recognition, image-dehazing, image-deraining, image-enhancemenmt, image-classification-logits, image-frame-interpolation, image-text-detection, image-denoising, image-others, video-classification, video-feature-extraction, video-generation, video-audio-text, video-text, audio-embedding, audio-event-classification, audio-command-detection, audio-paralinguists-classification, audio-speech-to-text, audio-speech-synthesis, audio-synthesis, audio-pitch-extraction}'
else:
print('Error: API name is not supported.')
prompt = (
question
+ '\nWrite a python program in 1 to 2 lines to call API in '
+ api_name
+ '.\n\nThe answer should follow the format: <<<domain>>> $DOMAIN, <<<api_call>>>: $API_CALL, <<<api_provider>>>: $API_PROVIDER, <<<explanation>>>: $EXPLANATION, <<<code>>>: $CODE}. Here are the requirements:\n'
+ domains
+ '\n2. The $API_CALL should have only 1 line of code that calls api.\n3. The $API_PROVIDER should be the programming framework used.\n4. $EXPLANATION should be a step-by-step explanation.\n5. The $CODE is the python code.\n6. Do not repeat the format in your answer.'
)
# prompts.append({"role": "system", "content": ""})
prompts = (
'You are a helpful API writer who can write APIs based on requirements.\n'
+ prompt
)
return prompts
def get_data(hub):
if hub == 'hf':
question_data = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/eval/eval-data/questions/huggingface/questions_huggingface_0_shot.jsonl'
api_dataset = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/api/huggingface_api.jsonl'
apibench = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/apibench/huggingface_eval.json'
ast_eval = ast_eval_hf
if hub == 'torch':
question_data = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/eval/eval-data/questions/torchhub/questions_torchhub_0_shot.jsonl'
api_dataset = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/api/torchhub_api.jsonl'
apibench = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/apibench/torchhub_eval.json'
ast_eval = ast_eval_th
if hub == 'tf':
question_data = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/eval/eval-data/questions/tensorflowhub/questions_tensorflowhub_0_shot.jsonl'
api_dataset = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/api/tensorflowhub_api.jsonl'
apibench = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/apibench/tensorflow_eval.json'
ast_eval = ast_eval_tf
# get questions and question_ids
questions = []
question_ids = []
question_data = requests.get(question_data)
if question_data.status_code == 200:
lines = question_data.text.splitlines()
for line in lines:
questions.append(json.loads(line)['text'])
question_ids.append(json.loads(line)['question_id'])
# get the api datasest
api_database = []
api_dataset = requests.get(api_dataset)
if api_dataset.status_code == 200:
lines = api_dataset.text.splitlines()
for line in lines:
api_database.append(json.loads(line))
# get the question answer pair datasest
qa_pairs = []
apibench = requests.get(apibench)
if apibench.status_code == 200:
lines = apibench.text.splitlines()
for line in lines:
qa_pairs.append(json.loads(line)['api_data'])
# Parse all apis to ast trees
ast_database = []
for data in api_database:
ast_tree = ast_parse(data['api_call'])
ast_database.append(ast_tree)
ast_eval = partial(ast_eval, api_database, qa_pairs, ast_database)
return questions, question_ids, ast_eval