mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Refactor messages serialization (#3832)
Co-authored-by: Robert Brennan <accounts@rbren.io>
This commit is contained in:
parent
ad0b549d8b
commit
8fdfece059
@ -216,10 +216,8 @@ class BrowsingAgent(Agent):
|
||||
prompt = get_prompt(error_prefix, cur_url, cur_axtree_txt, prev_action_str)
|
||||
messages.append(Message(role='user', content=[TextContent(text=prompt)]))
|
||||
|
||||
flat_messages = self.llm.format_messages_for_llm(messages)
|
||||
|
||||
response = self.llm.completion(
|
||||
messages=flat_messages,
|
||||
messages=self.llm.format_messages_for_llm(messages),
|
||||
temperature=0.0,
|
||||
stop=[')```', ')\n```'],
|
||||
)
|
||||
|
||||
@ -164,12 +164,12 @@ model = "gpt-4o"
|
||||
# If model is vision capable, this option allows to disable image processing (useful for cost reduction).
|
||||
#disable_vision = true
|
||||
|
||||
[llm.gpt3]
|
||||
[llm.gpt4o-mini]
|
||||
# API key to use
|
||||
api_key = "your-api-key"
|
||||
|
||||
# Model to use
|
||||
model = "gpt-3.5"
|
||||
model = "gpt-4o-mini"
|
||||
|
||||
#################################### Agent ###################################
|
||||
# Configuration for agents (group name starts with 'agent')
|
||||
|
||||
@ -14,9 +14,9 @@ To run the tests for OpenHands project, you can use the provided test runner scr
|
||||
3. Navigate to the root directory of the project.
|
||||
4. Run the test suite using the test runner script with the required arguments:
|
||||
```
|
||||
python evaluation/regression/run_tests.py --OPENAI_API_KEY=sk-xxxxxxxxxxxxxxxxxxxxxx --model=gpt-3.5-turbo
|
||||
python evaluation/regression/run_tests.py --OPENAI_API_KEY=sk-xxxxxxxxxxxxxxxxxxxxxx --model=gpt-4o
|
||||
```
|
||||
Replace `sk-xxxxxxxxxxxxxxxxxxxxxx` with your actual OpenAI API key. The default model is `gpt-3.5-turbo`, but you can specify a different model if needed.
|
||||
Replace `sk-xxxxxxxxxxxxxxxxxxxxxx` with your actual OpenAI API key. The default model is `gpt-4o`, but you can specify a different model if needed.
|
||||
|
||||
The test runner will discover and execute all the test cases in the `cases/` directory, and display the results of the test suite, including the status of each individual test case and the overall summary.
|
||||
|
||||
|
||||
@ -1,10 +1,7 @@
|
||||
from enum import Enum
|
||||
from typing import Union
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field, model_serializer
|
||||
from typing_extensions import Literal
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
class ContentType(Enum):
|
||||
@ -60,60 +57,24 @@ class Message(BaseModel):
|
||||
|
||||
@model_serializer
|
||||
def serialize_model(self) -> dict:
|
||||
content: list[dict[str, str | dict[str, str]]] = []
|
||||
|
||||
for item in self.content:
|
||||
if isinstance(item, TextContent):
|
||||
content.append(item.model_dump())
|
||||
elif isinstance(item, ImageContent):
|
||||
content.extend(item.model_dump())
|
||||
content: list[dict] | str
|
||||
if self.role == 'system':
|
||||
# For system role, concatenate all text content into a single string
|
||||
content = '\n'.join(
|
||||
item.text for item in self.content if isinstance(item, TextContent)
|
||||
)
|
||||
elif self.role == 'assistant' and not self.contains_image:
|
||||
# For assistant role without vision, concatenate all text content into a single string
|
||||
content = '\n'.join(
|
||||
item.text for item in self.content if isinstance(item, TextContent)
|
||||
)
|
||||
else:
|
||||
# For user role or assistant role with vision enabled, serialize each content item
|
||||
content = []
|
||||
for item in self.content:
|
||||
if isinstance(item, TextContent):
|
||||
content.append(item.model_dump())
|
||||
elif isinstance(item, ImageContent):
|
||||
content.extend(item.model_dump())
|
||||
|
||||
return {'content': content, 'role': self.role}
|
||||
|
||||
|
||||
def format_messages(
|
||||
messages: Union[Message, list[Message]],
|
||||
with_images: bool,
|
||||
with_prompt_caching: bool,
|
||||
) -> list[dict]:
|
||||
if not isinstance(messages, list):
|
||||
messages = [messages]
|
||||
|
||||
if with_images or with_prompt_caching:
|
||||
return [message.model_dump() for message in messages]
|
||||
|
||||
converted_messages = []
|
||||
for message in messages:
|
||||
content_parts = []
|
||||
role = 'user'
|
||||
|
||||
if isinstance(message, str) and message:
|
||||
content_parts.append(message)
|
||||
elif isinstance(message, dict):
|
||||
role = message.get('role', 'user')
|
||||
if 'content' in message and message['content']:
|
||||
content_parts.append(message['content'])
|
||||
elif isinstance(message, Message):
|
||||
role = message.role
|
||||
for content in message.content:
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if isinstance(item, TextContent) and item.text:
|
||||
content_parts.append(item.text)
|
||||
elif isinstance(content, TextContent) and content.text:
|
||||
content_parts.append(content.text)
|
||||
else:
|
||||
logger.error(
|
||||
f'>>> `message` is not a string, dict, or Message: {type(message)}'
|
||||
)
|
||||
|
||||
if content_parts:
|
||||
content_str = '\n'.join(content_parts)
|
||||
converted_messages.append(
|
||||
{
|
||||
'role': role,
|
||||
'content': content_str,
|
||||
}
|
||||
)
|
||||
|
||||
return converted_messages
|
||||
|
||||
@ -2,7 +2,6 @@ import asyncio
|
||||
import copy
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import Union
|
||||
|
||||
from openhands.core.config import LLMConfig
|
||||
from openhands.runtime.utils.shutdown_listener import should_continue
|
||||
@ -32,7 +31,7 @@ from tenacity import (
|
||||
from openhands.core.exceptions import LLMResponseError, UserCancelledError
|
||||
from openhands.core.logger import llm_prompt_logger, llm_response_logger
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.message import Message, format_messages
|
||||
from openhands.core.message import Message
|
||||
from openhands.core.metrics import Metrics
|
||||
|
||||
__all__ = ['LLM']
|
||||
@ -633,9 +632,7 @@ class LLM:
|
||||
def reset(self):
|
||||
self.metrics = Metrics()
|
||||
|
||||
def format_messages_for_llm(
|
||||
self, messages: Union[Message, list[Message]]
|
||||
) -> list[dict]:
|
||||
return format_messages(
|
||||
messages, self.vision_is_active(), self.is_caching_prompt_active()
|
||||
)
|
||||
def format_messages_for_llm(self, messages: Message | list[Message]) -> list[dict]:
|
||||
if isinstance(messages, Message):
|
||||
return [messages.model_dump()]
|
||||
return [message.model_dump() for message in messages]
|
||||
|
||||
@ -11,7 +11,6 @@ from http.server import HTTPServer, SimpleHTTPRequestHandler
|
||||
import pytest
|
||||
from litellm import completion
|
||||
|
||||
from openhands.core.message import format_messages
|
||||
from openhands.llm.llm import message_separator
|
||||
|
||||
script_dir = os.environ.get('SCRIPT_DIR')
|
||||
@ -78,6 +77,29 @@ def get_log_id(prompt_log_name):
|
||||
return match.group(1)
|
||||
|
||||
|
||||
def _format_messages(messages):
|
||||
message_str = ''
|
||||
for message in messages:
|
||||
if isinstance(message, str):
|
||||
message_str += message_separator + message if message_str else message
|
||||
elif isinstance(message, dict):
|
||||
if isinstance(message['content'], list):
|
||||
for m in message['content']:
|
||||
if isinstance(m, str):
|
||||
message_str += message_separator + m if message_str else m
|
||||
elif isinstance(m, dict) and m['type'] == 'text':
|
||||
message_str += (
|
||||
message_separator + m['text'] if message_str else m['text']
|
||||
)
|
||||
elif isinstance(message['content'], str):
|
||||
message_str += (
|
||||
message_separator + message['content']
|
||||
if message_str
|
||||
else message['content']
|
||||
)
|
||||
return message_str
|
||||
|
||||
|
||||
def apply_prompt_and_get_mock_response(
|
||||
test_name: str, messages: str, id: int
|
||||
) -> str | None:
|
||||
@ -185,10 +207,7 @@ def mock_user_response(*args, test_name, **kwargs):
|
||||
def mock_completion(*args, test_name, **kwargs):
|
||||
global cur_id
|
||||
messages = kwargs['messages']
|
||||
plain_messages = format_messages(
|
||||
messages, with_images=False, with_prompt_caching=False
|
||||
)
|
||||
message_str = message_separator.join(msg['content'] for msg in plain_messages)
|
||||
message_str = _format_messages(messages) # text only
|
||||
|
||||
# this assumes all response_(*).log filenames are in numerical order, starting from one
|
||||
cur_id += 1
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user