[eval] save eventstream & llm completions for SWE-Bench run_infer (#3923)

This commit is contained in:
Xingyao Wang 2024-09-21 23:39:13 -05:00 committed by GitHub
parent e0608af0b3
commit 714e46f29a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 33 additions and 5 deletions

View File

@ -30,6 +30,7 @@ from openhands.core.logger import openhands_logger as logger
from openhands.core.main import create_runtime, run_controller
from openhands.events.action import CmdRunAction
from openhands.events.observation import CmdOutputObservation, ErrorObservation
from openhands.events.serialization.event import event_to_dict
from openhands.runtime.runtime import Runtime
from openhands.runtime.utils.shutdown_listener import sleep_if_should_continue
@ -383,10 +384,7 @@ def process_instance(
if state is None:
raise ValueError('State should not be None.')
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
# for compatibility with the existing output format, we can remake the pairs here
# remove when it becomes unnecessary
histories = state.history.compatibility_for_eval_history_pairs()
histories = [event_to_dict(event) for event in state.history.get_events()]
metrics = state.metrics.get() if state.metrics else None
# Save the output
@ -398,6 +396,7 @@ def process_instance(
metadata=metadata,
history=histories,
metrics=metrics,
llm_completions=state.extra_data.get('llm_completions', []),
error=state.last_error if state and state.last_error else None,
)
return output

View File

@ -58,7 +58,11 @@ class EvalOutput(BaseModel):
# Interaction info
metadata: EvalMetadata | None = None
history: list[tuple[dict[str, Any], dict[str, Any]]] | None = None
# list[tuple[dict[str, Any], dict[str, Any]]] - for compatibility with the old format
history: (
list[dict[str, Any]] | list[tuple[dict[str, Any], dict[str, Any]]] | None
) = None
llm_completions: list[dict[str, Any]]
metrics: dict[str, Any] | None = None
error: str | None = None
@ -278,6 +282,7 @@ def _process_instance_wrapper(
+ '-' * 10
)
# Raise an error after all retries & stop the evaluation
logger.exception(e)
raise RuntimeError(
f'Maximum error retries reached for instance {instance.instance_id}'
) from e

View File

@ -132,6 +132,10 @@ class AgentController:
async def update_state_after_step(self):
# update metrics especially for cost
self.state.local_metrics = self.agent.llm.metrics
if 'llm_completions' not in self.state.extra_data:
self.state.extra_data['llm_completions'] = []
self.state.extra_data['llm_completions'].extend(self.agent.llm.llm_completions)
self.agent.llm.llm_completions.clear()
async def report_error(self, message: str, exception: Exception | None = None):
"""Reports an error to the user and sends the exception to the LLM next step, in the hope it can self-correct.

View File

@ -53,6 +53,7 @@ class LLMConfig:
drop_params: Drop any unmapped (unsupported) params without causing an exception.
disable_vision: If model is vision capable, this option allows to disable image processing (useful for cost reduction).
caching_prompt: Using the prompt caching feature provided by the LLM.
log_completions: Whether to log LLM completions to the state.
"""
model: str = 'gpt-4o'
@ -82,6 +83,7 @@ class LLMConfig:
drop_params: bool | None = None
disable_vision: bool | None = None
caching_prompt: bool = False
log_completions: bool = False
def defaults_to_dict(self) -> dict:
"""Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional."""

View File

@ -1,7 +1,9 @@
import asyncio
import copy
import time
import warnings
from functools import partial
from typing import Any
from openhands.core.config import LLMConfig
from openhands.runtime.utils.shutdown_listener import should_continue
@ -73,6 +75,11 @@ class LLM:
self.cost_metric_supported = True
self.config = copy.deepcopy(config)
# list of LLM completions (for logging purposes). Each completion is a dict with the following keys:
# - 'messages': list of messages
# - 'response': response from the LLM
self.llm_completions: list[dict[str, Any]] = []
# Set up config attributes with default values to prevent AttributeError
LLMConfig.set_missing_attributes(self.config)
@ -257,6 +264,16 @@ class LLM:
logger.debug('No completion messages!')
resp = {'choices': [{'message': {'content': ''}}]}
if self.config.log_completions:
self.llm_completions.append(
{
'messages': messages,
'response': resp,
'timestamp': time.time(),
'cost': self.completion_cost(resp),
}
)
# log the response
message_back = resp['choices'][0]['message']['content']
if message_back:
@ -659,6 +676,7 @@ class LLM:
def reset(self):
self.metrics = Metrics()
self.llm_completions = []
def format_messages_for_llm(self, messages: Message | list[Message]) -> list[dict]:
if isinstance(messages, Message):