Refactor: shorter syntax (#4558)

This commit is contained in:
tofarr 2024-10-25 06:45:28 -06:00 committed by GitHub
parent 349e2dbe50
commit c4f5c07be1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 19 additions and 28 deletions

View File

@ -131,11 +131,9 @@ class MultipleChoiceTask(Task):
def compare_two_numbers(p, gt):
if isinstance(p, int) or isinstance(p, float):
if isinstance(p, (int, float)):
pass
elif isinstance(p, list) or isinstance(p, bool) or isinstance(p, str):
return False
elif isinstance(p, tuple) or isinstance(p, complex) or isinstance(p, dict):
elif isinstance(p, (bool, complex, dict, list, str, tuple)):
return False
else:
raise ValueError(p)
@ -227,8 +225,8 @@ class TheoremqaTask(Task):
prediction = prediction.replace('°', '')
# Detect the boolean keyword in the generation
if prediction in ['true', 'yes', 'false', 'no']:
if prediction == 'true' or prediction == 'yes':
if prediction in ('true', 'yes', 'false', 'no'):
if prediction in ('true', 'yes'):
prediction = 'True'
else:
prediction = 'False'
@ -342,7 +340,7 @@ class TheoremqaTask(Task):
answer_type = self._answer_type
gt = self.extract_answer(self.reference)
if isinstance(prediction, (str, int, float)) or isinstance(prediction, list):
if isinstance(prediction, (str, int, float, list)):
# Comparing prediction against the reference
if answer_type in ['bool', 'option', 'Option']:
cur_correct = int(prediction == f'({gt})') or int(prediction == gt)

View File

@ -113,14 +113,16 @@ class CodeActAgent(Agent):
return ''
def get_action_message(self, action: Action) -> Message | None:
if (
isinstance(action, AgentDelegateAction)
or isinstance(action, CmdRunAction)
or isinstance(action, IPythonRunCellAction)
or isinstance(action, MessageAction)
or isinstance(action, FileEditAction)
or (isinstance(action, AgentFinishAction) and action.source == 'agent')
):
if isinstance(
action,
(
AgentDelegateAction,
CmdRunAction,
IPythonRunCellAction,
MessageAction,
FileEditAction,
),
) or (isinstance(action, AgentFinishAction) and action.source == 'agent'):
content = [TextContent(text=self.action_to_str(action))]
if (

View File

@ -89,11 +89,7 @@ class CodeActSWEAgent(Agent):
return ''
def get_action_message(self, action: Action) -> Message | None:
if (
isinstance(action, CmdRunAction)
or isinstance(action, IPythonRunCellAction)
or isinstance(action, MessageAction)
):
if isinstance(action, (CmdRunAction, IPythonRunCellAction, MessageAction)):
content = [TextContent(text=self.action_to_str(action))]
if (

View File

@ -33,8 +33,7 @@ class StuckDetector:
(isinstance(event, MessageAction) and event.source == EventSource.USER)
or
# there might be some NullAction or NullObservation in the history at least for now
isinstance(event, NullAction)
or isinstance(event, NullObservation)
isinstance(event, (NullAction, NullObservation))
)
]

View File

@ -11,7 +11,7 @@ def remove_fields(obj, fields: set[str]):
del obj[field]
for _, value in obj.items():
remove_fields(value, fields)
elif isinstance(obj, list) or isinstance(obj, tuple):
elif isinstance(obj, (list, tuple)):
for item in obj:
remove_fields(item, fields)
elif hasattr(obj, '__dataclass_fields__'):

View File

@ -295,11 +295,7 @@ class RemoteRuntime(Runtime):
logger.info(
f'Runtime pod not found. Count: {not_found_count} / {max_not_found_count}'
)
elif (
pod_status == 'Failed'
or pod_status == 'Unknown'
or pod_status == 'Not Found'
):
elif pod_status in ('Failed', 'Unknown', 'Not Found'):
# clean up the runtime
self.close()
raise RuntimeError(