From 415843476c9a71cbbe26639c2ab1df8c7c72d354 Mon Sep 17 00:00:00 2001 From: Kaushik Deka <55996465+Kaushikdkrikhanu@users.noreply.github.com> Date: Sat, 3 Aug 2024 11:26:22 -0700 Subject: [PATCH] Feat: Add Vision Input Support for LLM with Vision Capabilities (#2848) * add image feature * fix-linting * check model support for images * add comment * Add image support to other models * Add images to chat * fix linting * fix test issues * refactor variable names and import * fix tests * fix chat message tests * fix linting * add pydantic class message * use message * remove redundant comments * remove redundant comments * change Message class * remove unintended change * fix integration tests using regenerate.sh * rename image_bas64 to images_url, fix tests * rename Message.py to message, change reminder append logic, add unit tests * remove comment, fix error to merge * codeact_swe_agent * fix f string * update eventstream integration tests * add missing if check in codeact_swe_agent * update integration tests * Update frontend/src/components/chat/ChatInput.tsx * Update frontend/src/components/chat/ChatInput.tsx * Update frontend/src/components/chat/ChatInput.tsx * Update frontend/src/components/chat/ChatInput.tsx * Update frontend/src/components/chat/ChatMessage.tsx --------- Co-authored-by: tobitege Co-authored-by: Xingyao Wang Co-authored-by: sp.wack <83104063+amanape@users.noreply.github.com> --- agenthub/browsing_agent/browsing_agent.py | 14 +- agenthub/codeact_agent/codeact_agent.py | 83 ++- .../codeact_swe_agent/codeact_swe_agent.py | 74 +- agenthub/delegator_agent/agent.py | 4 +- agenthub/micro/agent.py | 11 +- agenthub/planner_agent/agent.py | 15 +- agenthub/planner_agent/prompt.py | 10 +- frontend/src/components/chat/Chat.test.tsx | 6 +- .../src/components/chat/ChatInput.test.tsx | 4 +- frontend/src/components/chat/ChatInput.tsx | 97 ++- .../components/chat/ChatInterface.test.tsx | 4 +- .../src/components/chat/ChatInterface.tsx | 8 +- .../src/components/chat/ChatMessage.test.tsx | 29 +- frontend/src/components/chat/ChatMessage.tsx | 12 + frontend/src/components/chat/message.d.ts | 1 + frontend/src/i18n/translation.json | 5 + frontend/src/services/actions.ts | 4 +- frontend/src/services/chatService.ts | 7 +- frontend/src/state/chatSlice.ts | 11 +- opendevin/controller/state/state.py | 6 +- opendevin/core/message.py | 59 ++ opendevin/events/action/message.py | 4 + opendevin/events/serialization/event.py | 1 + opendevin/llm/llm.py | 50 +- opendevin/server/listen.py | 2 +- opendevin/server/session/session.py | 10 +- tests/integration/conftest.py | 4 +- .../test_browse_internet/prompt_001.log | 2 +- .../CodeActAgent/test_edits/prompt_001.log | 2 +- .../CodeActAgent/test_ipython/prompt_001.log | 2 +- .../test_ipython_module/prompt_001.log | 2 +- .../test_write_simple_script/prompt_001.log | 2 +- .../CodeActSWEAgent/test_edits/prompt_001.log | 4 +- .../CodeActSWEAgent/test_edits/prompt_002.log | 2 +- .../CodeActSWEAgent/test_edits/prompt_003.log | 2 +- .../CodeActSWEAgent/test_edits/prompt_004.log | 2 +- .../test_ipython/prompt_001.log | 4 +- .../test_ipython/prompt_002.log | 2 +- .../test_ipython_module/prompt_001.log | 4 +- .../test_ipython_module/prompt_002.log | 6 +- .../test_ipython_module/prompt_003.log | 6 +- .../test_ipython_module/prompt_004.log | 8 +- .../test_write_simple_script/prompt_001.log | 4 +- .../test_write_simple_script/prompt_002.log | 2 +- .../test_write_simple_script/prompt_003.log | 2 +- .../test_write_simple_script/prompt_004.log | 2 +- .../ManagerAgent/test_edits/prompt_001.log | 2 +- .../ManagerAgent/test_edits/prompt_004.log | 2 +- .../ManagerAgent/test_edits/prompt_005.log | 2 +- .../ManagerAgent/test_edits/prompt_006.log | 2 +- .../test_simple_task_rejection/prompt_001.log | 2 +- .../test_simple_task_rejection/prompt_008.log | 2 +- .../test_write_simple_script/prompt_001.log | 1 + .../test_write_simple_script/prompt_002.log | 1 + .../test_write_simple_script/prompt_003.log | 1 + .../test_write_simple_script/prompt_004.log | 1 + .../test_write_simple_script/prompt_005.log | 1 + .../test_write_simple_script/prompt_006.log | 1 + .../test_write_simple_script/prompt_007.log | 1 + .../test_write_simple_script/prompt_008.log | 1 + .../test_write_simple_script/prompt_009.log | 1 + .../test_write_simple_script/prompt_010.log | 1 + .../test_write_simple_script/prompt_011.log | 1 + .../test_browse_internet/prompt_005.log | 2 +- .../CodeActAgent/test_edits/prompt_002.log | 2 +- .../CodeActAgent/test_edits/prompt_003.log | 2 +- .../CodeActAgent/test_edits/prompt_004.log | 2 +- .../CodeActAgent/test_edits/prompt_005.log | 2 +- .../CodeActAgent/test_ipython/prompt_002.log | 2 +- .../CodeActAgent/test_ipython/prompt_003.log | 2 +- .../test_ipython_module/prompt_002.log | 2 +- .../test_ipython_module/prompt_003.log | 2 +- .../test_ipython_module/prompt_004.log | 2 +- .../test_write_simple_script/prompt_002.log | 2 +- .../test_write_simple_script/prompt_003.log | 2 +- .../test_write_simple_script/prompt_004.log | 2 +- .../CodeActSWEAgent/test_edits/prompt_001.log | 2 +- .../CodeActSWEAgent/test_edits/prompt_002.log | 4 +- .../CodeActSWEAgent/test_edits/prompt_003.log | 4 +- .../CodeActSWEAgent/test_edits/prompt_004.log | 4 +- .../CodeActSWEAgent/test_edits/prompt_005.log | 4 +- .../test_ipython/prompt_001.log | 2 +- .../test_ipython/prompt_002.log | 4 +- .../test_ipython_module/prompt_001.log | 2 +- .../test_ipython_module/prompt_002.log | 6 +- .../test_ipython_module/prompt_003.log | 25 +- .../test_ipython_module/prompt_004.log | 636 ++++++++++++++++++ .../test_ipython_module/response_001.log | 2 +- .../test_ipython_module/response_002.log | 16 +- .../test_ipython_module/response_003.log | 5 +- .../test_ipython_module/response_004.log | 3 + .../test_write_simple_script/prompt_001.log | 2 +- .../test_write_simple_script/prompt_002.log | 4 +- .../test_write_simple_script/prompt_003.log | 4 +- .../test_write_simple_script/prompt_004.log | 4 +- .../ManagerAgent/test_edits/prompt_001.log | 2 +- .../ManagerAgent/test_edits/prompt_004.log | 2 +- .../ManagerAgent/test_edits/prompt_005.log | 2 +- .../ManagerAgent/test_edits/prompt_006.log | 2 +- .../test_simple_task_rejection/prompt_001.log | 2 +- .../test_simple_task_rejection/prompt_006.log | 2 +- .../test_write_simple_script/prompt_001.log | 1 + .../test_write_simple_script/prompt_002.log | 1 + .../test_write_simple_script/prompt_003.log | 1 + .../test_write_simple_script/prompt_004.log | 1 + .../test_write_simple_script/prompt_005.log | 1 + .../test_write_simple_script/prompt_006.log | 1 + .../test_write_simple_script/prompt_007.log | 1 + .../test_write_simple_script/prompt_008.log | 1 + .../test_write_simple_script/prompt_009.log | 1 + .../test_write_simple_script/prompt_010.log | 1 + .../test_write_simple_script/prompt_011.log | 1 + tests/unit/test_action_serialization.py | 2 + tests/unit/test_json.py | 2 + tests/unit/test_message_serialization.py | 64 ++ tests/unit/test_micro_agents.py | 4 +- 116 files changed, 1260 insertions(+), 212 deletions(-) create mode 100644 opendevin/core/message.py create mode 100644 tests/integration/mock/server_runtime/CodeActSWEAgent/test_ipython_module/prompt_004.log create mode 100644 tests/integration/mock/server_runtime/CodeActSWEAgent/test_ipython_module/response_004.log create mode 100644 tests/unit/test_message_serialization.py diff --git a/agenthub/browsing_agent/browsing_agent.py b/agenthub/browsing_agent/browsing_agent.py index 91a59cd282..448e8b1392 100644 --- a/agenthub/browsing_agent/browsing_agent.py +++ b/agenthub/browsing_agent/browsing_agent.py @@ -7,6 +7,7 @@ from agenthub.browsing_agent.response_parser import BrowsingResponseParser from opendevin.controller.agent import Agent from opendevin.controller.state.state import State from opendevin.core.logger import opendevin_logger as logger +from opendevin.core.message import Message, TextContent from opendevin.events.action import ( Action, AgentFinishAction, @@ -136,7 +137,7 @@ class BrowsingAgent(Agent): - MessageAction(content) - Message action to run (e.g. ask for clarification) - AgentFinishAction() - end the interaction """ - messages = [] + messages: list[Message] = [] prev_actions = [] cur_axtree_txt = '' error_prefix = '' @@ -191,20 +192,23 @@ class BrowsingAgent(Agent): ) return MessageAction('Error encountered when browsing.') - if (goal := state.get_current_user_intent()) is None: + goal, _ = state.get_current_user_intent() + + if goal is None: goal = state.inputs['task'] + system_msg = get_system_message( goal, self.action_space.describe(with_long_description=False, with_examples=True), ) - messages.append({'role': 'system', 'content': system_msg}) + messages.append(Message(role='system', content=[TextContent(text=system_msg)])) prompt = get_prompt(error_prefix, cur_axtree_txt, prev_action_str) - messages.append({'role': 'user', 'content': prompt}) + messages.append(Message(role='user', content=[TextContent(text=prompt)])) logger.debug(prompt) response = self.llm.completion( - messages=messages, + messages=[message.model_dump() for message in messages], temperature=0.0, stop=[')```', ')\n```'], ) diff --git a/agenthub/codeact_agent/codeact_agent.py b/agenthub/codeact_agent/codeact_agent.py index 08af8ca95a..56d1b5f5c9 100644 --- a/agenthub/codeact_agent/codeact_agent.py +++ b/agenthub/codeact_agent/codeact_agent.py @@ -8,6 +8,7 @@ from agenthub.codeact_agent.prompt import ( ) from opendevin.controller.agent import Agent from opendevin.controller.state.state import State +from opendevin.core.message import ImageContent, Message, TextContent from opendevin.events.action import ( Action, AgentDelegateAction, @@ -125,7 +126,7 @@ class CodeActAgent(Agent): return action.thought return '' - def get_action_message(self, action: Action) -> dict[str, str] | None: + def get_action_message(self, action: Action) -> Message | None: if ( isinstance(action, AgentDelegateAction) or isinstance(action, CmdRunAction) @@ -133,39 +134,41 @@ class CodeActAgent(Agent): or isinstance(action, MessageAction) or (isinstance(action, AgentFinishAction) and action.source == 'agent') ): - return { - 'role': 'user' if action.source == 'user' else 'assistant', - 'content': self.action_to_str(action), - } + content = [TextContent(text=self.action_to_str(action))] + + if isinstance(action, MessageAction) and action.images_urls: + content.append(ImageContent(image_urls=action.images_urls)) + + return Message( + role='user' if action.source == 'user' else 'assistant', content=content + ) return None - def get_observation_message(self, obs: Observation) -> dict[str, str] | None: + def get_observation_message(self, obs: Observation) -> Message | None: max_message_chars = self.llm.config.max_message_chars if isinstance(obs, CmdOutputObservation): - content = 'OBSERVATION:\n' + truncate_content( - obs.content, max_message_chars - ) - content += ( + text = 'OBSERVATION:\n' + truncate_content(obs.content, max_message_chars) + text += ( f'\n[Command {obs.command_id} finished with exit code {obs.exit_code}]' ) - return {'role': 'user', 'content': content} + return Message(role='user', content=[TextContent(text=text)]) elif isinstance(obs, IPythonRunCellObservation): - content = 'OBSERVATION:\n' + obs.content + text = 'OBSERVATION:\n' + obs.content # replace base64 images with a placeholder - splitted = content.split('\n') + splitted = text.split('\n') for i, line in enumerate(splitted): if '![image](data:image/png;base64,' in line: splitted[i] = ( '![image](data:image/png;base64, ...) already displayed to user' ) - content = '\n'.join(splitted) - content = truncate_content(content, max_message_chars) - return {'role': 'user', 'content': content} + text = '\n'.join(splitted) + text = truncate_content(text, max_message_chars) + return Message(role='user', content=[TextContent(text=text)]) elif isinstance(obs, AgentDelegateObservation): - content = 'OBSERVATION:\n' + truncate_content( + text = 'OBSERVATION:\n' + truncate_content( str(obs.outputs), max_message_chars ) - return {'role': 'user', 'content': content} + return Message(role='user', content=[TextContent(text=text)]) return None def reset(self) -> None: @@ -192,10 +195,10 @@ class CodeActAgent(Agent): return AgentFinishAction() # prepare what we want to send to the LLM - messages: list[dict[str, str]] = self._get_messages(state) + messages = self._get_messages(state) response = self.llm.completion( - messages=messages, + messages=[message.model_dump() for message in messages], stop=[ '', '', @@ -205,10 +208,10 @@ class CodeActAgent(Agent): ) return self.action_parser.parse(response) - def _get_messages(self, state: State) -> list[dict[str, str]]: - messages = [ - {'role': 'system', 'content': self.system_message}, - {'role': 'user', 'content': self.in_context_example}, + def _get_messages(self, state: State) -> list[Message]: + messages: list[Message] = [ + Message(role='system', content=[TextContent(text=self.system_message)]), + Message(role='user', content=[TextContent(text=self.in_context_example)]), ] for event in state.history.get_events(): @@ -225,21 +228,41 @@ class CodeActAgent(Agent): # handle error if the message is the SAME role as the previous message # litellm.exceptions.BadRequestError: litellm.BadRequestError: OpenAIException - Error code: 400 - {'detail': 'Only supports u/a/u/a/u...'} # there should not have two consecutive messages from the same role - if messages and messages[-1]['role'] == message['role']: - messages[-1]['content'] += '\n\n' + message['content'] + if messages and messages[-1].role == message.role: + messages[-1].content.extend(message.content) else: messages.append(message) # the latest user message is important: # we want to remind the agent of the environment constraints latest_user_message = next( - (m for m in reversed(messages) if m['role'] == 'user'), None + ( + m + for m in reversed(messages) + if m.role == 'user' + and any(isinstance(c, TextContent) for c in m.content) + ), + None, ) - # add a reminder to the prompt + # Get the last user text inside content if latest_user_message: - latest_user_message['content'] += ( - f'\n\nENVIRONMENT REMINDER: You have {state.max_iterations - state.iteration} turns left to complete the task. When finished reply with ' + latest_user_message_text = next( + ( + t + for t in reversed(latest_user_message.content) + if isinstance(t, TextContent) + ) ) + # add a reminder to the prompt + reminder_text = f'\n\nENVIRONMENT REMINDER: You have {state.max_iterations - state.iteration} turns left to complete the task. When finished reply with .' + + if latest_user_message_text: + latest_user_message_text.text = ( + latest_user_message_text.text + reminder_text + ) + else: + latest_user_message_text = TextContent(text=reminder_text) + latest_user_message.content.append(latest_user_message_text) return messages diff --git a/agenthub/codeact_swe_agent/codeact_swe_agent.py b/agenthub/codeact_swe_agent/codeact_swe_agent.py index 67ed646ee2..b1191ff672 100644 --- a/agenthub/codeact_swe_agent/codeact_swe_agent.py +++ b/agenthub/codeact_swe_agent/codeact_swe_agent.py @@ -7,6 +7,7 @@ from agenthub.codeact_swe_agent.prompt import ( from agenthub.codeact_swe_agent.response_parser import CodeActSWEResponseParser from opendevin.controller.agent import Agent from opendevin.controller.state.state import State +from opendevin.core.message import ImageContent, Message, TextContent from opendevin.events.action import ( Action, AgentFinishAction, @@ -84,40 +85,43 @@ class CodeActSWEAgent(Agent): return action.content return '' - def get_action_message(self, action: Action) -> dict[str, str] | None: + def get_action_message(self, action: Action) -> Message | None: if ( isinstance(action, CmdRunAction) or isinstance(action, IPythonRunCellAction) or isinstance(action, MessageAction) ): - return { - 'role': 'user' if action.source == 'user' else 'assistant', - 'content': self.action_to_str(action), - } + content = [TextContent(text=self.action_to_str(action))] + + if isinstance(action, MessageAction) and action.images_urls: + content.append(ImageContent(image_urls=action.images_urls)) + + return Message( + role='user' if action.source == 'user' else 'assistant', content=content + ) + return None - def get_observation_message(self, obs: Observation) -> dict[str, str] | None: + def get_observation_message(self, obs: Observation) -> Message | None: max_message_chars = self.llm.config.max_message_chars if isinstance(obs, CmdOutputObservation): - content = 'OBSERVATION:\n' + truncate_content( - obs.content, max_message_chars - ) - content += ( + text = 'OBSERVATION:\n' + truncate_content(obs.content, max_message_chars) + text += ( f'\n[Command {obs.command_id} finished with exit code {obs.exit_code}]' ) - return {'role': 'user', 'content': content} + return Message(role='user', content=[TextContent(text=text)]) elif isinstance(obs, IPythonRunCellObservation): - content = 'OBSERVATION:\n' + obs.content + text = 'OBSERVATION:\n' + obs.content # replace base64 images with a placeholder - splitted = content.split('\n') + splitted = text.split('\n') for i, line in enumerate(splitted): if '![image](data:image/png;base64,' in line: splitted[i] = ( '![image](data:image/png;base64, ...) already displayed to user' ) - content = '\n'.join(splitted) - content = truncate_content(content, max_message_chars) - return {'role': 'user', 'content': content} + text = '\n'.join(splitted) + text = truncate_content(text, max_message_chars) + return Message(role='user', content=[TextContent(text=text)]) return None def reset(self) -> None: @@ -143,10 +147,10 @@ class CodeActSWEAgent(Agent): return AgentFinishAction() # prepare what we want to send to the LLM - messages: list[dict[str, str]] = self._get_messages(state) + messages: list[Message] = self._get_messages(state) response = self.llm.completion( - messages=messages, + messages=[message.model_dump() for message in messages], stop=[ '', '', @@ -156,10 +160,10 @@ class CodeActSWEAgent(Agent): return self.response_parser.parse(response) - def _get_messages(self, state: State) -> list[dict[str, str]]: - messages = [ - {'role': 'system', 'content': self.system_message}, - {'role': 'user', 'content': self.in_context_example}, + def _get_messages(self, state: State) -> list[Message]: + messages: list[Message] = [ + Message(role='system', content=[TextContent(text=self.system_message)]), + Message(role='user', content=[TextContent(text=self.in_context_example)]), ] for event in state.history.get_events(): @@ -176,21 +180,35 @@ class CodeActSWEAgent(Agent): # handle error if the message is the SAME role as the previous message # litellm.exceptions.BadRequestError: litellm.BadRequestError: OpenAIException - Error code: 400 - {'detail': 'Only supports u/a/u/a/u...'} # there should not have two consecutive messages from the same role - if messages and messages[-1]['role'] == message['role']: - messages[-1]['content'] += '\n\n' + message['content'] + if messages and messages[-1].role == message.role: + messages[-1].content.extend(message.content) else: messages.append(message) # the latest user message is important: # we want to remind the agent of the environment constraints latest_user_message = next( - (m for m in reversed(messages) if m['role'] == 'user'), None + (m for m in reversed(messages) if m.role == 'user'), None ) - # add a reminder to the prompt + # Get the last user text inside content if latest_user_message: - latest_user_message['content'] += ( - f'\n\nENVIRONMENT REMINDER: You have {state.max_iterations - state.iteration} turns left to complete the task.' + latest_user_message_text = next( + ( + t + for t in reversed(latest_user_message.content) + if isinstance(t, TextContent) + ) ) + # add a reminder to the prompt + reminder_text = f'\n\nENVIRONMENT REMINDER: You have {state.max_iterations - state.iteration} turns left to complete the task. When finished reply with .' + + if latest_user_message_text: + latest_user_message_text.text = ( + latest_user_message_text.text + reminder_text + ) + else: + latest_user_message_text = TextContent(text=reminder_text) + latest_user_message.content.append(latest_user_message_text) return messages diff --git a/agenthub/delegator_agent/agent.py b/agenthub/delegator_agent/agent.py index 5a07cfc68c..1e49603bf3 100644 --- a/agenthub/delegator_agent/agent.py +++ b/agenthub/delegator_agent/agent.py @@ -34,7 +34,7 @@ class DelegatorAgent(Agent): """ if self.current_delegate == '': self.current_delegate = 'study' - task = state.get_current_user_intent() + task, _ = state.get_current_user_intent() return AgentDelegateAction( agent='StudyRepoForTaskAgent', inputs={'task': task} ) @@ -45,7 +45,7 @@ class DelegatorAgent(Agent): if not isinstance(last_observation, AgentDelegateObservation): raise Exception('Last observation is not an AgentDelegateObservation') - goal = state.get_current_user_intent() + goal, _ = state.get_current_user_intent() if self.current_delegate == 'study': self.current_delegate = 'coder' return AgentDelegateAction( diff --git a/agenthub/micro/agent.py b/agenthub/micro/agent.py index 941bc9e903..0d305f70ec 100644 --- a/agenthub/micro/agent.py +++ b/agenthub/micro/agent.py @@ -2,6 +2,7 @@ from jinja2 import BaseLoader, Environment from opendevin.controller.agent import Agent from opendevin.controller.state.state import State +from opendevin.core.message import ImageContent, Message, TextContent from opendevin.core.utils import json from opendevin.events.action import Action from opendevin.events.serialization.action import action_from_dict @@ -62,16 +63,20 @@ class MicroAgent(Agent): del self.delegates[self.agent_definition['name']] def step(self, state: State) -> Action: + last_user_message, last_image_urls = state.get_current_user_intent() prompt = self.prompt_template.render( state=state, instructions=instructions, to_json=to_json, history_to_json=self.history_to_json, delegates=self.delegates, - latest_user_message=state.get_current_user_intent(), + latest_user_message=last_user_message, ) - messages = [{'content': prompt, 'role': 'user'}] - resp = self.llm.completion(messages=messages) + content = [TextContent(text=prompt)] + if last_image_urls: + content.append(ImageContent(image_urls=last_image_urls)) + message = Message(role='user', content=content) + resp = self.llm.completion(messages=[message.model_dump()]) action_resp = resp['choices'][0]['message']['content'] action = parse_response(action_resp) return action diff --git a/agenthub/planner_agent/agent.py b/agenthub/planner_agent/agent.py index f3225e6e0e..04e38c0b95 100644 --- a/agenthub/planner_agent/agent.py +++ b/agenthub/planner_agent/agent.py @@ -1,11 +1,12 @@ from agenthub.planner_agent.response_parser import PlannerResponseParser from opendevin.controller.agent import Agent from opendevin.controller.state.state import State +from opendevin.core.message import ImageContent, Message, TextContent from opendevin.events.action import Action, AgentFinishAction from opendevin.llm.llm import LLM from opendevin.runtime.tools import RuntimeTool -from .prompt import get_prompt +from .prompt import get_prompt_and_images class PlannerAgent(Agent): @@ -42,7 +43,13 @@ class PlannerAgent(Agent): 'abandoned', ]: return AgentFinishAction() - prompt = get_prompt(state, self.llm.config.max_message_chars) - messages = [{'content': prompt, 'role': 'user'}] - resp = self.llm.completion(messages=messages) + + prompt, image_urls = get_prompt_and_images( + state, self.llm.config.max_message_chars + ) + content = [TextContent(text=prompt)] + if image_urls: + content.append(ImageContent(image_urls=image_urls)) + message = Message(role='user', content=content) + resp = self.llm.completion(messages=[message.model_dump()]) return self.response_parser.parse(resp) diff --git a/agenthub/planner_agent/prompt.py b/agenthub/planner_agent/prompt.py index 9289ac045b..ad26b20cac 100644 --- a/agenthub/planner_agent/prompt.py +++ b/agenthub/planner_agent/prompt.py @@ -115,7 +115,9 @@ def get_hint(latest_action_id: str) -> str: return hints.get(latest_action_id, '') -def get_prompt(state: State, max_message_chars: int) -> str: +def get_prompt_and_images( + state: State, max_message_chars: int +) -> tuple[str, list[str]]: """Gets the prompt for the planner agent. Formatted with the most recent action-observation pairs, current task, and hint based on last action @@ -161,16 +163,16 @@ def get_prompt(state: State, max_message_chars: int) -> str: logger.info('HINT:\n' + hint, extra={'msg_type': 'DETAIL'}) # the last relevant user message (the task) - task = state.get_current_user_intent() + message, image_urls = state.get_current_user_intent() # finally, fill in the prompt return prompt % { - 'task': task, + 'task': message, 'plan': plan_str, 'history': history_str, 'hint': hint, 'plan_status': plan_status, - } + }, image_urls def parse_response(response: str) -> Action: diff --git a/frontend/src/components/chat/Chat.test.tsx b/frontend/src/components/chat/Chat.test.tsx index 5637f81947..6b912c7222 100644 --- a/frontend/src/components/chat/Chat.test.tsx +++ b/frontend/src/components/chat/Chat.test.tsx @@ -5,9 +5,9 @@ import { renderWithProviders } from "test-utils"; import Chat from "./Chat"; const MESSAGES: Message[] = [ - { sender: "assistant", content: "Hello!" }, - { sender: "user", content: "Hi!" }, - { sender: "assistant", content: "How can I help you today?" }, + { sender: "assistant", content: "Hello!", imageUrls: [] }, + { sender: "user", content: "Hi!", imageUrls: [] }, + { sender: "assistant", content: "How can I help you today?", imageUrls: [] }, ]; describe("Chat", () => { diff --git a/frontend/src/components/chat/ChatInput.test.tsx b/frontend/src/components/chat/ChatInput.test.tsx index 02b7863080..0947f282ac 100644 --- a/frontend/src/components/chat/ChatInput.test.tsx +++ b/frontend/src/components/chat/ChatInput.test.tsx @@ -55,7 +55,7 @@ describe("ChatInput", () => { await user.type(textarea, "Hello, world!"); await user.click(button); - expect(onSendMessage).toHaveBeenCalledWith("Hello, world!"); + expect(onSendMessage).toHaveBeenCalledWith("Hello, world!", []); // Additionally, check if it was called exactly once expect(onSendMessage).toHaveBeenCalledTimes(1); }); @@ -68,7 +68,7 @@ describe("ChatInput", () => { await user.type(textarea, "Hello, world!"); await user.keyboard("{Enter}"); - expect(onSendMessage).toHaveBeenCalledWith("Hello, world!"); + expect(onSendMessage).toHaveBeenCalledWith("Hello, world!", []); }); it("should NOT send a message when shift + enter is pressed", async () => { diff --git a/frontend/src/components/chat/ChatInput.tsx b/frontend/src/components/chat/ChatInput.tsx index c92dddd1c1..7019c18607 100644 --- a/frontend/src/components/chat/ChatInput.tsx +++ b/frontend/src/components/chat/ChatInput.tsx @@ -1,26 +1,44 @@ import { Textarea } from "@nextui-org/react"; import React from "react"; import { useTranslation } from "react-i18next"; -import { VscArrowUp } from "react-icons/vsc"; +import { VscArrowUp, VscFileMedia } from "react-icons/vsc"; import { twMerge } from "tailwind-merge"; import { I18nKey } from "#/i18n/declaration"; interface ChatInputProps { disabled?: boolean; - onSendMessage: (message: string) => void; + onSendMessage: (message: string, image_urls: string[]) => void; } function ChatInput({ disabled = false, onSendMessage }: ChatInputProps) { const { t } = useTranslation(); const [message, setMessage] = React.useState(""); + const [files, setFiles] = React.useState([]); // This is true when the user is typing in an IME (e.g., Chinese, Japanese) const [isComposing, setIsComposing] = React.useState(false); - const handleSendChatMessage = () => { + const convertImageToBase64 = (file: File): Promise => + new Promise((resolve, reject) => { + const reader = new FileReader(); + reader.onloadend = () => { + resolve(reader.result as string); + }; + reader.onerror = reject; + reader.readAsDataURL(file); + }); + + const handleSendChatMessage = async () => { if (message.trim()) { - onSendMessage(message); + let base64images: string[] = []; + if (files.length > 0) { + base64images = await Promise.all( + files.map((file) => convertImageToBase64(file)), + ); + } + onSendMessage(message, base64images); setMessage(""); + setFiles([]); } }; @@ -33,6 +51,33 @@ function ChatInput({ disabled = false, onSendMessage }: ChatInputProps) { } }; + const handleFileChange = (event: React.ChangeEvent) => { + if (event.target.files) { + setFiles((prev) => [...prev, ...Array.from(event.target.files!)]); + } + }; + + const removeFile = (index: number) => { + setFiles((prevFiles) => prevFiles.filter((_, i) => i !== index)); + }; + + const handlePaste = (event: React.ClipboardEvent) => { + const clipboardItems = Array.from(event.clipboardData.items); + const pastedFiles: File[] = []; + clipboardItems.forEach((item) => { + if (item.type.startsWith("image/")) { + const file = item.getAsFile(); + if (file) { + pastedFiles.push(file); + } + } + }); + if (pastedFiles.length > 0) { + setFiles((prevFiles) => [...prevFiles, ...pastedFiles]); + event.preventDefault(); + } + }; + return (