Fix mypy errors in security/invariant directory (#6908)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Graham Neubig 2025-02-24 10:00:43 -05:00 committed by GitHub
parent ecd573febc
commit 753e3c4205
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 16 additions and 9 deletions

View File

@ -307,11 +307,17 @@ class InvariantAnalyzer(SecurityAnalyzer):
new_elements = parse_element(self.trace, event)
input = [e.model_dump(exclude_none=True) for e in new_elements] # type: ignore [call-overload]
self.trace.extend(new_elements)
result, err = self.monitor.check(self.input, input)
check_result = self.monitor.check(self.input, input)
self.input.extend(input)
risk = ActionSecurityRisk.UNKNOWN
if err:
logger.warning(f'Error checking policy: {err}')
if isinstance(check_result, tuple):
result, err = check_result
if err:
logger.warning(f'Error checking policy: {err}')
return risk
else:
logger.warning(f'Error checking policy: {check_result}')
return risk
risk = self.get_risk(result)

View File

@ -50,7 +50,7 @@ class InvariantClient:
return None
class _Policy:
def __init__(self, invariant):
def __init__(self, invariant: 'InvariantClient') -> None:
self.server = invariant.server
self.session_id = invariant.session_id
@ -77,7 +77,7 @@ class InvariantClient:
except (ConnectionError, Timeout, HTTPError) as err:
return None, err
def from_string(self, rule: str):
def from_string(self, rule: str) -> 'InvariantClient._Policy':
policy_id, err = self._create_policy(rule)
if err:
raise err
@ -97,7 +97,7 @@ class InvariantClient:
return None, err
class _Monitor:
def __init__(self, invariant):
def __init__(self, invariant: 'InvariantClient') -> None:
self.server = invariant.server
self.session_id = invariant.session_id
self.policy = ''
@ -114,7 +114,7 @@ class InvariantClient:
except (ConnectionError, Timeout, HTTPError) as err:
return None, err
def from_string(self, rule: str):
def from_string(self, rule: str) -> 'InvariantClient._Monitor':
monitor_id, err = self._create_monitor(rule)
if err:
raise err

View File

@ -1,3 +1,4 @@
from typing import Any, Iterable, Tuple
from pydantic import BaseModel, Field
from pydantic.dataclasses import dataclass
@ -10,7 +11,7 @@ class LLM:
class Event(BaseModel):
metadata: dict | None = Field(
default_factory=dict, description='Metadata associated with the event'
default_factory=lambda: dict(), description='Metadata associated with the event'
)
@ -30,7 +31,7 @@ class Message(Event):
content: str | None
tool_calls: list[ToolCall] | None = None
def __rich_repr__(self):
def __rich_repr__(self) -> Iterable[Any | tuple[Any] | tuple[str, Any] | tuple[str, Any, Any]]:
# Print on separate line
yield 'role', self.role
yield 'content', self.content