mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Add extensive typing to openhands/security directory (#7732)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
parent
60e8b5841c
commit
84e28234e5
@ -13,7 +13,7 @@ from openhands.events.stream import EventStream, EventStreamSubscriber
|
||||
class SecurityAnalyzer:
|
||||
"""Security analyzer that receives all events and analyzes agent actions for security risks."""
|
||||
|
||||
def __init__(self, event_stream: EventStream):
|
||||
def __init__(self, event_stream: EventStream) -> None:
|
||||
"""Initializes a new instance of the SecurityAnalyzer class.
|
||||
|
||||
Args:
|
||||
@ -36,6 +36,7 @@ class SecurityAnalyzer:
|
||||
return
|
||||
|
||||
try:
|
||||
# Set the security_risk attribute on the event
|
||||
event.security_risk = await self.security_risk(event) # type: ignore [attr-defined]
|
||||
await self.act(event)
|
||||
except Exception as e:
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import ast
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
import docker
|
||||
from fastapi import HTTPException, Request
|
||||
@ -32,12 +32,12 @@ class InvariantAnalyzer(SecurityAnalyzer):
|
||||
"""Security analyzer based on Invariant."""
|
||||
|
||||
trace: list[TraceElement]
|
||||
input: list[dict]
|
||||
input: list[dict[str, Any]]
|
||||
container_name: str = 'openhands-invariant-server'
|
||||
image_name: str = 'ghcr.io/invariantlabs-ai/server:openhands'
|
||||
api_host: str = 'http://localhost'
|
||||
timeout: int = 180
|
||||
settings: dict = {}
|
||||
settings: dict[str, Any] = {}
|
||||
|
||||
check_browsing_alignment: bool = False
|
||||
guardrail_llm: LLM | None = None
|
||||
@ -47,7 +47,7 @@ class InvariantAnalyzer(SecurityAnalyzer):
|
||||
event_stream: EventStream,
|
||||
policy: str | None = None,
|
||||
sid: str | None = None,
|
||||
):
|
||||
) -> None:
|
||||
"""Initializes a new instance of the InvariantAnalzyer class."""
|
||||
super().__init__(event_stream)
|
||||
self.trace = []
|
||||
@ -108,14 +108,16 @@ class InvariantAnalyzer(SecurityAnalyzer):
|
||||
policy = ''
|
||||
self.monitor = self.client.Monitor.from_string(policy)
|
||||
|
||||
async def close(self):
|
||||
async def close(self) -> None:
|
||||
self.container.stop()
|
||||
|
||||
async def log_event(self, event: Event) -> None:
|
||||
if isinstance(event, Observation):
|
||||
element = parse_element(self.trace, event)
|
||||
self.trace.extend(element)
|
||||
self.input.extend([e.model_dump(exclude_none=True) for e in element]) # type: ignore [call-overload]
|
||||
self.input.extend(
|
||||
[cast(dict[str, Any], e.model_dump(exclude_none=True)) for e in element]
|
||||
)
|
||||
else:
|
||||
logger.debug('Invariant skipping element: event')
|
||||
|
||||
@ -126,7 +128,7 @@ class InvariantAnalyzer(SecurityAnalyzer):
|
||||
'low': ActionSecurityRisk.LOW,
|
||||
}
|
||||
regex = r'(?<=risk=)\w+'
|
||||
risks = []
|
||||
risks: list[ActionSecurityRisk] = []
|
||||
for result in results:
|
||||
m = re.search(regex, result)
|
||||
if m and m.group() in mapping:
|
||||
@ -148,7 +150,7 @@ class InvariantAnalyzer(SecurityAnalyzer):
|
||||
await self.check_usertask()
|
||||
await self.check_fillaction()
|
||||
|
||||
async def check_usertask(self):
|
||||
async def check_usertask(self) -> None:
|
||||
"""Looks at the most recent trace element. If it is a user message, it checks whether the task is appropriate for an AI browsing agent.
|
||||
|
||||
Ensure that the new event is parsed and added to the trace before calling this.
|
||||
@ -198,23 +200,24 @@ class InvariantAnalyzer(SecurityAnalyzer):
|
||||
self.event_stream.add_event, new_event, event_source
|
||||
)
|
||||
|
||||
def parse_browser_action(self, browser_action):
|
||||
def parse_browser_action(
|
||||
self, browser_action: str
|
||||
) -> list[tuple[str | None, list[str]]]:
|
||||
assert browser_action[-1] == ')'
|
||||
tree = ast.parse(browser_action, mode='exec')
|
||||
function_calls = []
|
||||
function_calls: list[tuple[str | None, list[str]]] = []
|
||||
|
||||
for node in tree.body:
|
||||
if isinstance(node, ast.Expr) and isinstance(node.value, ast.Call):
|
||||
call_node = node.value # This contains the actual function call
|
||||
# Extract function name
|
||||
func_name: str | None = None
|
||||
if isinstance(call_node.func, ast.Name):
|
||||
func_name = call_node.func.id
|
||||
elif isinstance(call_node.func, ast.Attribute):
|
||||
func_name = (
|
||||
f'{ast.unparse(call_node.func.value)}.{call_node.func.attr}'
|
||||
)
|
||||
else:
|
||||
func_name = None
|
||||
|
||||
# Extract positional arguments
|
||||
args = [ast.unparse(arg) for arg in call_node.args]
|
||||
@ -223,7 +226,7 @@ class InvariantAnalyzer(SecurityAnalyzer):
|
||||
raise ValueError('The code does not represent a function call.')
|
||||
return function_calls
|
||||
|
||||
async def check_fillaction(self):
|
||||
async def check_fillaction(self) -> None:
|
||||
"""Looks at the most recent trace element. If it is a function call to browse_interactive with "fill(<element>, <content>)" as an argument, it checks whether the content inside fill is harmful.
|
||||
|
||||
Ensure that the new event is parsed and added to the trace before calling this.
|
||||
@ -285,7 +288,7 @@ class InvariantAnalyzer(SecurityAnalyzer):
|
||||
break
|
||||
|
||||
async def should_confirm(self, event: Event) -> bool:
|
||||
risk = event.security_risk # type: ignore [attr-defined]
|
||||
risk = event.security_risk if hasattr(event, 'security_risk') else None # type: ignore [attr-defined]
|
||||
return (
|
||||
risk is not None
|
||||
and risk < self.settings.get('RISK_SEVERITY', ActionSecurityRisk.MEDIUM)
|
||||
@ -305,24 +308,21 @@ class InvariantAnalyzer(SecurityAnalyzer):
|
||||
async def security_risk(self, event: Action) -> ActionSecurityRisk:
|
||||
logger.debug('Calling security_risk on InvariantAnalyzer')
|
||||
new_elements = parse_element(self.trace, event)
|
||||
input = [e.model_dump(exclude_none=True) for e in new_elements] # type: ignore [call-overload]
|
||||
input_data = [
|
||||
cast(dict[str, Any], e.model_dump(exclude_none=True)) for e in new_elements
|
||||
]
|
||||
self.trace.extend(new_elements)
|
||||
check_result = self.monitor.check(self.input, input)
|
||||
self.input.extend(input)
|
||||
check_result = self.monitor.check(self.input, input_data)
|
||||
self.input.extend(input_data)
|
||||
risk = ActionSecurityRisk.UNKNOWN
|
||||
|
||||
if isinstance(check_result, tuple):
|
||||
result, err = check_result
|
||||
if err:
|
||||
logger.warning(f'Error checking policy: {err}')
|
||||
return risk
|
||||
else:
|
||||
logger.warning(f'Error checking policy: {check_result}')
|
||||
# Process check_result
|
||||
result, err = check_result
|
||||
if err:
|
||||
logger.warning(f'Error checking policy: {err}')
|
||||
return risk
|
||||
|
||||
risk = self.get_risk(result)
|
||||
|
||||
return risk
|
||||
return self.get_risk(result)
|
||||
|
||||
### Handle API requests
|
||||
async def handle_api_request(self, request: Request) -> Any:
|
||||
@ -343,23 +343,23 @@ class InvariantAnalyzer(SecurityAnalyzer):
|
||||
return await self.update_settings(request)
|
||||
raise HTTPException(status_code=405, detail='Method Not Allowed')
|
||||
|
||||
async def export_trace(self, request: Request) -> Any:
|
||||
async def export_trace(self, request: Request) -> JSONResponse:
|
||||
return JSONResponse(content=self.input)
|
||||
|
||||
async def get_policy(self, request: Request) -> Any:
|
||||
async def get_policy(self, request: Request) -> JSONResponse:
|
||||
return JSONResponse(content={'policy': self.monitor.policy})
|
||||
|
||||
async def update_policy(self, request: Request) -> Any:
|
||||
async def update_policy(self, request: Request) -> JSONResponse:
|
||||
data = await request.json()
|
||||
policy = data.get('policy')
|
||||
new_monitor = self.client.Monitor.from_string(policy)
|
||||
self.monitor = new_monitor
|
||||
return JSONResponse(content={'policy': policy})
|
||||
|
||||
async def get_settings(self, request: Request) -> Any:
|
||||
async def get_settings(self, request: Request) -> JSONResponse:
|
||||
return JSONResponse(content=self.settings)
|
||||
|
||||
async def update_settings(self, request: Request) -> Any:
|
||||
async def update_settings(self, request: Request) -> JSONResponse:
|
||||
settings = await request.json()
|
||||
self.settings = settings
|
||||
return JSONResponse(content=self.settings)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import time
|
||||
from typing import Any, Union
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
@ -7,7 +7,7 @@ import httpx
|
||||
class InvariantClient:
|
||||
timeout: int = 120
|
||||
|
||||
def __init__(self, server_url: str, session_id: str | None = None):
|
||||
def __init__(self, server_url: str, session_id: str | None = None) -> None:
|
||||
self.server = server_url
|
||||
self.session_id, err = self._create_session(session_id)
|
||||
if err:
|
||||
@ -38,7 +38,7 @@ class InvariantClient:
|
||||
return None, err
|
||||
return None, ConnectionError('Connection timed out')
|
||||
|
||||
def close_session(self) -> Union[None, Exception]:
|
||||
def close_session(self) -> Exception | None:
|
||||
try:
|
||||
response = httpx.delete(
|
||||
f'{self.server}/session/?session_id={self.session_id}', timeout=60
|
||||
@ -52,6 +52,7 @@ class InvariantClient:
|
||||
def __init__(self, invariant: 'InvariantClient') -> None:
|
||||
self.server = invariant.server
|
||||
self.session_id = invariant.session_id
|
||||
self.policy_id: str | None = None
|
||||
|
||||
def _create_policy(self, rule: str) -> tuple[str | None, Exception | None]:
|
||||
try:
|
||||
@ -83,7 +84,7 @@ class InvariantClient:
|
||||
self.policy_id = policy_id
|
||||
return self
|
||||
|
||||
def analyze(self, trace: list[dict]) -> Union[Any, Exception]:
|
||||
def analyze(self, trace: list[dict[str, Any]]) -> tuple[Any, Exception | None]:
|
||||
try:
|
||||
response = httpx.post(
|
||||
f'{self.server}/policy/{self.policy_id}/analyze?session_id={self.session_id}',
|
||||
@ -100,6 +101,7 @@ class InvariantClient:
|
||||
self.server = invariant.server
|
||||
self.session_id = invariant.session_id
|
||||
self.policy = ''
|
||||
self.monitor_id: str | None = None
|
||||
|
||||
def _create_monitor(self, rule: str) -> tuple[str | None, Exception | None]:
|
||||
try:
|
||||
@ -122,8 +124,10 @@ class InvariantClient:
|
||||
return self
|
||||
|
||||
def check(
|
||||
self, past_events: list[dict], pending_events: list[dict]
|
||||
) -> Union[Any, Exception]:
|
||||
self,
|
||||
past_events: list[dict[str, Any]],
|
||||
pending_events: list[dict[str, Any]],
|
||||
) -> tuple[Any, Exception | None]:
|
||||
try:
|
||||
response = httpx.post(
|
||||
f'{self.server}/monitor/{self.monitor_id}/check?session_id={self.session_id}',
|
||||
|
||||
@ -11,14 +11,14 @@ class LLM:
|
||||
|
||||
|
||||
class Event(BaseModel):
|
||||
metadata: dict | None = Field(
|
||||
metadata: dict[str, Any] | None = Field(
|
||||
default_factory=lambda: dict(), description='Metadata associated with the event'
|
||||
)
|
||||
|
||||
|
||||
class Function(BaseModel):
|
||||
name: str
|
||||
arguments: dict
|
||||
arguments: dict[str, Any]
|
||||
|
||||
|
||||
class ToolCall(Event):
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from typing import Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@ -18,7 +16,7 @@ from openhands.events.observation import (
|
||||
from openhands.events.serialization.event import event_to_dict
|
||||
from openhands.security.invariant.nodes import Function, Message, ToolCall, ToolOutput
|
||||
|
||||
TraceElement = Union[Message, ToolCall, ToolOutput, Function]
|
||||
TraceElement = Message | ToolCall | ToolOutput | Function
|
||||
|
||||
|
||||
def get_next_id(trace: list[TraceElement]) -> str:
|
||||
@ -40,7 +38,7 @@ def get_last_id(
|
||||
|
||||
def parse_action(trace: list[TraceElement], action: Action) -> list[TraceElement]:
|
||||
next_id = get_next_id(trace)
|
||||
inv_trace = [] # type: list[TraceElement]
|
||||
inv_trace: list[TraceElement] = []
|
||||
if type(action) == MessageAction:
|
||||
if action.source == EventSource.USER:
|
||||
inv_trace.append(Message(role='user', content=action.content))
|
||||
@ -82,8 +80,8 @@ def parse_element(
|
||||
return parse_observation(trace, element)
|
||||
|
||||
|
||||
def parse_trace(trace: list[tuple[Action, Observation]]):
|
||||
inv_trace = [] # type: list[TraceElement]
|
||||
def parse_trace(trace: list[tuple[Action, Observation]]) -> list[TraceElement]:
|
||||
inv_trace: list[TraceElement] = []
|
||||
for action, obs in trace:
|
||||
inv_trace.extend(parse_action(inv_trace, action))
|
||||
inv_trace.extend(parse_observation(inv_trace, obs))
|
||||
@ -93,11 +91,11 @@ def parse_trace(trace: list[tuple[Action, Observation]]):
|
||||
class InvariantState(BaseModel):
|
||||
trace: list[TraceElement] = Field(default_factory=list)
|
||||
|
||||
def add_action(self, action: Action):
|
||||
def add_action(self, action: Action) -> None:
|
||||
self.trace.extend(parse_action(self.trace, action))
|
||||
|
||||
def add_observation(self, obs: Observation):
|
||||
def add_observation(self, obs: Observation) -> None:
|
||||
self.trace.extend(parse_observation(self.trace, obs))
|
||||
|
||||
def concatenate(self, other: 'InvariantState'):
|
||||
def concatenate(self, other: 'InvariantState') -> None:
|
||||
self.trace.extend(other.trace)
|
||||
|
||||
@ -1,5 +1,8 @@
|
||||
from typing import Type
|
||||
|
||||
from openhands.security.analyzer import SecurityAnalyzer
|
||||
from openhands.security.invariant.analyzer import InvariantAnalyzer
|
||||
|
||||
SecurityAnalyzers = {
|
||||
SecurityAnalyzers: dict[str, Type[SecurityAnalyzer]] = {
|
||||
'invariant': InvariantAnalyzer,
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user