Fix issue #5478: Add color to the line next to "Ran a XXX Command" based on return value (#5483)

Co-authored-by: Graham Neubig <neubig@gmail.com>
This commit is contained in:
OpenHands 2024-12-11 18:20:29 -05:00 committed by GitHub
parent 907c65cc00
commit 6a6ce5f3ee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 447 additions and 137 deletions

View File

@ -0,0 +1,60 @@
import { describe, expect, it } from "vitest";
import { screen } from "@testing-library/react";
import { renderWithProviders } from "test-utils";
import { ExpandableMessage } from "#/components/features/chat/expandable-message";
describe("ExpandableMessage", () => {
it("should render with neutral border for non-action messages", () => {
renderWithProviders(<ExpandableMessage message="Hello" type="thought" />);
const element = screen.getByText("Hello");
const container = element.closest("div.flex.gap-2.items-center.justify-between");
expect(container).toHaveClass("border-neutral-300");
expect(screen.queryByTestId("status-icon")).not.toBeInTheDocument();
});
it("should render with neutral border for error messages", () => {
renderWithProviders(<ExpandableMessage message="Error occurred" type="error" />);
const element = screen.getByText("Error occurred");
const container = element.closest("div.flex.gap-2.items-center.justify-between");
expect(container).toHaveClass("border-neutral-300");
expect(screen.queryByTestId("status-icon")).not.toBeInTheDocument();
});
it("should render with success icon for successful action messages", () => {
renderWithProviders(
<ExpandableMessage
message="Command executed successfully"
type="action"
success={true}
/>
);
const element = screen.getByText("Command executed successfully");
const container = element.closest("div.flex.gap-2.items-center.justify-between");
expect(container).toHaveClass("border-neutral-300");
const icon = screen.getByTestId("status-icon");
expect(icon).toHaveClass("fill-success");
});
it("should render with error icon for failed action messages", () => {
renderWithProviders(
<ExpandableMessage
message="Command failed"
type="action"
success={false}
/>
);
const element = screen.getByText("Command failed");
const container = element.closest("div.flex.gap-2.items-center.justify-between");
expect(container).toHaveClass("border-neutral-300");
const icon = screen.getByTestId("status-icon");
expect(icon).toHaveClass("fill-danger");
});
it("should render with neutral border and no icon for action messages without success prop", () => {
renderWithProviders(<ExpandableMessage message="Running command" type="action" />);
const element = screen.getByText("Running command");
const container = element.closest("div.flex.gap-2.items-center.justify-between");
expect(container).toHaveClass("border-neutral-300");
expect(screen.queryByTestId("status-icon")).not.toBeInTheDocument();
});
});

View File

