mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
356 lines
12 KiB
Python
356 lines
12 KiB
Python
import ast
|
||
import logging
|
||
import re
|
||
import traceback
|
||
from typing import Any
|
||
|
||
import numpy as np
|
||
from sympy import Rational
|
||
|
||
from tasks.base import Task
|
||
|
||
LOGGER = logging.getLogger('MINT')
|
||
|
||
|
||
class ReasoningTask(Task):
|
||
task_name = 'reasoning'
|
||
|
||
def __init__(self, id: str, prompt: str, reference: str, **kwargs):
|
||
super().__init__(**kwargs)
|
||
self._id = id
|
||
self._prompt = prompt.strip()
|
||
self._reference = str(reference).strip().lower()
|
||
|
||
def extract_answer(self, solution: str) -> str | None:
|
||
"""Extract the answer from the given solution."""
|
||
return solution.lower().strip()
|
||
|
||
def compare_w_digits(self, reference: str, answer: str) -> bool:
|
||
"""Compare the reference and answer with digits."""
|
||
# if reference can and answer can both be converted to floats by float()
|
||
try:
|
||
float(reference)
|
||
float(answer)
|
||
return abs(float(reference) - float(answer)) <= 0.05 * abs(float(reference))
|
||
except ValueError:
|
||
return reference in answer
|
||
except Exception:
|
||
raise ValueError(f'Cannot compare {reference} and {answer}')
|
||
|
||
def success(self, solution: str) -> bool:
|
||
answer = self.extract_answer(solution)
|
||
return self.compare_w_digits(self._reference, answer)
|
||
|
||
|
||
class MultipleChoiceTask(Task):
|
||
"""Subclass of Task for multiple choice tasks."""
|
||
|
||
task_name = 'reasoning'
|
||
|
||
def __init__(self, id, prompt: str, reference: str, **kwargs):
|
||
super().__init__(**kwargs)
|
||
self._id = id
|
||
self.hide_options = kwargs.get('hide_options', False)
|
||
if self.hide_options:
|
||
self._prompt = prompt.split('Options:')[0].strip()
|
||
else:
|
||
self._prompt = prompt
|
||
self._reference = reference.strip().lower()
|
||
self._options = self.extract_options(prompt)
|
||
# if all options can be converted to float, strictly perform hide options
|
||
try:
|
||
for option in self._options.values():
|
||
float(option)
|
||
self.hide_options = True
|
||
except ValueError:
|
||
pass
|
||
self.metadata.update({'options': self._options})
|
||
|
||
def extract_answer(self, solution: str) -> str | None:
|
||
# Extract the selected option from the solution
|
||
solution = solution.lower().strip()
|
||
for letter in 'abcdefghijklmnopqrstuvwxyz':
|
||
if f'{letter})' in solution or f'{letter} )' in solution:
|
||
print('SOLUTION', letter)
|
||
return letter
|
||
else:
|
||
print('SOLUTION', solution)
|
||
return solution
|
||
|
||
def compare_w_digits(self, reference: str, answer: str) -> bool:
|
||
if reference.isdigit() and answer.isdigit():
|
||
return abs(float(reference) - float(answer)) <= 0.05 * float(reference)
|
||
else:
|
||
return reference in answer
|
||
|
||
def success(self, solution: str) -> bool:
|
||
answer = self.extract_answer(solution)
|
||
if self.compare_w_digits(self._reference, answer):
|
||
return True
|
||
else:
|
||
correct_option = self._options[self._reference]
|
||
wrong_option_list = list(self._options.values())
|
||
print('OPTIONS', correct_option, wrong_option_list)
|
||
print('ANSWER', answer)
|
||
for i in wrong_option_list:
|
||
if i in correct_option:
|
||
wrong_option_list.remove(i)
|
||
for i in wrong_option_list:
|
||
if self.compare_w_digits(i, answer) or (i in answer):
|
||
return False
|
||
if self.compare_w_digits(correct_option, answer) or (
|
||
correct_option in answer
|
||
):
|
||
return True
|
||
else:
|
||
return False
|
||
|
||
def extract_options(self, prompt: str) -> dict:
|
||
# Find the possible option separators (comma, semicolon, or parentheses)
|
||
prompt = prompt.split('Options: ')[-1]
|
||
# Extract the options using the delimiter
|
||
options_match = prompt.split(' , ')
|
||
options = {}
|
||
for i in range(len(options_match)):
|
||
option = options_match[i].strip("[]' ")
|
||
option = option.split(')')
|
||
letter = option[0].lower().strip()
|
||
content = (
|
||
option[1]
|
||
.lower()
|
||
.strip('.')
|
||
.replace('. Which option is correct?', '')
|
||
.replace('. Which one is correct?', '')
|
||
.strip()
|
||
)
|
||
options.update({letter: content})
|
||
return options
|
||
|
||
|
||
# ==== TheoremQA ====
|
||
|
||
|
||
def compare_two_numbers(p, gt):
|
||
if isinstance(p, (int, float)):
|
||
pass
|
||
elif isinstance(p, (bool, complex, dict, list, str, tuple)):
|
||
return False
|
||
else:
|
||
raise ValueError(p)
|
||
|
||
if isinstance(gt, float):
|
||
return within_eps(pred=p, gt=gt)
|
||
else:
|
||
return round(p) == gt
|
||
|
||
|
||
def compare_two_list(pred, gt):
|
||
if not isinstance(pred, list):
|
||
return False
|
||
elif len(pred) != len(gt):
|
||
return False
|
||
elif any([not isinstance(x, (int, float)) for x in pred]):
|
||
return False
|
||
else:
|
||
pred = sorted(pred)
|
||
gt = sorted(gt)
|
||
return all([compare_two_numbers(p, g) for p, g in zip(pred, gt)])
|
||
|
||
|
||
def within_eps(pred: float, gt: float):
|
||
eps = abs(gt) * 0.04
|
||
if pred >= gt - eps and pred <= gt + eps:
|
||
return True
|
||
else:
|
||
return False
|
||
|
||
|
||
def parse_number_list(s: str):
|
||
# Check if the string is a valid list by trying to parse it
|
||
parsed_list = ast.literal_eval(s)
|
||
return parsed_list
|
||
|
||
|
||
def is_number(string):
|
||
pattern = r'^[-+]?(\d{1,3}(,\d{3})*|(\d+))(\.\d+)?$'
|
||
match = re.match(pattern, string)
|
||
return bool(match)
|
||
|
||
|
||
def is_scientific_number(string):
|
||
pattern = r'^[-+]?\d+(\.\d+)?e[-]?\d+$'
|
||
match = re.match(pattern, string)
|
||
return bool(match)
|
||
|
||
|
||
def contain_num_and_str(string):
|
||
pattern_str = r'[a-zA-Z]'
|
||
pattern_num = r'[0-9]'
|
||
return bool(re.search(pattern_str, string) and re.search(pattern_num, string))
|
||
|
||
|
||
class TheoremqaTask(Task):
|
||
task_name = 'reasoning'
|
||
|
||
def __init__(self, id: str, prompt: str, reference: str, **kwargs):
|
||
super().__init__(**kwargs)
|
||
self._id = id
|
||
self._prompt = (
|
||
'Answer the following question with a number, a list of numbers or True or False. '
|
||
+ prompt.strip()
|
||
)
|
||
self._reference = reference
|
||
self._answer_type = kwargs.get('answer_type')
|
||
|
||
def extract_answer(self, solution: str) -> Any:
|
||
"""Extract the answer from the given solution."""
|
||
prediction = solution
|
||
# Following the preprocessing steps from TheoremQA
|
||
# https://github.com/wenhuchen/TheoremQA/blob/123e36beaaa97c01f28a582f13c4f77a6822c199/predict_accuracy.py#L170
|
||
|
||
# Preprocessing the string [Stage 1]
|
||
if not isinstance(prediction, str):
|
||
prediction = str(prediction) if prediction is not None else '0'
|
||
|
||
# Replace special tokens
|
||
if '=' in prediction:
|
||
prediction = prediction.split('=')[-1].strip()
|
||
if '≈' in prediction:
|
||
prediction = prediction.split('≈')[-1].strip()
|
||
if '`' in prediction:
|
||
prediction = prediction.replace('`', '')
|
||
if '$' in prediction:
|
||
prediction = prediction.replace('$', '')
|
||
if '°' in prediction:
|
||
prediction = prediction.replace('°', '')
|
||
|
||
# Detect the boolean keyword in the generation
|
||
if prediction in ('true', 'yes', 'false', 'no'):
|
||
if prediction in ('true', 'yes'):
|
||
prediction = 'True'
|
||
else:
|
||
prediction = 'False'
|
||
if 'True' in prediction or 'False' in prediction:
|
||
prediction = 'True' if 'True' in prediction else 'False'
|
||
|
||
# Detect the approximation keyword
|
||
if 'approximately' in prediction:
|
||
prediction = prediction.replace('approximately', '').strip()
|
||
if ' or ' in prediction:
|
||
prediction = prediction.split(' or ')[0]
|
||
|
||
# Drop the units before and after the number
|
||
if re.match(r'[-+]?(?:[\d,]*\.*\d+) [^0-9 ]+$', prediction):
|
||
prediction = re.search(
|
||
r'([-+]?(?:[\d,]*\.*\d+)) [^0-9 ]+$', prediction
|
||
).group(1)
|
||
if re.match(r'[^0-9 ]+ [-+]?(?:[\d,]*\.*\d+)$', prediction):
|
||
prediction = re.search(
|
||
r'[^0-9 ]+ ([-+]?(?:[\d,]*\.*\d+))$', prediction
|
||
).group(1)
|
||
if re.match(r'[-+]?(?:[\d,]*\.*\d+)[^\d]{1,2}$', prediction):
|
||
prediction = re.search(
|
||
r'([-+]?(?:[\d,]*\.*\d+))[^\d]{1,2}$', prediction
|
||
).group(1)
|
||
if re.match(r'[^-+\d]{1,2}(?:[\d,]*\.*\d+)$', prediction):
|
||
prediction = re.search(
|
||
r'[^-+\d]{1,2}((?:[\d,]*\.*\d+))$', prediction
|
||
).group(1)
|
||
|
||
# Preprocessing the number [Stage 1]
|
||
if '10^' in prediction:
|
||
prediction = re.sub(r'10\^(-?\d+)', r'math.pow(10, \1)', prediction)
|
||
if ' x ' in prediction:
|
||
prediction = prediction.replace(' x ', '*')
|
||
if ' × ' in prediction:
|
||
prediction = prediction.replace(' × ', '*')
|
||
if is_number(prediction):
|
||
prediction = prediction.replace(',', '')
|
||
|
||
# Preprocessing the option [Stage 3]
|
||
if (
|
||
'a)' in prediction
|
||
or 'a )' in prediction
|
||
or prediction.lower().strip() == 'a'
|
||
):
|
||
prediction = '(a)'
|
||
if (
|
||
'b)' in prediction
|
||
or 'b )' in prediction
|
||
or prediction.lower().strip() == 'b'
|
||
):
|
||
prediction = '(b)'
|
||
if (
|
||
'c)' in prediction
|
||
or 'c )' in prediction
|
||
or prediction.lower().strip() == 'c'
|
||
):
|
||
prediction = '(c)'
|
||
if (
|
||
'd)' in prediction
|
||
or 'd )' in prediction
|
||
or prediction.lower().strip() == 'd'
|
||
):
|
||
prediction = '(d)'
|
||
|
||
if (
|
||
'(a)' in prediction
|
||
or '(b)' in prediction
|
||
or '(c)' in prediction
|
||
or '(d)' in prediction
|
||
):
|
||
prediction = '"' + re.search(r'\([a-d]\)', prediction).group(0) + '"'
|
||
|
||
# If the prediction is empty, use dummy '0'
|
||
if not prediction:
|
||
prediction = '0'
|
||
|
||
# Converting the string answer to a number/list/bool/option
|
||
try:
|
||
prediction = eval(prediction)
|
||
except Exception:
|
||
LOGGER.warning(
|
||
f'[TASK] Failed to convert the answer: {prediction}\n{traceback.format_exc()}'
|
||
)
|
||
return None # failed to convert the answer
|
||
|
||
# Performing common type conversion
|
||
if isinstance(prediction, (set, tuple)):
|
||
prediction = list(prediction)
|
||
if isinstance(prediction[0], complex):
|
||
prediction = [tmp.real for tmp in prediction]
|
||
elif isinstance(prediction[0], Rational):
|
||
prediction = [float(tmp) for tmp in prediction]
|
||
elif isinstance(prediction, np.ndarray):
|
||
prediction = prediction.tolist()
|
||
else:
|
||
if isinstance(prediction, complex):
|
||
prediction = prediction.real
|
||
elif isinstance(prediction, Rational):
|
||
prediction = float(prediction)
|
||
|
||
return prediction
|
||
|
||
def success(self, solution: str) -> bool:
|
||
"""This checks whether the given solution can complete the current task."""
|
||
# Follow the implementation from TheoremQA
|
||
# https://github.com/wenhuchen/TheoremQA/blob/123e36beaaa97c01f28a582f13c4f77a6822c199/predict_accuracy.py#L301C9-L317C1
|
||
prediction = self.extract_answer(solution)
|
||
LOGGER.info(f'TheoremQA Parsed Prediction: {prediction}')
|
||
answer_type = self._answer_type
|
||
gt = self.extract_answer(self.reference)
|
||
|
||
if isinstance(prediction, (str, int, float, list)):
|
||
# Comparing prediction against the reference
|
||
if answer_type in ['bool', 'option', 'Option']:
|
||
cur_correct = int(prediction == f'({gt})') or int(prediction == gt)
|
||
elif answer_type == 'integer':
|
||
cur_correct = int(compare_two_numbers(prediction, gt))
|
||
elif answer_type == 'float':
|
||
cur_correct = int(compare_two_numbers(prediction, gt))
|
||
elif answer_type in ['list of integer', 'list of float']:
|
||
cur_correct = int(compare_two_list(prediction, gt))
|
||
else:
|
||
cur_correct = 0
|
||
return bool(cur_correct)
|