diff --git a/openhands/security/analyzer.py b/openhands/security/analyzer.py index ffabca14ff..5dd4db5224 100644 --- a/openhands/security/analyzer.py +++ b/openhands/security/analyzer.py @@ -13,7 +13,7 @@ from openhands.events.stream import EventStream, EventStreamSubscriber class SecurityAnalyzer: """Security analyzer that receives all events and analyzes agent actions for security risks.""" - def __init__(self, event_stream: EventStream): + def __init__(self, event_stream: EventStream) -> None: """Initializes a new instance of the SecurityAnalyzer class. Args: @@ -36,6 +36,7 @@ class SecurityAnalyzer: return try: + # Set the security_risk attribute on the event event.security_risk = await self.security_risk(event) # type: ignore [attr-defined] await self.act(event) except Exception as e: diff --git a/openhands/security/invariant/analyzer.py b/openhands/security/invariant/analyzer.py index 25afcbec51..7251d01aa3 100644 --- a/openhands/security/invariant/analyzer.py +++ b/openhands/security/invariant/analyzer.py @@ -1,7 +1,7 @@ import ast import re import uuid -from typing import Any +from typing import Any, cast import docker from fastapi import HTTPException, Request @@ -32,12 +32,12 @@ class InvariantAnalyzer(SecurityAnalyzer): """Security analyzer based on Invariant.""" trace: list[TraceElement] - input: list[dict] + input: list[dict[str, Any]] container_name: str = 'openhands-invariant-server' image_name: str = 'ghcr.io/invariantlabs-ai/server:openhands' api_host: str = 'http://localhost' timeout: int = 180 - settings: dict = {} + settings: dict[str, Any] = {} check_browsing_alignment: bool = False guardrail_llm: LLM | None = None @@ -47,7 +47,7 @@ class InvariantAnalyzer(SecurityAnalyzer): event_stream: EventStream, policy: str | None = None, sid: str | None = None, - ): + ) -> None: """Initializes a new instance of the InvariantAnalzyer class.""" super().__init__(event_stream) self.trace = [] @@ -108,14 +108,16 @@ class InvariantAnalyzer(SecurityAnalyzer): policy = '' self.monitor = self.client.Monitor.from_string(policy) - async def close(self): + async def close(self) -> None: self.container.stop() async def log_event(self, event: Event) -> None: if isinstance(event, Observation): element = parse_element(self.trace, event) self.trace.extend(element) - self.input.extend([e.model_dump(exclude_none=True) for e in element]) # type: ignore [call-overload] + self.input.extend( + [cast(dict[str, Any], e.model_dump(exclude_none=True)) for e in element] + ) else: logger.debug('Invariant skipping element: event') @@ -126,7 +128,7 @@ class InvariantAnalyzer(SecurityAnalyzer): 'low': ActionSecurityRisk.LOW, } regex = r'(?<=risk=)\w+' - risks = [] + risks: list[ActionSecurityRisk] = [] for result in results: m = re.search(regex, result) if m and m.group() in mapping: @@ -148,7 +150,7 @@ class InvariantAnalyzer(SecurityAnalyzer): await self.check_usertask() await self.check_fillaction() - async def check_usertask(self): + async def check_usertask(self) -> None: """Looks at the most recent trace element. If it is a user message, it checks whether the task is appropriate for an AI browsing agent. Ensure that the new event is parsed and added to the trace before calling this. @@ -198,23 +200,24 @@ class InvariantAnalyzer(SecurityAnalyzer): self.event_stream.add_event, new_event, event_source ) - def parse_browser_action(self, browser_action): + def parse_browser_action( + self, browser_action: str + ) -> list[tuple[str | None, list[str]]]: assert browser_action[-1] == ')' tree = ast.parse(browser_action, mode='exec') - function_calls = [] + function_calls: list[tuple[str | None, list[str]]] = [] for node in tree.body: if isinstance(node, ast.Expr) and isinstance(node.value, ast.Call): call_node = node.value # This contains the actual function call # Extract function name + func_name: str | None = None if isinstance(call_node.func, ast.Name): func_name = call_node.func.id elif isinstance(call_node.func, ast.Attribute): func_name = ( f'{ast.unparse(call_node.func.value)}.{call_node.func.attr}' ) - else: - func_name = None # Extract positional arguments args = [ast.unparse(arg) for arg in call_node.args] @@ -223,7 +226,7 @@ class InvariantAnalyzer(SecurityAnalyzer): raise ValueError('The code does not represent a function call.') return function_calls - async def check_fillaction(self): + async def check_fillaction(self) -> None: """Looks at the most recent trace element. If it is a function call to browse_interactive with "fill(, )" as an argument, it checks whether the content inside fill is harmful. Ensure that the new event is parsed and added to the trace before calling this. @@ -285,7 +288,7 @@ class InvariantAnalyzer(SecurityAnalyzer): break async def should_confirm(self, event: Event) -> bool: - risk = event.security_risk # type: ignore [attr-defined] + risk = event.security_risk if hasattr(event, 'security_risk') else None # type: ignore [attr-defined] return ( risk is not None and risk < self.settings.get('RISK_SEVERITY', ActionSecurityRisk.MEDIUM) @@ -305,24 +308,21 @@ class InvariantAnalyzer(SecurityAnalyzer): async def security_risk(self, event: Action) -> ActionSecurityRisk: logger.debug('Calling security_risk on InvariantAnalyzer') new_elements = parse_element(self.trace, event) - input = [e.model_dump(exclude_none=True) for e in new_elements] # type: ignore [call-overload] + input_data = [ + cast(dict[str, Any], e.model_dump(exclude_none=True)) for e in new_elements + ] self.trace.extend(new_elements) - check_result = self.monitor.check(self.input, input) - self.input.extend(input) + check_result = self.monitor.check(self.input, input_data) + self.input.extend(input_data) risk = ActionSecurityRisk.UNKNOWN - if isinstance(check_result, tuple): - result, err = check_result - if err: - logger.warning(f'Error checking policy: {err}') - return risk - else: - logger.warning(f'Error checking policy: {check_result}') + # Process check_result + result, err = check_result + if err: + logger.warning(f'Error checking policy: {err}') return risk - risk = self.get_risk(result) - - return risk + return self.get_risk(result) ### Handle API requests async def handle_api_request(self, request: Request) -> Any: @@ -343,23 +343,23 @@ class InvariantAnalyzer(SecurityAnalyzer): return await self.update_settings(request) raise HTTPException(status_code=405, detail='Method Not Allowed') - async def export_trace(self, request: Request) -> Any: + async def export_trace(self, request: Request) -> JSONResponse: return JSONResponse(content=self.input) - async def get_policy(self, request: Request) -> Any: + async def get_policy(self, request: Request) -> JSONResponse: return JSONResponse(content={'policy': self.monitor.policy}) - async def update_policy(self, request: Request) -> Any: + async def update_policy(self, request: Request) -> JSONResponse: data = await request.json() policy = data.get('policy') new_monitor = self.client.Monitor.from_string(policy) self.monitor = new_monitor return JSONResponse(content={'policy': policy}) - async def get_settings(self, request: Request) -> Any: + async def get_settings(self, request: Request) -> JSONResponse: return JSONResponse(content=self.settings) - async def update_settings(self, request: Request) -> Any: + async def update_settings(self, request: Request) -> JSONResponse: settings = await request.json() self.settings = settings return JSONResponse(content=self.settings) diff --git a/openhands/security/invariant/client.py b/openhands/security/invariant/client.py index 077a02b92c..64fcecce3d 100644 --- a/openhands/security/invariant/client.py +++ b/openhands/security/invariant/client.py @@ -1,5 +1,5 @@ import time -from typing import Any, Union +from typing import Any import httpx @@ -7,7 +7,7 @@ import httpx class InvariantClient: timeout: int = 120 - def __init__(self, server_url: str, session_id: str | None = None): + def __init__(self, server_url: str, session_id: str | None = None) -> None: self.server = server_url self.session_id, err = self._create_session(session_id) if err: @@ -38,7 +38,7 @@ class InvariantClient: return None, err return None, ConnectionError('Connection timed out') - def close_session(self) -> Union[None, Exception]: + def close_session(self) -> Exception | None: try: response = httpx.delete( f'{self.server}/session/?session_id={self.session_id}', timeout=60 @@ -52,6 +52,7 @@ class InvariantClient: def __init__(self, invariant: 'InvariantClient') -> None: self.server = invariant.server self.session_id = invariant.session_id + self.policy_id: str | None = None def _create_policy(self, rule: str) -> tuple[str | None, Exception | None]: try: @@ -83,7 +84,7 @@ class InvariantClient: self.policy_id = policy_id return self - def analyze(self, trace: list[dict]) -> Union[Any, Exception]: + def analyze(self, trace: list[dict[str, Any]]) -> tuple[Any, Exception | None]: try: response = httpx.post( f'{self.server}/policy/{self.policy_id}/analyze?session_id={self.session_id}', @@ -100,6 +101,7 @@ class InvariantClient: self.server = invariant.server self.session_id = invariant.session_id self.policy = '' + self.monitor_id: str | None = None def _create_monitor(self, rule: str) -> tuple[str | None, Exception | None]: try: @@ -122,8 +124,10 @@ class InvariantClient: return self def check( - self, past_events: list[dict], pending_events: list[dict] - ) -> Union[Any, Exception]: + self, + past_events: list[dict[str, Any]], + pending_events: list[dict[str, Any]], + ) -> tuple[Any, Exception | None]: try: response = httpx.post( f'{self.server}/monitor/{self.monitor_id}/check?session_id={self.session_id}', diff --git a/openhands/security/invariant/nodes.py b/openhands/security/invariant/nodes.py index ac294622fb..9ffe5b22bb 100644 --- a/openhands/security/invariant/nodes.py +++ b/openhands/security/invariant/nodes.py @@ -11,14 +11,14 @@ class LLM: class Event(BaseModel): - metadata: dict | None = Field( + metadata: dict[str, Any] | None = Field( default_factory=lambda: dict(), description='Metadata associated with the event' ) class Function(BaseModel): name: str - arguments: dict + arguments: dict[str, Any] class ToolCall(Event): diff --git a/openhands/security/invariant/parser.py b/openhands/security/invariant/parser.py index dea1286924..bcdb8c56e8 100644 --- a/openhands/security/invariant/parser.py +++ b/openhands/security/invariant/parser.py @@ -1,5 +1,3 @@ -from typing import Union - from pydantic import BaseModel, Field from openhands.core.logger import openhands_logger as logger @@ -18,7 +16,7 @@ from openhands.events.observation import ( from openhands.events.serialization.event import event_to_dict from openhands.security.invariant.nodes import Function, Message, ToolCall, ToolOutput -TraceElement = Union[Message, ToolCall, ToolOutput, Function] +TraceElement = Message | ToolCall | ToolOutput | Function def get_next_id(trace: list[TraceElement]) -> str: @@ -40,7 +38,7 @@ def get_last_id( def parse_action(trace: list[TraceElement], action: Action) -> list[TraceElement]: next_id = get_next_id(trace) - inv_trace = [] # type: list[TraceElement] + inv_trace: list[TraceElement] = [] if type(action) == MessageAction: if action.source == EventSource.USER: inv_trace.append(Message(role='user', content=action.content)) @@ -82,8 +80,8 @@ def parse_element( return parse_observation(trace, element) -def parse_trace(trace: list[tuple[Action, Observation]]): - inv_trace = [] # type: list[TraceElement] +def parse_trace(trace: list[tuple[Action, Observation]]) -> list[TraceElement]: + inv_trace: list[TraceElement] = [] for action, obs in trace: inv_trace.extend(parse_action(inv_trace, action)) inv_trace.extend(parse_observation(inv_trace, obs)) @@ -93,11 +91,11 @@ def parse_trace(trace: list[tuple[Action, Observation]]): class InvariantState(BaseModel): trace: list[TraceElement] = Field(default_factory=list) - def add_action(self, action: Action): + def add_action(self, action: Action) -> None: self.trace.extend(parse_action(self.trace, action)) - def add_observation(self, obs: Observation): + def add_observation(self, obs: Observation) -> None: self.trace.extend(parse_observation(self.trace, obs)) - def concatenate(self, other: 'InvariantState'): + def concatenate(self, other: 'InvariantState') -> None: self.trace.extend(other.trace) diff --git a/openhands/security/options.py b/openhands/security/options.py index d02a9efc2a..309686b7f0 100644 --- a/openhands/security/options.py +++ b/openhands/security/options.py @@ -1,5 +1,8 @@ +from typing import Type + +from openhands.security.analyzer import SecurityAnalyzer from openhands.security.invariant.analyzer import InvariantAnalyzer -SecurityAnalyzers = { +SecurityAnalyzers: dict[str, Type[SecurityAnalyzer]] = { 'invariant': InvariantAnalyzer, }