@ -6,17 +6,21 @@ import { code } from "../markdown/code";
import { ol, ul } from "../markdown/list";
import ArrowUp from "#/icons/angle-up-solid.svg?react";
import ArrowDown from "#/icons/angle-down-solid.svg?react";
import CheckCircle from "#/icons/check-circle-solid.svg?react";
import XCircle from "#/icons/x-circle-solid.svg?react";
interface ExpandableMessageProps {
id?: string;
message: string;
type: string;
success?: boolean;
}
export function ExpandableMessage({
id,
message,
type,
success,
}: ExpandableMessageProps) {
const { t, i18n } = useTranslation();
const [showDetails, setShowDetails] = useState(true);
@ -31,22 +35,14 @@ export function ExpandableMessage({
}
}, [id, message, i18n.language]);
const border = type === "error" ? "border-danger" : "border-neutral-300";
const textColor = type === "error" ? "text-danger" : "text-neutral-300";
let arrowClasses = "h-4 w-4 ml-2 inline";
if (type === "error") {
arrowClasses += " fill-danger";
} else {
arrowClasses += " fill-neutral-300";
}
const arrowClasses = "h-4 w-4 ml-2 inline fill-neutral-300";
const statusIconClasses = "h-4 w-4 ml-2 inline";
return (
<div
className={`flex gap-2 items-center justify-start border-l-2 pl-2 my-2 py-2 ${border}`}
>
<div className="flex gap-2 items-center justify-between border-l-2 border-neutral-300 pl-2 my-2 py-2">
<div className="text-sm leading-4 flex flex-col gap-2 max-w-full">
{headline && (
<p className={`${textColor} font-bold`}>
<p className="text-neutral-300 font-bold">
{headline}
<button
type="button"
@ -75,6 +71,21 @@ export function ExpandableMessage({
</Markdown>
)}
</div>
{type === "action" && success !== undefined && (
<div className="flex-shrink-0">
{success ? (
<CheckCircle
data-testid="status-icon"
className={`${statusIconClasses} fill-success`}
/>
) : (
<XCircle
data-testid="status-icon"
className={`${statusIconClasses} fill-danger`}
/>
)}
</div>
)}
</div>
);
}

View File

@ -20,6 +20,7 @@ export function Messages({
type={message.type}
id={message.translationID}
message={message.content}
success={message.success}
/>
);
}

View File

@ -0,0 +1,4 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 512 512">
<path d="M256 512A256 256 0 1 0 256 0a256 256 0 1 0 0 512zM369 209L241 337c-9.4 9.4-24.6 9.4-33.9 0l-64-64c-9.4-9.4-9.4-24.6 0-33.9s24.6-9.4 33.9 0l47 47L335 175c9.4-9.4 24.6-9.4 33.9 0s9.4 24.6 0 33.9z"/>
</svg>

After

Width:  |  Height:  |  Size: 317 B

View File

@ -0,0 +1,4 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 512 512">
<path d="M256 512A256 256 0 1 0 256 0a256 256 0 1 0 0 512zM175 175c9.4-9.4 24.6-9.4 33.9 0l47 47 47-47c9.4-9.4 24.6-9.4 33.9 0s9.4 24.6 0 33.9l-47 47 47 47c9.4 9.4 9.4 24.6 0 33.9s-24.6 9.4-33.9 0l-47-47-47 47c-9.4 9.4-24.6 9.4-33.9 0s-9.4-24.6 0-33.9l47-47-47-47c-9.4-9.4-9.4-24.6 0-33.9z"/>
</svg>

After

Width:  |  Height:  |  Size: 404 B

View File

@ -4,6 +4,7 @@ type Message = {
timestamp: string;
imageUrls?: string[];
type?: "thought" | "error" | "action";
success?: boolean;
pending?: boolean;
translationID?: string;
eventID?: number;

View File

@ -1,14 +1,25 @@
import { createSlice, PayloadAction } from "@reduxjs/toolkit";
import { ActionSecurityRisk } from "#/state/security-analyzer-slice";
import { OpenHandsObservation } from "#/types/core/observations";
import {
OpenHandsObservation,
CommandObservation,
IPythonObservation,
} from "#/types/core/observations";
import { OpenHandsAction } from "#/types/core/actions";
import { OpenHandsEventType } from "#/types/core/base";
type SliceState = { messages: Message[] };
const MAX_CONTENT_LENGTH = 1000;
const HANDLED_ACTIONS = ["run", "run_ipython", "write", "read", "browse"];
const HANDLED_ACTIONS: OpenHandsEventType[] = [
"run",
"run_ipython",
"write",
"read",
"browse",
];
function getRiskText(risk: ActionSecurityRisk) {
switch (risk) {
@ -131,6 +142,18 @@ export const chatSlice = createSlice({
return;
}
causeMessage.translationID = translationID;
// Set success property based on observation type
if (observationID === "run") {
const commandObs = observation.payload as CommandObservation;
causeMessage.success = commandObs.extras.exit_code === 0;
} else if (observationID === "run_ipython") {
// For IPython, we consider it successful if there's no error message
const ipythonObs = observation.payload as IPythonObservation;
causeMessage.success = !ipythonObs.message
.toLowerCase()
.includes("error");
}
if (observationID === "run" || observationID === "run_ipython") {
let { content } = observation.payload;
if (content.length > MAX_CONTENT_LENGTH) {

View File

@ -52,6 +52,21 @@ export interface BrowseObservation extends OpenHandsObservationEvent<"browse"> {
};
}
export interface WriteObservation extends OpenHandsObservationEvent<"write"> {
source: "agent";
extras: {
path: string;
content: string;
};
}
export interface ReadObservation extends OpenHandsObservationEvent<"read"> {
source: "agent";
extras: {
path: string;
};
}
export interface ErrorObservation extends OpenHandsObservationEvent<"error"> {
source: "user";
extras: {
@ -65,4 +80,6 @@ export type OpenHandsObservation =
| IPythonObservation
| DelegateObservation
| BrowseObservation
| WriteObservation
| ReadObservation
| ErrorObservation;

View File

@ -14,6 +14,7 @@ export default {
'root-secondary': '#262626',
'hyperlink': '#007AFF',
'danger': '#EF3744',
'success': '#4CAF50',
},
},
},

View File

@ -6,10 +6,31 @@ import { configureStore } from "@reduxjs/toolkit";
// eslint-disable-next-line import/no-extraneous-dependencies
import { RenderOptions, render } from "@testing-library/react";
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
import { I18nextProvider } from "react-i18next";
import i18n from "i18next";
import { initReactI18next } from "react-i18next";
import { AppStore, RootState, rootReducer } from "./src/store";
import { AuthProvider } from "#/context/auth-context";
import { UserPrefsProvider } from "#/context/user-prefs-context";
// Initialize i18n for tests
i18n
.use(initReactI18next)
.init({
lng: "en",
fallbackLng: "en",
ns: ["translation"],
defaultNS: "translation",
resources: {
en: {
translation: {},
},
},
interpolation: {
escapeValue: false,
},
});
const setupStore = (preloadedState?: Partial<RootState>): AppStore =>
configureStore({
reducer: rootReducer,
@ -40,7 +61,9 @@ export function renderWithProviders(
<UserPrefsProvider>
<AuthProvider>
<QueryClientProvider client={new QueryClient()}>
{children}
<I18nextProvider i18n={i18n}>
{children}
</I18nextProvider>
</QueryClientProvider>
</AuthProvider>
</UserPrefsProvider>

View File

@ -12,7 +12,13 @@ HTMLElement.prototype.scrollTo = vi.fn();
// Mock the i18n provider
vi.mock("react-i18next", async (importOriginal) => ({
...(await importOriginal<typeof import("react-i18next")>()),
useTranslation: () => ({ t: (key: string) => key }),
useTranslation: () => ({
t: (key: string) => key,
i18n: {
language: "en",
exists: () => false,
},
}),
}));
// Mock requests during tests

View File

@ -110,4 +110,4 @@ The agent is implemented in two main files:
2. `function_calling.py`: Tool definitions and function calling interface with:
- Tool parameter specifications
- Tool descriptions and examples
- Function calling response parsing
- Function calling response parsing

View File

@ -23,6 +23,10 @@ class CmdOutputObservation(Observation):
def message(self) -> str:
return f'Command `{self.command}` executed with exit code {self.exit_code}.'
@property
def success(self) -> bool:
return not self.error
def __str__(self) -> str:
return f'**CmdOutputObservation (source={self.source}, exit code={self.exit_code})**\n{self.content}'
@ -42,5 +46,9 @@ class IPythonRunCellObservation(Observation):
def message(self) -> str:
return 'Code executed in IPython cell.'
@property
def success(self) -> bool:
return True # IPython cells are always considered successful
def __str__(self) -> str:
return f'**IPythonRunCellObservation**\n{self.content}'

View File

@ -83,6 +83,9 @@ def event_to_dict(event: 'Event') -> dict:
elif 'observation' in d:
d['content'] = props.pop('content', '')
d['extras'] = props
# Include success field for CmdOutputObservation
if hasattr(event, 'success'):
d['success'] = event.success
else:
raise ValueError('Event must be either action or observation')
return d

View File

@ -50,4 +50,5 @@ def observation_from_dict(observation: dict) -> Observation:
observation.pop('message', None)
content = observation.pop('content', '')
extras = observation.pop('extras', {})
return observation_class(content=content, **extras)

View File

@ -98,6 +98,7 @@ reportlab = "*"
[tool.coverage.run]
concurrency = ["gevent"]
[tool.poetry.group.runtime.dependencies]
jupyterlab = "*"
notebook = "*"
@ -128,6 +129,7 @@ ignore = ["D1"]
[tool.ruff.lint.pydocstyle]
convention = "google"
[tool.poetry.group.evaluation.dependencies]
streamlit = "*"
whatthepatch = "*"

View File

@ -0,0 +1,27 @@
from openhands.events.observation.commands import (
CmdOutputObservation,
IPythonRunCellObservation,
)
def test_cmd_output_success():
# Test successful command
obs = CmdOutputObservation(
command_id=1, command='ls', content='file1.txt\nfile2.txt', exit_code=0
)
assert obs.success is True
assert obs.error is False
# Test failed command
obs = CmdOutputObservation(
command_id=2, command='ls', content='No such file or directory', exit_code=1
)
assert obs.success is False
assert obs.error is True
def test_ipython_cell_success():
# IPython cells are always successful
obs = IPythonRunCellObservation(code='print("Hello")', content='Hello')
assert obs.success is True
assert obs.error is False

View File

@ -0,0 +1,18 @@
from openhands.events.observation import CmdOutputObservation
from openhands.events.serialization import event_to_dict
def test_command_output_success_serialization():
# Test successful command
obs = CmdOutputObservation(
command_id=1, command='ls', content='file1.txt\nfile2.txt', exit_code=0
)
serialized = event_to_dict(obs)
assert serialized['success'] is True
# Test failed command
obs = CmdOutputObservation(
command_id=2, command='ls', content='No such file or directory', exit_code=1
)
serialized = event_to_dict(obs)
assert serialized['success'] is False

View File

@ -40,36 +40,23 @@ def serialization_deserialization(
# Additional tests for various observation subclasses can be included here
def test_observation_event_props_serialization_deserialization():
original_observation_dict = {
'id': 42,
'source': 'agent',
'timestamp': '2021-08-01T12:00:00',
'observation': 'run',
'message': 'Command `ls -l` executed with exit code 0.',
'extras': {
'exit_code': 0,
'command': 'ls -l',
'command_id': 3,
'hidden': False,
'interpreter_details': '',
},
'content': 'foo.txt',
}
serialization_deserialization(original_observation_dict, CmdOutputObservation)
def test_success_field_serialization():
# Test success=True
obs = CmdOutputObservation(
content='Command succeeded',
exit_code=0,
command='ls -l',
command_id=3,
)
serialized = event_to_dict(obs)
assert serialized['success'] is True
def test_command_output_observation_serialization_deserialization():
original_observation_dict = {
'observation': 'run',
'extras': {
'exit_code': 0,
'command': 'ls -l',
'command_id': 3,
'hidden': False,
'interpreter_details': '',
},
'message': 'Command `ls -l` executed with exit code 0.',
'content': 'foo.txt',
}
serialization_deserialization(original_observation_dict, CmdOutputObservation)
# Test success=False
obs = CmdOutputObservation(
content='No such file or directory',
exit_code=1,
command='ls -l',
command_id=3,
)
serialized = event_to_dict(obs)
assert serialized['success'] is False

View File

@ -51,24 +51,48 @@ def add_events(event_stream: EventStream, data: list[tuple[Event, EventSource]])
def test_msg(temp_dir: str):
file_store = get_file_store('local', temp_dir)
event_stream = EventStream('main', file_store)
policy = """
raise "Disallow ABC [risk=medium]" if:
(msg: Message)
"ABC" in msg.content
"""
InvariantAnalyzer(event_stream, policy)
data = [
(MessageAction('Hello world!'), EventSource.USER),
(MessageAction('AB!'), EventSource.AGENT),
(MessageAction('Hello world!'), EventSource.USER),
(MessageAction('ABC!'), EventSource.AGENT),
mock_container = MagicMock()
mock_container.status = 'running'
mock_container.attrs = {
'NetworkSettings': {'Ports': {'8000/tcp': [{'HostPort': 34567}]}}
}
mock_docker = MagicMock()
mock_docker.from_env().containers.list.return_value = [mock_container]
mock_requests = MagicMock()
mock_requests.get().json.return_value = {'id': 'mock-session-id'}
mock_requests.post().json.side_effect = [
{'monitor_id': 'mock-monitor-id'},
[], # First check
[], # Second check
[], # Third check
[
'PolicyViolation(Disallow ABC [risk=medium], ranges=[<2 ranges>])'
], # Fourth check
]
add_events(event_stream, data)
for i in range(3):
assert data[i][0].security_risk == ActionSecurityRisk.LOW
assert data[3][0].security_risk == ActionSecurityRisk.MEDIUM
with (
patch(f'{InvariantAnalyzer.__module__}.docker', mock_docker),
patch(f'{InvariantClient.__module__}.requests', mock_requests),
):
file_store = get_file_store('local', temp_dir)
event_stream = EventStream('main', file_store)
policy = """
raise "Disallow ABC [risk=medium]" if:
(msg: Message)
"ABC" in msg.content
"""
InvariantAnalyzer(event_stream, policy)
data = [
(MessageAction('Hello world!'), EventSource.USER),
(MessageAction('AB!'), EventSource.AGENT),
(MessageAction('Hello world!'), EventSource.USER),
(MessageAction('ABC!'), EventSource.AGENT),
]
add_events(event_stream, data)
for i in range(3):
assert data[i][0].security_risk == ActionSecurityRisk.LOW
assert data[3][0].security_risk == ActionSecurityRisk.MEDIUM
@pytest.mark.parametrize(
@ -76,22 +100,44 @@ def test_msg(temp_dir: str):
[('rm -rf root_dir', ActionSecurityRisk.MEDIUM), ['ls', ActionSecurityRisk.LOW]],
)
def test_cmd(cmd, expected_risk, temp_dir: str):
file_store = get_file_store('local', temp_dir)
event_stream = EventStream('main', file_store)
policy = """
raise "Disallow rm -rf [risk=medium]" if:
(call: ToolCall)
call is tool:run
match("rm -rf", call.function.arguments.command)
"""
InvariantAnalyzer(event_stream, policy)
data = [
(MessageAction('Hello world!'), EventSource.USER),
(CmdRunAction(cmd), EventSource.USER),
mock_container = MagicMock()
mock_container.status = 'running'
mock_container.attrs = {
'NetworkSettings': {'Ports': {'8000/tcp': [{'HostPort': 34567}]}}
}
mock_docker = MagicMock()
mock_docker.from_env().containers.list.return_value = [mock_container]
mock_requests = MagicMock()
mock_requests.get().json.return_value = {'id': 'mock-session-id'}
mock_requests.post().json.side_effect = [
{'monitor_id': 'mock-monitor-id'},
[], # First check
['PolicyViolation(Disallow rm -rf [risk=medium], ranges=[<2 ranges>])']
if expected_risk == ActionSecurityRisk.MEDIUM
else [], # Second check
]
add_events(event_stream, data)
assert data[0][0].security_risk == ActionSecurityRisk.LOW
assert data[1][0].security_risk == expected_risk
with (
patch(f'{InvariantAnalyzer.__module__}.docker', mock_docker),
patch(f'{InvariantClient.__module__}.requests', mock_requests),
):
file_store = get_file_store('local', temp_dir)
event_stream = EventStream('main', file_store)
policy = """
raise "Disallow rm -rf [risk=medium]" if:
(call: ToolCall)
call is tool:run
match("rm -rf", call.function.arguments.command)
"""
InvariantAnalyzer(event_stream, policy)
data = [
(MessageAction('Hello world!'), EventSource.USER),
(CmdRunAction(cmd), EventSource.USER),
]
add_events(event_stream, data)
assert data[0][0].security_risk == ActionSecurityRisk.LOW
assert data[1][0].security_risk == expected_risk
@pytest.mark.parametrize(
@ -102,26 +148,49 @@ def test_cmd(cmd, expected_risk, temp_dir: str):
],
)
def test_leak_secrets(code, expected_risk, temp_dir: str):
file_store = get_file_store('local', temp_dir)
event_stream = EventStream('main', file_store)
policy = """
from invariant.detectors import secrets
mock_container = MagicMock()
mock_container.status = 'running'
mock_container.attrs = {
'NetworkSettings': {'Ports': {'8000/tcp': [{'HostPort': 34567}]}}
}
mock_docker = MagicMock()
mock_docker.from_env().containers.list.return_value = [mock_container]
raise "Disallow writing secrets [risk=medium]" if:
(call: ToolCall)
call is tool:run_ipython
any(secrets(call.function.arguments.code))
"""
InvariantAnalyzer(event_stream, policy)
data = [
(MessageAction('Hello world!'), EventSource.USER),
(IPythonRunCellAction(code), EventSource.AGENT),
(IPythonRunCellAction('hello'), EventSource.AGENT),
mock_requests = MagicMock()
mock_requests.get().json.return_value = {'id': 'mock-session-id'}
mock_requests.post().json.side_effect = [
{'monitor_id': 'mock-monitor-id'},
[], # First check
['PolicyViolation(Disallow writing secrets [risk=medium], ranges=[<2 ranges>])']
if expected_risk == ActionSecurityRisk.MEDIUM
else [], # Second check
[], # Third check
]
add_events(event_stream, data)
assert data[0][0].security_risk == ActionSecurityRisk.LOW
assert data[1][0].security_risk == expected_risk
assert data[2][0].security_risk == ActionSecurityRisk.LOW
with (
patch(f'{InvariantAnalyzer.__module__}.docker', mock_docker),
patch(f'{InvariantClient.__module__}.requests', mock_requests),
):
file_store = get_file_store('local', temp_dir)
event_stream = EventStream('main', file_store)
policy = """
from invariant.detectors import secrets
raise "Disallow writing secrets [risk=medium]" if:
(call: ToolCall)
call is tool:run_ipython
any(secrets(call.function.arguments.code))
"""
InvariantAnalyzer(event_stream, policy)
data = [
(MessageAction('Hello world!'), EventSource.USER),
(IPythonRunCellAction(code), EventSource.AGENT),
(IPythonRunCellAction('hello'), EventSource.AGENT),
]
add_events(event_stream, data)
assert data[0][0].security_risk == ActionSecurityRisk.LOW
assert data[1][0].security_risk == expected_risk
assert data[2][0].security_risk == ActionSecurityRisk.LOW
def test_unsafe_python_code(temp_dir: str):
@ -458,26 +527,48 @@ def default_config():
def test_check_usertask(
mock_litellm_completion, usertask, is_appropriate, default_config, temp_dir: str
):
file_store = get_file_store('local', temp_dir)
event_stream = EventStream('main', file_store)
analyzer = InvariantAnalyzer(event_stream)
mock_response = {'choices': [{'message': {'content': is_appropriate}}]}
mock_litellm_completion.return_value = mock_response
analyzer.guardrail_llm = LLM(config=default_config)
analyzer.check_browsing_alignment = True
data = [
(MessageAction(usertask), EventSource.USER),
]
add_events(event_stream, data)
event_list = list(event_stream.get_events())
mock_container = MagicMock()
mock_container.status = 'running'
mock_container.attrs = {
'NetworkSettings': {'Ports': {'8000/tcp': [{'HostPort': 34567}]}}
}
mock_docker = MagicMock()
mock_docker.from_env().containers.list.return_value = [mock_container]
if is_appropriate == 'No':
assert len(event_list) == 2
assert type(event_list[0]) == MessageAction
assert type(event_list[1]) == ChangeAgentStateAction
elif is_appropriate == 'Yes':
assert len(event_list) == 1
assert type(event_list[0]) == MessageAction
mock_requests = MagicMock()
mock_requests.get().json.return_value = {'id': 'mock-session-id'}
mock_requests.post().json.side_effect = [
{'monitor_id': 'mock-monitor-id'},
[],
[
'PolicyViolation(Vulnerability in python code [risk=medium], ranges=[<2 ranges>])'
],
]
with (
patch(f'{InvariantAnalyzer.__module__}.docker', mock_docker),
patch(f'{InvariantClient.__module__}.requests', mock_requests),
):
file_store = get_file_store('local', temp_dir)
event_stream = EventStream('main', file_store)
analyzer = InvariantAnalyzer(event_stream)
mock_response = {'choices': [{'message': {'content': is_appropriate}}]}
mock_litellm_completion.return_value = mock_response
analyzer.guardrail_llm = LLM(config=default_config)
analyzer.check_browsing_alignment = True
data = [
(MessageAction(usertask), EventSource.USER),
]
add_events(event_stream, data)
event_list = list(event_stream.get_events())
if is_appropriate == 'No':
assert len(event_list) == 2
assert type(event_list[0]) == MessageAction
assert type(event_list[1]) == ChangeAgentStateAction
elif is_appropriate == 'Yes':
assert len(event_list) == 1
assert type(event_list[0]) == MessageAction
@pytest.mark.parametrize(
@ -491,23 +582,45 @@ def test_check_usertask(
def test_check_fillaction(
mock_litellm_completion, fillaction, is_harmful, default_config, temp_dir: str
):
file_store = get_file_store('local', temp_dir)
event_stream = EventStream('main', file_store)
analyzer = InvariantAnalyzer(event_stream)
mock_response = {'choices': [{'message': {'content': is_harmful}}]}
mock_litellm_completion.return_value = mock_response
analyzer.guardrail_llm = LLM(config=default_config)
analyzer.check_browsing_alignment = True
data = [
(BrowseInteractiveAction(browser_actions=fillaction), EventSource.AGENT),
]
add_events(event_stream, data)
event_list = list(event_stream.get_events())
mock_container = MagicMock()
mock_container.status = 'running'
mock_container.attrs = {
'NetworkSettings': {'Ports': {'8000/tcp': [{'HostPort': 34567}]}}
}
mock_docker = MagicMock()
mock_docker.from_env().containers.list.return_value = [mock_container]
if is_harmful == 'Yes':
assert len(event_list) == 2
assert type(event_list[0]) == BrowseInteractiveAction
assert type(event_list[1]) == ChangeAgentStateAction
elif is_harmful == 'No':
assert len(event_list) == 1
assert type(event_list[0]) == BrowseInteractiveAction
mock_requests = MagicMock()
mock_requests.get().json.return_value = {'id': 'mock-session-id'}
mock_requests.post().json.side_effect = [
{'monitor_id': 'mock-monitor-id'},
[],
[
'PolicyViolation(Vulnerability in python code [risk=medium], ranges=[<2 ranges>])'
],
]
with (
patch(f'{InvariantAnalyzer.__module__}.docker', mock_docker),
patch(f'{InvariantClient.__module__}.requests', mock_requests),
):
file_store = get_file_store('local', temp_dir)
event_stream = EventStream('main', file_store)
analyzer = InvariantAnalyzer(event_stream)
mock_response = {'choices': [{'message': {'content': is_harmful}}]}
mock_litellm_completion.return_value = mock_response
analyzer.guardrail_llm = LLM(config=default_config)
analyzer.check_browsing_alignment = True
data = [
(BrowseInteractiveAction(browser_actions=fillaction), EventSource.AGENT),
]
add_events(event_stream, data)
event_list = list(event_stream.get_events())
if is_harmful == 'Yes':
assert len(event_list) == 2
assert type(event_list[0]) == BrowseInteractiveAction
assert type(event_list[1]) == ChangeAgentStateAction
elif is_harmful == 'No':
assert len(event_list) == 1
assert type(event_list[0]) == BrowseInteractiveAction