diff --git a/frontend/__tests__/components/chat/expandable-message.test.tsx b/frontend/__tests__/components/chat/expandable-message.test.tsx
new file mode 100644
index 0000000000..8eab988339
--- /dev/null
+++ b/frontend/__tests__/components/chat/expandable-message.test.tsx
@@ -0,0 +1,60 @@
+import { describe, expect, it } from "vitest";
+import { screen } from "@testing-library/react";
+import { renderWithProviders } from "test-utils";
+import { ExpandableMessage } from "#/components/features/chat/expandable-message";
+
+describe("ExpandableMessage", () => {
+ it("should render with neutral border for non-action messages", () => {
+ renderWithProviders();
+ const element = screen.getByText("Hello");
+ const container = element.closest("div.flex.gap-2.items-center.justify-between");
+ expect(container).toHaveClass("border-neutral-300");
+ expect(screen.queryByTestId("status-icon")).not.toBeInTheDocument();
+ });
+
+ it("should render with neutral border for error messages", () => {
+ renderWithProviders();
+ const element = screen.getByText("Error occurred");
+ const container = element.closest("div.flex.gap-2.items-center.justify-between");
+ expect(container).toHaveClass("border-neutral-300");
+ expect(screen.queryByTestId("status-icon")).not.toBeInTheDocument();
+ });
+
+ it("should render with success icon for successful action messages", () => {
+ renderWithProviders(
+
+ );
+ const element = screen.getByText("Command executed successfully");
+ const container = element.closest("div.flex.gap-2.items-center.justify-between");
+ expect(container).toHaveClass("border-neutral-300");
+ const icon = screen.getByTestId("status-icon");
+ expect(icon).toHaveClass("fill-success");
+ });
+
+ it("should render with error icon for failed action messages", () => {
+ renderWithProviders(
+
+ );
+ const element = screen.getByText("Command failed");
+ const container = element.closest("div.flex.gap-2.items-center.justify-between");
+ expect(container).toHaveClass("border-neutral-300");
+ const icon = screen.getByTestId("status-icon");
+ expect(icon).toHaveClass("fill-danger");
+ });
+
+ it("should render with neutral border and no icon for action messages without success prop", () => {
+ renderWithProviders();
+ const element = screen.getByText("Running command");
+ const container = element.closest("div.flex.gap-2.items-center.justify-between");
+ expect(container).toHaveClass("border-neutral-300");
+ expect(screen.queryByTestId("status-icon")).not.toBeInTheDocument();
+ });
+});
diff --git a/frontend/src/components/features/chat/expandable-message.tsx b/frontend/src/components/features/chat/expandable-message.tsx
index f42b3f0b13..6ebcaa3aee 100644
--- a/frontend/src/components/features/chat/expandable-message.tsx
+++ b/frontend/src/components/features/chat/expandable-message.tsx
@@ -6,17 +6,21 @@ import { code } from "../markdown/code";
import { ol, ul } from "../markdown/list";
import ArrowUp from "#/icons/angle-up-solid.svg?react";
import ArrowDown from "#/icons/angle-down-solid.svg?react";
+import CheckCircle from "#/icons/check-circle-solid.svg?react";
+import XCircle from "#/icons/x-circle-solid.svg?react";
interface ExpandableMessageProps {
id?: string;
message: string;
type: string;
+ success?: boolean;
}
export function ExpandableMessage({
id,
message,
type,
+ success,
}: ExpandableMessageProps) {
const { t, i18n } = useTranslation();
const [showDetails, setShowDetails] = useState(true);
@@ -31,22 +35,14 @@ export function ExpandableMessage({
}
}, [id, message, i18n.language]);
- const border = type === "error" ? "border-danger" : "border-neutral-300";
- const textColor = type === "error" ? "text-danger" : "text-neutral-300";
- let arrowClasses = "h-4 w-4 ml-2 inline";
- if (type === "error") {
- arrowClasses += " fill-danger";
- } else {
- arrowClasses += " fill-neutral-300";
- }
+ const arrowClasses = "h-4 w-4 ml-2 inline fill-neutral-300";
+ const statusIconClasses = "h-4 w-4 ml-2 inline";
return (
-
+
{headline && (
-
+
{headline}
+ {type === "action" && success !== undefined && (
+
+ {success ? (
+
+ ) : (
+
+ )}
+
+ )}
);
}
diff --git a/frontend/src/components/features/chat/messages.tsx b/frontend/src/components/features/chat/messages.tsx
index 8b0d703b75..e1bd346374 100644
--- a/frontend/src/components/features/chat/messages.tsx
+++ b/frontend/src/components/features/chat/messages.tsx
@@ -20,6 +20,7 @@ export function Messages({
type={message.type}
id={message.translationID}
message={message.content}
+ success={message.success}
/>
);
}
diff --git a/frontend/src/icons/check-circle-solid.svg b/frontend/src/icons/check-circle-solid.svg
new file mode 100644
index 0000000000..a07362b4ab
--- /dev/null
+++ b/frontend/src/icons/check-circle-solid.svg
@@ -0,0 +1,4 @@
+
+
diff --git a/frontend/src/icons/x-circle-solid.svg b/frontend/src/icons/x-circle-solid.svg
new file mode 100644
index 0000000000..f673bbf0b1
--- /dev/null
+++ b/frontend/src/icons/x-circle-solid.svg
@@ -0,0 +1,4 @@
+
+
diff --git a/frontend/src/message.d.ts b/frontend/src/message.d.ts
index 5b70e39c8f..65bd7e0cb1 100644
--- a/frontend/src/message.d.ts
+++ b/frontend/src/message.d.ts
@@ -4,6 +4,7 @@ type Message = {
timestamp: string;
imageUrls?: string[];
type?: "thought" | "error" | "action";
+ success?: boolean;
pending?: boolean;
translationID?: string;
eventID?: number;
diff --git a/frontend/src/state/chat-slice.ts b/frontend/src/state/chat-slice.ts
index df24236c2c..47d2b65175 100644
--- a/frontend/src/state/chat-slice.ts
+++ b/frontend/src/state/chat-slice.ts
@@ -1,14 +1,25 @@
import { createSlice, PayloadAction } from "@reduxjs/toolkit";
import { ActionSecurityRisk } from "#/state/security-analyzer-slice";
-import { OpenHandsObservation } from "#/types/core/observations";
+import {
+ OpenHandsObservation,
+ CommandObservation,
+ IPythonObservation,
+} from "#/types/core/observations";
import { OpenHandsAction } from "#/types/core/actions";
+import { OpenHandsEventType } from "#/types/core/base";
type SliceState = { messages: Message[] };
const MAX_CONTENT_LENGTH = 1000;
-const HANDLED_ACTIONS = ["run", "run_ipython", "write", "read", "browse"];
+const HANDLED_ACTIONS: OpenHandsEventType[] = [
+ "run",
+ "run_ipython",
+ "write",
+ "read",
+ "browse",
+];
function getRiskText(risk: ActionSecurityRisk) {
switch (risk) {
@@ -131,6 +142,18 @@ export const chatSlice = createSlice({
return;
}
causeMessage.translationID = translationID;
+ // Set success property based on observation type
+ if (observationID === "run") {
+ const commandObs = observation.payload as CommandObservation;
+ causeMessage.success = commandObs.extras.exit_code === 0;
+ } else if (observationID === "run_ipython") {
+ // For IPython, we consider it successful if there's no error message
+ const ipythonObs = observation.payload as IPythonObservation;
+ causeMessage.success = !ipythonObs.message
+ .toLowerCase()
+ .includes("error");
+ }
+
if (observationID === "run" || observationID === "run_ipython") {
let { content } = observation.payload;
if (content.length > MAX_CONTENT_LENGTH) {
diff --git a/frontend/src/types/core/observations.ts b/frontend/src/types/core/observations.ts
index 0b95099a83..7ddc3f05dd 100644
--- a/frontend/src/types/core/observations.ts
+++ b/frontend/src/types/core/observations.ts
@@ -52,6 +52,21 @@ export interface BrowseObservation extends OpenHandsObservationEvent<"browse"> {
};
}
+export interface WriteObservation extends OpenHandsObservationEvent<"write"> {
+ source: "agent";
+ extras: {
+ path: string;
+ content: string;
+ };
+}
+
+export interface ReadObservation extends OpenHandsObservationEvent<"read"> {
+ source: "agent";
+ extras: {
+ path: string;
+ };
+}
+
export interface ErrorObservation extends OpenHandsObservationEvent<"error"> {
source: "user";
extras: {
@@ -65,4 +80,6 @@ export type OpenHandsObservation =
| IPythonObservation
| DelegateObservation
| BrowseObservation
+ | WriteObservation
+ | ReadObservation
| ErrorObservation;
diff --git a/frontend/tailwind.config.js b/frontend/tailwind.config.js
index 1e665d23fd..1a57ebcd8b 100644
--- a/frontend/tailwind.config.js
+++ b/frontend/tailwind.config.js
@@ -14,6 +14,7 @@ export default {
'root-secondary': '#262626',
'hyperlink': '#007AFF',
'danger': '#EF3744',
+ 'success': '#4CAF50',
},
},
},
diff --git a/frontend/test-utils.tsx b/frontend/test-utils.tsx
index 4b336602fb..6739e3be6e 100644
--- a/frontend/test-utils.tsx
+++ b/frontend/test-utils.tsx
@@ -6,10 +6,31 @@ import { configureStore } from "@reduxjs/toolkit";
// eslint-disable-next-line import/no-extraneous-dependencies
import { RenderOptions, render } from "@testing-library/react";
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
+import { I18nextProvider } from "react-i18next";
+import i18n from "i18next";
+import { initReactI18next } from "react-i18next";
import { AppStore, RootState, rootReducer } from "./src/store";
import { AuthProvider } from "#/context/auth-context";
import { UserPrefsProvider } from "#/context/user-prefs-context";
+// Initialize i18n for tests
+i18n
+ .use(initReactI18next)
+ .init({
+ lng: "en",
+ fallbackLng: "en",
+ ns: ["translation"],
+ defaultNS: "translation",
+ resources: {
+ en: {
+ translation: {},
+ },
+ },
+ interpolation: {
+ escapeValue: false,
+ },
+ });
+
const setupStore = (preloadedState?: Partial
): AppStore =>
configureStore({
reducer: rootReducer,
@@ -40,7 +61,9 @@ export function renderWithProviders(
- {children}
+
+ {children}
+
diff --git a/frontend/vitest.setup.ts b/frontend/vitest.setup.ts
index 105337e75e..e9a89c8677 100644
--- a/frontend/vitest.setup.ts
+++ b/frontend/vitest.setup.ts
@@ -12,7 +12,13 @@ HTMLElement.prototype.scrollTo = vi.fn();
// Mock the i18n provider
vi.mock("react-i18next", async (importOriginal) => ({
...(await importOriginal()),
- useTranslation: () => ({ t: (key: string) => key }),
+ useTranslation: () => ({
+ t: (key: string) => key,
+ i18n: {
+ language: "en",
+ exists: () => false,
+ },
+ }),
}));
// Mock requests during tests
diff --git a/openhands/agenthub/codeact_agent/README.md b/openhands/agenthub/codeact_agent/README.md
index 9a5093820e..0e15939cdf 100644
--- a/openhands/agenthub/codeact_agent/README.md
+++ b/openhands/agenthub/codeact_agent/README.md
@@ -110,4 +110,4 @@ The agent is implemented in two main files:
2. `function_calling.py`: Tool definitions and function calling interface with:
- Tool parameter specifications
- Tool descriptions and examples
- - Function calling response parsing
\ No newline at end of file
+ - Function calling response parsing
diff --git a/openhands/events/observation/commands.py b/openhands/events/observation/commands.py
index a182168e69..b522b5c472 100644
--- a/openhands/events/observation/commands.py
+++ b/openhands/events/observation/commands.py
@@ -23,6 +23,10 @@ class CmdOutputObservation(Observation):
def message(self) -> str:
return f'Command `{self.command}` executed with exit code {self.exit_code}.'
+ @property
+ def success(self) -> bool:
+ return not self.error
+
def __str__(self) -> str:
return f'**CmdOutputObservation (source={self.source}, exit code={self.exit_code})**\n{self.content}'
@@ -42,5 +46,9 @@ class IPythonRunCellObservation(Observation):
def message(self) -> str:
return 'Code executed in IPython cell.'
+ @property
+ def success(self) -> bool:
+ return True # IPython cells are always considered successful
+
def __str__(self) -> str:
return f'**IPythonRunCellObservation**\n{self.content}'
diff --git a/openhands/events/serialization/event.py b/openhands/events/serialization/event.py
index 78f7940626..6ee82a1cc8 100644
--- a/openhands/events/serialization/event.py
+++ b/openhands/events/serialization/event.py
@@ -83,6 +83,9 @@ def event_to_dict(event: 'Event') -> dict:
elif 'observation' in d:
d['content'] = props.pop('content', '')
d['extras'] = props
+ # Include success field for CmdOutputObservation
+ if hasattr(event, 'success'):
+ d['success'] = event.success
else:
raise ValueError('Event must be either action or observation')
return d
diff --git a/openhands/events/serialization/observation.py b/openhands/events/serialization/observation.py
index 9030ccb1e1..d9d8dc51ad 100644
--- a/openhands/events/serialization/observation.py
+++ b/openhands/events/serialization/observation.py
@@ -50,4 +50,5 @@ def observation_from_dict(observation: dict) -> Observation:
observation.pop('message', None)
content = observation.pop('content', '')
extras = observation.pop('extras', {})
+
return observation_class(content=content, **extras)
diff --git a/pyproject.toml b/pyproject.toml
index acae091bb6..aa81db8e55 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -98,6 +98,7 @@ reportlab = "*"
[tool.coverage.run]
concurrency = ["gevent"]
+
[tool.poetry.group.runtime.dependencies]
jupyterlab = "*"
notebook = "*"
@@ -128,6 +129,7 @@ ignore = ["D1"]
[tool.ruff.lint.pydocstyle]
convention = "google"
+
[tool.poetry.group.evaluation.dependencies]
streamlit = "*"
whatthepatch = "*"
diff --git a/tests/unit/test_command_success.py b/tests/unit/test_command_success.py
new file mode 100644
index 0000000000..b52ceb4815
--- /dev/null
+++ b/tests/unit/test_command_success.py
@@ -0,0 +1,27 @@
+from openhands.events.observation.commands import (
+ CmdOutputObservation,
+ IPythonRunCellObservation,
+)
+
+
+def test_cmd_output_success():
+ # Test successful command
+ obs = CmdOutputObservation(
+ command_id=1, command='ls', content='file1.txt\nfile2.txt', exit_code=0
+ )
+ assert obs.success is True
+ assert obs.error is False
+
+ # Test failed command
+ obs = CmdOutputObservation(
+ command_id=2, command='ls', content='No such file or directory', exit_code=1
+ )
+ assert obs.success is False
+ assert obs.error is True
+
+
+def test_ipython_cell_success():
+ # IPython cells are always successful
+ obs = IPythonRunCellObservation(code='print("Hello")', content='Hello')
+ assert obs.success is True
+ assert obs.error is False
diff --git a/tests/unit/test_event_serialization.py b/tests/unit/test_event_serialization.py
new file mode 100644
index 0000000000..d1989a30bb
--- /dev/null
+++ b/tests/unit/test_event_serialization.py
@@ -0,0 +1,18 @@
+from openhands.events.observation import CmdOutputObservation
+from openhands.events.serialization import event_to_dict
+
+
+def test_command_output_success_serialization():
+ # Test successful command
+ obs = CmdOutputObservation(
+ command_id=1, command='ls', content='file1.txt\nfile2.txt', exit_code=0
+ )
+ serialized = event_to_dict(obs)
+ assert serialized['success'] is True
+
+ # Test failed command
+ obs = CmdOutputObservation(
+ command_id=2, command='ls', content='No such file or directory', exit_code=1
+ )
+ serialized = event_to_dict(obs)
+ assert serialized['success'] is False
diff --git a/tests/unit/test_observation_serialization.py b/tests/unit/test_observation_serialization.py
index ae636ddf56..67a95449b7 100644
--- a/tests/unit/test_observation_serialization.py
+++ b/tests/unit/test_observation_serialization.py
@@ -40,36 +40,23 @@ def serialization_deserialization(
# Additional tests for various observation subclasses can be included here
-def test_observation_event_props_serialization_deserialization():
- original_observation_dict = {
- 'id': 42,
- 'source': 'agent',
- 'timestamp': '2021-08-01T12:00:00',
- 'observation': 'run',
- 'message': 'Command `ls -l` executed with exit code 0.',
- 'extras': {
- 'exit_code': 0,
- 'command': 'ls -l',
- 'command_id': 3,
- 'hidden': False,
- 'interpreter_details': '',
- },
- 'content': 'foo.txt',
- }
- serialization_deserialization(original_observation_dict, CmdOutputObservation)
+def test_success_field_serialization():
+ # Test success=True
+ obs = CmdOutputObservation(
+ content='Command succeeded',
+ exit_code=0,
+ command='ls -l',
+ command_id=3,
+ )
+ serialized = event_to_dict(obs)
+ assert serialized['success'] is True
-
-def test_command_output_observation_serialization_deserialization():
- original_observation_dict = {
- 'observation': 'run',
- 'extras': {
- 'exit_code': 0,
- 'command': 'ls -l',
- 'command_id': 3,
- 'hidden': False,
- 'interpreter_details': '',
- },
- 'message': 'Command `ls -l` executed with exit code 0.',
- 'content': 'foo.txt',
- }
- serialization_deserialization(original_observation_dict, CmdOutputObservation)
+ # Test success=False
+ obs = CmdOutputObservation(
+ content='No such file or directory',
+ exit_code=1,
+ command='ls -l',
+ command_id=3,
+ )
+ serialized = event_to_dict(obs)
+ assert serialized['success'] is False
diff --git a/tests/unit/test_security.py b/tests/unit/test_security.py
index fab27a3ec2..a36c66104f 100644
--- a/tests/unit/test_security.py
+++ b/tests/unit/test_security.py
@@ -51,24 +51,48 @@ def add_events(event_stream: EventStream, data: list[tuple[Event, EventSource]])
def test_msg(temp_dir: str):
- file_store = get_file_store('local', temp_dir)
- event_stream = EventStream('main', file_store)
- policy = """
- raise "Disallow ABC [risk=medium]" if:
- (msg: Message)
- "ABC" in msg.content
- """
- InvariantAnalyzer(event_stream, policy)
- data = [
- (MessageAction('Hello world!'), EventSource.USER),
- (MessageAction('AB!'), EventSource.AGENT),
- (MessageAction('Hello world!'), EventSource.USER),
- (MessageAction('ABC!'), EventSource.AGENT),
+ mock_container = MagicMock()
+ mock_container.status = 'running'
+ mock_container.attrs = {
+ 'NetworkSettings': {'Ports': {'8000/tcp': [{'HostPort': 34567}]}}
+ }
+ mock_docker = MagicMock()
+ mock_docker.from_env().containers.list.return_value = [mock_container]
+
+ mock_requests = MagicMock()
+ mock_requests.get().json.return_value = {'id': 'mock-session-id'}
+ mock_requests.post().json.side_effect = [
+ {'monitor_id': 'mock-monitor-id'},
+ [], # First check
+ [], # Second check
+ [], # Third check
+ [
+ 'PolicyViolation(Disallow ABC [risk=medium], ranges=[<2 ranges>])'
+ ], # Fourth check
]
- add_events(event_stream, data)
- for i in range(3):
- assert data[i][0].security_risk == ActionSecurityRisk.LOW
- assert data[3][0].security_risk == ActionSecurityRisk.MEDIUM
+
+ with (
+ patch(f'{InvariantAnalyzer.__module__}.docker', mock_docker),
+ patch(f'{InvariantClient.__module__}.requests', mock_requests),
+ ):
+ file_store = get_file_store('local', temp_dir)
+ event_stream = EventStream('main', file_store)
+ policy = """
+ raise "Disallow ABC [risk=medium]" if:
+ (msg: Message)
+ "ABC" in msg.content
+ """
+ InvariantAnalyzer(event_stream, policy)
+ data = [
+ (MessageAction('Hello world!'), EventSource.USER),
+ (MessageAction('AB!'), EventSource.AGENT),
+ (MessageAction('Hello world!'), EventSource.USER),
+ (MessageAction('ABC!'), EventSource.AGENT),
+ ]
+ add_events(event_stream, data)
+ for i in range(3):
+ assert data[i][0].security_risk == ActionSecurityRisk.LOW
+ assert data[3][0].security_risk == ActionSecurityRisk.MEDIUM
@pytest.mark.parametrize(
@@ -76,22 +100,44 @@ def test_msg(temp_dir: str):
[('rm -rf root_dir', ActionSecurityRisk.MEDIUM), ['ls', ActionSecurityRisk.LOW]],
)
def test_cmd(cmd, expected_risk, temp_dir: str):
- file_store = get_file_store('local', temp_dir)
- event_stream = EventStream('main', file_store)
- policy = """
- raise "Disallow rm -rf [risk=medium]" if:
- (call: ToolCall)
- call is tool:run
- match("rm -rf", call.function.arguments.command)
- """
- InvariantAnalyzer(event_stream, policy)
- data = [
- (MessageAction('Hello world!'), EventSource.USER),
- (CmdRunAction(cmd), EventSource.USER),
+ mock_container = MagicMock()
+ mock_container.status = 'running'
+ mock_container.attrs = {
+ 'NetworkSettings': {'Ports': {'8000/tcp': [{'HostPort': 34567}]}}
+ }
+ mock_docker = MagicMock()
+ mock_docker.from_env().containers.list.return_value = [mock_container]
+
+ mock_requests = MagicMock()
+ mock_requests.get().json.return_value = {'id': 'mock-session-id'}
+ mock_requests.post().json.side_effect = [
+ {'monitor_id': 'mock-monitor-id'},
+ [], # First check
+ ['PolicyViolation(Disallow rm -rf [risk=medium], ranges=[<2 ranges>])']
+ if expected_risk == ActionSecurityRisk.MEDIUM
+ else [], # Second check
]
- add_events(event_stream, data)
- assert data[0][0].security_risk == ActionSecurityRisk.LOW
- assert data[1][0].security_risk == expected_risk
+
+ with (
+ patch(f'{InvariantAnalyzer.__module__}.docker', mock_docker),
+ patch(f'{InvariantClient.__module__}.requests', mock_requests),
+ ):
+ file_store = get_file_store('local', temp_dir)
+ event_stream = EventStream('main', file_store)
+ policy = """
+ raise "Disallow rm -rf [risk=medium]" if:
+ (call: ToolCall)
+ call is tool:run
+ match("rm -rf", call.function.arguments.command)
+ """
+ InvariantAnalyzer(event_stream, policy)
+ data = [
+ (MessageAction('Hello world!'), EventSource.USER),
+ (CmdRunAction(cmd), EventSource.USER),
+ ]
+ add_events(event_stream, data)
+ assert data[0][0].security_risk == ActionSecurityRisk.LOW
+ assert data[1][0].security_risk == expected_risk
@pytest.mark.parametrize(
@@ -102,26 +148,49 @@ def test_cmd(cmd, expected_risk, temp_dir: str):
],
)
def test_leak_secrets(code, expected_risk, temp_dir: str):
- file_store = get_file_store('local', temp_dir)
- event_stream = EventStream('main', file_store)
- policy = """
- from invariant.detectors import secrets
+ mock_container = MagicMock()
+ mock_container.status = 'running'
+ mock_container.attrs = {
+ 'NetworkSettings': {'Ports': {'8000/tcp': [{'HostPort': 34567}]}}
+ }
+ mock_docker = MagicMock()
+ mock_docker.from_env().containers.list.return_value = [mock_container]
- raise "Disallow writing secrets [risk=medium]" if:
- (call: ToolCall)
- call is tool:run_ipython
- any(secrets(call.function.arguments.code))
- """
- InvariantAnalyzer(event_stream, policy)
- data = [
- (MessageAction('Hello world!'), EventSource.USER),
- (IPythonRunCellAction(code), EventSource.AGENT),
- (IPythonRunCellAction('hello'), EventSource.AGENT),
+ mock_requests = MagicMock()
+ mock_requests.get().json.return_value = {'id': 'mock-session-id'}
+ mock_requests.post().json.side_effect = [
+ {'monitor_id': 'mock-monitor-id'},
+ [], # First check
+ ['PolicyViolation(Disallow writing secrets [risk=medium], ranges=[<2 ranges>])']
+ if expected_risk == ActionSecurityRisk.MEDIUM
+ else [], # Second check
+ [], # Third check
]
- add_events(event_stream, data)
- assert data[0][0].security_risk == ActionSecurityRisk.LOW
- assert data[1][0].security_risk == expected_risk
- assert data[2][0].security_risk == ActionSecurityRisk.LOW
+
+ with (
+ patch(f'{InvariantAnalyzer.__module__}.docker', mock_docker),
+ patch(f'{InvariantClient.__module__}.requests', mock_requests),
+ ):
+ file_store = get_file_store('local', temp_dir)
+ event_stream = EventStream('main', file_store)
+ policy = """
+ from invariant.detectors import secrets
+
+ raise "Disallow writing secrets [risk=medium]" if:
+ (call: ToolCall)
+ call is tool:run_ipython
+ any(secrets(call.function.arguments.code))
+ """
+ InvariantAnalyzer(event_stream, policy)
+ data = [
+ (MessageAction('Hello world!'), EventSource.USER),
+ (IPythonRunCellAction(code), EventSource.AGENT),
+ (IPythonRunCellAction('hello'), EventSource.AGENT),
+ ]
+ add_events(event_stream, data)
+ assert data[0][0].security_risk == ActionSecurityRisk.LOW
+ assert data[1][0].security_risk == expected_risk
+ assert data[2][0].security_risk == ActionSecurityRisk.LOW
def test_unsafe_python_code(temp_dir: str):
@@ -458,26 +527,48 @@ def default_config():
def test_check_usertask(
mock_litellm_completion, usertask, is_appropriate, default_config, temp_dir: str
):
- file_store = get_file_store('local', temp_dir)
- event_stream = EventStream('main', file_store)
- analyzer = InvariantAnalyzer(event_stream)
- mock_response = {'choices': [{'message': {'content': is_appropriate}}]}
- mock_litellm_completion.return_value = mock_response
- analyzer.guardrail_llm = LLM(config=default_config)
- analyzer.check_browsing_alignment = True
- data = [
- (MessageAction(usertask), EventSource.USER),
- ]
- add_events(event_stream, data)
- event_list = list(event_stream.get_events())
+ mock_container = MagicMock()
+ mock_container.status = 'running'
+ mock_container.attrs = {
+ 'NetworkSettings': {'Ports': {'8000/tcp': [{'HostPort': 34567}]}}
+ }
+ mock_docker = MagicMock()
+ mock_docker.from_env().containers.list.return_value = [mock_container]
- if is_appropriate == 'No':
- assert len(event_list) == 2
- assert type(event_list[0]) == MessageAction
- assert type(event_list[1]) == ChangeAgentStateAction
- elif is_appropriate == 'Yes':
- assert len(event_list) == 1
- assert type(event_list[0]) == MessageAction
+ mock_requests = MagicMock()
+ mock_requests.get().json.return_value = {'id': 'mock-session-id'}
+ mock_requests.post().json.side_effect = [
+ {'monitor_id': 'mock-monitor-id'},
+ [],
+ [
+ 'PolicyViolation(Vulnerability in python code [risk=medium], ranges=[<2 ranges>])'
+ ],
+ ]
+
+ with (
+ patch(f'{InvariantAnalyzer.__module__}.docker', mock_docker),
+ patch(f'{InvariantClient.__module__}.requests', mock_requests),
+ ):
+ file_store = get_file_store('local', temp_dir)
+ event_stream = EventStream('main', file_store)
+ analyzer = InvariantAnalyzer(event_stream)
+ mock_response = {'choices': [{'message': {'content': is_appropriate}}]}
+ mock_litellm_completion.return_value = mock_response
+ analyzer.guardrail_llm = LLM(config=default_config)
+ analyzer.check_browsing_alignment = True
+ data = [
+ (MessageAction(usertask), EventSource.USER),
+ ]
+ add_events(event_stream, data)
+ event_list = list(event_stream.get_events())
+
+ if is_appropriate == 'No':
+ assert len(event_list) == 2
+ assert type(event_list[0]) == MessageAction
+ assert type(event_list[1]) == ChangeAgentStateAction
+ elif is_appropriate == 'Yes':
+ assert len(event_list) == 1
+ assert type(event_list[0]) == MessageAction
@pytest.mark.parametrize(
@@ -491,23 +582,45 @@ def test_check_usertask(
def test_check_fillaction(
mock_litellm_completion, fillaction, is_harmful, default_config, temp_dir: str
):
- file_store = get_file_store('local', temp_dir)
- event_stream = EventStream('main', file_store)
- analyzer = InvariantAnalyzer(event_stream)
- mock_response = {'choices': [{'message': {'content': is_harmful}}]}
- mock_litellm_completion.return_value = mock_response
- analyzer.guardrail_llm = LLM(config=default_config)
- analyzer.check_browsing_alignment = True
- data = [
- (BrowseInteractiveAction(browser_actions=fillaction), EventSource.AGENT),
- ]
- add_events(event_stream, data)
- event_list = list(event_stream.get_events())
+ mock_container = MagicMock()
+ mock_container.status = 'running'
+ mock_container.attrs = {
+ 'NetworkSettings': {'Ports': {'8000/tcp': [{'HostPort': 34567}]}}
+ }
+ mock_docker = MagicMock()
+ mock_docker.from_env().containers.list.return_value = [mock_container]
- if is_harmful == 'Yes':
- assert len(event_list) == 2
- assert type(event_list[0]) == BrowseInteractiveAction
- assert type(event_list[1]) == ChangeAgentStateAction
- elif is_harmful == 'No':
- assert len(event_list) == 1
- assert type(event_list[0]) == BrowseInteractiveAction
+ mock_requests = MagicMock()
+ mock_requests.get().json.return_value = {'id': 'mock-session-id'}
+ mock_requests.post().json.side_effect = [
+ {'monitor_id': 'mock-monitor-id'},
+ [],
+ [
+ 'PolicyViolation(Vulnerability in python code [risk=medium], ranges=[<2 ranges>])'
+ ],
+ ]
+
+ with (
+ patch(f'{InvariantAnalyzer.__module__}.docker', mock_docker),
+ patch(f'{InvariantClient.__module__}.requests', mock_requests),
+ ):
+ file_store = get_file_store('local', temp_dir)
+ event_stream = EventStream('main', file_store)
+ analyzer = InvariantAnalyzer(event_stream)
+ mock_response = {'choices': [{'message': {'content': is_harmful}}]}
+ mock_litellm_completion.return_value = mock_response
+ analyzer.guardrail_llm = LLM(config=default_config)
+ analyzer.check_browsing_alignment = True
+ data = [
+ (BrowseInteractiveAction(browser_actions=fillaction), EventSource.AGENT),
+ ]
+ add_events(event_stream, data)
+ event_list = list(event_stream.get_events())
+
+ if is_harmful == 'Yes':
+ assert len(event_list) == 2
+ assert type(event_list[0]) == BrowseInteractiveAction
+ assert type(event_list[1]) == ChangeAgentStateAction
+ elif is_harmful == 'No':
+ assert len(event_list) == 1
+ assert type(event_list[0]) == BrowseInteractiveAction