diff --git a/frontend/__tests__/components/chat/expandable-message.test.tsx b/frontend/__tests__/components/chat/expandable-message.test.tsx new file mode 100644 index 0000000000..8eab988339 --- /dev/null +++ b/frontend/__tests__/components/chat/expandable-message.test.tsx @@ -0,0 +1,60 @@ +import { describe, expect, it } from "vitest"; +import { screen } from "@testing-library/react"; +import { renderWithProviders } from "test-utils"; +import { ExpandableMessage } from "#/components/features/chat/expandable-message"; + +describe("ExpandableMessage", () => { + it("should render with neutral border for non-action messages", () => { + renderWithProviders(); + const element = screen.getByText("Hello"); + const container = element.closest("div.flex.gap-2.items-center.justify-between"); + expect(container).toHaveClass("border-neutral-300"); + expect(screen.queryByTestId("status-icon")).not.toBeInTheDocument(); + }); + + it("should render with neutral border for error messages", () => { + renderWithProviders(); + const element = screen.getByText("Error occurred"); + const container = element.closest("div.flex.gap-2.items-center.justify-between"); + expect(container).toHaveClass("border-neutral-300"); + expect(screen.queryByTestId("status-icon")).not.toBeInTheDocument(); + }); + + it("should render with success icon for successful action messages", () => { + renderWithProviders( + + ); + const element = screen.getByText("Command executed successfully"); + const container = element.closest("div.flex.gap-2.items-center.justify-between"); + expect(container).toHaveClass("border-neutral-300"); + const icon = screen.getByTestId("status-icon"); + expect(icon).toHaveClass("fill-success"); + }); + + it("should render with error icon for failed action messages", () => { + renderWithProviders( + + ); + const element = screen.getByText("Command failed"); + const container = element.closest("div.flex.gap-2.items-center.justify-between"); + expect(container).toHaveClass("border-neutral-300"); + const icon = screen.getByTestId("status-icon"); + expect(icon).toHaveClass("fill-danger"); + }); + + it("should render with neutral border and no icon for action messages without success prop", () => { + renderWithProviders(); + const element = screen.getByText("Running command"); + const container = element.closest("div.flex.gap-2.items-center.justify-between"); + expect(container).toHaveClass("border-neutral-300"); + expect(screen.queryByTestId("status-icon")).not.toBeInTheDocument(); + }); +}); diff --git a/frontend/src/components/features/chat/expandable-message.tsx b/frontend/src/components/features/chat/expandable-message.tsx index f42b3f0b13..6ebcaa3aee 100644 --- a/frontend/src/components/features/chat/expandable-message.tsx +++ b/frontend/src/components/features/chat/expandable-message.tsx @@ -6,17 +6,21 @@ import { code } from "../markdown/code"; import { ol, ul } from "../markdown/list"; import ArrowUp from "#/icons/angle-up-solid.svg?react"; import ArrowDown from "#/icons/angle-down-solid.svg?react"; +import CheckCircle from "#/icons/check-circle-solid.svg?react"; +import XCircle from "#/icons/x-circle-solid.svg?react"; interface ExpandableMessageProps { id?: string; message: string; type: string; + success?: boolean; } export function ExpandableMessage({ id, message, type, + success, }: ExpandableMessageProps) { const { t, i18n } = useTranslation(); const [showDetails, setShowDetails] = useState(true); @@ -31,22 +35,14 @@ export function ExpandableMessage({ } }, [id, message, i18n.language]); - const border = type === "error" ? "border-danger" : "border-neutral-300"; - const textColor = type === "error" ? "text-danger" : "text-neutral-300"; - let arrowClasses = "h-4 w-4 ml-2 inline"; - if (type === "error") { - arrowClasses += " fill-danger"; - } else { - arrowClasses += " fill-neutral-300"; - } + const arrowClasses = "h-4 w-4 ml-2 inline fill-neutral-300"; + const statusIconClasses = "h-4 w-4 ml-2 inline"; return ( -
+
{headline && ( -

+

{headline}

+ {type === "action" && success !== undefined && ( +
+ {success ? ( + + ) : ( + + )} +
+ )}
); } diff --git a/frontend/src/components/features/chat/messages.tsx b/frontend/src/components/features/chat/messages.tsx index 8b0d703b75..e1bd346374 100644 --- a/frontend/src/components/features/chat/messages.tsx +++ b/frontend/src/components/features/chat/messages.tsx @@ -20,6 +20,7 @@ export function Messages({ type={message.type} id={message.translationID} message={message.content} + success={message.success} /> ); } diff --git a/frontend/src/icons/check-circle-solid.svg b/frontend/src/icons/check-circle-solid.svg new file mode 100644 index 0000000000..a07362b4ab --- /dev/null +++ b/frontend/src/icons/check-circle-solid.svg @@ -0,0 +1,4 @@ + + + + diff --git a/frontend/src/icons/x-circle-solid.svg b/frontend/src/icons/x-circle-solid.svg new file mode 100644 index 0000000000..f673bbf0b1 --- /dev/null +++ b/frontend/src/icons/x-circle-solid.svg @@ -0,0 +1,4 @@ + + + + diff --git a/frontend/src/message.d.ts b/frontend/src/message.d.ts index 5b70e39c8f..65bd7e0cb1 100644 --- a/frontend/src/message.d.ts +++ b/frontend/src/message.d.ts @@ -4,6 +4,7 @@ type Message = { timestamp: string; imageUrls?: string[]; type?: "thought" | "error" | "action"; + success?: boolean; pending?: boolean; translationID?: string; eventID?: number; diff --git a/frontend/src/state/chat-slice.ts b/frontend/src/state/chat-slice.ts index df24236c2c..47d2b65175 100644 --- a/frontend/src/state/chat-slice.ts +++ b/frontend/src/state/chat-slice.ts @@ -1,14 +1,25 @@ import { createSlice, PayloadAction } from "@reduxjs/toolkit"; import { ActionSecurityRisk } from "#/state/security-analyzer-slice"; -import { OpenHandsObservation } from "#/types/core/observations"; +import { + OpenHandsObservation, + CommandObservation, + IPythonObservation, +} from "#/types/core/observations"; import { OpenHandsAction } from "#/types/core/actions"; +import { OpenHandsEventType } from "#/types/core/base"; type SliceState = { messages: Message[] }; const MAX_CONTENT_LENGTH = 1000; -const HANDLED_ACTIONS = ["run", "run_ipython", "write", "read", "browse"]; +const HANDLED_ACTIONS: OpenHandsEventType[] = [ + "run", + "run_ipython", + "write", + "read", + "browse", +]; function getRiskText(risk: ActionSecurityRisk) { switch (risk) { @@ -131,6 +142,18 @@ export const chatSlice = createSlice({ return; } causeMessage.translationID = translationID; + // Set success property based on observation type + if (observationID === "run") { + const commandObs = observation.payload as CommandObservation; + causeMessage.success = commandObs.extras.exit_code === 0; + } else if (observationID === "run_ipython") { + // For IPython, we consider it successful if there's no error message + const ipythonObs = observation.payload as IPythonObservation; + causeMessage.success = !ipythonObs.message + .toLowerCase() + .includes("error"); + } + if (observationID === "run" || observationID === "run_ipython") { let { content } = observation.payload; if (content.length > MAX_CONTENT_LENGTH) { diff --git a/frontend/src/types/core/observations.ts b/frontend/src/types/core/observations.ts index 0b95099a83..7ddc3f05dd 100644 --- a/frontend/src/types/core/observations.ts +++ b/frontend/src/types/core/observations.ts @@ -52,6 +52,21 @@ export interface BrowseObservation extends OpenHandsObservationEvent<"browse"> { }; } +export interface WriteObservation extends OpenHandsObservationEvent<"write"> { + source: "agent"; + extras: { + path: string; + content: string; + }; +} + +export interface ReadObservation extends OpenHandsObservationEvent<"read"> { + source: "agent"; + extras: { + path: string; + }; +} + export interface ErrorObservation extends OpenHandsObservationEvent<"error"> { source: "user"; extras: { @@ -65,4 +80,6 @@ export type OpenHandsObservation = | IPythonObservation | DelegateObservation | BrowseObservation + | WriteObservation + | ReadObservation | ErrorObservation; diff --git a/frontend/tailwind.config.js b/frontend/tailwind.config.js index 1e665d23fd..1a57ebcd8b 100644 --- a/frontend/tailwind.config.js +++ b/frontend/tailwind.config.js @@ -14,6 +14,7 @@ export default { 'root-secondary': '#262626', 'hyperlink': '#007AFF', 'danger': '#EF3744', + 'success': '#4CAF50', }, }, }, diff --git a/frontend/test-utils.tsx b/frontend/test-utils.tsx index 4b336602fb..6739e3be6e 100644 --- a/frontend/test-utils.tsx +++ b/frontend/test-utils.tsx @@ -6,10 +6,31 @@ import { configureStore } from "@reduxjs/toolkit"; // eslint-disable-next-line import/no-extraneous-dependencies import { RenderOptions, render } from "@testing-library/react"; import { QueryClient, QueryClientProvider } from "@tanstack/react-query"; +import { I18nextProvider } from "react-i18next"; +import i18n from "i18next"; +import { initReactI18next } from "react-i18next"; import { AppStore, RootState, rootReducer } from "./src/store"; import { AuthProvider } from "#/context/auth-context"; import { UserPrefsProvider } from "#/context/user-prefs-context"; +// Initialize i18n for tests +i18n + .use(initReactI18next) + .init({ + lng: "en", + fallbackLng: "en", + ns: ["translation"], + defaultNS: "translation", + resources: { + en: { + translation: {}, + }, + }, + interpolation: { + escapeValue: false, + }, + }); + const setupStore = (preloadedState?: Partial): AppStore => configureStore({ reducer: rootReducer, @@ -40,7 +61,9 @@ export function renderWithProviders( - {children} + + {children} + diff --git a/frontend/vitest.setup.ts b/frontend/vitest.setup.ts index 105337e75e..e9a89c8677 100644 --- a/frontend/vitest.setup.ts +++ b/frontend/vitest.setup.ts @@ -12,7 +12,13 @@ HTMLElement.prototype.scrollTo = vi.fn(); // Mock the i18n provider vi.mock("react-i18next", async (importOriginal) => ({ ...(await importOriginal()), - useTranslation: () => ({ t: (key: string) => key }), + useTranslation: () => ({ + t: (key: string) => key, + i18n: { + language: "en", + exists: () => false, + }, + }), })); // Mock requests during tests diff --git a/openhands/agenthub/codeact_agent/README.md b/openhands/agenthub/codeact_agent/README.md index 9a5093820e..0e15939cdf 100644 --- a/openhands/agenthub/codeact_agent/README.md +++ b/openhands/agenthub/codeact_agent/README.md @@ -110,4 +110,4 @@ The agent is implemented in two main files: 2. `function_calling.py`: Tool definitions and function calling interface with: - Tool parameter specifications - Tool descriptions and examples - - Function calling response parsing \ No newline at end of file + - Function calling response parsing diff --git a/openhands/events/observation/commands.py b/openhands/events/observation/commands.py index a182168e69..b522b5c472 100644 --- a/openhands/events/observation/commands.py +++ b/openhands/events/observation/commands.py @@ -23,6 +23,10 @@ class CmdOutputObservation(Observation): def message(self) -> str: return f'Command `{self.command}` executed with exit code {self.exit_code}.' + @property + def success(self) -> bool: + return not self.error + def __str__(self) -> str: return f'**CmdOutputObservation (source={self.source}, exit code={self.exit_code})**\n{self.content}' @@ -42,5 +46,9 @@ class IPythonRunCellObservation(Observation): def message(self) -> str: return 'Code executed in IPython cell.' + @property + def success(self) -> bool: + return True # IPython cells are always considered successful + def __str__(self) -> str: return f'**IPythonRunCellObservation**\n{self.content}' diff --git a/openhands/events/serialization/event.py b/openhands/events/serialization/event.py index 78f7940626..6ee82a1cc8 100644 --- a/openhands/events/serialization/event.py +++ b/openhands/events/serialization/event.py @@ -83,6 +83,9 @@ def event_to_dict(event: 'Event') -> dict: elif 'observation' in d: d['content'] = props.pop('content', '') d['extras'] = props + # Include success field for CmdOutputObservation + if hasattr(event, 'success'): + d['success'] = event.success else: raise ValueError('Event must be either action or observation') return d diff --git a/openhands/events/serialization/observation.py b/openhands/events/serialization/observation.py index 9030ccb1e1..d9d8dc51ad 100644 --- a/openhands/events/serialization/observation.py +++ b/openhands/events/serialization/observation.py @@ -50,4 +50,5 @@ def observation_from_dict(observation: dict) -> Observation: observation.pop('message', None) content = observation.pop('content', '') extras = observation.pop('extras', {}) + return observation_class(content=content, **extras) diff --git a/pyproject.toml b/pyproject.toml index acae091bb6..aa81db8e55 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,6 +98,7 @@ reportlab = "*" [tool.coverage.run] concurrency = ["gevent"] + [tool.poetry.group.runtime.dependencies] jupyterlab = "*" notebook = "*" @@ -128,6 +129,7 @@ ignore = ["D1"] [tool.ruff.lint.pydocstyle] convention = "google" + [tool.poetry.group.evaluation.dependencies] streamlit = "*" whatthepatch = "*" diff --git a/tests/unit/test_command_success.py b/tests/unit/test_command_success.py new file mode 100644 index 0000000000..b52ceb4815 --- /dev/null +++ b/tests/unit/test_command_success.py @@ -0,0 +1,27 @@ +from openhands.events.observation.commands import ( + CmdOutputObservation, + IPythonRunCellObservation, +) + + +def test_cmd_output_success(): + # Test successful command + obs = CmdOutputObservation( + command_id=1, command='ls', content='file1.txt\nfile2.txt', exit_code=0 + ) + assert obs.success is True + assert obs.error is False + + # Test failed command + obs = CmdOutputObservation( + command_id=2, command='ls', content='No such file or directory', exit_code=1 + ) + assert obs.success is False + assert obs.error is True + + +def test_ipython_cell_success(): + # IPython cells are always successful + obs = IPythonRunCellObservation(code='print("Hello")', content='Hello') + assert obs.success is True + assert obs.error is False diff --git a/tests/unit/test_event_serialization.py b/tests/unit/test_event_serialization.py new file mode 100644 index 0000000000..d1989a30bb --- /dev/null +++ b/tests/unit/test_event_serialization.py @@ -0,0 +1,18 @@ +from openhands.events.observation import CmdOutputObservation +from openhands.events.serialization import event_to_dict + + +def test_command_output_success_serialization(): + # Test successful command + obs = CmdOutputObservation( + command_id=1, command='ls', content='file1.txt\nfile2.txt', exit_code=0 + ) + serialized = event_to_dict(obs) + assert serialized['success'] is True + + # Test failed command + obs = CmdOutputObservation( + command_id=2, command='ls', content='No such file or directory', exit_code=1 + ) + serialized = event_to_dict(obs) + assert serialized['success'] is False diff --git a/tests/unit/test_observation_serialization.py b/tests/unit/test_observation_serialization.py index ae636ddf56..67a95449b7 100644 --- a/tests/unit/test_observation_serialization.py +++ b/tests/unit/test_observation_serialization.py @@ -40,36 +40,23 @@ def serialization_deserialization( # Additional tests for various observation subclasses can be included here -def test_observation_event_props_serialization_deserialization(): - original_observation_dict = { - 'id': 42, - 'source': 'agent', - 'timestamp': '2021-08-01T12:00:00', - 'observation': 'run', - 'message': 'Command `ls -l` executed with exit code 0.', - 'extras': { - 'exit_code': 0, - 'command': 'ls -l', - 'command_id': 3, - 'hidden': False, - 'interpreter_details': '', - }, - 'content': 'foo.txt', - } - serialization_deserialization(original_observation_dict, CmdOutputObservation) +def test_success_field_serialization(): + # Test success=True + obs = CmdOutputObservation( + content='Command succeeded', + exit_code=0, + command='ls -l', + command_id=3, + ) + serialized = event_to_dict(obs) + assert serialized['success'] is True - -def test_command_output_observation_serialization_deserialization(): - original_observation_dict = { - 'observation': 'run', - 'extras': { - 'exit_code': 0, - 'command': 'ls -l', - 'command_id': 3, - 'hidden': False, - 'interpreter_details': '', - }, - 'message': 'Command `ls -l` executed with exit code 0.', - 'content': 'foo.txt', - } - serialization_deserialization(original_observation_dict, CmdOutputObservation) + # Test success=False + obs = CmdOutputObservation( + content='No such file or directory', + exit_code=1, + command='ls -l', + command_id=3, + ) + serialized = event_to_dict(obs) + assert serialized['success'] is False diff --git a/tests/unit/test_security.py b/tests/unit/test_security.py index fab27a3ec2..a36c66104f 100644 --- a/tests/unit/test_security.py +++ b/tests/unit/test_security.py @@ -51,24 +51,48 @@ def add_events(event_stream: EventStream, data: list[tuple[Event, EventSource]]) def test_msg(temp_dir: str): - file_store = get_file_store('local', temp_dir) - event_stream = EventStream('main', file_store) - policy = """ - raise "Disallow ABC [risk=medium]" if: - (msg: Message) - "ABC" in msg.content - """ - InvariantAnalyzer(event_stream, policy) - data = [ - (MessageAction('Hello world!'), EventSource.USER), - (MessageAction('AB!'), EventSource.AGENT), - (MessageAction('Hello world!'), EventSource.USER), - (MessageAction('ABC!'), EventSource.AGENT), + mock_container = MagicMock() + mock_container.status = 'running' + mock_container.attrs = { + 'NetworkSettings': {'Ports': {'8000/tcp': [{'HostPort': 34567}]}} + } + mock_docker = MagicMock() + mock_docker.from_env().containers.list.return_value = [mock_container] + + mock_requests = MagicMock() + mock_requests.get().json.return_value = {'id': 'mock-session-id'} + mock_requests.post().json.side_effect = [ + {'monitor_id': 'mock-monitor-id'}, + [], # First check + [], # Second check + [], # Third check + [ + 'PolicyViolation(Disallow ABC [risk=medium], ranges=[<2 ranges>])' + ], # Fourth check ] - add_events(event_stream, data) - for i in range(3): - assert data[i][0].security_risk == ActionSecurityRisk.LOW - assert data[3][0].security_risk == ActionSecurityRisk.MEDIUM + + with ( + patch(f'{InvariantAnalyzer.__module__}.docker', mock_docker), + patch(f'{InvariantClient.__module__}.requests', mock_requests), + ): + file_store = get_file_store('local', temp_dir) + event_stream = EventStream('main', file_store) + policy = """ + raise "Disallow ABC [risk=medium]" if: + (msg: Message) + "ABC" in msg.content + """ + InvariantAnalyzer(event_stream, policy) + data = [ + (MessageAction('Hello world!'), EventSource.USER), + (MessageAction('AB!'), EventSource.AGENT), + (MessageAction('Hello world!'), EventSource.USER), + (MessageAction('ABC!'), EventSource.AGENT), + ] + add_events(event_stream, data) + for i in range(3): + assert data[i][0].security_risk == ActionSecurityRisk.LOW + assert data[3][0].security_risk == ActionSecurityRisk.MEDIUM @pytest.mark.parametrize( @@ -76,22 +100,44 @@ def test_msg(temp_dir: str): [('rm -rf root_dir', ActionSecurityRisk.MEDIUM), ['ls', ActionSecurityRisk.LOW]], ) def test_cmd(cmd, expected_risk, temp_dir: str): - file_store = get_file_store('local', temp_dir) - event_stream = EventStream('main', file_store) - policy = """ - raise "Disallow rm -rf [risk=medium]" if: - (call: ToolCall) - call is tool:run - match("rm -rf", call.function.arguments.command) - """ - InvariantAnalyzer(event_stream, policy) - data = [ - (MessageAction('Hello world!'), EventSource.USER), - (CmdRunAction(cmd), EventSource.USER), + mock_container = MagicMock() + mock_container.status = 'running' + mock_container.attrs = { + 'NetworkSettings': {'Ports': {'8000/tcp': [{'HostPort': 34567}]}} + } + mock_docker = MagicMock() + mock_docker.from_env().containers.list.return_value = [mock_container] + + mock_requests = MagicMock() + mock_requests.get().json.return_value = {'id': 'mock-session-id'} + mock_requests.post().json.side_effect = [ + {'monitor_id': 'mock-monitor-id'}, + [], # First check + ['PolicyViolation(Disallow rm -rf [risk=medium], ranges=[<2 ranges>])'] + if expected_risk == ActionSecurityRisk.MEDIUM + else [], # Second check ] - add_events(event_stream, data) - assert data[0][0].security_risk == ActionSecurityRisk.LOW - assert data[1][0].security_risk == expected_risk + + with ( + patch(f'{InvariantAnalyzer.__module__}.docker', mock_docker), + patch(f'{InvariantClient.__module__}.requests', mock_requests), + ): + file_store = get_file_store('local', temp_dir) + event_stream = EventStream('main', file_store) + policy = """ + raise "Disallow rm -rf [risk=medium]" if: + (call: ToolCall) + call is tool:run + match("rm -rf", call.function.arguments.command) + """ + InvariantAnalyzer(event_stream, policy) + data = [ + (MessageAction('Hello world!'), EventSource.USER), + (CmdRunAction(cmd), EventSource.USER), + ] + add_events(event_stream, data) + assert data[0][0].security_risk == ActionSecurityRisk.LOW + assert data[1][0].security_risk == expected_risk @pytest.mark.parametrize( @@ -102,26 +148,49 @@ def test_cmd(cmd, expected_risk, temp_dir: str): ], ) def test_leak_secrets(code, expected_risk, temp_dir: str): - file_store = get_file_store('local', temp_dir) - event_stream = EventStream('main', file_store) - policy = """ - from invariant.detectors import secrets + mock_container = MagicMock() + mock_container.status = 'running' + mock_container.attrs = { + 'NetworkSettings': {'Ports': {'8000/tcp': [{'HostPort': 34567}]}} + } + mock_docker = MagicMock() + mock_docker.from_env().containers.list.return_value = [mock_container] - raise "Disallow writing secrets [risk=medium]" if: - (call: ToolCall) - call is tool:run_ipython - any(secrets(call.function.arguments.code)) - """ - InvariantAnalyzer(event_stream, policy) - data = [ - (MessageAction('Hello world!'), EventSource.USER), - (IPythonRunCellAction(code), EventSource.AGENT), - (IPythonRunCellAction('hello'), EventSource.AGENT), + mock_requests = MagicMock() + mock_requests.get().json.return_value = {'id': 'mock-session-id'} + mock_requests.post().json.side_effect = [ + {'monitor_id': 'mock-monitor-id'}, + [], # First check + ['PolicyViolation(Disallow writing secrets [risk=medium], ranges=[<2 ranges>])'] + if expected_risk == ActionSecurityRisk.MEDIUM + else [], # Second check + [], # Third check ] - add_events(event_stream, data) - assert data[0][0].security_risk == ActionSecurityRisk.LOW - assert data[1][0].security_risk == expected_risk - assert data[2][0].security_risk == ActionSecurityRisk.LOW + + with ( + patch(f'{InvariantAnalyzer.__module__}.docker', mock_docker), + patch(f'{InvariantClient.__module__}.requests', mock_requests), + ): + file_store = get_file_store('local', temp_dir) + event_stream = EventStream('main', file_store) + policy = """ + from invariant.detectors import secrets + + raise "Disallow writing secrets [risk=medium]" if: + (call: ToolCall) + call is tool:run_ipython + any(secrets(call.function.arguments.code)) + """ + InvariantAnalyzer(event_stream, policy) + data = [ + (MessageAction('Hello world!'), EventSource.USER), + (IPythonRunCellAction(code), EventSource.AGENT), + (IPythonRunCellAction('hello'), EventSource.AGENT), + ] + add_events(event_stream, data) + assert data[0][0].security_risk == ActionSecurityRisk.LOW + assert data[1][0].security_risk == expected_risk + assert data[2][0].security_risk == ActionSecurityRisk.LOW def test_unsafe_python_code(temp_dir: str): @@ -458,26 +527,48 @@ def default_config(): def test_check_usertask( mock_litellm_completion, usertask, is_appropriate, default_config, temp_dir: str ): - file_store = get_file_store('local', temp_dir) - event_stream = EventStream('main', file_store) - analyzer = InvariantAnalyzer(event_stream) - mock_response = {'choices': [{'message': {'content': is_appropriate}}]} - mock_litellm_completion.return_value = mock_response - analyzer.guardrail_llm = LLM(config=default_config) - analyzer.check_browsing_alignment = True - data = [ - (MessageAction(usertask), EventSource.USER), - ] - add_events(event_stream, data) - event_list = list(event_stream.get_events()) + mock_container = MagicMock() + mock_container.status = 'running' + mock_container.attrs = { + 'NetworkSettings': {'Ports': {'8000/tcp': [{'HostPort': 34567}]}} + } + mock_docker = MagicMock() + mock_docker.from_env().containers.list.return_value = [mock_container] - if is_appropriate == 'No': - assert len(event_list) == 2 - assert type(event_list[0]) == MessageAction - assert type(event_list[1]) == ChangeAgentStateAction - elif is_appropriate == 'Yes': - assert len(event_list) == 1 - assert type(event_list[0]) == MessageAction + mock_requests = MagicMock() + mock_requests.get().json.return_value = {'id': 'mock-session-id'} + mock_requests.post().json.side_effect = [ + {'monitor_id': 'mock-monitor-id'}, + [], + [ + 'PolicyViolation(Vulnerability in python code [risk=medium], ranges=[<2 ranges>])' + ], + ] + + with ( + patch(f'{InvariantAnalyzer.__module__}.docker', mock_docker), + patch(f'{InvariantClient.__module__}.requests', mock_requests), + ): + file_store = get_file_store('local', temp_dir) + event_stream = EventStream('main', file_store) + analyzer = InvariantAnalyzer(event_stream) + mock_response = {'choices': [{'message': {'content': is_appropriate}}]} + mock_litellm_completion.return_value = mock_response + analyzer.guardrail_llm = LLM(config=default_config) + analyzer.check_browsing_alignment = True + data = [ + (MessageAction(usertask), EventSource.USER), + ] + add_events(event_stream, data) + event_list = list(event_stream.get_events()) + + if is_appropriate == 'No': + assert len(event_list) == 2 + assert type(event_list[0]) == MessageAction + assert type(event_list[1]) == ChangeAgentStateAction + elif is_appropriate == 'Yes': + assert len(event_list) == 1 + assert type(event_list[0]) == MessageAction @pytest.mark.parametrize( @@ -491,23 +582,45 @@ def test_check_usertask( def test_check_fillaction( mock_litellm_completion, fillaction, is_harmful, default_config, temp_dir: str ): - file_store = get_file_store('local', temp_dir) - event_stream = EventStream('main', file_store) - analyzer = InvariantAnalyzer(event_stream) - mock_response = {'choices': [{'message': {'content': is_harmful}}]} - mock_litellm_completion.return_value = mock_response - analyzer.guardrail_llm = LLM(config=default_config) - analyzer.check_browsing_alignment = True - data = [ - (BrowseInteractiveAction(browser_actions=fillaction), EventSource.AGENT), - ] - add_events(event_stream, data) - event_list = list(event_stream.get_events()) + mock_container = MagicMock() + mock_container.status = 'running' + mock_container.attrs = { + 'NetworkSettings': {'Ports': {'8000/tcp': [{'HostPort': 34567}]}} + } + mock_docker = MagicMock() + mock_docker.from_env().containers.list.return_value = [mock_container] - if is_harmful == 'Yes': - assert len(event_list) == 2 - assert type(event_list[0]) == BrowseInteractiveAction - assert type(event_list[1]) == ChangeAgentStateAction - elif is_harmful == 'No': - assert len(event_list) == 1 - assert type(event_list[0]) == BrowseInteractiveAction + mock_requests = MagicMock() + mock_requests.get().json.return_value = {'id': 'mock-session-id'} + mock_requests.post().json.side_effect = [ + {'monitor_id': 'mock-monitor-id'}, + [], + [ + 'PolicyViolation(Vulnerability in python code [risk=medium], ranges=[<2 ranges>])' + ], + ] + + with ( + patch(f'{InvariantAnalyzer.__module__}.docker', mock_docker), + patch(f'{InvariantClient.__module__}.requests', mock_requests), + ): + file_store = get_file_store('local', temp_dir) + event_stream = EventStream('main', file_store) + analyzer = InvariantAnalyzer(event_stream) + mock_response = {'choices': [{'message': {'content': is_harmful}}]} + mock_litellm_completion.return_value = mock_response + analyzer.guardrail_llm = LLM(config=default_config) + analyzer.check_browsing_alignment = True + data = [ + (BrowseInteractiveAction(browser_actions=fillaction), EventSource.AGENT), + ] + add_events(event_stream, data) + event_list = list(event_stream.get_events()) + + if is_harmful == 'Yes': + assert len(event_list) == 2 + assert type(event_list[0]) == BrowseInteractiveAction + assert type(event_list[1]) == ChangeAgentStateAction + elif is_harmful == 'No': + assert len(event_list) == 1 + assert type(event_list[0]) == BrowseInteractiveAction