diff --git a/evaluation/gorilla/README.md b/evaluation/gorilla/README.md
new file mode 100644
index 0000000000..b7a04a704d
--- /dev/null
+++ b/evaluation/gorilla/README.md
@@ -0,0 +1,41 @@
+# Gorilla APIBench Evaluation with OpenDevin
+
+This folder contains evaluation harness we built on top of the original [Gorilla APIBench](https://github.com/ShishirPatil/gorilla) ([paper](https://arxiv.org/pdf/2305.15334)).
+
+## Setup Environment
+
+Please follow [this document](https://github.com/OpenDevin/OpenDevin/blob/main/Development.md) to setup local development environment for OpenDevin.
+
+## Configure OpenDevin and your LLM
+
+Run `make setup-config` to set up the `config.toml` file if it does not exist at the root of the workspace.
+
+## Run Inference on APIBench Instances
+
+Make sure your Docker daemon is running, then run this bash script:
+
+```bash
+bash evaluation/gorilla/scripts/run_infer.sh [model_config] [agent] [eval_limit] [hubs]
+```
+
+where `model_config` is mandatory, while all other arguments are optional.
+
+`model_config`, e.g. `llm`, is the config group name for your
+LLM settings, as defined in your `config.toml`.
+
+`agent`, e.g. `CodeActAgent`, is the name of the agent for benchmarks, defaulting
+to `CodeActAgent`.
+
+`eval_limit`, e.g. `10`, limits the evaluation to the first `eval_limit` instances.
+By default, the script evaluates 1 instance.
+
+`hubs`, the hub from APIBench to evaluate from. You could choose one or more from `torch` or `th` (which is abbreviation of torch), `hf` (which is abbreviation of huggingface), and `tf` (which is abbreviation of tensorflow), for `hubs`. The default is `hf,torch,tf`.
+
+Note: in order to use `eval_limit`, you must also set `agent`; in order to use `hubs`, you must also set `eval_limit`.
+
+Let's say you'd like to run 10 instances using `llm` and CodeActAgent on `th` test,
+then your command would be:
+
+```bash
+bash evaluation/gorilla/scripts/run_infer.sh llm CodeActAgent 10 th
+```
diff --git a/evaluation/gorilla/ast_eval_hf.py b/evaluation/gorilla/ast_eval_hf.py
new file mode 100644
index 0000000000..1c3ce6e3c7
--- /dev/null
+++ b/evaluation/gorilla/ast_eval_hf.py
@@ -0,0 +1,127 @@
+# Copyright 2023 https://github.com/ShishirPatil/gorilla
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# This file is modifed from https://github.com/ShishirPatil/gorilla/blob/main/eval/eval-scripts/ast_eval_hf.py
+
+from tree_sitter import Language, Parser
+
+
+# Get all the subtrees given a root_node
+def get_all_sub_trees(root_node):
+ node_stack = []
+ sub_tree_sexp_list = []
+ depth = 1
+ # text = root_node.text
+ node_stack.append([root_node, depth])
+ while len(node_stack) != 0:
+ cur_node, cur_depth = node_stack.pop()
+ if cur_node.child_count > 0:
+ sub_tree_sexp_list.append(
+ [cur_node.sexp(), cur_depth, cur_node, cur_node.children[0].text]
+ )
+ else:
+ sub_tree_sexp_list.append([cur_node.sexp(), cur_depth, cur_node, None])
+ for child_node in cur_node.children:
+ if len(child_node.children) != 0:
+ depth = cur_depth + 1
+ node_stack.append([child_node, depth])
+ return sub_tree_sexp_list
+
+
+# Parse the program into AST trees
+def ast_parse(candidate, lang='python'):
+ LANGUAGE = Language('evaluation/gorilla/my-languages.so', lang)
+ parser = Parser()
+ parser.set_language(LANGUAGE)
+
+ candidate_tree = parser.parse(bytes(candidate, 'utf8')).root_node
+ return candidate_tree
+
+
+# Get all the arguments in the ast tree
+def get_args(node):
+ if node.child_count == 0:
+ return []
+ args_list = []
+ for child in node.children[0].children[0].children[1].children:
+ if '=' in child.text.decode():
+ args_list.append(child.children[2].text)
+ elif (
+ child.text.decode() != '('
+ and child.text.decode() != ')'
+ and child.text.decode() != ','
+ ):
+ args_list.append(child.text)
+ return args_list
+
+
+# Check if there is an api match
+def ast_check(candidate_subtree_list, base_tree_list):
+ for idx, base_tree in enumerate(base_tree_list):
+ if base_tree.children[0].children[0].child_count == 0:
+ continue
+ api_name = base_tree.children[0].children[0].children[0].text
+ for candidate_tree in candidate_subtree_list:
+ if candidate_tree[3] == api_name:
+ break
+ # Now we have a sub-tree
+ candidate_tree = candidate_tree[2]
+ args_list = get_args(base_tree)
+ if len(args_list) == 0:
+ continue
+ ast_match = True
+ for arg in args_list:
+ if arg.decode().lstrip("'").rstrip("'") not in candidate_tree.text.decode():
+ ast_match = False
+ break
+ if ast_match:
+ return idx
+ return -1
+
+
+def ast_eval_hf(api_database, qa_pairs, ast_database, question_id, response):
+ # Check correctness
+ correct = False
+ hallucination = False
+ output = response
+ # Index the "api_call" domain
+ output = output.split('api_call')
+ if len(output) == 1:
+ api_call = output[0]
+ else:
+ # Parse the output
+ output = output[1].split('api_provider')[0]
+ if ':' not in output:
+ start = 0
+ else:
+ start = output.index(':')
+ if ')' not in output:
+ end = -2
+ else:
+ end = output.rindex(')')
+ api_call = output[start + 2 : end + 1]
+ # Parse the api_call into AST tree
+ ast_tree = ast_parse(api_call)
+ # Search for a subtree
+ ast_subtree_list = get_all_sub_trees(ast_tree)
+ # Check which ast tree is matching
+ database_index = ast_check(ast_subtree_list, ast_database)
+ # We cannot index this ast in our database
+ if database_index == -1:
+ hallucination = True
+ # We index our reference api_call
+ ref_api_call = api_database[database_index]
+ # Check for functionality
+ if ref_api_call['domain'] == qa_pairs[question_id - 1]['domain']:
+ correct = True
+ return correct, hallucination
diff --git a/evaluation/gorilla/ast_eval_tf.py b/evaluation/gorilla/ast_eval_tf.py
new file mode 100644
index 0000000000..4bfbeaca38
--- /dev/null
+++ b/evaluation/gorilla/ast_eval_tf.py
@@ -0,0 +1,127 @@
+# Copyright 2023 https://github.com/ShishirPatil/gorilla
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# This file is modifed from https://github.com/ShishirPatil/gorilla/blob/main/eval/eval-scripts/ast_eval_tf.py
+
+from tree_sitter import Language, Parser
+
+
+# Get all the subtrees given a root_node
+def get_all_sub_trees(root_node):
+ node_stack = []
+ sub_tree_sexp_list = []
+ depth = 1
+ # text = root_node.text
+ node_stack.append([root_node, depth])
+ while len(node_stack) != 0:
+ cur_node, cur_depth = node_stack.pop()
+ if cur_node.child_count > 0:
+ sub_tree_sexp_list.append(
+ [cur_node.sexp(), cur_depth, cur_node, cur_node.children[0].text]
+ )
+ else:
+ sub_tree_sexp_list.append([cur_node.sexp(), cur_depth, cur_node, None])
+ for child_node in cur_node.children:
+ if len(child_node.children) != 0:
+ depth = cur_depth + 1
+ node_stack.append([child_node, depth])
+ return sub_tree_sexp_list
+
+
+# Parse the program into AST trees
+def ast_parse(candidate, lang='python'):
+ LANGUAGE = Language('evaluation/gorilla/my-languages.so', lang)
+ parser = Parser()
+ parser.set_language(LANGUAGE)
+
+ candidate_tree = parser.parse(bytes(candidate, 'utf8')).root_node
+ return candidate_tree
+
+
+# Get all the arguments in the ast tree
+def get_args(node):
+ if node.child_count == 0:
+ return []
+ args_list = []
+ for child in node.children[0].children[0].children[1].children:
+ if 'model=' in child.text.decode() or 'model =' in child.text.decode():
+ args_list.append(child.children[2].text)
+ elif (
+ child.text.decode() != '('
+ and child.text.decode() != ')'
+ and child.text.decode() != ','
+ ):
+ args_list.append(child.text)
+ return args_list
+
+
+# Check if there is an api match
+def ast_check(candidate_subtree_list, base_tree_list):
+ for idx, base_tree in enumerate(base_tree_list):
+ if base_tree.children[0].children[0].child_count == 0:
+ continue
+ api_name = base_tree.children[0].children[0].children[0].text
+ for candidate_tree in candidate_subtree_list:
+ if candidate_tree[3] == api_name:
+ break
+ # Now we have a sub-tree
+ candidate_tree = candidate_tree[2]
+ args_list = get_args(base_tree)
+ if len(args_list) == 0:
+ continue
+ ast_match = True
+ for arg in args_list:
+ if arg.decode().lstrip("'").rstrip("'") not in candidate_tree.text.decode():
+ ast_match = False
+ break
+ if ast_match:
+ return idx
+ return -1
+
+
+def ast_eval_tf(api_database, qa_pairs, ast_database, question_id, response):
+ # Check correctness
+ correct = False
+ hallucination = False
+ output = response
+ # Index the "api_call" domain
+ output = output.split('api_call')
+ if len(output) == 1:
+ api_call = output[0]
+ else:
+ # Parse the output
+ output = output[1].split('api_provider')[0]
+ if ':' not in output:
+ start = 0
+ else:
+ start = output.index(':')
+ if ')' not in output:
+ end = -2
+ else:
+ end = output.rindex(')')
+ api_call = output[start + 2 : end + 1]
+ # Parse the api_call into AST tree
+ ast_tree = ast_parse(api_call)
+ # Search for a subtree
+ ast_subtree_list = get_all_sub_trees(ast_tree)
+ # Check which ast tree is matching
+ database_index = ast_check(ast_subtree_list, ast_database)
+ # We cannot index this ast in our database
+ if database_index == -1:
+ hallucination = True
+ # We index our reference api_call
+ ref_api_call = api_database[database_index]
+ # Check for functionality
+ if ref_api_call['domain'] == qa_pairs[question_id - 1]['domain']:
+ correct = True
+ return correct, hallucination
diff --git a/evaluation/gorilla/ast_eval_th.py b/evaluation/gorilla/ast_eval_th.py
new file mode 100644
index 0000000000..4da8777c16
--- /dev/null
+++ b/evaluation/gorilla/ast_eval_th.py
@@ -0,0 +1,123 @@
+# Copyright 2023 https://github.com/ShishirPatil/gorilla
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# This file is modifed from https://github.com/ShishirPatil/gorilla/blob/main/eval/eval-scripts/ast_eval_th.py
+
+from tree_sitter import Language, Parser
+
+
+# Get all the subtrees given a root_node
+def get_all_sub_trees(root_node):
+ node_stack = []
+ sub_tree_sexp_list = []
+ depth = 1
+ # text = root_node.text
+ node_stack.append([root_node, depth])
+ while len(node_stack) != 0:
+ cur_node, cur_depth = node_stack.pop()
+ if cur_node.child_count > 0:
+ sub_tree_sexp_list.append(
+ [cur_node.sexp(), cur_depth, cur_node, cur_node.children[0].text]
+ )
+ else:
+ sub_tree_sexp_list.append([cur_node.sexp(), cur_depth, cur_node, None])
+ for child_node in cur_node.children:
+ if len(child_node.children) != 0:
+ depth = cur_depth + 1
+ node_stack.append([child_node, depth])
+ return sub_tree_sexp_list
+
+
+# Parse the program into AST trees
+def ast_parse(candidate, lang='python'):
+ LANGUAGE = Language('evaluation/gorilla/my-languages.so', lang)
+ parser = Parser()
+ parser.set_language(LANGUAGE)
+
+ candidate_tree = parser.parse(bytes(candidate, 'utf8')).root_node
+ return candidate_tree
+
+
+# Get all the arguments in the ast tree
+def get_args(node):
+ if node.child_count == 0:
+ return []
+ args_list = []
+ for child in node.children[0].children[0].children[1].children:
+ if 'repo_or_dir' in child.text.decode() or 'model' in child.text.decode():
+ args_list.append(child.children[2].text)
+ return args_list
+
+
+# Check if there is an api match
+def ast_check(candidate_subtree_list, base_tree_list):
+ for idx, base_tree in enumerate(base_tree_list):
+ if base_tree.children[0].children[0].child_count == 0:
+ continue
+ api_name = base_tree.children[0].children[0].children[0].text
+ for candidate_tree in candidate_subtree_list:
+ if candidate_tree[3] == api_name:
+ break
+ # Now we have a sub-tree
+ candidate_tree = candidate_tree[2]
+ args_list = get_args(base_tree)
+ if len(args_list) == 0:
+ continue
+ ast_match = True
+ for arg in args_list:
+ if arg.decode().lstrip("'").rstrip("'") not in candidate_tree.text.decode():
+ ast_match = False
+ break
+ if ast_match:
+ return idx
+ return -1
+
+
+def process_response(question_id, output, api_database, qa_pairs, ast_database):
+ # Index the "api_call" domain
+ output = output.split('api_call')
+ if len(output) == 1:
+ return False, False
+ else:
+ output = output[1].split('api_provider')[0]
+ if ':' not in output:
+ start = 0
+ else:
+ start = output.index(':')
+ if ')' not in output:
+ end = -2
+ else:
+ end = output.rindex(')')
+ api_call = output[start + 2 : end + 1]
+
+ # Parse the api_call into AST tree
+ ast_tree = ast_parse(api_call)
+ # Search for a subtree
+ ast_subtree_list = get_all_sub_trees(ast_tree)
+ # Check which ast tree is matching
+ database_index = ast_check(ast_subtree_list, ast_database)
+ # We cannot index this ast in our database
+ if database_index == -1:
+ return False, True
+ # We index our reference api_call
+ ref_api_call = api_database[database_index]
+ # Check for functionality
+ if ref_api_call['domain'] == qa_pairs[question_id - 1]['domain']:
+ return True, False
+ else:
+ return False, False
+
+
+def ast_eval_th(api_database, qa_pairs, ast_database, question_id, response):
+ # Check correctness
+ return process_response(question_id, response, api_database, qa_pairs, ast_database)
diff --git a/evaluation/gorilla/run_infer.py b/evaluation/gorilla/run_infer.py
new file mode 100644
index 0000000000..df3fe34c4f
--- /dev/null
+++ b/evaluation/gorilla/run_infer.py
@@ -0,0 +1,355 @@
+import asyncio
+import json
+import logging
+import multiprocessing as mp
+import os
+import pathlib
+import subprocess
+import time
+from concurrent.futures import ProcessPoolExecutor
+
+from tqdm import tqdm
+from utils import encode_question, get_data
+
+from opendevin.controller.state.state import State
+from opendevin.core.config import config, get_llm_config_arg, get_parser
+from opendevin.core.logger import get_console_handler
+from opendevin.core.logger import opendevin_logger as logger
+from opendevin.core.main import main
+from opendevin.events.action import MessageAction
+from opendevin.events.serialization.event import event_to_dict
+
+
+def cleanup():
+ print('Cleaning up child processes...')
+ for process in mp.active_children():
+ print(f'Terminating child process: {process.name}')
+ process.terminate()
+ process.join()
+
+
+def codeact_user_response(state: State) -> str:
+ msg = (
+ #'Please continue working on the task on whatever approach you think is suitable.\n'
+ 'Please run the following command: exit .\n'
+ #'IMPORTANT: YOU SHOULD NEVER ASK FOR HUMAN HELP OR USE THE INTERNET TO SOLVE THIS TASK.\n'
+ )
+ if state.history:
+ user_msgs = [
+ action
+ for action, _ in state.history
+ if isinstance(action, MessageAction) and action.source == 'user'
+ ]
+ if len(user_msgs) >= 2:
+ # let the agent know that it can give up when it has tried 3 times
+ return (
+ msg
+ + 'If you want to give up, run: exit .\n'
+ )
+ return msg
+
+
+def monologue_user_response(state: State) -> str:
+ raise NotImplementedError('MonologueAgent should never ask for user responses.')
+
+
+AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
+ 'CodeActAgent': codeact_user_response,
+ 'MonologueAgent': monologue_user_response,
+}
+
+AGENT_CLS_TO_INST_SUFFIX = {
+ 'CodeActAgent': 'When you think you have completed the request, please run the following command: exit .\n'
+}
+
+
+def process_instance(
+ question_id, question, agent_class, metadata, reset_logger: bool = True
+):
+ # create process-specific workspace dir
+ # we will create a workspace directory for EACH process
+ # so that different agent don't interfere with each other.
+ old_workspace_mount_path = config.workspace_mount_path
+ try:
+ workspace_mount_path = os.path.join(
+ config.workspace_mount_path, '_eval_workspace'
+ )
+ workspace_mount_path = os.path.join(workspace_mount_path, str(os.getpid()))
+ pathlib.Path(workspace_mount_path).mkdir(parents=True, exist_ok=True)
+ config.workspace_mount_path = workspace_mount_path
+
+ # Setup the logger properly, so you can run multi-processing to parallize the evaluation
+ eval_output_dir = metadata['eval_output_dir']
+ if reset_logger:
+ # Set up logger
+ log_file = os.path.join(
+ eval_output_dir, 'logs', f'instance_{question_id}.log'
+ )
+ # Remove all existing handlers from logger
+ for handler in logger.handlers[:]:
+ logger.removeHandler(handler)
+ # add back the console handler to print ONE line
+ logger.addHandler(get_console_handler())
+ logger.info(
+ f'Starting evaluation for instance {question_id}.\nLOG: tail -f {log_file}'
+ )
+ # Remove all existing handlers from logger
+ for handler in logger.handlers[:]:
+ logger.removeHandler(handler)
+ file_handler = logging.FileHandler(log_file)
+ file_handler.setFormatter(
+ logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
+ )
+ logger.addHandler(file_handler)
+ logger.info(f'Process-specific workspace mounted at {workspace_mount_path}')
+
+ # Prepare instruction
+ instruction = encode_question(question, metadata['hub'])
+ instruction += 'IMPORTANT: You should ONLY interact with the environment provided to you AND NEVER ASK FOR HUMAN HELP.\n'
+ # NOTE: You can actually set slightly different instruction for different agents
+ instruction += AGENT_CLS_TO_INST_SUFFIX.get(agent_class, '')
+ # logger.info(f'Instruction:\n{instruction}', extra={'msg_type': 'OBSERVATION'})
+
+ # Here's how you can run the agent (similar to the `main` function) and get the final task state
+ state: State = asyncio.run(
+ main(
+ instruction,
+ fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
+ agent_class
+ ),
+ )
+ )
+ # ======= Attempt to evaluate the agent's edits =======
+ # If you are working on simplier benchmark that only evaluates the final model output (e.g., in a MessageAction)
+ # You can simply get the LAST `MessageAction` from the returned `state.history` and parse it for evaluation.
+
+ if state is None:
+ raise ValueError('State should not be None.')
+
+ model_answer_raw = ''
+ for act, _ in reversed(state.history):
+ if isinstance(act, MessageAction) and act.source == 'agent':
+ model_answer_raw = act.content
+ break
+ # attempt to parse model_answer
+ _, _, ast_eval = get_data(metadata['hub'])
+ correct, hallucination = ast_eval(question_id, model_answer_raw)
+ metrics = state.metrics.get() if state.metrics else None
+ logger.info(
+ f'Final message: {model_answer_raw} | Correctness: {correct} | Hallucination: {hallucination}'
+ )
+ # Save the output
+ output = {
+ 'question_id': question_id,
+ 'text': model_answer_raw,
+ 'correct': correct,
+ 'hallucination': hallucination,
+ 'answer_id': 'None',
+ 'model_id': metadata['model_name'],
+ 'metadata': metadata,
+ 'history': [
+ (event_to_dict(action), event_to_dict(obs))
+ for action, obs in state.history
+ ],
+ 'metrics': metrics,
+ 'error': state.error if state and state.error else None,
+ }
+ except Exception:
+ logger.error('Process instance failed')
+ raise
+ finally:
+ config.workspace_mount_path = old_workspace_mount_path
+ return output
+
+
+if __name__ == '__main__':
+ parser = get_parser()
+ parser.add_argument(
+ '--hubs',
+ type=str,
+ help='Which hubs to evaluate from APIBench. APIBench contains 3 hubs, namely huggingface, torch, and tensorflow. You could choose one or more from hf, torch, or tf, seperated by commas. For example, the default is --hub hf,torch,tf.',
+ default='hf,torch,tf',
+ )
+ args, _ = parser.parse_known_args()
+ if args.directory:
+ config.workspace_base = os.path.abspath(args.directory)
+ print(f'Setting workspace base to {config.workspace_base}')
+
+ # Check https://github.com/OpenDevin/OpenDevin/blob/main/evaluation/swe_bench/README.md#configure-opendevin-and-your-llm
+ # for details of how to set `llm_config`
+ if args.llm_config:
+ specified_llm_config = get_llm_config_arg(args.llm_config)
+ if specified_llm_config:
+ config.llm = specified_llm_config
+ logger.info(f'Config for evaluation: {config}')
+ agent_class = args.agent_cls
+ assert (
+ agent_class in AGENT_CLS_TO_FAKE_USER_RESPONSE_FN
+ ), f'Unsupported agent class: {agent_class}'
+ model_name = config.llm.model.split('/')[-1]
+ max_iterations = args.max_iterations
+ eval_note = ''
+ if args.eval_note is not None:
+ eval_note += '_N_' + args.eval_note
+ eval_output_dir = os.path.join(
+ args.eval_output_dir,
+ 'gorilla',
+ agent_class,
+ model_name + '_maxiter_' + str(max_iterations) + eval_note,
+ )
+ pathlib.Path(eval_output_dir).mkdir(parents=True, exist_ok=True)
+ pathlib.Path(os.path.join(eval_output_dir, 'logs')).mkdir(
+ parents=True, exist_ok=True
+ )
+ logger.info(f'Using evaluation output directory: {eval_output_dir}')
+
+ hubs = []
+ if 'hf' in args.hubs:
+ hubs.append('hf')
+ if 'torch' in args.hubs or 'th' in args.hubs:
+ hubs.append('torch')
+ if 'tf' in args.hubs:
+ hubs.append('tf')
+ if hubs == []:
+ raise ValueError('Please choose at least one from hf, torch, and tf for hubs.')
+
+ for hub in hubs:
+ logger.info(f'Evaluating APIBench {hub} test')
+ questions, question_ids, ast_eval = get_data(hub)
+
+ # TEST METADATA
+ metadata = {
+ 'hub': hub,
+ 'agent_class': agent_class,
+ 'model_name': model_name,
+ 'max_iterations': max_iterations,
+ 'eval_output_dir': eval_output_dir,
+ 'start_time': time.strftime('%Y-%m-%d %H:%M:%S'),
+ # get the commit id of current repo for reproduciblity
+ 'git_commit': subprocess.check_output(['git', 'rev-parse', 'HEAD'])
+ .decode('utf-8')
+ .strip(),
+ }
+ logger.info(f'Metadata: {metadata}')
+ with open(os.path.join(eval_output_dir, f'metadata_{hub}.json'), 'w') as f:
+ json.dump(metadata, f)
+
+ # LIMIT EVALUATION
+ eval_n_limit = args.eval_n_limit
+ if eval_n_limit:
+ questions = questions[: (eval_n_limit // len(hubs))]
+ question_ids = question_ids[: (eval_n_limit // len(hubs))]
+ logger.info(
+ f'Limiting evaluation to a total of first {eval_n_limit} instances -> first {eval_n_limit//len(hubs)} instances per hub.'
+ )
+ output_file = os.path.join(eval_output_dir, f'output_{model_name}_{hub}.jsonl')
+ logger.info(f'Writing evaluation output to {output_file}')
+ finished_task_ids = set()
+ if os.path.exists(output_file):
+ with open(output_file, 'r') as f:
+ for line in f:
+ data = json.loads(line)
+ for i in range(len(question_ids)):
+ if question_ids[i] == int(data['question_id']):
+ finished_task_ids.add(data['question_id'])
+ logger.warning(
+ f'Output file {output_file} already exists. Loaded {len(finished_task_ids)} finished instances.'
+ )
+ output_fp = open(output_file, 'a')
+ logger.info(
+ f'Evaluation started with Agent {agent_class}, model {model_name}, max iterations {max_iterations}.'
+ )
+ # =============================================
+ # filter out finished instances
+ new_questions = []
+ new_question_ids = []
+ for i in range(len(question_ids)):
+ if question_ids[i] in finished_task_ids:
+ logger.info(
+ f'Skipping instance {question_ids[i]} as it is already finished.'
+ )
+ continue
+ new_questions.append(questions[i])
+ new_question_ids.append(question_ids[i])
+
+ finished_task_number = len(finished_task_ids)
+ questions = new_questions
+ question_ids = new_question_ids
+ logger.info(
+ f'Finished instances: {finished_task_number}, Remaining instances: {len(question_ids)}'
+ )
+ # =============================================
+ pbar = tqdm(total=len(question_ids))
+
+ # This function tracks the progress AND write the output to a JSONL file
+ def update_progress(future, pbar, output_fp, finished_task_ids):
+ pbar.update(1)
+ output = future.result()
+ pbar.set_description(f'Instance {output["question_id"]}')
+ pbar.set_postfix_str(f'Test Result: {output["correct"]}')
+ logger.info(
+ f'Finished evaluation for instance {output["question_id"]}: {output["correct"]}'
+ )
+ output_fp.write(json.dumps(output) + '\n')
+ output_fp.flush()
+ finished_task_ids.add(output['question_id'])
+
+ # This sets the multi-processing
+ num_workers = args.eval_num_workers
+ logger.info(f'Using {num_workers} workers for evaluation.')
+ try:
+ with ProcessPoolExecutor(num_workers) as executor:
+ futures = []
+ # This is how we perform multi-processing
+ for i in range(len(question_ids)):
+ try:
+ question_id = question_ids[i]
+ question = questions[i]
+ future = executor.submit(
+ process_instance,
+ question_id,
+ question,
+ agent_class,
+ metadata,
+ reset_logger=bool(num_workers > 1),
+ )
+ future.add_done_callback(
+ update_progress, pbar, output_fp, finished_task_ids
+ )
+ futures.append(future)
+ except Exception:
+ continue
+
+ # Wait for all futures to complete
+ for future in futures:
+ try:
+ future.result()
+ except Exception:
+ continue
+ except KeyboardInterrupt:
+ logger.info('KeyboardInterrupt received. Cleaning up...')
+ cleanup()
+
+ output_fp.close()
+ total_correct = 0
+ total_hallucination = 0
+ output = []
+ with open(output_file, 'r') as f:
+ for line in f:
+ data = json.loads(line)
+ output.append(data)
+ if int(data['question_id']) in finished_task_ids:
+ if str(data['correct']).lower() == 'true':
+ total_correct += 1
+ if str(data['hallucination']).lower() == 'true':
+ total_hallucination += 1
+ # sort all output by question_id
+ output = sorted(output, key=lambda x: x['question_id'])
+ with open(output_file, 'w') as f:
+ for dat in output:
+ f.write(json.dumps(dat) + '\n')
+ f.flush()
+
+ logger.info(
+ f'Evaluation finished for {hub}. Total: {len(question_ids)+finished_task_number}; Correct: {total_correct}; Hallucination: {total_hallucination}. Accuracy: {total_correct / (len(question_ids)+finished_task_number)}'
+ )
diff --git a/evaluation/gorilla/scripts/run_infer.sh b/evaluation/gorilla/scripts/run_infer.sh
new file mode 100644
index 0000000000..8bef6abda4
--- /dev/null
+++ b/evaluation/gorilla/scripts/run_infer.sh
@@ -0,0 +1,42 @@
+#!/bin/bash
+MODEL_CONFIG=$1
+AGENT=$2
+EVAL_LIMIT=$3
+HUBS=$4
+
+if [ -z "$AGENT" ]; then
+ echo "Agent not specified, use default CodeActAgent"
+ AGENT="CodeActAgent"
+fi
+
+if [ -z "$HUBS" ]; then
+ HUBS="hf,torch,tf"
+ echo "Hubs not specified, use default $HUBS"
+fi
+
+# IMPORTANT: Because Agent's prompt changes fairly often in the rapidly evolving codebase of OpenDevin
+# We need to track the version of Agent in the evaluation to make sure results are comparable
+AGENT_VERSION=v$(poetry run python -c "import agenthub; from opendevin.controller.agent import Agent; print(Agent.get_cls('$AGENT').VERSION)")
+
+echo "AGENT: $AGENT"
+echo "AGENT_VERSION: $AGENT_VERSION"
+echo "MODEL_CONFIG: $MODEL_CONFIG"
+echo "HUBS: $HUBS"
+
+COMMAND="poetry run python evaluation/gorilla/run_infer.py \
+ --agent-cls $AGENT \
+ --llm-config $MODEL_CONFIG \
+ --max-iterations 30 \
+ --hubs $HUBS \
+ --data-split validation \
+ --max-chars 10000000 \
+ --eval-num-workers 1 \
+ --eval-note ${AGENT_VERSION}_${LEVELS}"
+
+if [ -n "$EVAL_LIMIT" ]; then
+ echo "EVAL_LIMIT: $EVAL_LIMIT"
+ COMMAND="$COMMAND --eval-n-limit $EVAL_LIMIT"
+fi
+
+# Run the command
+eval $COMMAND
diff --git a/evaluation/gorilla/utils.py b/evaluation/gorilla/utils.py
new file mode 100644
index 0000000000..f3031a831d
--- /dev/null
+++ b/evaluation/gorilla/utils.py
@@ -0,0 +1,101 @@
+import json
+from functools import partial
+
+import requests
+from ast_eval_hf import ast_eval_hf, ast_parse
+from ast_eval_tf import ast_eval_tf
+from ast_eval_th import ast_eval_th
+
+
+# This function is modified from Gorilla's APIBench implementations (https://github.com/ShishirPatil/gorilla/blob/main/eval/get_llm_responses.py).
+def encode_question(question, api_name):
+ """Encode multiple prompt instructions into a single string."""
+
+ prompts = []
+ if api_name == 'torch':
+ api_name = 'torchhub'
+ domains = '1. $DOMAIN is inferred from the task description and should include one of {Classification, Semantic Segmentation, Object Detection, Audio Separation, Video Classification, Text-to-Speech}.'
+ elif api_name == 'hf':
+ api_name = 'huggingface'
+ domains = '1. $DOMAIN should include one of {Multimodal Feature Extraction, Multimodal Text-to-Image, Multimodal Image-to-Text, Multimodal Text-to-Video, \
+ Multimodal Visual Question Answering, Multimodal Document Question Answer, Multimodal Graph Machine Learning, Computer Vision Depth Estimation,\
+ Computer Vision Image Classification, Computer Vision Object Detection, Computer Vision Image Segmentation, Computer Vision Image-to-Image, \
+ Computer Vision Unconditional Image Generation, Computer Vision Video Classification, Computer Vision Zero-Shor Image Classification, \
+ Natural Language Processing Text Classification, Natural Language Processing Token Classification, Natural Language Processing Table Question Answering, \
+ Natural Language Processing Question Answering, Natural Language Processing Zero-Shot Classification, Natural Language Processing Translation, \
+ Natural Language Processing Summarization, Natural Language Processing Conversational, Natural Language Processing Text Generation, Natural Language Processing Fill-Mask,\
+ Natural Language Processing Text2Text Generation, Natural Language Processing Sentence Similarity, Audio Text-to-Speech, Audio Automatic Speech Recognition, \
+ Audio Audio-to-Audio, Audio Audio Classification, Audio Voice Activity Detection, Tabular Tabular Classification, Tabular Tabular Regression, \
+ Reinforcement Learning Reinforcement Learning, Reinforcement Learning Robotics }'
+ elif api_name == 'tf':
+ api_name = 'tensorhub'
+ domains = '1. $DOMAIN is inferred from the task description and should include one of {text-sequence-alignment, text-embedding, text-language-model, text-preprocessing, text-classification, text-generation, text-question-answering, text-retrieval-question-answering, text-segmentation, text-to-mel, image-classification, image-feature-vector, image-object-detection, image-segmentation, image-generator, image-pose-detection, image-rnn-agent, image-augmentation, image-classifier, image-style-transfer, image-aesthetic-quality, image-depth-estimation, image-super-resolution, image-deblurring, image-extrapolation, image-text-recognition, image-dehazing, image-deraining, image-enhancemenmt, image-classification-logits, image-frame-interpolation, image-text-detection, image-denoising, image-others, video-classification, video-feature-extraction, video-generation, video-audio-text, video-text, audio-embedding, audio-event-classification, audio-command-detection, audio-paralinguists-classification, audio-speech-to-text, audio-speech-synthesis, audio-synthesis, audio-pitch-extraction}'
+ else:
+ print('Error: API name is not supported.')
+
+ prompt = (
+ question
+ + '\nWrite a python program in 1 to 2 lines to call API in '
+ + api_name
+ + '.\n\nThe answer should follow the format: <<>> $DOMAIN, <<>>: $API_CALL, <<>>: $API_PROVIDER, <<>>: $EXPLANATION, <<>>: $CODE}. Here are the requirements:\n'
+ + domains
+ + '\n2. The $API_CALL should have only 1 line of code that calls api.\n3. The $API_PROVIDER should be the programming framework used.\n4. $EXPLANATION should be a step-by-step explanation.\n5. The $CODE is the python code.\n6. Do not repeat the format in your answer.'
+ )
+ # prompts.append({"role": "system", "content": ""})
+ prompts = (
+ 'You are a helpful API writer who can write APIs based on requirements.\n'
+ + prompt
+ )
+ return prompts
+
+
+def get_data(hub):
+ if hub == 'hf':
+ question_data = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/eval/eval-data/questions/huggingface/questions_huggingface_0_shot.jsonl'
+ api_dataset = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/api/huggingface_api.jsonl'
+ apibench = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/apibench/huggingface_eval.json'
+ ast_eval = ast_eval_hf
+ if hub == 'torch':
+ question_data = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/eval/eval-data/questions/torchhub/questions_torchhub_0_shot.jsonl'
+ api_dataset = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/api/torchhub_api.jsonl'
+ apibench = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/apibench/torchhub_eval.json'
+ ast_eval = ast_eval_th
+ if hub == 'tf':
+ question_data = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/eval/eval-data/questions/tensorflowhub/questions_tensorflowhub_0_shot.jsonl'
+ api_dataset = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/api/tensorflowhub_api.jsonl'
+ apibench = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/apibench/tensorflow_eval.json'
+ ast_eval = ast_eval_tf
+
+ # get questions and question_ids
+ questions = []
+ question_ids = []
+ question_data = requests.get(question_data)
+ if question_data.status_code == 200:
+ lines = question_data.text.splitlines()
+ for line in lines:
+ questions.append(json.loads(line)['text'])
+ question_ids.append(json.loads(line)['question_id'])
+
+ # get the api datasest
+ api_database = []
+ api_dataset = requests.get(api_dataset)
+ if api_dataset.status_code == 200:
+ lines = api_dataset.text.splitlines()
+ for line in lines:
+ api_database.append(json.loads(line))
+
+ # get the question answer pair datasest
+ qa_pairs = []
+ apibench = requests.get(apibench)
+ if apibench.status_code == 200:
+ lines = apibench.text.splitlines()
+ for line in lines:
+ qa_pairs.append(json.loads(line)['api_data'])
+
+ # Parse all apis to ast trees
+ ast_database = []
+ for data in api_database:
+ ast_tree = ast_parse(data['api_call'])
+ ast_database.append(ast_tree)
+ ast_eval = partial(ast_eval, api_database, qa_pairs, ast_database)
+ return questions, question_ids, ast_eval