From 68d9ad61cf59724c761c4f700d4b7be25690f320 Mon Sep 17 00:00:00 2001 From: yueqis <141804823+yueqis@users.noreply.github.com> Date: Sat, 8 Jun 2024 12:54:54 -0400 Subject: [PATCH] Feat: Support Gorilla APIBench (#2081) * removed unused files from gorilla * Update run_infer.py, removed unused imports * Update utils.py * Update ast_eval_hf.py * Update ast_eval_tf.py * Update ast_eval_th.py * Create README.md * Update run_infer.py * make lint * Update run_infer.py * fix lint --------- Co-authored-by: yufansong --- evaluation/gorilla/README.md | 41 +++ evaluation/gorilla/ast_eval_hf.py | 127 +++++++++ evaluation/gorilla/ast_eval_tf.py | 127 +++++++++ evaluation/gorilla/ast_eval_th.py | 123 ++++++++ evaluation/gorilla/run_infer.py | 355 ++++++++++++++++++++++++ evaluation/gorilla/scripts/run_infer.sh | 42 +++ evaluation/gorilla/utils.py | 101 +++++++ 7 files changed, 916 insertions(+) create mode 100644 evaluation/gorilla/README.md create mode 100644 evaluation/gorilla/ast_eval_hf.py create mode 100644 evaluation/gorilla/ast_eval_tf.py create mode 100644 evaluation/gorilla/ast_eval_th.py create mode 100644 evaluation/gorilla/run_infer.py create mode 100644 evaluation/gorilla/scripts/run_infer.sh create mode 100644 evaluation/gorilla/utils.py diff --git a/evaluation/gorilla/README.md b/evaluation/gorilla/README.md new file mode 100644 index 0000000000..b7a04a704d --- /dev/null +++ b/evaluation/gorilla/README.md @@ -0,0 +1,41 @@ +# Gorilla APIBench Evaluation with OpenDevin + +This folder contains evaluation harness we built on top of the original [Gorilla APIBench](https://github.com/ShishirPatil/gorilla) ([paper](https://arxiv.org/pdf/2305.15334)). + +## Setup Environment + +Please follow [this document](https://github.com/OpenDevin/OpenDevin/blob/main/Development.md) to setup local development environment for OpenDevin. + +## Configure OpenDevin and your LLM + +Run `make setup-config` to set up the `config.toml` file if it does not exist at the root of the workspace. + +## Run Inference on APIBench Instances + +Make sure your Docker daemon is running, then run this bash script: + +```bash +bash evaluation/gorilla/scripts/run_infer.sh [model_config] [agent] [eval_limit] [hubs] +``` + +where `model_config` is mandatory, while all other arguments are optional. + +`model_config`, e.g. `llm`, is the config group name for your +LLM settings, as defined in your `config.toml`. + +`agent`, e.g. `CodeActAgent`, is the name of the agent for benchmarks, defaulting +to `CodeActAgent`. + +`eval_limit`, e.g. `10`, limits the evaluation to the first `eval_limit` instances. +By default, the script evaluates 1 instance. + +`hubs`, the hub from APIBench to evaluate from. You could choose one or more from `torch` or `th` (which is abbreviation of torch), `hf` (which is abbreviation of huggingface), and `tf` (which is abbreviation of tensorflow), for `hubs`. The default is `hf,torch,tf`. + +Note: in order to use `eval_limit`, you must also set `agent`; in order to use `hubs`, you must also set `eval_limit`. + +Let's say you'd like to run 10 instances using `llm` and CodeActAgent on `th` test, +then your command would be: + +```bash +bash evaluation/gorilla/scripts/run_infer.sh llm CodeActAgent 10 th +``` diff --git a/evaluation/gorilla/ast_eval_hf.py b/evaluation/gorilla/ast_eval_hf.py new file mode 100644 index 0000000000..1c3ce6e3c7 --- /dev/null +++ b/evaluation/gorilla/ast_eval_hf.py @@ -0,0 +1,127 @@ +# Copyright 2023 https://github.com/ShishirPatil/gorilla +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is modifed from https://github.com/ShishirPatil/gorilla/blob/main/eval/eval-scripts/ast_eval_hf.py + +from tree_sitter import Language, Parser + + +# Get all the subtrees given a root_node +def get_all_sub_trees(root_node): + node_stack = [] + sub_tree_sexp_list = [] + depth = 1 + # text = root_node.text + node_stack.append([root_node, depth]) + while len(node_stack) != 0: + cur_node, cur_depth = node_stack.pop() + if cur_node.child_count > 0: + sub_tree_sexp_list.append( + [cur_node.sexp(), cur_depth, cur_node, cur_node.children[0].text] + ) + else: + sub_tree_sexp_list.append([cur_node.sexp(), cur_depth, cur_node, None]) + for child_node in cur_node.children: + if len(child_node.children) != 0: + depth = cur_depth + 1 + node_stack.append([child_node, depth]) + return sub_tree_sexp_list + + +# Parse the program into AST trees +def ast_parse(candidate, lang='python'): + LANGUAGE = Language('evaluation/gorilla/my-languages.so', lang) + parser = Parser() + parser.set_language(LANGUAGE) + + candidate_tree = parser.parse(bytes(candidate, 'utf8')).root_node + return candidate_tree + + +# Get all the arguments in the ast tree +def get_args(node): + if node.child_count == 0: + return [] + args_list = [] + for child in node.children[0].children[0].children[1].children: + if '=' in child.text.decode(): + args_list.append(child.children[2].text) + elif ( + child.text.decode() != '(' + and child.text.decode() != ')' + and child.text.decode() != ',' + ): + args_list.append(child.text) + return args_list + + +# Check if there is an api match +def ast_check(candidate_subtree_list, base_tree_list): + for idx, base_tree in enumerate(base_tree_list): + if base_tree.children[0].children[0].child_count == 0: + continue + api_name = base_tree.children[0].children[0].children[0].text + for candidate_tree in candidate_subtree_list: + if candidate_tree[3] == api_name: + break + # Now we have a sub-tree + candidate_tree = candidate_tree[2] + args_list = get_args(base_tree) + if len(args_list) == 0: + continue + ast_match = True + for arg in args_list: + if arg.decode().lstrip("'").rstrip("'") not in candidate_tree.text.decode(): + ast_match = False + break + if ast_match: + return idx + return -1 + + +def ast_eval_hf(api_database, qa_pairs, ast_database, question_id, response): + # Check correctness + correct = False + hallucination = False + output = response + # Index the "api_call" domain + output = output.split('api_call') + if len(output) == 1: + api_call = output[0] + else: + # Parse the output + output = output[1].split('api_provider')[0] + if ':' not in output: + start = 0 + else: + start = output.index(':') + if ')' not in output: + end = -2 + else: + end = output.rindex(')') + api_call = output[start + 2 : end + 1] + # Parse the api_call into AST tree + ast_tree = ast_parse(api_call) + # Search for a subtree + ast_subtree_list = get_all_sub_trees(ast_tree) + # Check which ast tree is matching + database_index = ast_check(ast_subtree_list, ast_database) + # We cannot index this ast in our database + if database_index == -1: + hallucination = True + # We index our reference api_call + ref_api_call = api_database[database_index] + # Check for functionality + if ref_api_call['domain'] == qa_pairs[question_id - 1]['domain']: + correct = True + return correct, hallucination diff --git a/evaluation/gorilla/ast_eval_tf.py b/evaluation/gorilla/ast_eval_tf.py new file mode 100644 index 0000000000..4bfbeaca38 --- /dev/null +++ b/evaluation/gorilla/ast_eval_tf.py @@ -0,0 +1,127 @@ +# Copyright 2023 https://github.com/ShishirPatil/gorilla +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is modifed from https://github.com/ShishirPatil/gorilla/blob/main/eval/eval-scripts/ast_eval_tf.py + +from tree_sitter import Language, Parser + + +# Get all the subtrees given a root_node +def get_all_sub_trees(root_node): + node_stack = [] + sub_tree_sexp_list = [] + depth = 1 + # text = root_node.text + node_stack.append([root_node, depth]) + while len(node_stack) != 0: + cur_node, cur_depth = node_stack.pop() + if cur_node.child_count > 0: + sub_tree_sexp_list.append( + [cur_node.sexp(), cur_depth, cur_node, cur_node.children[0].text] + ) + else: + sub_tree_sexp_list.append([cur_node.sexp(), cur_depth, cur_node, None]) + for child_node in cur_node.children: + if len(child_node.children) != 0: + depth = cur_depth + 1 + node_stack.append([child_node, depth]) + return sub_tree_sexp_list + + +# Parse the program into AST trees +def ast_parse(candidate, lang='python'): + LANGUAGE = Language('evaluation/gorilla/my-languages.so', lang) + parser = Parser() + parser.set_language(LANGUAGE) + + candidate_tree = parser.parse(bytes(candidate, 'utf8')).root_node + return candidate_tree + + +# Get all the arguments in the ast tree +def get_args(node): + if node.child_count == 0: + return [] + args_list = [] + for child in node.children[0].children[0].children[1].children: + if 'model=' in child.text.decode() or 'model =' in child.text.decode(): + args_list.append(child.children[2].text) + elif ( + child.text.decode() != '(' + and child.text.decode() != ')' + and child.text.decode() != ',' + ): + args_list.append(child.text) + return args_list + + +# Check if there is an api match +def ast_check(candidate_subtree_list, base_tree_list): + for idx, base_tree in enumerate(base_tree_list): + if base_tree.children[0].children[0].child_count == 0: + continue + api_name = base_tree.children[0].children[0].children[0].text + for candidate_tree in candidate_subtree_list: + if candidate_tree[3] == api_name: + break + # Now we have a sub-tree + candidate_tree = candidate_tree[2] + args_list = get_args(base_tree) + if len(args_list) == 0: + continue + ast_match = True + for arg in args_list: + if arg.decode().lstrip("'").rstrip("'") not in candidate_tree.text.decode(): + ast_match = False + break + if ast_match: + return idx + return -1 + + +def ast_eval_tf(api_database, qa_pairs, ast_database, question_id, response): + # Check correctness + correct = False + hallucination = False + output = response + # Index the "api_call" domain + output = output.split('api_call') + if len(output) == 1: + api_call = output[0] + else: + # Parse the output + output = output[1].split('api_provider')[0] + if ':' not in output: + start = 0 + else: + start = output.index(':') + if ')' not in output: + end = -2 + else: + end = output.rindex(')') + api_call = output[start + 2 : end + 1] + # Parse the api_call into AST tree + ast_tree = ast_parse(api_call) + # Search for a subtree + ast_subtree_list = get_all_sub_trees(ast_tree) + # Check which ast tree is matching + database_index = ast_check(ast_subtree_list, ast_database) + # We cannot index this ast in our database + if database_index == -1: + hallucination = True + # We index our reference api_call + ref_api_call = api_database[database_index] + # Check for functionality + if ref_api_call['domain'] == qa_pairs[question_id - 1]['domain']: + correct = True + return correct, hallucination diff --git a/evaluation/gorilla/ast_eval_th.py b/evaluation/gorilla/ast_eval_th.py new file mode 100644 index 0000000000..4da8777c16 --- /dev/null +++ b/evaluation/gorilla/ast_eval_th.py @@ -0,0 +1,123 @@ +# Copyright 2023 https://github.com/ShishirPatil/gorilla +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is modifed from https://github.com/ShishirPatil/gorilla/blob/main/eval/eval-scripts/ast_eval_th.py + +from tree_sitter import Language, Parser + + +# Get all the subtrees given a root_node +def get_all_sub_trees(root_node): + node_stack = [] + sub_tree_sexp_list = [] + depth = 1 + # text = root_node.text + node_stack.append([root_node, depth]) + while len(node_stack) != 0: + cur_node, cur_depth = node_stack.pop() + if cur_node.child_count > 0: + sub_tree_sexp_list.append( + [cur_node.sexp(), cur_depth, cur_node, cur_node.children[0].text] + ) + else: + sub_tree_sexp_list.append([cur_node.sexp(), cur_depth, cur_node, None]) + for child_node in cur_node.children: + if len(child_node.children) != 0: + depth = cur_depth + 1 + node_stack.append([child_node, depth]) + return sub_tree_sexp_list + + +# Parse the program into AST trees +def ast_parse(candidate, lang='python'): + LANGUAGE = Language('evaluation/gorilla/my-languages.so', lang) + parser = Parser() + parser.set_language(LANGUAGE) + + candidate_tree = parser.parse(bytes(candidate, 'utf8')).root_node + return candidate_tree + + +# Get all the arguments in the ast tree +def get_args(node): + if node.child_count == 0: + return [] + args_list = [] + for child in node.children[0].children[0].children[1].children: + if 'repo_or_dir' in child.text.decode() or 'model' in child.text.decode(): + args_list.append(child.children[2].text) + return args_list + + +# Check if there is an api match +def ast_check(candidate_subtree_list, base_tree_list): + for idx, base_tree in enumerate(base_tree_list): + if base_tree.children[0].children[0].child_count == 0: + continue + api_name = base_tree.children[0].children[0].children[0].text + for candidate_tree in candidate_subtree_list: + if candidate_tree[3] == api_name: + break + # Now we have a sub-tree + candidate_tree = candidate_tree[2] + args_list = get_args(base_tree) + if len(args_list) == 0: + continue + ast_match = True + for arg in args_list: + if arg.decode().lstrip("'").rstrip("'") not in candidate_tree.text.decode(): + ast_match = False + break + if ast_match: + return idx + return -1 + + +def process_response(question_id, output, api_database, qa_pairs, ast_database): + # Index the "api_call" domain + output = output.split('api_call') + if len(output) == 1: + return False, False + else: + output = output[1].split('api_provider')[0] + if ':' not in output: + start = 0 + else: + start = output.index(':') + if ')' not in output: + end = -2 + else: + end = output.rindex(')') + api_call = output[start + 2 : end + 1] + + # Parse the api_call into AST tree + ast_tree = ast_parse(api_call) + # Search for a subtree + ast_subtree_list = get_all_sub_trees(ast_tree) + # Check which ast tree is matching + database_index = ast_check(ast_subtree_list, ast_database) + # We cannot index this ast in our database + if database_index == -1: + return False, True + # We index our reference api_call + ref_api_call = api_database[database_index] + # Check for functionality + if ref_api_call['domain'] == qa_pairs[question_id - 1]['domain']: + return True, False + else: + return False, False + + +def ast_eval_th(api_database, qa_pairs, ast_database, question_id, response): + # Check correctness + return process_response(question_id, response, api_database, qa_pairs, ast_database) diff --git a/evaluation/gorilla/run_infer.py b/evaluation/gorilla/run_infer.py new file mode 100644 index 0000000000..df3fe34c4f --- /dev/null +++ b/evaluation/gorilla/run_infer.py @@ -0,0 +1,355 @@ +import asyncio +import json +import logging +import multiprocessing as mp +import os +import pathlib +import subprocess +import time +from concurrent.futures import ProcessPoolExecutor + +from tqdm import tqdm +from utils import encode_question, get_data + +from opendevin.controller.state.state import State +from opendevin.core.config import config, get_llm_config_arg, get_parser +from opendevin.core.logger import get_console_handler +from opendevin.core.logger import opendevin_logger as logger +from opendevin.core.main import main +from opendevin.events.action import MessageAction +from opendevin.events.serialization.event import event_to_dict + + +def cleanup(): + print('Cleaning up child processes...') + for process in mp.active_children(): + print(f'Terminating child process: {process.name}') + process.terminate() + process.join() + + +def codeact_user_response(state: State) -> str: + msg = ( + #'Please continue working on the task on whatever approach you think is suitable.\n' + 'Please run the following command: exit .\n' + #'IMPORTANT: YOU SHOULD NEVER ASK FOR HUMAN HELP OR USE THE INTERNET TO SOLVE THIS TASK.\n' + ) + if state.history: + user_msgs = [ + action + for action, _ in state.history + if isinstance(action, MessageAction) and action.source == 'user' + ] + if len(user_msgs) >= 2: + # let the agent know that it can give up when it has tried 3 times + return ( + msg + + 'If you want to give up, run: exit .\n' + ) + return msg + + +def monologue_user_response(state: State) -> str: + raise NotImplementedError('MonologueAgent should never ask for user responses.') + + +AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = { + 'CodeActAgent': codeact_user_response, + 'MonologueAgent': monologue_user_response, +} + +AGENT_CLS_TO_INST_SUFFIX = { + 'CodeActAgent': 'When you think you have completed the request, please run the following command: exit .\n' +} + + +def process_instance( + question_id, question, agent_class, metadata, reset_logger: bool = True +): + # create process-specific workspace dir + # we will create a workspace directory for EACH process + # so that different agent don't interfere with each other. + old_workspace_mount_path = config.workspace_mount_path + try: + workspace_mount_path = os.path.join( + config.workspace_mount_path, '_eval_workspace' + ) + workspace_mount_path = os.path.join(workspace_mount_path, str(os.getpid())) + pathlib.Path(workspace_mount_path).mkdir(parents=True, exist_ok=True) + config.workspace_mount_path = workspace_mount_path + + # Setup the logger properly, so you can run multi-processing to parallize the evaluation + eval_output_dir = metadata['eval_output_dir'] + if reset_logger: + # Set up logger + log_file = os.path.join( + eval_output_dir, 'logs', f'instance_{question_id}.log' + ) + # Remove all existing handlers from logger + for handler in logger.handlers[:]: + logger.removeHandler(handler) + # add back the console handler to print ONE line + logger.addHandler(get_console_handler()) + logger.info( + f'Starting evaluation for instance {question_id}.\nLOG: tail -f {log_file}' + ) + # Remove all existing handlers from logger + for handler in logger.handlers[:]: + logger.removeHandler(handler) + file_handler = logging.FileHandler(log_file) + file_handler.setFormatter( + logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + ) + logger.addHandler(file_handler) + logger.info(f'Process-specific workspace mounted at {workspace_mount_path}') + + # Prepare instruction + instruction = encode_question(question, metadata['hub']) + instruction += 'IMPORTANT: You should ONLY interact with the environment provided to you AND NEVER ASK FOR HUMAN HELP.\n' + # NOTE: You can actually set slightly different instruction for different agents + instruction += AGENT_CLS_TO_INST_SUFFIX.get(agent_class, '') + # logger.info(f'Instruction:\n{instruction}', extra={'msg_type': 'OBSERVATION'}) + + # Here's how you can run the agent (similar to the `main` function) and get the final task state + state: State = asyncio.run( + main( + instruction, + fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get( + agent_class + ), + ) + ) + # ======= Attempt to evaluate the agent's edits ======= + # If you are working on simplier benchmark that only evaluates the final model output (e.g., in a MessageAction) + # You can simply get the LAST `MessageAction` from the returned `state.history` and parse it for evaluation. + + if state is None: + raise ValueError('State should not be None.') + + model_answer_raw = '' + for act, _ in reversed(state.history): + if isinstance(act, MessageAction) and act.source == 'agent': + model_answer_raw = act.content + break + # attempt to parse model_answer + _, _, ast_eval = get_data(metadata['hub']) + correct, hallucination = ast_eval(question_id, model_answer_raw) + metrics = state.metrics.get() if state.metrics else None + logger.info( + f'Final message: {model_answer_raw} | Correctness: {correct} | Hallucination: {hallucination}' + ) + # Save the output + output = { + 'question_id': question_id, + 'text': model_answer_raw, + 'correct': correct, + 'hallucination': hallucination, + 'answer_id': 'None', + 'model_id': metadata['model_name'], + 'metadata': metadata, + 'history': [ + (event_to_dict(action), event_to_dict(obs)) + for action, obs in state.history + ], + 'metrics': metrics, + 'error': state.error if state and state.error else None, + } + except Exception: + logger.error('Process instance failed') + raise + finally: + config.workspace_mount_path = old_workspace_mount_path + return output + + +if __name__ == '__main__': + parser = get_parser() + parser.add_argument( + '--hubs', + type=str, + help='Which hubs to evaluate from APIBench. APIBench contains 3 hubs, namely huggingface, torch, and tensorflow. You could choose one or more from hf, torch, or tf, seperated by commas. For example, the default is --hub hf,torch,tf.', + default='hf,torch,tf', + ) + args, _ = parser.parse_known_args() + if args.directory: + config.workspace_base = os.path.abspath(args.directory) + print(f'Setting workspace base to {config.workspace_base}') + + # Check https://github.com/OpenDevin/OpenDevin/blob/main/evaluation/swe_bench/README.md#configure-opendevin-and-your-llm + # for details of how to set `llm_config` + if args.llm_config: + specified_llm_config = get_llm_config_arg(args.llm_config) + if specified_llm_config: + config.llm = specified_llm_config + logger.info(f'Config for evaluation: {config}') + agent_class = args.agent_cls + assert ( + agent_class in AGENT_CLS_TO_FAKE_USER_RESPONSE_FN + ), f'Unsupported agent class: {agent_class}' + model_name = config.llm.model.split('/')[-1] + max_iterations = args.max_iterations + eval_note = '' + if args.eval_note is not None: + eval_note += '_N_' + args.eval_note + eval_output_dir = os.path.join( + args.eval_output_dir, + 'gorilla', + agent_class, + model_name + '_maxiter_' + str(max_iterations) + eval_note, + ) + pathlib.Path(eval_output_dir).mkdir(parents=True, exist_ok=True) + pathlib.Path(os.path.join(eval_output_dir, 'logs')).mkdir( + parents=True, exist_ok=True + ) + logger.info(f'Using evaluation output directory: {eval_output_dir}') + + hubs = [] + if 'hf' in args.hubs: + hubs.append('hf') + if 'torch' in args.hubs or 'th' in args.hubs: + hubs.append('torch') + if 'tf' in args.hubs: + hubs.append('tf') + if hubs == []: + raise ValueError('Please choose at least one from hf, torch, and tf for hubs.') + + for hub in hubs: + logger.info(f'Evaluating APIBench {hub} test') + questions, question_ids, ast_eval = get_data(hub) + + # TEST METADATA + metadata = { + 'hub': hub, + 'agent_class': agent_class, + 'model_name': model_name, + 'max_iterations': max_iterations, + 'eval_output_dir': eval_output_dir, + 'start_time': time.strftime('%Y-%m-%d %H:%M:%S'), + # get the commit id of current repo for reproduciblity + 'git_commit': subprocess.check_output(['git', 'rev-parse', 'HEAD']) + .decode('utf-8') + .strip(), + } + logger.info(f'Metadata: {metadata}') + with open(os.path.join(eval_output_dir, f'metadata_{hub}.json'), 'w') as f: + json.dump(metadata, f) + + # LIMIT EVALUATION + eval_n_limit = args.eval_n_limit + if eval_n_limit: + questions = questions[: (eval_n_limit // len(hubs))] + question_ids = question_ids[: (eval_n_limit // len(hubs))] + logger.info( + f'Limiting evaluation to a total of first {eval_n_limit} instances -> first {eval_n_limit//len(hubs)} instances per hub.' + ) + output_file = os.path.join(eval_output_dir, f'output_{model_name}_{hub}.jsonl') + logger.info(f'Writing evaluation output to {output_file}') + finished_task_ids = set() + if os.path.exists(output_file): + with open(output_file, 'r') as f: + for line in f: + data = json.loads(line) + for i in range(len(question_ids)): + if question_ids[i] == int(data['question_id']): + finished_task_ids.add(data['question_id']) + logger.warning( + f'Output file {output_file} already exists. Loaded {len(finished_task_ids)} finished instances.' + ) + output_fp = open(output_file, 'a') + logger.info( + f'Evaluation started with Agent {agent_class}, model {model_name}, max iterations {max_iterations}.' + ) + # ============================================= + # filter out finished instances + new_questions = [] + new_question_ids = [] + for i in range(len(question_ids)): + if question_ids[i] in finished_task_ids: + logger.info( + f'Skipping instance {question_ids[i]} as it is already finished.' + ) + continue + new_questions.append(questions[i]) + new_question_ids.append(question_ids[i]) + + finished_task_number = len(finished_task_ids) + questions = new_questions + question_ids = new_question_ids + logger.info( + f'Finished instances: {finished_task_number}, Remaining instances: {len(question_ids)}' + ) + # ============================================= + pbar = tqdm(total=len(question_ids)) + + # This function tracks the progress AND write the output to a JSONL file + def update_progress(future, pbar, output_fp, finished_task_ids): + pbar.update(1) + output = future.result() + pbar.set_description(f'Instance {output["question_id"]}') + pbar.set_postfix_str(f'Test Result: {output["correct"]}') + logger.info( + f'Finished evaluation for instance {output["question_id"]}: {output["correct"]}' + ) + output_fp.write(json.dumps(output) + '\n') + output_fp.flush() + finished_task_ids.add(output['question_id']) + + # This sets the multi-processing + num_workers = args.eval_num_workers + logger.info(f'Using {num_workers} workers for evaluation.') + try: + with ProcessPoolExecutor(num_workers) as executor: + futures = [] + # This is how we perform multi-processing + for i in range(len(question_ids)): + try: + question_id = question_ids[i] + question = questions[i] + future = executor.submit( + process_instance, + question_id, + question, + agent_class, + metadata, + reset_logger=bool(num_workers > 1), + ) + future.add_done_callback( + update_progress, pbar, output_fp, finished_task_ids + ) + futures.append(future) + except Exception: + continue + + # Wait for all futures to complete + for future in futures: + try: + future.result() + except Exception: + continue + except KeyboardInterrupt: + logger.info('KeyboardInterrupt received. Cleaning up...') + cleanup() + + output_fp.close() + total_correct = 0 + total_hallucination = 0 + output = [] + with open(output_file, 'r') as f: + for line in f: + data = json.loads(line) + output.append(data) + if int(data['question_id']) in finished_task_ids: + if str(data['correct']).lower() == 'true': + total_correct += 1 + if str(data['hallucination']).lower() == 'true': + total_hallucination += 1 + # sort all output by question_id + output = sorted(output, key=lambda x: x['question_id']) + with open(output_file, 'w') as f: + for dat in output: + f.write(json.dumps(dat) + '\n') + f.flush() + + logger.info( + f'Evaluation finished for {hub}. Total: {len(question_ids)+finished_task_number}; Correct: {total_correct}; Hallucination: {total_hallucination}. Accuracy: {total_correct / (len(question_ids)+finished_task_number)}' + ) diff --git a/evaluation/gorilla/scripts/run_infer.sh b/evaluation/gorilla/scripts/run_infer.sh new file mode 100644 index 0000000000..8bef6abda4 --- /dev/null +++ b/evaluation/gorilla/scripts/run_infer.sh @@ -0,0 +1,42 @@ +#!/bin/bash +MODEL_CONFIG=$1 +AGENT=$2 +EVAL_LIMIT=$3 +HUBS=$4 + +if [ -z "$AGENT" ]; then + echo "Agent not specified, use default CodeActAgent" + AGENT="CodeActAgent" +fi + +if [ -z "$HUBS" ]; then + HUBS="hf,torch,tf" + echo "Hubs not specified, use default $HUBS" +fi + +# IMPORTANT: Because Agent's prompt changes fairly often in the rapidly evolving codebase of OpenDevin +# We need to track the version of Agent in the evaluation to make sure results are comparable +AGENT_VERSION=v$(poetry run python -c "import agenthub; from opendevin.controller.agent import Agent; print(Agent.get_cls('$AGENT').VERSION)") + +echo "AGENT: $AGENT" +echo "AGENT_VERSION: $AGENT_VERSION" +echo "MODEL_CONFIG: $MODEL_CONFIG" +echo "HUBS: $HUBS" + +COMMAND="poetry run python evaluation/gorilla/run_infer.py \ + --agent-cls $AGENT \ + --llm-config $MODEL_CONFIG \ + --max-iterations 30 \ + --hubs $HUBS \ + --data-split validation \ + --max-chars 10000000 \ + --eval-num-workers 1 \ + --eval-note ${AGENT_VERSION}_${LEVELS}" + +if [ -n "$EVAL_LIMIT" ]; then + echo "EVAL_LIMIT: $EVAL_LIMIT" + COMMAND="$COMMAND --eval-n-limit $EVAL_LIMIT" +fi + +# Run the command +eval $COMMAND diff --git a/evaluation/gorilla/utils.py b/evaluation/gorilla/utils.py new file mode 100644 index 0000000000..f3031a831d --- /dev/null +++ b/evaluation/gorilla/utils.py @@ -0,0 +1,101 @@ +import json +from functools import partial + +import requests +from ast_eval_hf import ast_eval_hf, ast_parse +from ast_eval_tf import ast_eval_tf +from ast_eval_th import ast_eval_th + + +# This function is modified from Gorilla's APIBench implementations (https://github.com/ShishirPatil/gorilla/blob/main/eval/get_llm_responses.py). +def encode_question(question, api_name): + """Encode multiple prompt instructions into a single string.""" + + prompts = [] + if api_name == 'torch': + api_name = 'torchhub' + domains = '1. $DOMAIN is inferred from the task description and should include one of {Classification, Semantic Segmentation, Object Detection, Audio Separation, Video Classification, Text-to-Speech}.' + elif api_name == 'hf': + api_name = 'huggingface' + domains = '1. $DOMAIN should include one of {Multimodal Feature Extraction, Multimodal Text-to-Image, Multimodal Image-to-Text, Multimodal Text-to-Video, \ + Multimodal Visual Question Answering, Multimodal Document Question Answer, Multimodal Graph Machine Learning, Computer Vision Depth Estimation,\ + Computer Vision Image Classification, Computer Vision Object Detection, Computer Vision Image Segmentation, Computer Vision Image-to-Image, \ + Computer Vision Unconditional Image Generation, Computer Vision Video Classification, Computer Vision Zero-Shor Image Classification, \ + Natural Language Processing Text Classification, Natural Language Processing Token Classification, Natural Language Processing Table Question Answering, \ + Natural Language Processing Question Answering, Natural Language Processing Zero-Shot Classification, Natural Language Processing Translation, \ + Natural Language Processing Summarization, Natural Language Processing Conversational, Natural Language Processing Text Generation, Natural Language Processing Fill-Mask,\ + Natural Language Processing Text2Text Generation, Natural Language Processing Sentence Similarity, Audio Text-to-Speech, Audio Automatic Speech Recognition, \ + Audio Audio-to-Audio, Audio Audio Classification, Audio Voice Activity Detection, Tabular Tabular Classification, Tabular Tabular Regression, \ + Reinforcement Learning Reinforcement Learning, Reinforcement Learning Robotics }' + elif api_name == 'tf': + api_name = 'tensorhub' + domains = '1. $DOMAIN is inferred from the task description and should include one of {text-sequence-alignment, text-embedding, text-language-model, text-preprocessing, text-classification, text-generation, text-question-answering, text-retrieval-question-answering, text-segmentation, text-to-mel, image-classification, image-feature-vector, image-object-detection, image-segmentation, image-generator, image-pose-detection, image-rnn-agent, image-augmentation, image-classifier, image-style-transfer, image-aesthetic-quality, image-depth-estimation, image-super-resolution, image-deblurring, image-extrapolation, image-text-recognition, image-dehazing, image-deraining, image-enhancemenmt, image-classification-logits, image-frame-interpolation, image-text-detection, image-denoising, image-others, video-classification, video-feature-extraction, video-generation, video-audio-text, video-text, audio-embedding, audio-event-classification, audio-command-detection, audio-paralinguists-classification, audio-speech-to-text, audio-speech-synthesis, audio-synthesis, audio-pitch-extraction}' + else: + print('Error: API name is not supported.') + + prompt = ( + question + + '\nWrite a python program in 1 to 2 lines to call API in ' + + api_name + + '.\n\nThe answer should follow the format: <<>> $DOMAIN, <<>>: $API_CALL, <<>>: $API_PROVIDER, <<>>: $EXPLANATION, <<>>: $CODE}. Here are the requirements:\n' + + domains + + '\n2. The $API_CALL should have only 1 line of code that calls api.\n3. The $API_PROVIDER should be the programming framework used.\n4. $EXPLANATION should be a step-by-step explanation.\n5. The $CODE is the python code.\n6. Do not repeat the format in your answer.' + ) + # prompts.append({"role": "system", "content": ""}) + prompts = ( + 'You are a helpful API writer who can write APIs based on requirements.\n' + + prompt + ) + return prompts + + +def get_data(hub): + if hub == 'hf': + question_data = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/eval/eval-data/questions/huggingface/questions_huggingface_0_shot.jsonl' + api_dataset = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/api/huggingface_api.jsonl' + apibench = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/apibench/huggingface_eval.json' + ast_eval = ast_eval_hf + if hub == 'torch': + question_data = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/eval/eval-data/questions/torchhub/questions_torchhub_0_shot.jsonl' + api_dataset = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/api/torchhub_api.jsonl' + apibench = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/apibench/torchhub_eval.json' + ast_eval = ast_eval_th + if hub == 'tf': + question_data = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/eval/eval-data/questions/tensorflowhub/questions_tensorflowhub_0_shot.jsonl' + api_dataset = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/api/tensorflowhub_api.jsonl' + apibench = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/apibench/tensorflow_eval.json' + ast_eval = ast_eval_tf + + # get questions and question_ids + questions = [] + question_ids = [] + question_data = requests.get(question_data) + if question_data.status_code == 200: + lines = question_data.text.splitlines() + for line in lines: + questions.append(json.loads(line)['text']) + question_ids.append(json.loads(line)['question_id']) + + # get the api datasest + api_database = [] + api_dataset = requests.get(api_dataset) + if api_dataset.status_code == 200: + lines = api_dataset.text.splitlines() + for line in lines: + api_database.append(json.loads(line)) + + # get the question answer pair datasest + qa_pairs = [] + apibench = requests.get(apibench) + if apibench.status_code == 200: + lines = apibench.text.splitlines() + for line in lines: + qa_pairs.append(json.loads(line)['api_data']) + + # Parse all apis to ast trees + ast_database = [] + for data in api_database: + ast_tree = ast_parse(data['api_call']) + ast_database.append(ast_tree) + ast_eval = partial(ast_eval, api_database, qa_pairs, ast_database) + return questions, question_ids, ast_eval