mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Co-authored-by: Graham Neubig <neubig@gmail.com>
This commit is contained in:
parent
907c65cc00
commit
6a6ce5f3ee
@ -0,0 +1,60 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { screen } from "@testing-library/react";
|
||||
import { renderWithProviders } from "test-utils";
|
||||
import { ExpandableMessage } from "#/components/features/chat/expandable-message";
|
||||
|
||||
describe("ExpandableMessage", () => {
|
||||
it("should render with neutral border for non-action messages", () => {
|
||||
renderWithProviders(<ExpandableMessage message="Hello" type="thought" />);
|
||||
const element = screen.getByText("Hello");
|
||||
const container = element.closest("div.flex.gap-2.items-center.justify-between");
|
||||
expect(container).toHaveClass("border-neutral-300");
|
||||
expect(screen.queryByTestId("status-icon")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should render with neutral border for error messages", () => {
|
||||
renderWithProviders(<ExpandableMessage message="Error occurred" type="error" />);
|
||||
const element = screen.getByText("Error occurred");
|
||||
const container = element.closest("div.flex.gap-2.items-center.justify-between");
|
||||
expect(container).toHaveClass("border-neutral-300");
|
||||
expect(screen.queryByTestId("status-icon")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should render with success icon for successful action messages", () => {
|
||||
renderWithProviders(
|
||||
<ExpandableMessage
|
||||
message="Command executed successfully"
|
||||
type="action"
|
||||
success={true}
|
||||
/>
|
||||
);
|
||||
const element = screen.getByText("Command executed successfully");
|
||||
const container = element.closest("div.flex.gap-2.items-center.justify-between");
|
||||
expect(container).toHaveClass("border-neutral-300");
|
||||
const icon = screen.getByTestId("status-icon");
|
||||
expect(icon).toHaveClass("fill-success");
|
||||
});
|
||||
|
||||
it("should render with error icon for failed action messages", () => {
|
||||
renderWithProviders(
|
||||
<ExpandableMessage
|
||||
message="Command failed"
|
||||
type="action"
|
||||
success={false}
|
||||
/>
|
||||
);
|
||||
const element = screen.getByText("Command failed");
|
||||
const container = element.closest("div.flex.gap-2.items-center.justify-between");
|
||||
expect(container).toHaveClass("border-neutral-300");
|
||||
const icon = screen.getByTestId("status-icon");
|
||||
expect(icon).toHaveClass("fill-danger");
|
||||
});
|
||||
|
||||
it("should render with neutral border and no icon for action messages without success prop", () => {
|
||||
renderWithProviders(<ExpandableMessage message="Running command" type="action" />);
|
||||
const element = screen.getByText("Running command");
|
||||
const container = element.closest("div.flex.gap-2.items-center.justify-between");
|
||||
expect(container).toHaveClass("border-neutral-300");
|
||||
expect(screen.queryByTestId("status-icon")).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
@ -6,17 +6,21 @@ import { code } from "../markdown/code";
|
||||
import { ol, ul } from "../markdown/list";
|
||||
import ArrowUp from "#/icons/angle-up-solid.svg?react";
|
||||
import ArrowDown from "#/icons/angle-down-solid.svg?react";
|
||||
import CheckCircle from "#/icons/check-circle-solid.svg?react";
|
||||
import XCircle from "#/icons/x-circle-solid.svg?react";
|
||||
|
||||
interface ExpandableMessageProps {
|
||||
id?: string;
|
||||
message: string;
|
||||
type: string;
|
||||
success?: boolean;
|
||||
}
|
||||
|
||||
export function ExpandableMessage({
|
||||
id,
|
||||
message,
|
||||
type,
|
||||
success,
|
||||
}: ExpandableMessageProps) {
|
||||
const { t, i18n } = useTranslation();
|
||||
const [showDetails, setShowDetails] = useState(true);
|
||||
@ -31,22 +35,14 @@ export function ExpandableMessage({
|
||||
}
|
||||
}, [id, message, i18n.language]);
|
||||
|
||||
const border = type === "error" ? "border-danger" : "border-neutral-300";
|
||||
const textColor = type === "error" ? "text-danger" : "text-neutral-300";
|
||||
let arrowClasses = "h-4 w-4 ml-2 inline";
|
||||
if (type === "error") {
|
||||
arrowClasses += " fill-danger";
|
||||
} else {
|
||||
arrowClasses += " fill-neutral-300";
|
||||
}
|
||||
const arrowClasses = "h-4 w-4 ml-2 inline fill-neutral-300";
|
||||
const statusIconClasses = "h-4 w-4 ml-2 inline";
|
||||
|
||||
return (
|
||||
<div
|
||||
className={`flex gap-2 items-center justify-start border-l-2 pl-2 my-2 py-2 ${border}`}
|
||||
>
|
||||
<div className="flex gap-2 items-center justify-between border-l-2 border-neutral-300 pl-2 my-2 py-2">
|
||||
<div className="text-sm leading-4 flex flex-col gap-2 max-w-full">
|
||||
{headline && (
|
||||
<p className={`${textColor} font-bold`}>
|
||||
<p className="text-neutral-300 font-bold">
|
||||
{headline}
|
||||
<button
|
||||
type="button"
|
||||
@ -75,6 +71,21 @@ export function ExpandableMessage({
|
||||
</Markdown>
|
||||
)}
|
||||
</div>
|
||||
{type === "action" && success !== undefined && (
|
||||
<div className="flex-shrink-0">
|
||||
{success ? (
|
||||
<CheckCircle
|
||||
data-testid="status-icon"
|
||||
className={`${statusIconClasses} fill-success`}
|
||||
/>
|
||||
) : (
|
||||
<XCircle
|
||||
data-testid="status-icon"
|
||||
className={`${statusIconClasses} fill-danger`}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@ -20,6 +20,7 @@ export function Messages({
|
||||
type={message.type}
|
||||
id={message.translationID}
|
||||
message={message.content}
|
||||
success={message.success}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
4
frontend/src/icons/check-circle-solid.svg
Normal file
4
frontend/src/icons/check-circle-solid.svg
Normal file
@ -0,0 +1,4 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 512 512">
|
||||
<path d="M256 512A256 256 0 1 0 256 0a256 256 0 1 0 0 512zM369 209L241 337c-9.4 9.4-24.6 9.4-33.9 0l-64-64c-9.4-9.4-9.4-24.6 0-33.9s24.6-9.4 33.9 0l47 47L335 175c9.4-9.4 24.6-9.4 33.9 0s9.4 24.6 0 33.9z"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 317 B |
4
frontend/src/icons/x-circle-solid.svg
Normal file
4
frontend/src/icons/x-circle-solid.svg
Normal file
@ -0,0 +1,4 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 512 512">
|
||||
<path d="M256 512A256 256 0 1 0 256 0a256 256 0 1 0 0 512zM175 175c9.4-9.4 24.6-9.4 33.9 0l47 47 47-47c9.4-9.4 24.6-9.4 33.9 0s9.4 24.6 0 33.9l-47 47 47 47c9.4 9.4 9.4 24.6 0 33.9s-24.6 9.4-33.9 0l-47-47-47 47c-9.4 9.4-24.6 9.4-33.9 0s-9.4-24.6 0-33.9l47-47-47-47c-9.4-9.4-9.4-24.6 0-33.9z"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 404 B |
1
frontend/src/message.d.ts
vendored
1
frontend/src/message.d.ts
vendored
@ -4,6 +4,7 @@ type Message = {
|
||||
timestamp: string;
|
||||
imageUrls?: string[];
|
||||
type?: "thought" | "error" | "action";
|
||||
success?: boolean;
|
||||
pending?: boolean;
|
||||
translationID?: string;
|
||||
eventID?: number;
|
||||
|
||||
@ -1,14 +1,25 @@
|
||||
import { createSlice, PayloadAction } from "@reduxjs/toolkit";
|
||||
|
||||
import { ActionSecurityRisk } from "#/state/security-analyzer-slice";
|
||||
import { OpenHandsObservation } from "#/types/core/observations";
|
||||
import {
|
||||
OpenHandsObservation,
|
||||
CommandObservation,
|
||||
IPythonObservation,
|
||||
} from "#/types/core/observations";
|
||||
import { OpenHandsAction } from "#/types/core/actions";
|
||||
import { OpenHandsEventType } from "#/types/core/base";
|
||||
|
||||
type SliceState = { messages: Message[] };
|
||||
|
||||
const MAX_CONTENT_LENGTH = 1000;
|
||||
|
||||
const HANDLED_ACTIONS = ["run", "run_ipython", "write", "read", "browse"];
|
||||
const HANDLED_ACTIONS: OpenHandsEventType[] = [
|
||||
"run",
|
||||
"run_ipython",
|
||||
"write",
|
||||
"read",
|
||||
"browse",
|
||||
];
|
||||
|
||||
function getRiskText(risk: ActionSecurityRisk) {
|
||||
switch (risk) {
|
||||
@ -131,6 +142,18 @@ export const chatSlice = createSlice({
|
||||
return;
|
||||
}
|
||||
causeMessage.translationID = translationID;
|
||||
// Set success property based on observation type
|
||||
if (observationID === "run") {
|
||||
const commandObs = observation.payload as CommandObservation;
|
||||
causeMessage.success = commandObs.extras.exit_code === 0;
|
||||
} else if (observationID === "run_ipython") {
|
||||
// For IPython, we consider it successful if there's no error message
|
||||
const ipythonObs = observation.payload as IPythonObservation;
|
||||
causeMessage.success = !ipythonObs.message
|
||||
.toLowerCase()
|
||||
.includes("error");
|
||||
}
|
||||
|
||||
if (observationID === "run" || observationID === "run_ipython") {
|
||||
let { content } = observation.payload;
|
||||
if (content.length > MAX_CONTENT_LENGTH) {
|
||||
|
||||
@ -52,6 +52,21 @@ export interface BrowseObservation extends OpenHandsObservationEvent<"browse"> {
|
||||
};
|
||||
}
|
||||
|
||||
export interface WriteObservation extends OpenHandsObservationEvent<"write"> {
|
||||
source: "agent";
|
||||
extras: {
|
||||
path: string;
|
||||
content: string;
|
||||
};
|
||||
}
|
||||
|
||||
export interface ReadObservation extends OpenHandsObservationEvent<"read"> {
|
||||
source: "agent";
|
||||
extras: {
|
||||
path: string;
|
||||
};
|
||||
}
|
||||
|
||||
export interface ErrorObservation extends OpenHandsObservationEvent<"error"> {
|
||||
source: "user";
|
||||
extras: {
|
||||
@ -65,4 +80,6 @@ export type OpenHandsObservation =
|
||||
| IPythonObservation
|
||||
| DelegateObservation
|
||||
| BrowseObservation
|
||||
| WriteObservation
|
||||
| ReadObservation
|
||||
| ErrorObservation;
|
||||
|
||||
@ -14,6 +14,7 @@ export default {
|
||||
'root-secondary': '#262626',
|
||||
'hyperlink': '#007AFF',
|
||||
'danger': '#EF3744',
|
||||
'success': '#4CAF50',
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@ -6,10 +6,31 @@ import { configureStore } from "@reduxjs/toolkit";
|
||||
// eslint-disable-next-line import/no-extraneous-dependencies
|
||||
import { RenderOptions, render } from "@testing-library/react";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import { I18nextProvider } from "react-i18next";
|
||||
import i18n from "i18next";
|
||||
import { initReactI18next } from "react-i18next";
|
||||
import { AppStore, RootState, rootReducer } from "./src/store";
|
||||
import { AuthProvider } from "#/context/auth-context";
|
||||
import { UserPrefsProvider } from "#/context/user-prefs-context";
|
||||
|
||||
// Initialize i18n for tests
|
||||
i18n
|
||||
.use(initReactI18next)
|
||||
.init({
|
||||
lng: "en",
|
||||
fallbackLng: "en",
|
||||
ns: ["translation"],
|
||||
defaultNS: "translation",
|
||||
resources: {
|
||||
en: {
|
||||
translation: {},
|
||||
},
|
||||
},
|
||||
interpolation: {
|
||||
escapeValue: false,
|
||||
},
|
||||
});
|
||||
|
||||
const setupStore = (preloadedState?: Partial<RootState>): AppStore =>
|
||||
configureStore({
|
||||
reducer: rootReducer,
|
||||
@ -40,7 +61,9 @@ export function renderWithProviders(
|
||||
<UserPrefsProvider>
|
||||
<AuthProvider>
|
||||
<QueryClientProvider client={new QueryClient()}>
|
||||
{children}
|
||||
<I18nextProvider i18n={i18n}>
|
||||
{children}
|
||||
</I18nextProvider>
|
||||
</QueryClientProvider>
|
||||
</AuthProvider>
|
||||
</UserPrefsProvider>
|
||||
|
||||
@ -12,7 +12,13 @@ HTMLElement.prototype.scrollTo = vi.fn();
|
||||
// Mock the i18n provider
|
||||
vi.mock("react-i18next", async (importOriginal) => ({
|
||||
...(await importOriginal<typeof import("react-i18next")>()),
|
||||
useTranslation: () => ({ t: (key: string) => key }),
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => key,
|
||||
i18n: {
|
||||
language: "en",
|
||||
exists: () => false,
|
||||
},
|
||||
}),
|
||||
}));
|
||||
|
||||
// Mock requests during tests
|
||||
|
||||
@ -110,4 +110,4 @@ The agent is implemented in two main files:
|
||||
2. `function_calling.py`: Tool definitions and function calling interface with:
|
||||
- Tool parameter specifications
|
||||
- Tool descriptions and examples
|
||||
- Function calling response parsing
|
||||
- Function calling response parsing
|
||||
|
||||
@ -23,6 +23,10 @@ class CmdOutputObservation(Observation):
|
||||
def message(self) -> str:
|
||||
return f'Command `{self.command}` executed with exit code {self.exit_code}.'
|
||||
|
||||
@property
|
||||
def success(self) -> bool:
|
||||
return not self.error
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'**CmdOutputObservation (source={self.source}, exit code={self.exit_code})**\n{self.content}'
|
||||
|
||||
@ -42,5 +46,9 @@ class IPythonRunCellObservation(Observation):
|
||||
def message(self) -> str:
|
||||
return 'Code executed in IPython cell.'
|
||||
|
||||
@property
|
||||
def success(self) -> bool:
|
||||
return True # IPython cells are always considered successful
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'**IPythonRunCellObservation**\n{self.content}'
|
||||
|
||||
@ -83,6 +83,9 @@ def event_to_dict(event: 'Event') -> dict:
|
||||
elif 'observation' in d:
|
||||
d['content'] = props.pop('content', '')
|
||||
d['extras'] = props
|
||||
# Include success field for CmdOutputObservation
|
||||
if hasattr(event, 'success'):
|
||||
d['success'] = event.success
|
||||
else:
|
||||
raise ValueError('Event must be either action or observation')
|
||||
return d
|
||||
|
||||
@ -50,4 +50,5 @@ def observation_from_dict(observation: dict) -> Observation:
|
||||
observation.pop('message', None)
|
||||
content = observation.pop('content', '')
|
||||
extras = observation.pop('extras', {})
|
||||
|
||||
return observation_class(content=content, **extras)
|
||||
|
||||
@ -98,6 +98,7 @@ reportlab = "*"
|
||||
[tool.coverage.run]
|
||||
concurrency = ["gevent"]
|
||||
|
||||
|
||||
[tool.poetry.group.runtime.dependencies]
|
||||
jupyterlab = "*"
|
||||
notebook = "*"
|
||||
@ -128,6 +129,7 @@ ignore = ["D1"]
|
||||
[tool.ruff.lint.pydocstyle]
|
||||
convention = "google"
|
||||
|
||||
|
||||
[tool.poetry.group.evaluation.dependencies]
|
||||
streamlit = "*"
|
||||
whatthepatch = "*"
|
||||
|
||||
27
tests/unit/test_command_success.py
Normal file
27
tests/unit/test_command_success.py
Normal file
@ -0,0 +1,27 @@
|
||||
from openhands.events.observation.commands import (
|
||||
CmdOutputObservation,
|
||||
IPythonRunCellObservation,
|
||||
)
|
||||
|
||||
|
||||
def test_cmd_output_success():
|
||||
# Test successful command
|
||||
obs = CmdOutputObservation(
|
||||
command_id=1, command='ls', content='file1.txt\nfile2.txt', exit_code=0
|
||||
)
|
||||
assert obs.success is True
|
||||
assert obs.error is False
|
||||
|
||||
# Test failed command
|
||||
obs = CmdOutputObservation(
|
||||
command_id=2, command='ls', content='No such file or directory', exit_code=1
|
||||
)
|
||||
assert obs.success is False
|
||||
assert obs.error is True
|
||||
|
||||
|
||||
def test_ipython_cell_success():
|
||||
# IPython cells are always successful
|
||||
obs = IPythonRunCellObservation(code='print("Hello")', content='Hello')
|
||||
assert obs.success is True
|
||||
assert obs.error is False
|
||||
18
tests/unit/test_event_serialization.py
Normal file
18
tests/unit/test_event_serialization.py
Normal file
@ -0,0 +1,18 @@
|
||||
from openhands.events.observation import CmdOutputObservation
|
||||
from openhands.events.serialization import event_to_dict
|
||||
|
||||
|
||||
def test_command_output_success_serialization():
|
||||
# Test successful command
|
||||
obs = CmdOutputObservation(
|
||||
command_id=1, command='ls', content='file1.txt\nfile2.txt', exit_code=0
|
||||
)
|
||||
serialized = event_to_dict(obs)
|
||||
assert serialized['success'] is True
|
||||
|
||||
# Test failed command
|
||||
obs = CmdOutputObservation(
|
||||
command_id=2, command='ls', content='No such file or directory', exit_code=1
|
||||
)
|
||||
serialized = event_to_dict(obs)
|
||||
assert serialized['success'] is False
|
||||
@ -40,36 +40,23 @@ def serialization_deserialization(
|
||||
|
||||
|
||||
# Additional tests for various observation subclasses can be included here
|
||||
def test_observation_event_props_serialization_deserialization():
|
||||
original_observation_dict = {
|
||||
'id': 42,
|
||||
'source': 'agent',
|
||||
'timestamp': '2021-08-01T12:00:00',
|
||||
'observation': 'run',
|
||||
'message': 'Command `ls -l` executed with exit code 0.',
|
||||
'extras': {
|
||||
'exit_code': 0,
|
||||
'command': 'ls -l',
|
||||
'command_id': 3,
|
||||
'hidden': False,
|
||||
'interpreter_details': '',
|
||||
},
|
||||
'content': 'foo.txt',
|
||||
}
|
||||
serialization_deserialization(original_observation_dict, CmdOutputObservation)
|
||||
def test_success_field_serialization():
|
||||
# Test success=True
|
||||
obs = CmdOutputObservation(
|
||||
content='Command succeeded',
|
||||
exit_code=0,
|
||||
command='ls -l',
|
||||
command_id=3,
|
||||
)
|
||||
serialized = event_to_dict(obs)
|
||||
assert serialized['success'] is True
|
||||
|
||||
|
||||
def test_command_output_observation_serialization_deserialization():
|
||||
original_observation_dict = {
|
||||
'observation': 'run',
|
||||
'extras': {
|
||||
'exit_code': 0,
|
||||
'command': 'ls -l',
|
||||
'command_id': 3,
|
||||
'hidden': False,
|
||||
'interpreter_details': '',
|
||||
},
|
||||
'message': 'Command `ls -l` executed with exit code 0.',
|
||||
'content': 'foo.txt',
|
||||
}
|
||||
serialization_deserialization(original_observation_dict, CmdOutputObservation)
|
||||
# Test success=False
|
||||
obs = CmdOutputObservation(
|
||||
content='No such file or directory',
|
||||
exit_code=1,
|
||||
command='ls -l',
|
||||
command_id=3,
|
||||
)
|
||||
serialized = event_to_dict(obs)
|
||||
assert serialized['success'] is False
|
||||
|
||||
@ -51,24 +51,48 @@ def add_events(event_stream: EventStream, data: list[tuple[Event, EventSource]])
|
||||
|
||||
|
||||
def test_msg(temp_dir: str):
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('main', file_store)
|
||||
policy = """
|
||||
raise "Disallow ABC [risk=medium]" if:
|
||||
(msg: Message)
|
||||
"ABC" in msg.content
|
||||
"""
|
||||
InvariantAnalyzer(event_stream, policy)
|
||||
data = [
|
||||
(MessageAction('Hello world!'), EventSource.USER),
|
||||
(MessageAction('AB!'), EventSource.AGENT),
|
||||
(MessageAction('Hello world!'), EventSource.USER),
|
||||
(MessageAction('ABC!'), EventSource.AGENT),
|
||||
mock_container = MagicMock()
|
||||
mock_container.status = 'running'
|
||||
mock_container.attrs = {
|
||||
'NetworkSettings': {'Ports': {'8000/tcp': [{'HostPort': 34567}]}}
|
||||
}
|
||||
mock_docker = MagicMock()
|
||||
mock_docker.from_env().containers.list.return_value = [mock_container]
|
||||
|
||||
mock_requests = MagicMock()
|
||||
mock_requests.get().json.return_value = {'id': 'mock-session-id'}
|
||||
mock_requests.post().json.side_effect = [
|
||||
{'monitor_id': 'mock-monitor-id'},
|
||||
[], # First check
|
||||
[], # Second check
|
||||
[], # Third check
|
||||
[
|
||||
'PolicyViolation(Disallow ABC [risk=medium], ranges=[<2 ranges>])'
|
||||
], # Fourth check
|
||||
]
|
||||
add_events(event_stream, data)
|
||||
for i in range(3):
|
||||
assert data[i][0].security_risk == ActionSecurityRisk.LOW
|
||||
assert data[3][0].security_risk == ActionSecurityRisk.MEDIUM
|
||||
|
||||
with (
|
||||
patch(f'{InvariantAnalyzer.__module__}.docker', mock_docker),
|
||||
patch(f'{InvariantClient.__module__}.requests', mock_requests),
|
||||
):
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('main', file_store)
|
||||
policy = """
|
||||
raise "Disallow ABC [risk=medium]" if:
|
||||
(msg: Message)
|
||||
"ABC" in msg.content
|
||||
"""
|
||||
InvariantAnalyzer(event_stream, policy)
|
||||
data = [
|
||||
(MessageAction('Hello world!'), EventSource.USER),
|
||||
(MessageAction('AB!'), EventSource.AGENT),
|
||||
(MessageAction('Hello world!'), EventSource.USER),
|
||||
(MessageAction('ABC!'), EventSource.AGENT),
|
||||
]
|
||||
add_events(event_stream, data)
|
||||
for i in range(3):
|
||||
assert data[i][0].security_risk == ActionSecurityRisk.LOW
|
||||
assert data[3][0].security_risk == ActionSecurityRisk.MEDIUM
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -76,22 +100,44 @@ def test_msg(temp_dir: str):
|
||||
[('rm -rf root_dir', ActionSecurityRisk.MEDIUM), ['ls', ActionSecurityRisk.LOW]],
|
||||
)
|
||||
def test_cmd(cmd, expected_risk, temp_dir: str):
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('main', file_store)
|
||||
policy = """
|
||||
raise "Disallow rm -rf [risk=medium]" if:
|
||||
(call: ToolCall)
|
||||
call is tool:run
|
||||
match("rm -rf", call.function.arguments.command)
|
||||
"""
|
||||
InvariantAnalyzer(event_stream, policy)
|
||||
data = [
|
||||
(MessageAction('Hello world!'), EventSource.USER),
|
||||
(CmdRunAction(cmd), EventSource.USER),
|
||||
mock_container = MagicMock()
|
||||
mock_container.status = 'running'
|
||||
mock_container.attrs = {
|
||||
'NetworkSettings': {'Ports': {'8000/tcp': [{'HostPort': 34567}]}}
|
||||
}
|
||||
mock_docker = MagicMock()
|
||||
mock_docker.from_env().containers.list.return_value = [mock_container]
|
||||
|
||||
mock_requests = MagicMock()
|
||||
mock_requests.get().json.return_value = {'id': 'mock-session-id'}
|
||||
mock_requests.post().json.side_effect = [
|
||||
{'monitor_id': 'mock-monitor-id'},
|
||||
[], # First check
|
||||
['PolicyViolation(Disallow rm -rf [risk=medium], ranges=[<2 ranges>])']
|
||||
if expected_risk == ActionSecurityRisk.MEDIUM
|
||||
else [], # Second check
|
||||
]
|
||||
add_events(event_stream, data)
|
||||
assert data[0][0].security_risk == ActionSecurityRisk.LOW
|
||||
assert data[1][0].security_risk == expected_risk
|
||||
|
||||
with (
|
||||
patch(f'{InvariantAnalyzer.__module__}.docker', mock_docker),
|
||||
patch(f'{InvariantClient.__module__}.requests', mock_requests),
|
||||
):
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('main', file_store)
|
||||
policy = """
|
||||
raise "Disallow rm -rf [risk=medium]" if:
|
||||
(call: ToolCall)
|
||||
call is tool:run
|
||||
match("rm -rf", call.function.arguments.command)
|
||||
"""
|
||||
InvariantAnalyzer(event_stream, policy)
|
||||
data = [
|
||||
(MessageAction('Hello world!'), EventSource.USER),
|
||||
(CmdRunAction(cmd), EventSource.USER),
|
||||
]
|
||||
add_events(event_stream, data)
|
||||
assert data[0][0].security_risk == ActionSecurityRisk.LOW
|
||||
assert data[1][0].security_risk == expected_risk
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -102,26 +148,49 @@ def test_cmd(cmd, expected_risk, temp_dir: str):
|
||||
],
|
||||
)
|
||||
def test_leak_secrets(code, expected_risk, temp_dir: str):
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('main', file_store)
|
||||
policy = """
|
||||
from invariant.detectors import secrets
|
||||
mock_container = MagicMock()
|
||||
mock_container.status = 'running'
|
||||
mock_container.attrs = {
|
||||
'NetworkSettings': {'Ports': {'8000/tcp': [{'HostPort': 34567}]}}
|
||||
}
|
||||
mock_docker = MagicMock()
|
||||
mock_docker.from_env().containers.list.return_value = [mock_container]
|
||||
|
||||
raise "Disallow writing secrets [risk=medium]" if:
|
||||
(call: ToolCall)
|
||||
call is tool:run_ipython
|
||||
any(secrets(call.function.arguments.code))
|
||||
"""
|
||||
InvariantAnalyzer(event_stream, policy)
|
||||
data = [
|
||||
(MessageAction('Hello world!'), EventSource.USER),
|
||||
(IPythonRunCellAction(code), EventSource.AGENT),
|
||||
(IPythonRunCellAction('hello'), EventSource.AGENT),
|
||||
mock_requests = MagicMock()
|
||||
mock_requests.get().json.return_value = {'id': 'mock-session-id'}
|
||||
mock_requests.post().json.side_effect = [
|
||||
{'monitor_id': 'mock-monitor-id'},
|
||||
[], # First check
|
||||
['PolicyViolation(Disallow writing secrets [risk=medium], ranges=[<2 ranges>])']
|
||||
if expected_risk == ActionSecurityRisk.MEDIUM
|
||||
else [], # Second check
|
||||
[], # Third check
|
||||
]
|
||||
add_events(event_stream, data)
|
||||
assert data[0][0].security_risk == ActionSecurityRisk.LOW
|
||||
assert data[1][0].security_risk == expected_risk
|
||||
assert data[2][0].security_risk == ActionSecurityRisk.LOW
|
||||
|
||||
with (
|
||||
patch(f'{InvariantAnalyzer.__module__}.docker', mock_docker),
|
||||
patch(f'{InvariantClient.__module__}.requests', mock_requests),
|
||||
):
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('main', file_store)
|
||||
policy = """
|
||||
from invariant.detectors import secrets
|
||||
|
||||
raise "Disallow writing secrets [risk=medium]" if:
|
||||
(call: ToolCall)
|
||||
call is tool:run_ipython
|
||||
any(secrets(call.function.arguments.code))
|
||||
"""
|
||||
InvariantAnalyzer(event_stream, policy)
|
||||
data = [
|
||||
(MessageAction('Hello world!'), EventSource.USER),
|
||||
(IPythonRunCellAction(code), EventSource.AGENT),
|
||||
(IPythonRunCellAction('hello'), EventSource.AGENT),
|
||||
]
|
||||
add_events(event_stream, data)
|
||||
assert data[0][0].security_risk == ActionSecurityRisk.LOW
|
||||
assert data[1][0].security_risk == expected_risk
|
||||
assert data[2][0].security_risk == ActionSecurityRisk.LOW
|
||||
|
||||
|
||||
def test_unsafe_python_code(temp_dir: str):
|
||||
@ -458,26 +527,48 @@ def default_config():
|
||||
def test_check_usertask(
|
||||
mock_litellm_completion, usertask, is_appropriate, default_config, temp_dir: str
|
||||
):
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('main', file_store)
|
||||
analyzer = InvariantAnalyzer(event_stream)
|
||||
mock_response = {'choices': [{'message': {'content': is_appropriate}}]}
|
||||
mock_litellm_completion.return_value = mock_response
|
||||
analyzer.guardrail_llm = LLM(config=default_config)
|
||||
analyzer.check_browsing_alignment = True
|
||||
data = [
|
||||
(MessageAction(usertask), EventSource.USER),
|
||||
]
|
||||
add_events(event_stream, data)
|
||||
event_list = list(event_stream.get_events())
|
||||
mock_container = MagicMock()
|
||||
mock_container.status = 'running'
|
||||
mock_container.attrs = {
|
||||
'NetworkSettings': {'Ports': {'8000/tcp': [{'HostPort': 34567}]}}
|
||||
}
|
||||
mock_docker = MagicMock()
|
||||
mock_docker.from_env().containers.list.return_value = [mock_container]
|
||||
|
||||
if is_appropriate == 'No':
|
||||
assert len(event_list) == 2
|
||||
assert type(event_list[0]) == MessageAction
|
||||
assert type(event_list[1]) == ChangeAgentStateAction
|
||||
elif is_appropriate == 'Yes':
|
||||
assert len(event_list) == 1
|
||||
assert type(event_list[0]) == MessageAction
|
||||
mock_requests = MagicMock()
|
||||
mock_requests.get().json.return_value = {'id': 'mock-session-id'}
|
||||
mock_requests.post().json.side_effect = [
|
||||
{'monitor_id': 'mock-monitor-id'},
|
||||
[],
|
||||
[
|
||||
'PolicyViolation(Vulnerability in python code [risk=medium], ranges=[<2 ranges>])'
|
||||
],
|
||||
]
|
||||
|
||||
with (
|
||||
patch(f'{InvariantAnalyzer.__module__}.docker', mock_docker),
|
||||
patch(f'{InvariantClient.__module__}.requests', mock_requests),
|
||||
):
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('main', file_store)
|
||||
analyzer = InvariantAnalyzer(event_stream)
|
||||
mock_response = {'choices': [{'message': {'content': is_appropriate}}]}
|
||||
mock_litellm_completion.return_value = mock_response
|
||||
analyzer.guardrail_llm = LLM(config=default_config)
|
||||
analyzer.check_browsing_alignment = True
|
||||
data = [
|
||||
(MessageAction(usertask), EventSource.USER),
|
||||
]
|
||||
add_events(event_stream, data)
|
||||
event_list = list(event_stream.get_events())
|
||||
|
||||
if is_appropriate == 'No':
|
||||
assert len(event_list) == 2
|
||||
assert type(event_list[0]) == MessageAction
|
||||
assert type(event_list[1]) == ChangeAgentStateAction
|
||||
elif is_appropriate == 'Yes':
|
||||
assert len(event_list) == 1
|
||||
assert type(event_list[0]) == MessageAction
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -491,23 +582,45 @@ def test_check_usertask(
|
||||
def test_check_fillaction(
|
||||
mock_litellm_completion, fillaction, is_harmful, default_config, temp_dir: str
|
||||
):
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('main', file_store)
|
||||
analyzer = InvariantAnalyzer(event_stream)
|
||||
mock_response = {'choices': [{'message': {'content': is_harmful}}]}
|
||||
mock_litellm_completion.return_value = mock_response
|
||||
analyzer.guardrail_llm = LLM(config=default_config)
|
||||
analyzer.check_browsing_alignment = True
|
||||
data = [
|
||||
(BrowseInteractiveAction(browser_actions=fillaction), EventSource.AGENT),
|
||||
]
|
||||
add_events(event_stream, data)
|
||||
event_list = list(event_stream.get_events())
|
||||
mock_container = MagicMock()
|
||||
mock_container.status = 'running'
|
||||
mock_container.attrs = {
|
||||
'NetworkSettings': {'Ports': {'8000/tcp': [{'HostPort': 34567}]}}
|
||||
}
|
||||
mock_docker = MagicMock()
|
||||
mock_docker.from_env().containers.list.return_value = [mock_container]
|
||||
|
||||
if is_harmful == 'Yes':
|
||||
assert len(event_list) == 2
|
||||
assert type(event_list[0]) == BrowseInteractiveAction
|
||||
assert type(event_list[1]) == ChangeAgentStateAction
|
||||
elif is_harmful == 'No':
|
||||
assert len(event_list) == 1
|
||||
assert type(event_list[0]) == BrowseInteractiveAction
|
||||
mock_requests = MagicMock()
|
||||
mock_requests.get().json.return_value = {'id': 'mock-session-id'}
|
||||
mock_requests.post().json.side_effect = [
|
||||
{'monitor_id': 'mock-monitor-id'},
|
||||
[],
|
||||
[
|
||||
'PolicyViolation(Vulnerability in python code [risk=medium], ranges=[<2 ranges>])'
|
||||
],
|
||||
]
|
||||
|
||||
with (
|
||||
patch(f'{InvariantAnalyzer.__module__}.docker', mock_docker),
|
||||
patch(f'{InvariantClient.__module__}.requests', mock_requests),
|
||||
):
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('main', file_store)
|
||||
analyzer = InvariantAnalyzer(event_stream)
|
||||
mock_response = {'choices': [{'message': {'content': is_harmful}}]}
|
||||
mock_litellm_completion.return_value = mock_response
|
||||
analyzer.guardrail_llm = LLM(config=default_config)
|
||||
analyzer.check_browsing_alignment = True
|
||||
data = [
|
||||
(BrowseInteractiveAction(browser_actions=fillaction), EventSource.AGENT),
|
||||
]
|
||||
add_events(event_stream, data)
|
||||
event_list = list(event_stream.get_events())
|
||||
|
||||
if is_harmful == 'Yes':
|
||||
assert len(event_list) == 2
|
||||
assert type(event_list[0]) == BrowseInteractiveAction
|
||||
assert type(event_list[1]) == ChangeAgentStateAction
|
||||
elif is_harmful == 'No':
|
||||
assert len(event_list) == 1
|
||||
assert type(event_list[0]) == BrowseInteractiveAction
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user