Refactor messages serialization (#3832)

Co-authored-by: Robert Brennan <accounts@rbren.io>
This commit is contained in:
Engel Nyst 2024-09-18 23:48:58 +02:00 committed by GitHub
parent ad0b549d8b
commit 8fdfece059
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 54 additions and 79 deletions

View File

@ -216,10 +216,8 @@ class BrowsingAgent(Agent):
prompt = get_prompt(error_prefix, cur_url, cur_axtree_txt, prev_action_str)
messages.append(Message(role='user', content=[TextContent(text=prompt)]))
flat_messages = self.llm.format_messages_for_llm(messages)
response = self.llm.completion(
messages=flat_messages,
messages=self.llm.format_messages_for_llm(messages),
temperature=0.0,
stop=[')```', ')\n```'],
)

View File

@ -164,12 +164,12 @@ model = "gpt-4o"
# If model is vision capable, this option allows to disable image processing (useful for cost reduction).
#disable_vision = true
[llm.gpt3]
[llm.gpt4o-mini]
# API key to use
api_key = "your-api-key"
# Model to use
model = "gpt-3.5"
model = "gpt-4o-mini"
#################################### Agent ###################################
# Configuration for agents (group name starts with 'agent')

View File

@ -14,9 +14,9 @@ To run the tests for OpenHands project, you can use the provided test runner scr
3. Navigate to the root directory of the project.
4. Run the test suite using the test runner script with the required arguments:
```
python evaluation/regression/run_tests.py --OPENAI_API_KEY=sk-xxxxxxxxxxxxxxxxxxxxxx --model=gpt-3.5-turbo
python evaluation/regression/run_tests.py --OPENAI_API_KEY=sk-xxxxxxxxxxxxxxxxxxxxxx --model=gpt-4o
```
Replace `sk-xxxxxxxxxxxxxxxxxxxxxx` with your actual OpenAI API key. The default model is `gpt-3.5-turbo`, but you can specify a different model if needed.
Replace `sk-xxxxxxxxxxxxxxxxxxxxxx` with your actual OpenAI API key. The default model is `gpt-4o`, but you can specify a different model if needed.
The test runner will discover and execute all the test cases in the `cases/` directory, and display the results of the test suite, including the status of each individual test case and the overall summary.

View File

@ -1,10 +1,7 @@
from enum import Enum
from typing import Union
from typing import Literal
from pydantic import BaseModel, Field, model_serializer
from typing_extensions import Literal
from openhands.core.logger import openhands_logger as logger
class ContentType(Enum):
@ -60,60 +57,24 @@ class Message(BaseModel):
@model_serializer
def serialize_model(self) -> dict:
content: list[dict[str, str | dict[str, str]]] = []
for item in self.content:
if isinstance(item, TextContent):
content.append(item.model_dump())
elif isinstance(item, ImageContent):
content.extend(item.model_dump())
content: list[dict] | str
if self.role == 'system':
# For system role, concatenate all text content into a single string
content = '\n'.join(
item.text for item in self.content if isinstance(item, TextContent)
)
elif self.role == 'assistant' and not self.contains_image:
# For assistant role without vision, concatenate all text content into a single string
content = '\n'.join(
item.text for item in self.content if isinstance(item, TextContent)
)
else:
# For user role or assistant role with vision enabled, serialize each content item
content = []
for item in self.content:
if isinstance(item, TextContent):
content.append(item.model_dump())
elif isinstance(item, ImageContent):
content.extend(item.model_dump())
return {'content': content, 'role': self.role}
def format_messages(
messages: Union[Message, list[Message]],
with_images: bool,
with_prompt_caching: bool,
) -> list[dict]:
if not isinstance(messages, list):
messages = [messages]
if with_images or with_prompt_caching:
return [message.model_dump() for message in messages]
converted_messages = []
for message in messages:
content_parts = []
role = 'user'
if isinstance(message, str) and message:
content_parts.append(message)
elif isinstance(message, dict):
role = message.get('role', 'user')
if 'content' in message and message['content']:
content_parts.append(message['content'])
elif isinstance(message, Message):
role = message.role
for content in message.content:
if isinstance(content, list):
for item in content:
if isinstance(item, TextContent) and item.text:
content_parts.append(item.text)
elif isinstance(content, TextContent) and content.text:
content_parts.append(content.text)
else:
logger.error(
f'>>> `message` is not a string, dict, or Message: {type(message)}'
)
if content_parts:
content_str = '\n'.join(content_parts)
converted_messages.append(
{
'role': role,
'content': content_str,
}
)
return converted_messages

View File

@ -2,7 +2,6 @@ import asyncio
import copy
import warnings
from functools import partial
from typing import Union
from openhands.core.config import LLMConfig
from openhands.runtime.utils.shutdown_listener import should_continue
@ -32,7 +31,7 @@ from tenacity import (
from openhands.core.exceptions import LLMResponseError, UserCancelledError
from openhands.core.logger import llm_prompt_logger, llm_response_logger
from openhands.core.logger import openhands_logger as logger
from openhands.core.message import Message, format_messages
from openhands.core.message import Message
from openhands.core.metrics import Metrics
__all__ = ['LLM']
@ -633,9 +632,7 @@ class LLM:
def reset(self):
self.metrics = Metrics()
def format_messages_for_llm(
self, messages: Union[Message, list[Message]]
) -> list[dict]:
return format_messages(
messages, self.vision_is_active(), self.is_caching_prompt_active()
)
def format_messages_for_llm(self, messages: Message | list[Message]) -> list[dict]:
if isinstance(messages, Message):
return [messages.model_dump()]
return [message.model_dump() for message in messages]

View File

@ -11,7 +11,6 @@ from http.server import HTTPServer, SimpleHTTPRequestHandler
import pytest
from litellm import completion
from openhands.core.message import format_messages
from openhands.llm.llm import message_separator
script_dir = os.environ.get('SCRIPT_DIR')
@ -78,6 +77,29 @@ def get_log_id(prompt_log_name):
return match.group(1)
def _format_messages(messages):
message_str = ''
for message in messages:
if isinstance(message, str):
message_str += message_separator + message if message_str else message
elif isinstance(message, dict):
if isinstance(message['content'], list):
for m in message['content']:
if isinstance(m, str):
message_str += message_separator + m if message_str else m
elif isinstance(m, dict) and m['type'] == 'text':
message_str += (
message_separator + m['text'] if message_str else m['text']
)
elif isinstance(message['content'], str):
message_str += (
message_separator + message['content']
if message_str
else message['content']
)
return message_str
def apply_prompt_and_get_mock_response(
test_name: str, messages: str, id: int
) -> str | None:
@ -185,10 +207,7 @@ def mock_user_response(*args, test_name, **kwargs):
def mock_completion(*args, test_name, **kwargs):
global cur_id
messages = kwargs['messages']
plain_messages = format_messages(
messages, with_images=False, with_prompt_caching=False
)
message_str = message_separator.join(msg['content'] for msg in plain_messages)
message_str = _format_messages(messages) # text only
# this assumes all response_(*).log filenames are in numerical order, starting from one
cur_id += 1