Fix issue #5383: [Bug]: LLM Cost is added to the metrics twice (#5396)

Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
This commit is contained in:
OpenHands 2024-12-04 15:32:08 -05:00 committed by GitHub
parent 9aa89e8f2f
commit 794408cd31
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -219,32 +219,6 @@ class LLM(RetryMixin, DebugMixin):
)
resp.choices[0].message = fn_call_response_message
# log for evals or other scripts that need the raw completion
if self.config.log_completions:
assert self.config.log_completions_folder is not None
log_file = os.path.join(
self.config.log_completions_folder,
# use the metric model name (for draft editor)
f'{self.metrics.model_name.replace("/", "__")}-{time.time()}.json',
)
_d = {
'messages': messages,
'response': resp,
'args': args,
'kwargs': {k: v for k, v in kwargs.items() if k != 'messages'},
'timestamp': time.time(),
'cost': self._completion_cost(resp),
}
if mock_function_calling:
# Overwrite response as non-fncall to be consistent with `messages``
_d['response'] = non_fncall_response
# Save fncall_messages/response separately
_d['fncall_messages'] = original_fncall_messages
_d['fncall_response'] = resp
with open(log_file, 'w') as f:
f.write(json.dumps(_d))
message_back: str = resp['choices'][0]['message']['content'] or ''
tool_calls = resp['choices'][0]['message'].get('tool_calls', [])
if tool_calls:
@ -256,8 +230,38 @@ class LLM(RetryMixin, DebugMixin):
# log the LLM response
self.log_response(message_back)
# post-process the response
self._post_completion(resp)
# post-process the response first to calculate cost
cost = self._post_completion(resp)
# log for evals or other scripts that need the raw completion
if self.config.log_completions:
assert self.config.log_completions_folder is not None
log_file = os.path.join(
self.config.log_completions_folder,
# use the metric model name (for draft editor)
f'{self.metrics.model_name.replace("/", "__")}-{time.time()}.json',
)
# set up the dict to be logged
_d = {
'messages': messages,
'response': resp,
'args': args,
'kwargs': {k: v for k, v in kwargs.items() if k != 'messages'},
'timestamp': time.time(),
'cost': cost,
}
# if non-native function calling, save messages/response separately
if mock_function_calling:
# Overwrite response as non-fncall to be consistent with messages
_d['response'] = non_fncall_response
# Save fncall_messages/response separately
_d['fncall_messages'] = original_fncall_messages
_d['fncall_response'] = resp
with open(log_file, 'w') as f:
f.write(json.dumps(_d))
return resp
except APIError as e:
@ -414,7 +418,7 @@ class LLM(RetryMixin, DebugMixin):
)
return model_name_supported
def _post_completion(self, response: ModelResponse) -> None:
def _post_completion(self, response: ModelResponse) -> float:
"""Post-process the completion response.
Logs the cost and usage stats of the completion call.
@ -472,6 +476,8 @@ class LLM(RetryMixin, DebugMixin):
if stats:
logger.debug(stats)
return cur_cost
def get_token_count(self, messages) -> int:
"""Get the number of tokens in a list of messages.