diff --git a/openhands/security/invariant/analyzer.py b/openhands/security/invariant/analyzer.py index f843e93043..540a9341b8 100644 --- a/openhands/security/invariant/analyzer.py +++ b/openhands/security/invariant/analyzer.py @@ -307,11 +307,17 @@ class InvariantAnalyzer(SecurityAnalyzer): new_elements = parse_element(self.trace, event) input = [e.model_dump(exclude_none=True) for e in new_elements] # type: ignore [call-overload] self.trace.extend(new_elements) - result, err = self.monitor.check(self.input, input) + check_result = self.monitor.check(self.input, input) self.input.extend(input) risk = ActionSecurityRisk.UNKNOWN - if err: - logger.warning(f'Error checking policy: {err}') + + if isinstance(check_result, tuple): + result, err = check_result + if err: + logger.warning(f'Error checking policy: {err}') + return risk + else: + logger.warning(f'Error checking policy: {check_result}') return risk risk = self.get_risk(result) diff --git a/openhands/security/invariant/client.py b/openhands/security/invariant/client.py index c418287456..f2ccc78bd6 100644 --- a/openhands/security/invariant/client.py +++ b/openhands/security/invariant/client.py @@ -50,7 +50,7 @@ class InvariantClient: return None class _Policy: - def __init__(self, invariant): + def __init__(self, invariant: 'InvariantClient') -> None: self.server = invariant.server self.session_id = invariant.session_id @@ -77,7 +77,7 @@ class InvariantClient: except (ConnectionError, Timeout, HTTPError) as err: return None, err - def from_string(self, rule: str): + def from_string(self, rule: str) -> 'InvariantClient._Policy': policy_id, err = self._create_policy(rule) if err: raise err @@ -97,7 +97,7 @@ class InvariantClient: return None, err class _Monitor: - def __init__(self, invariant): + def __init__(self, invariant: 'InvariantClient') -> None: self.server = invariant.server self.session_id = invariant.session_id self.policy = '' @@ -114,7 +114,7 @@ class InvariantClient: except (ConnectionError, Timeout, HTTPError) as err: return None, err - def from_string(self, rule: str): + def from_string(self, rule: str) -> 'InvariantClient._Monitor': monitor_id, err = self._create_monitor(rule) if err: raise err diff --git a/openhands/security/invariant/nodes.py b/openhands/security/invariant/nodes.py index 4741026474..c3d7b9713b 100644 --- a/openhands/security/invariant/nodes.py +++ b/openhands/security/invariant/nodes.py @@ -1,3 +1,4 @@ +from typing import Any, Iterable, Tuple from pydantic import BaseModel, Field from pydantic.dataclasses import dataclass @@ -10,7 +11,7 @@ class LLM: class Event(BaseModel): metadata: dict | None = Field( - default_factory=dict, description='Metadata associated with the event' + default_factory=lambda: dict(), description='Metadata associated with the event' ) @@ -30,7 +31,7 @@ class Message(Event): content: str | None tool_calls: list[ToolCall] | None = None - def __rich_repr__(self): + def __rich_repr__(self) -> Iterable[Any | tuple[Any] | tuple[str, Any] | tuple[str, Any, Any]]: # Print on separate line yield 'role', self.role yield 'content', self.content