Fix #8551: Show images produced in Jupyter Notebook to LLM directly (#8552)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Xingyao Wang 2025-05-19 22:14:00 +08:00 committed by GitHub
parent 1a3cb16ba6
commit 4a3d2e6859
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 171 additions and 42 deletions

View File

@ -24,6 +24,10 @@ export function JupyterCellOutput({ lines }: JupyterCellOutputProps) {
{/* display the lines as plaintext or image */}
{lines.map((line, index) => {
if (line.type === "image") {
// Use markdown to display the image
const imageMarkdown = line.url
? `![image](${line.url})`
: line.content;
return (
<div key={index}>
<Markdown
@ -32,7 +36,7 @@ export function JupyterCellOutput({ lines }: JupyterCellOutputProps) {
}}
urlTransform={(value: string) => value}
>
{line.content}
{imageMarkdown}
</Markdown>
</div>
);

View File

@ -12,8 +12,8 @@ export function JupyterCell({ cell }: JupyterCellProps) {
const [lines, setLines] = React.useState<JupyterLine[]>([]);
React.useEffect(() => {
setLines(parseCellContent(cell.content));
}, [cell.content]);
setLines(parseCellContent(cell.content, cell.imageUrls));
}, [cell.content, cell.imageUrls]);
if (cell.type === "input") {
return <JupytrerCellInput code={cell.content} />;

View File

@ -26,8 +26,14 @@ export function handleObservationMessage(message: ObservationMessage) {
break;
}
case ObservationType.RUN_IPYTHON:
// FIXME: render this as markdown
store.dispatch(appendJupyterOutput(message.content));
store.dispatch(
appendJupyterOutput({
content: message.content,
imageUrls: Array.isArray(message.extras?.image_urls)
? message.extras.image_urls
: undefined,
}),
);
break;
case ObservationType.BROWSE:
case ObservationType.BROWSE_INTERACTIVE:
@ -139,6 +145,9 @@ export function handleObservationMessage(message: ObservationMessage) {
observation: "run_ipython" as const,
extras: {
code: String(message.extras.code || ""),
image_urls: Array.isArray(message.extras.image_urls)
? message.extras.image_urls
: [],
},
}),
);

View File

@ -3,6 +3,7 @@ import { createSlice } from "@reduxjs/toolkit";
export type Cell = {
content: string;
type: "input" | "output";
imageUrls?: string[];
};
const initialCells: Cell[] = [];
@ -17,7 +18,11 @@ export const jupyterSlice = createSlice({
state.cells.push({ content: action.payload, type: "input" });
},
appendJupyterOutput: (state, action) => {
state.cells.push({ content: action.payload, type: "output" });
state.cells.push({
content: action.payload.content,
type: "output",
imageUrls: action.payload.imageUrls,
});
},
clearJupyter: (state) => {
state.cells = [];

View File

@ -23,6 +23,7 @@ export interface IPythonObservation
source: "agent";
extras: {
code: string;
image_urls?: string[];
};
}

View File

@ -1,26 +1,32 @@
export type JupyterLine = { type: "plaintext" | "image"; content: string };
export type JupyterLine = {
type: "plaintext" | "image";
content: string;
url?: string;
};
const IMAGE_PREFIX = "![image](data:image/png;base64,";
export const parseCellContent = (content: string) => {
export const parseCellContent = (content: string, imageUrls?: string[]) => {
const lines: JupyterLine[] = [];
let currentText = "";
// First, process the text content
for (const line of content.split("\n")) {
if (line.startsWith(IMAGE_PREFIX)) {
if (currentText) {
lines.push({ type: "plaintext", content: currentText });
currentText = ""; // Reset after pushing plaintext
}
lines.push({ type: "image", content: line });
} else {
currentText += `${line}\n`;
}
currentText += `${line}\n`;
}
if (currentText) {
lines.push({ type: "plaintext", content: currentText });
}
// Then, add image lines if we have image URLs
if (imageUrls && imageUrls.length > 0) {
imageUrls.forEach((url) => {
lines.push({
type: "image",
content: `![image](${url})`,
url,
});
});
}
return lines;
};

View File

@ -170,6 +170,7 @@ class IPythonRunCellObservation(Observation):
code: str
observation: str = ObservationType.RUN_IPYTHON
image_urls: list[str] | None = None
@property
def error(self) -> bool:
@ -184,4 +185,7 @@ class IPythonRunCellObservation(Observation):
return True # IPython cells are always considered successful
def __str__(self) -> str:
return f'**IPythonRunCellObservation**\n{self.content}'
result = f'**IPythonRunCellObservation**\n{self.content}'
if self.image_urls:
result += f'\nImages: {len(self.image_urls)}'
return result

View File

@ -360,7 +360,7 @@ class ConversationMemory:
message = Message(role='user', content=[TextContent(text=text)])
elif isinstance(obs, IPythonRunCellObservation):
text = obs.content
# replace base64 images with a placeholder
# Clean up any remaining base64 images in text content
splitted = text.split('\n')
for i, line in enumerate(splitted):
if '![image](data:image/png;base64,' in line:
@ -369,7 +369,15 @@ class ConversationMemory:
)
text = '\n'.join(splitted)
text = truncate_content(text, max_message_chars)
message = Message(role='user', content=[TextContent(text=text)])
# Create message content with text
content = [TextContent(text=text)]
# Add image URLs if available and vision is active
if vision_is_active and obs.image_urls:
content.append(ImageContent(image_urls=obs.image_urls))
message = Message(role='user', content=content)
elif isinstance(obs, FileEditObservation):
text = truncate_content(str(obs), max_message_chars)
message = Message(role='user', content=[TextContent(text=text)])

View File

@ -153,10 +153,18 @@ class JupyterPlugin(Plugin):
if not self.kernel.initialized:
await self.kernel.initialize()
# Execute the code and get structured output
output = await self.kernel.execute(action.code, timeout=action.timeout)
# Extract text content and image URLs from the structured output
text_content = output.get('text', '')
image_urls = output.get('images', [])
return IPythonRunCellObservation(
content=output,
content=text_content,
code=action.code,
image_urls=image_urls if image_urls else None,
)
async def run(self, action: Action) -> IPythonRunCellObservation:

View File

@ -139,7 +139,9 @@ class JupyterKernel:
stop=stop_after_attempt(3),
wait=wait_fixed(2),
) # type: ignore
async def execute(self, code: str, timeout: int = 120) -> str:
async def execute(
self, code: str, timeout: int = 120
) -> dict[str, list[str] | str]:
if not self.ws or self.ws.stream.closed():
await self._connect()
@ -171,7 +173,7 @@ class JupyterKernel:
)
logging.info(f'Executed code in jupyter kernel:\n{res}')
outputs: list[str] = []
outputs: list[dict] = []
async def wait_for_messages() -> bool:
execution_done = False
@ -194,17 +196,23 @@ class JupyterKernel:
if msg_type == 'error':
traceback = '\n'.join(msg_dict['content']['traceback'])
outputs.append(traceback)
outputs.append({'type': 'text', 'content': traceback})
execution_done = True
elif msg_type == 'stream':
outputs.append(msg_dict['content']['text'])
outputs.append(
{'type': 'text', 'content': msg_dict['content']['text']}
)
elif msg_type in ['execute_result', 'display_data']:
outputs.append(msg_dict['content']['data']['text/plain'])
outputs.append(
{
'type': 'text',
'content': msg_dict['content']['data']['text/plain'],
}
)
if 'image/png' in msg_dict['content']['data']:
# use markdone to display image (in case of large image)
outputs.append(
f'\n![image](data:image/png;base64,{msg_dict["content"]["data"]["image/png"]})\n'
)
# Store image data in structured format
image_url = f'data:image/png;base64,{msg_dict["content"]["data"]["image/png"]}'
outputs.append({'type': 'image', 'content': image_url})
elif msg_type == 'execute_reply':
execution_done = True
@ -225,19 +233,28 @@ class JupyterKernel:
execution_done = await asyncio.wait_for(wait_for_messages(), timeout)
except asyncio.TimeoutError:
await interrupt_kernel()
return f'[Execution timed out ({timeout} seconds).]'
return {'text': f'[Execution timed out ({timeout} seconds).]', 'images': []}
if not outputs and execution_done:
ret = '[Code executed successfully with no output]'
# Process structured outputs
text_outputs = []
image_outputs = []
for output in outputs:
if output['type'] == 'text':
text_outputs.append(output['content'])
elif output['type'] == 'image':
image_outputs.append(output['content'])
if not text_outputs and execution_done:
text_content = '[Code executed successfully with no output]'
else:
ret = ''.join(outputs)
text_content = ''.join(text_outputs)
# Remove ANSI
ret = strip_ansi(ret)
# Remove ANSI from text content
text_content = strip_ansi(text_content)
if os.environ.get('DEBUG'):
logging.info(f'OUTPUT:\n{ret}')
return ret
# Return a dictionary with text content and image URLs
return {'text': text_content, 'images': image_outputs}
async def shutdown_async(self) -> None:
if self.kernel_id:
@ -267,7 +284,9 @@ class ExecuteHandler(tornado.web.RequestHandler):
output = await self.jupyter_kernel.execute(code)
self.write(output)
# Set content type to JSON and return the structured output
self.set_header('Content-Type', 'application/json')
self.write(json_encode(output))
def make_app() -> tornado.web.Application:

View File

@ -85,6 +85,11 @@ def mock_state():
return state
@pytest.fixture
def mock_prompt_manager():
return MagicMock()
def test_process_events_with_message_action(conversation_memory):
"""Test that MessageAction is processed correctly."""
# Create a system message action
@ -1514,3 +1519,63 @@ def test_process_events_partial_history(conversation_memory):
messages_partial_obs_only[1].role == 'user'
) # Added by _ensure_initial_user_message
assert messages_partial_obs_only[1].content[0].text == 'Initial user query'
def test_process_ipython_observation_with_vision_enabled(
agent_config, mock_prompt_manager
):
"""Test that _process_observation correctly handles IPythonRunCellObservation with image_urls when vision is enabled."""
# Create a ConversationMemory instance
memory = ConversationMemory(agent_config, mock_prompt_manager)
# Create an observation with image URLs
obs = IPythonRunCellObservation(
content='Test output',
code="print('test')",
image_urls=[''],
)
# Process the observation with vision enabled
messages = memory._process_observation(
obs=obs,
tool_call_id_to_message={},
max_message_chars=None,
vision_is_active=True,
)
# Check that the message contains both text and image content
assert len(messages) == 1
message = messages[0]
assert len(message.content) == 2
assert isinstance(message.content[0], TextContent)
assert isinstance(message.content[1], ImageContent)
assert message.content[1].image_urls == ['']
def test_process_ipython_observation_with_vision_disabled(
agent_config, mock_prompt_manager
):
"""Test that _process_observation correctly handles IPythonRunCellObservation with image_urls when vision is disabled."""
# Create a ConversationMemory instance
memory = ConversationMemory(agent_config, mock_prompt_manager)
# Create an observation with image URLs
obs = IPythonRunCellObservation(
content='Test output',
code="print('test')",
image_urls=[''],
)
# Process the observation with vision disabled
messages = memory._process_observation(
obs=obs,
tool_call_id_to_message={},
max_message_chars=None,
vision_is_active=False,
)
# Check that the message contains only text content
assert len(messages) == 1
message = messages[0]
assert len(message.content) == 1
assert isinstance(message.content[0], TextContent)