mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
[eval] save eventstream & llm completions for SWE-Bench run_infer (#3923)
This commit is contained in:
parent
e0608af0b3
commit
714e46f29a
@ -30,6 +30,7 @@ from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.main import create_runtime, run_controller
|
||||
from openhands.events.action import CmdRunAction
|
||||
from openhands.events.observation import CmdOutputObservation, ErrorObservation
|
||||
from openhands.events.serialization.event import event_to_dict
|
||||
from openhands.runtime.runtime import Runtime
|
||||
from openhands.runtime.utils.shutdown_listener import sleep_if_should_continue
|
||||
|
||||
@ -383,10 +384,7 @@ def process_instance(
|
||||
if state is None:
|
||||
raise ValueError('State should not be None.')
|
||||
|
||||
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
|
||||
# for compatibility with the existing output format, we can remake the pairs here
|
||||
# remove when it becomes unnecessary
|
||||
histories = state.history.compatibility_for_eval_history_pairs()
|
||||
histories = [event_to_dict(event) for event in state.history.get_events()]
|
||||
metrics = state.metrics.get() if state.metrics else None
|
||||
|
||||
# Save the output
|
||||
@ -398,6 +396,7 @@ def process_instance(
|
||||
metadata=metadata,
|
||||
history=histories,
|
||||
metrics=metrics,
|
||||
llm_completions=state.extra_data.get('llm_completions', []),
|
||||
error=state.last_error if state and state.last_error else None,
|
||||
)
|
||||
return output
|
||||
|
||||
@ -58,7 +58,11 @@ class EvalOutput(BaseModel):
|
||||
|
||||
# Interaction info
|
||||
metadata: EvalMetadata | None = None
|
||||
history: list[tuple[dict[str, Any], dict[str, Any]]] | None = None
|
||||
# list[tuple[dict[str, Any], dict[str, Any]]] - for compatibility with the old format
|
||||
history: (
|
||||
list[dict[str, Any]] | list[tuple[dict[str, Any], dict[str, Any]]] | None
|
||||
) = None
|
||||
llm_completions: list[dict[str, Any]]
|
||||
metrics: dict[str, Any] | None = None
|
||||
error: str | None = None
|
||||
|
||||
@ -278,6 +282,7 @@ def _process_instance_wrapper(
|
||||
+ '-' * 10
|
||||
)
|
||||
# Raise an error after all retries & stop the evaluation
|
||||
logger.exception(e)
|
||||
raise RuntimeError(
|
||||
f'Maximum error retries reached for instance {instance.instance_id}'
|
||||
) from e
|
||||
|
||||
@ -132,6 +132,10 @@ class AgentController:
|
||||
async def update_state_after_step(self):
|
||||
# update metrics especially for cost
|
||||
self.state.local_metrics = self.agent.llm.metrics
|
||||
if 'llm_completions' not in self.state.extra_data:
|
||||
self.state.extra_data['llm_completions'] = []
|
||||
self.state.extra_data['llm_completions'].extend(self.agent.llm.llm_completions)
|
||||
self.agent.llm.llm_completions.clear()
|
||||
|
||||
async def report_error(self, message: str, exception: Exception | None = None):
|
||||
"""Reports an error to the user and sends the exception to the LLM next step, in the hope it can self-correct.
|
||||
|
||||
@ -53,6 +53,7 @@ class LLMConfig:
|
||||
drop_params: Drop any unmapped (unsupported) params without causing an exception.
|
||||
disable_vision: If model is vision capable, this option allows to disable image processing (useful for cost reduction).
|
||||
caching_prompt: Using the prompt caching feature provided by the LLM.
|
||||
log_completions: Whether to log LLM completions to the state.
|
||||
"""
|
||||
|
||||
model: str = 'gpt-4o'
|
||||
@ -82,6 +83,7 @@ class LLMConfig:
|
||||
drop_params: bool | None = None
|
||||
disable_vision: bool | None = None
|
||||
caching_prompt: bool = False
|
||||
log_completions: bool = False
|
||||
|
||||
def defaults_to_dict(self) -> dict:
|
||||
"""Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional."""
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import time
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
from openhands.core.config import LLMConfig
|
||||
from openhands.runtime.utils.shutdown_listener import should_continue
|
||||
@ -73,6 +75,11 @@ class LLM:
|
||||
self.cost_metric_supported = True
|
||||
self.config = copy.deepcopy(config)
|
||||
|
||||
# list of LLM completions (for logging purposes). Each completion is a dict with the following keys:
|
||||
# - 'messages': list of messages
|
||||
# - 'response': response from the LLM
|
||||
self.llm_completions: list[dict[str, Any]] = []
|
||||
|
||||
# Set up config attributes with default values to prevent AttributeError
|
||||
LLMConfig.set_missing_attributes(self.config)
|
||||
|
||||
@ -257,6 +264,16 @@ class LLM:
|
||||
logger.debug('No completion messages!')
|
||||
resp = {'choices': [{'message': {'content': ''}}]}
|
||||
|
||||
if self.config.log_completions:
|
||||
self.llm_completions.append(
|
||||
{
|
||||
'messages': messages,
|
||||
'response': resp,
|
||||
'timestamp': time.time(),
|
||||
'cost': self.completion_cost(resp),
|
||||
}
|
||||
)
|
||||
|
||||
# log the response
|
||||
message_back = resp['choices'][0]['message']['content']
|
||||
if message_back:
|
||||
@ -659,6 +676,7 @@ class LLM:
|
||||
|
||||
def reset(self):
|
||||
self.metrics = Metrics()
|
||||
self.llm_completions = []
|
||||
|
||||
def format_messages_for_llm(self, messages: Message | list[Message]) -> list[dict]:
|
||||
if isinstance(messages, Message):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user