Add extensive typing to openhands/security directory (#7732)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Graham Neubig 2025-04-08 09:51:05 -04:00 committed by GitHub
parent 60e8b5841c
commit 84e28234e5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 57 additions and 51 deletions

View File

@ -13,7 +13,7 @@ from openhands.events.stream import EventStream, EventStreamSubscriber
class SecurityAnalyzer:
"""Security analyzer that receives all events and analyzes agent actions for security risks."""
def __init__(self, event_stream: EventStream):
def __init__(self, event_stream: EventStream) -> None:
"""Initializes a new instance of the SecurityAnalyzer class.
Args:
@ -36,6 +36,7 @@ class SecurityAnalyzer:
return
try:
# Set the security_risk attribute on the event
event.security_risk = await self.security_risk(event) # type: ignore [attr-defined]
await self.act(event)
except Exception as e:

View File

@ -1,7 +1,7 @@
import ast
import re
import uuid
from typing import Any
from typing import Any, cast
import docker
from fastapi import HTTPException, Request
@ -32,12 +32,12 @@ class InvariantAnalyzer(SecurityAnalyzer):
"""Security analyzer based on Invariant."""
trace: list[TraceElement]
input: list[dict]
input: list[dict[str, Any]]
container_name: str = 'openhands-invariant-server'
image_name: str = 'ghcr.io/invariantlabs-ai/server:openhands'
api_host: str = 'http://localhost'
timeout: int = 180
settings: dict = {}
settings: dict[str, Any] = {}
check_browsing_alignment: bool = False
guardrail_llm: LLM | None = None
@ -47,7 +47,7 @@ class InvariantAnalyzer(SecurityAnalyzer):
event_stream: EventStream,
policy: str | None = None,
sid: str | None = None,
):
) -> None:
"""Initializes a new instance of the InvariantAnalzyer class."""
super().__init__(event_stream)
self.trace = []
@ -108,14 +108,16 @@ class InvariantAnalyzer(SecurityAnalyzer):
policy = ''
self.monitor = self.client.Monitor.from_string(policy)
async def close(self):
async def close(self) -> None:
self.container.stop()
async def log_event(self, event: Event) -> None:
if isinstance(event, Observation):
element = parse_element(self.trace, event)
self.trace.extend(element)
self.input.extend([e.model_dump(exclude_none=True) for e in element]) # type: ignore [call-overload]
self.input.extend(
[cast(dict[str, Any], e.model_dump(exclude_none=True)) for e in element]
)
else:
logger.debug('Invariant skipping element: event')
@ -126,7 +128,7 @@ class InvariantAnalyzer(SecurityAnalyzer):
'low': ActionSecurityRisk.LOW,
}
regex = r'(?<=risk=)\w+'
risks = []
risks: list[ActionSecurityRisk] = []
for result in results:
m = re.search(regex, result)
if m and m.group() in mapping:
@ -148,7 +150,7 @@ class InvariantAnalyzer(SecurityAnalyzer):
await self.check_usertask()
await self.check_fillaction()
async def check_usertask(self):
async def check_usertask(self) -> None:
"""Looks at the most recent trace element. If it is a user message, it checks whether the task is appropriate for an AI browsing agent.
Ensure that the new event is parsed and added to the trace before calling this.
@ -198,23 +200,24 @@ class InvariantAnalyzer(SecurityAnalyzer):
self.event_stream.add_event, new_event, event_source
)
def parse_browser_action(self, browser_action):
def parse_browser_action(
self, browser_action: str
) -> list[tuple[str | None, list[str]]]:
assert browser_action[-1] == ')'
tree = ast.parse(browser_action, mode='exec')
function_calls = []
function_calls: list[tuple[str | None, list[str]]] = []
for node in tree.body:
if isinstance(node, ast.Expr) and isinstance(node.value, ast.Call):
call_node = node.value # This contains the actual function call
# Extract function name
func_name: str | None = None
if isinstance(call_node.func, ast.Name):
func_name = call_node.func.id
elif isinstance(call_node.func, ast.Attribute):
func_name = (
f'{ast.unparse(call_node.func.value)}.{call_node.func.attr}'
)
else:
func_name = None
# Extract positional arguments
args = [ast.unparse(arg) for arg in call_node.args]
@ -223,7 +226,7 @@ class InvariantAnalyzer(SecurityAnalyzer):
raise ValueError('The code does not represent a function call.')
return function_calls
async def check_fillaction(self):
async def check_fillaction(self) -> None:
"""Looks at the most recent trace element. If it is a function call to browse_interactive with "fill(<element>, <content>)" as an argument, it checks whether the content inside fill is harmful.
Ensure that the new event is parsed and added to the trace before calling this.
@ -285,7 +288,7 @@ class InvariantAnalyzer(SecurityAnalyzer):
break
async def should_confirm(self, event: Event) -> bool:
risk = event.security_risk # type: ignore [attr-defined]
risk = event.security_risk if hasattr(event, 'security_risk') else None # type: ignore [attr-defined]
return (
risk is not None
and risk < self.settings.get('RISK_SEVERITY', ActionSecurityRisk.MEDIUM)
@ -305,24 +308,21 @@ class InvariantAnalyzer(SecurityAnalyzer):
async def security_risk(self, event: Action) -> ActionSecurityRisk:
logger.debug('Calling security_risk on InvariantAnalyzer')
new_elements = parse_element(self.trace, event)
input = [e.model_dump(exclude_none=True) for e in new_elements] # type: ignore [call-overload]
input_data = [
cast(dict[str, Any], e.model_dump(exclude_none=True)) for e in new_elements
]
self.trace.extend(new_elements)
check_result = self.monitor.check(self.input, input)
self.input.extend(input)
check_result = self.monitor.check(self.input, input_data)
self.input.extend(input_data)
risk = ActionSecurityRisk.UNKNOWN
if isinstance(check_result, tuple):
result, err = check_result
if err:
logger.warning(f'Error checking policy: {err}')
return risk
else:
logger.warning(f'Error checking policy: {check_result}')
# Process check_result
result, err = check_result
if err:
logger.warning(f'Error checking policy: {err}')
return risk
risk = self.get_risk(result)
return risk
return self.get_risk(result)
### Handle API requests
async def handle_api_request(self, request: Request) -> Any:
@ -343,23 +343,23 @@ class InvariantAnalyzer(SecurityAnalyzer):
return await self.update_settings(request)
raise HTTPException(status_code=405, detail='Method Not Allowed')
async def export_trace(self, request: Request) -> Any:
async def export_trace(self, request: Request) -> JSONResponse:
return JSONResponse(content=self.input)
async def get_policy(self, request: Request) -> Any:
async def get_policy(self, request: Request) -> JSONResponse:
return JSONResponse(content={'policy': self.monitor.policy})
async def update_policy(self, request: Request) -> Any:
async def update_policy(self, request: Request) -> JSONResponse:
data = await request.json()
policy = data.get('policy')
new_monitor = self.client.Monitor.from_string(policy)
self.monitor = new_monitor
return JSONResponse(content={'policy': policy})
async def get_settings(self, request: Request) -> Any:
async def get_settings(self, request: Request) -> JSONResponse:
return JSONResponse(content=self.settings)
async def update_settings(self, request: Request) -> Any:
async def update_settings(self, request: Request) -> JSONResponse:
settings = await request.json()
self.settings = settings
return JSONResponse(content=self.settings)

View File

@ -1,5 +1,5 @@
import time
from typing import Any, Union
from typing import Any
import httpx
@ -7,7 +7,7 @@ import httpx
class InvariantClient:
timeout: int = 120
def __init__(self, server_url: str, session_id: str | None = None):
def __init__(self, server_url: str, session_id: str | None = None) -> None:
self.server = server_url
self.session_id, err = self._create_session(session_id)
if err:
@ -38,7 +38,7 @@ class InvariantClient:
return None, err
return None, ConnectionError('Connection timed out')
def close_session(self) -> Union[None, Exception]:
def close_session(self) -> Exception | None:
try:
response = httpx.delete(
f'{self.server}/session/?session_id={self.session_id}', timeout=60
@ -52,6 +52,7 @@ class InvariantClient:
def __init__(self, invariant: 'InvariantClient') -> None:
self.server = invariant.server
self.session_id = invariant.session_id
self.policy_id: str | None = None
def _create_policy(self, rule: str) -> tuple[str | None, Exception | None]:
try:
@ -83,7 +84,7 @@ class InvariantClient:
self.policy_id = policy_id
return self
def analyze(self, trace: list[dict]) -> Union[Any, Exception]:
def analyze(self, trace: list[dict[str, Any]]) -> tuple[Any, Exception | None]:
try:
response = httpx.post(
f'{self.server}/policy/{self.policy_id}/analyze?session_id={self.session_id}',
@ -100,6 +101,7 @@ class InvariantClient:
self.server = invariant.server
self.session_id = invariant.session_id
self.policy = ''
self.monitor_id: str | None = None
def _create_monitor(self, rule: str) -> tuple[str | None, Exception | None]:
try:
@ -122,8 +124,10 @@ class InvariantClient:
return self
def check(
self, past_events: list[dict], pending_events: list[dict]
) -> Union[Any, Exception]:
self,
past_events: list[dict[str, Any]],
pending_events: list[dict[str, Any]],
) -> tuple[Any, Exception | None]:
try:
response = httpx.post(
f'{self.server}/monitor/{self.monitor_id}/check?session_id={self.session_id}',

View File

@ -11,14 +11,14 @@ class LLM:
class Event(BaseModel):
metadata: dict | None = Field(
metadata: dict[str, Any] | None = Field(
default_factory=lambda: dict(), description='Metadata associated with the event'
)
class Function(BaseModel):
name: str
arguments: dict
arguments: dict[str, Any]
class ToolCall(Event):

View File

@ -1,5 +1,3 @@
from typing import Union
from pydantic import BaseModel, Field
from openhands.core.logger import openhands_logger as logger
@ -18,7 +16,7 @@ from openhands.events.observation import (
from openhands.events.serialization.event import event_to_dict
from openhands.security.invariant.nodes import Function, Message, ToolCall, ToolOutput
TraceElement = Union[Message, ToolCall, ToolOutput, Function]
TraceElement = Message | ToolCall | ToolOutput | Function
def get_next_id(trace: list[TraceElement]) -> str:
@ -40,7 +38,7 @@ def get_last_id(
def parse_action(trace: list[TraceElement], action: Action) -> list[TraceElement]:
next_id = get_next_id(trace)
inv_trace = [] # type: list[TraceElement]
inv_trace: list[TraceElement] = []
if type(action) == MessageAction:
if action.source == EventSource.USER:
inv_trace.append(Message(role='user', content=action.content))
@ -82,8 +80,8 @@ def parse_element(
return parse_observation(trace, element)
def parse_trace(trace: list[tuple[Action, Observation]]):
inv_trace = [] # type: list[TraceElement]
def parse_trace(trace: list[tuple[Action, Observation]]) -> list[TraceElement]:
inv_trace: list[TraceElement] = []
for action, obs in trace:
inv_trace.extend(parse_action(inv_trace, action))
inv_trace.extend(parse_observation(inv_trace, obs))
@ -93,11 +91,11 @@ def parse_trace(trace: list[tuple[Action, Observation]]):
class InvariantState(BaseModel):
trace: list[TraceElement] = Field(default_factory=list)
def add_action(self, action: Action):
def add_action(self, action: Action) -> None:
self.trace.extend(parse_action(self.trace, action))
def add_observation(self, obs: Observation):
def add_observation(self, obs: Observation) -> None:
self.trace.extend(parse_observation(self.trace, obs))
def concatenate(self, other: 'InvariantState'):
def concatenate(self, other: 'InvariantState') -> None:
self.trace.extend(other.trace)

View File

@ -1,5 +1,8 @@
from typing import Type
from openhands.security.analyzer import SecurityAnalyzer
from openhands.security.invariant.analyzer import InvariantAnalyzer
SecurityAnalyzers = {
SecurityAnalyzers: dict[str, Type[SecurityAnalyzer]] = {
'invariant': InvariantAnalyzer,
}