diff --git a/evaluation/benchmarks/gorilla/ast_eval_hf.py b/evaluation/benchmarks/gorilla/ast_eval_hf.py index 25229aee74..71f2c11daa 100644 --- a/evaluation/benchmarks/gorilla/ast_eval_hf.py +++ b/evaluation/benchmarks/gorilla/ast_eval_hf.py @@ -13,6 +13,7 @@ # limitations under the License. # This file is modified from https://github.com/ShishirPatil/gorilla/blob/main/eval/eval-scripts/ast_eval_hf.py +import tree_sitter_python as tspython from tree_sitter import Language, Parser @@ -39,10 +40,9 @@ def get_all_sub_trees(root_node): # Parse the program into AST trees -def ast_parse(candidate, lang='python'): - LANGUAGE = Language('evaluation/gorilla/my-languages.so', lang) - parser = Parser() - parser.set_language(LANGUAGE) +def ast_parse(candidate): + LANGUAGE = Language(tspython.language()) + parser = Parser(LANGUAGE) candidate_tree = parser.parse(bytes(candidate, 'utf8')).root_node return candidate_tree diff --git a/evaluation/benchmarks/gorilla/ast_eval_tf.py b/evaluation/benchmarks/gorilla/ast_eval_tf.py index 22067010c1..f1b9e648a3 100644 --- a/evaluation/benchmarks/gorilla/ast_eval_tf.py +++ b/evaluation/benchmarks/gorilla/ast_eval_tf.py @@ -13,6 +13,7 @@ # limitations under the License. # This file is modified from https://github.com/ShishirPatil/gorilla/blob/main/eval/eval-scripts/ast_eval_tf.py +import tree_sitter_python as tspython from tree_sitter import Language, Parser @@ -39,10 +40,9 @@ def get_all_sub_trees(root_node): # Parse the program into AST trees -def ast_parse(candidate, lang='python'): - LANGUAGE = Language('evaluation/gorilla/my-languages.so', lang) - parser = Parser() - parser.set_language(LANGUAGE) +def ast_parse(candidate): + LANGUAGE = Language(tspython.language()) + parser = Parser(LANGUAGE) candidate_tree = parser.parse(bytes(candidate, 'utf8')).root_node return candidate_tree diff --git a/evaluation/benchmarks/gorilla/ast_eval_th.py b/evaluation/benchmarks/gorilla/ast_eval_th.py index f55f70ed7c..c5fd15c733 100644 --- a/evaluation/benchmarks/gorilla/ast_eval_th.py +++ b/evaluation/benchmarks/gorilla/ast_eval_th.py @@ -13,6 +13,7 @@ # limitations under the License. # This file is modified from https://github.com/ShishirPatil/gorilla/blob/main/eval/eval-scripts/ast_eval_th.py +import tree_sitter_python as tspython from tree_sitter import Language, Parser @@ -39,10 +40,9 @@ def get_all_sub_trees(root_node): # Parse the program into AST trees -def ast_parse(candidate, lang='python'): - LANGUAGE = Language('evaluation/gorilla/my-languages.so', lang) - parser = Parser() - parser.set_language(LANGUAGE) +def ast_parse(candidate): + LANGUAGE = Language(tspython.language()) + parser = Parser(LANGUAGE) candidate_tree = parser.parse(bytes(candidate, 'utf8')).root_node return candidate_tree diff --git a/evaluation/benchmarks/gorilla/utils.py b/evaluation/benchmarks/gorilla/utils.py index 8c45cce58a..125420977a 100644 --- a/evaluation/benchmarks/gorilla/utils.py +++ b/evaluation/benchmarks/gorilla/utils.py @@ -71,19 +71,19 @@ def fetch_data(url, filename): def get_data_for_hub(hub: str): if hub == 'hf': - question_data = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/eval/eval-data/questions/huggingface/questions_huggingface_0_shot.jsonl' - api_dataset = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/api/huggingface_api.jsonl' - apibench = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/apibench/huggingface_eval.json' + question_data = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/refs/tags/v1.2/eval/eval-data/questions/huggingface/questions_huggingface_0_shot.jsonl' + api_dataset = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/refs/tags/v1.2/data/api/huggingface_api.jsonl' + apibench = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/refs/tags/v1.2/data/apibench/huggingface_eval.json' ast_eval = ast_eval_hf elif hub == 'torch': - question_data = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/eval/eval-data/questions/torchhub/questions_torchhub_0_shot.jsonl' - api_dataset = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/api/torchhub_api.jsonl' - apibench = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/apibench/torchhub_eval.json' + question_data = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/refs/tags/v1.2/eval/eval-data/questions/torchhub/questions_torchhub_0_shot.jsonl' + api_dataset = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/refs/tags/v1.2/data/api/torchhub_api.jsonl' + apibench = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/refs/tags/v1.2/data/apibench/torchhub_eval.json' ast_eval = ast_eval_th elif hub == 'tf': - question_data = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/eval/eval-data/questions/tensorflowhub/questions_tensorflowhub_0_shot.jsonl' - api_dataset = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/api/tensorflowhub_api.jsonl' - apibench = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/apibench/tensorflow_eval.json' + question_data = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/refs/tags/v1.2/eval/eval-data/questions/tensorflowhub/questions_tensorflowhub_0_shot.jsonl' + api_dataset = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/refs/tags/v1.2/data/api/tensorflowhub_api.jsonl' + apibench = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/refs/tags/v1.2/data/apibench/tensorflow_eval.json' ast_eval = ast_eval_tf question_data = fetch_data(question_data, 'question_data.jsonl')