feat(eval): rewrite log_completions to save completions to directory (#4566)

This commit is contained in:
Xingyao Wang 2024-10-25 11:36:11 -05:00 committed by GitHub
parent c3da25febc
commit 7340b78962
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 46 additions and 20 deletions

View File

@ -33,6 +33,7 @@ FAKE_RESPONSES = {
def get_config(
metadata: EvalMetadata,
instance_id: str,
) -> AppConfig:
config = AppConfig(
default_agent=metadata.agent_class,
@ -49,6 +50,14 @@ def get_config(
workspace_base=None,
workspace_mount_path=None,
)
if metadata.llm_config.log_completions:
metadata.llm_config.log_completions_folder = os.path.join(
metadata.eval_output_dir, 'llm_completions', instance_id
)
logger.info(
f'Logging LLM completions for instance {instance_id} to '
f'{metadata.llm_config.log_completions_folder}'
)
config.set_llm_config(metadata.llm_config)
return config
@ -58,7 +67,7 @@ def process_instance(
metadata: EvalMetadata,
reset_logger: bool = True,
) -> EvalOutput:
config = get_config(metadata)
config = get_config(metadata, instance.instance_id)
# Setup the logger properly, so you can run multi-processing to parallelize the evaluation
if reset_logger:

View File

@ -143,6 +143,14 @@ def get_config(
workspace_base=None,
workspace_mount_path=None,
)
if metadata.llm_config.log_completions:
metadata.llm_config.log_completions_folder = os.path.join(
metadata.eval_output_dir, 'llm_completions', instance['instance_id']
)
logger.info(
f'Logging LLM completions for instance {instance["instance_id"]} to '
f'{metadata.llm_config.log_completions_folder}'
)
config.set_llm_config(metadata.llm_config)
return config
@ -432,7 +440,6 @@ def process_instance(
metadata=metadata,
history=histories,
metrics=metrics,
llm_completions=state.extra_data.get('llm_completions', []),
error=state.last_error if state and state.last_error else None,
)
return output

View File

@ -61,7 +61,6 @@ class EvalOutput(BaseModel):
history: (
list[dict[str, Any]] | list[tuple[dict[str, Any], dict[str, Any]]] | None
) = None
llm_completions: list[dict[str, Any]] | None = None
metrics: dict[str, Any] | None = None
error: str | None = None

View File

@ -132,10 +132,6 @@ class AgentController:
async def update_state_after_step(self):
# update metrics especially for cost. Use deepcopy to avoid it being modified by agent.reset()
self.state.local_metrics = copy.deepcopy(self.agent.llm.metrics)
if 'llm_completions' not in self.state.extra_data:
self.state.extra_data['llm_completions'] = []
self.state.extra_data['llm_completions'].extend(self.agent.llm.llm_completions)
self.agent.llm.llm_completions.clear()
async def report_error(self, message: str, exception: Exception | None = None):
"""Reports an error to the user and sends the exception to the LLM next step, in the hope it can self-correct.

View File

@ -40,6 +40,7 @@ class LLMConfig:
disable_vision: If model is vision capable, this option allows to disable image processing (useful for cost reduction).
caching_prompt: Use the prompt caching feature if provided by the LLM and supported by the provider.
log_completions: Whether to log LLM completions to the state.
log_completions_folder: The folder to log LLM completions to. Required if log_completions is True.
draft_editor: A more efficient LLM to use for file editing. Introduced in [PR 3985](https://github.com/All-Hands-AI/OpenHands/pull/3985).
"""
@ -73,6 +74,7 @@ class LLMConfig:
disable_vision: bool | None = None
caching_prompt: bool = True
log_completions: bool = False
log_completions_folder: str | None = None
draft_editor: Optional['LLMConfig'] = None
def defaults_to_dict(self) -> dict:

View File

@ -1,4 +1,6 @@
import copy
import json
import os
import time
import warnings
from functools import partial
@ -77,11 +79,6 @@ class LLM(RetryMixin, DebugMixin):
self.cost_metric_supported: bool = True
self.config: LLMConfig = copy.deepcopy(config)
# list of LLM completions (for logging purposes). Each completion is a dict with the following keys:
# - 'messages': list of messages
# - 'response': response from the LLM
self.llm_completions: list[dict[str, Any]] = []
# litellm actually uses base Exception here for unknown model
self.model_info: ModelInfo | None = None
try:
@ -95,6 +92,13 @@ class LLM(RetryMixin, DebugMixin):
except Exception as e:
logger.warning(f'Could not get model info for {config.model}:\n{e}')
if self.config.log_completions:
if self.config.log_completions_folder is None:
raise RuntimeError(
'log_completions_folder is required when log_completions is enabled'
)
os.makedirs(self.config.log_completions_folder, exist_ok=True)
# Set the max tokens in an LM-specific way if not set
if self.config.max_input_tokens is None:
if (
@ -194,14 +198,24 @@ class LLM(RetryMixin, DebugMixin):
# log for evals or other scripts that need the raw completion
if self.config.log_completions:
self.llm_completions.append(
{
'messages': messages,
'response': resp,
'timestamp': time.time(),
'cost': self._completion_cost(resp),
}
assert self.config.log_completions_folder is not None
log_file = os.path.join(
self.config.log_completions_folder,
# use the metric model name (for draft editor)
f'{self.metrics.model_name}-{time.time()}.json',
)
with open(log_file, 'w') as f:
json.dump(
{
'messages': messages,
'response': resp,
'args': args,
'kwargs': kwargs,
'timestamp': time.time(),
'cost': self._completion_cost(resp),
},
f,
)
message_back: str = resp['choices'][0]['message']['content']
@ -400,7 +414,6 @@ class LLM(RetryMixin, DebugMixin):
def reset(self):
self.metrics.reset()
self.llm_completions = []
def format_messages_for_llm(self, messages: Message | list[Message]) -> list[dict]:
if isinstance(messages, Message):