mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
fix: asyncio issues with security analyzer + enable security analyzer in cli (#5356)
This commit is contained in:
parent
92b38dcea1
commit
871c544b74
@ -284,6 +284,8 @@ class AgentController:
|
||||
self.agent.llm.metrics.merge(observation.llm_metrics)
|
||||
|
||||
if self._pending_action and self._pending_action.id == observation.cause:
|
||||
if self.state.agent_state == AgentState.AWAITING_USER_CONFIRMATION:
|
||||
return
|
||||
self._pending_action = None
|
||||
if self.state.agent_state == AgentState.USER_CONFIRMED:
|
||||
await self.set_agent_state_to(AgentState.RUNNING)
|
||||
@ -369,6 +371,7 @@ class AgentController:
|
||||
else:
|
||||
confirmation_state = ActionConfirmationStatus.REJECTED
|
||||
self._pending_action.confirmation_state = confirmation_state # type: ignore[attr-defined]
|
||||
self._pending_action._id = None # type: ignore[attr-defined]
|
||||
self.event_stream.add_event(self._pending_action, EventSource.AGENT)
|
||||
|
||||
self.state.agent_state = new_state
|
||||
|
||||
@ -11,6 +11,7 @@ from openhands import __version__
|
||||
from openhands.controller import AgentController
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.core.config import (
|
||||
AppConfig,
|
||||
get_parser,
|
||||
load_app_config,
|
||||
)
|
||||
@ -20,6 +21,7 @@ from openhands.core.schema import AgentState
|
||||
from openhands.events import EventSource, EventStream, EventStreamSubscriber
|
||||
from openhands.events.action import (
|
||||
Action,
|
||||
ActionConfirmationStatus,
|
||||
ChangeAgentStateAction,
|
||||
CmdRunAction,
|
||||
FileEditAction,
|
||||
@ -30,10 +32,12 @@ from openhands.events.observation import (
|
||||
AgentStateChangedObservation,
|
||||
CmdOutputObservation,
|
||||
FileEditObservation,
|
||||
NullObservation,
|
||||
)
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.runtime import get_runtime_cls
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.security import SecurityAnalyzer, options
|
||||
from openhands.storage import get_file_store
|
||||
|
||||
|
||||
@ -45,6 +49,15 @@ def display_command(command: str):
|
||||
print('❯ ' + colored(command + '\n', 'green'))
|
||||
|
||||
|
||||
def display_confirmation(confirmation_state: ActionConfirmationStatus):
|
||||
if confirmation_state == ActionConfirmationStatus.CONFIRMED:
|
||||
print(colored('✅ ' + confirmation_state + '\n', 'green'))
|
||||
elif confirmation_state == ActionConfirmationStatus.REJECTED:
|
||||
print(colored('❌ ' + confirmation_state + '\n', 'red'))
|
||||
else:
|
||||
print(colored('⏳ ' + confirmation_state + '\n', 'yellow'))
|
||||
|
||||
|
||||
def display_command_output(output: str):
|
||||
lines = output.split('\n')
|
||||
for line in lines:
|
||||
@ -59,7 +72,7 @@ def display_file_edit(event: FileEditAction | FileEditObservation):
|
||||
print(colored(str(event), 'green'))
|
||||
|
||||
|
||||
def display_event(event: Event):
|
||||
def display_event(event: Event, config: AppConfig):
|
||||
if isinstance(event, Action):
|
||||
if hasattr(event, 'thought'):
|
||||
display_message(event.thought)
|
||||
@ -74,6 +87,8 @@ def display_event(event: Event):
|
||||
display_file_edit(event)
|
||||
if isinstance(event, FileEditObservation):
|
||||
display_file_edit(event)
|
||||
if hasattr(event, 'confirmation_state') and config.security.confirmation_mode:
|
||||
display_confirmation(event.confirmation_state)
|
||||
|
||||
|
||||
async def main():
|
||||
@ -119,12 +134,18 @@ async def main():
|
||||
headless_mode=True,
|
||||
)
|
||||
|
||||
if config.security.security_analyzer:
|
||||
options.SecurityAnalyzers.get(
|
||||
config.security.security_analyzer, SecurityAnalyzer
|
||||
)(event_stream)
|
||||
|
||||
controller = AgentController(
|
||||
agent=agent,
|
||||
max_iterations=config.max_iterations,
|
||||
max_budget_per_task=config.max_budget_per_task,
|
||||
agent_to_llm_config=config.get_agent_to_llm_config_map(),
|
||||
event_stream=event_stream,
|
||||
confirmation_mode=config.security.confirmation_mode,
|
||||
)
|
||||
|
||||
async def prompt_for_next_task():
|
||||
@ -143,14 +164,34 @@ async def main():
|
||||
action = MessageAction(content=next_message)
|
||||
event_stream.add_event(action, EventSource.USER)
|
||||
|
||||
async def prompt_for_user_confirmation():
|
||||
loop = asyncio.get_event_loop()
|
||||
user_confirmation = await loop.run_in_executor(
|
||||
None, lambda: input('Confirm action (possible security risk)? (y/n) >> ')
|
||||
)
|
||||
return user_confirmation.lower() == 'y'
|
||||
|
||||
async def on_event(event: Event):
|
||||
display_event(event)
|
||||
display_event(event, config)
|
||||
if isinstance(event, AgentStateChangedObservation):
|
||||
if event.agent_state in [
|
||||
AgentState.AWAITING_USER_INPUT,
|
||||
AgentState.FINISHED,
|
||||
]:
|
||||
await prompt_for_next_task()
|
||||
if (
|
||||
isinstance(event, NullObservation)
|
||||
and controller.state.agent_state == AgentState.AWAITING_USER_CONFIRMATION
|
||||
):
|
||||
user_confirmed = await prompt_for_user_confirmation()
|
||||
if user_confirmed:
|
||||
event_stream.add_event(
|
||||
ChangeAgentStateAction(AgentState.USER_CONFIRMED), EventSource.USER
|
||||
)
|
||||
else:
|
||||
event_stream.add_event(
|
||||
ChangeAgentStateAction(AgentState.USER_REJECTED), EventSource.USER
|
||||
)
|
||||
|
||||
event_stream.subscribe(EventStreamSubscriber.MAIN, on_event, str(uuid4()))
|
||||
|
||||
|
||||
@ -32,5 +32,9 @@ class SecurityConfig:
|
||||
|
||||
return f"SecurityConfig({', '.join(attr_str)})"
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, security_config_dict: dict) -> 'SecurityConfig':
|
||||
return cls(**security_config_dict)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
@ -18,6 +18,7 @@ from openhands.core.config.config_utils import (
|
||||
)
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.core.config.sandbox_config import SandboxConfig
|
||||
from openhands.core.config.security_config import SecurityConfig
|
||||
|
||||
load_dotenv()
|
||||
|
||||
@ -144,6 +145,12 @@ def load_from_toml(cfg: AppConfig, toml_file: str = 'config.toml'):
|
||||
)
|
||||
llm_config = LLMConfig.from_dict(nested_value)
|
||||
cfg.set_llm_config(llm_config, nested_key)
|
||||
elif key is not None and key.lower() == 'security':
|
||||
logger.openhands_logger.debug(
|
||||
'Attempt to load security config from config toml'
|
||||
)
|
||||
security_config = SecurityConfig.from_dict(value)
|
||||
cfg.security = security_config
|
||||
elif not key.startswith('sandbox') and key.lower() != 'core':
|
||||
logger.openhands_logger.warning(
|
||||
f'Unknown key in {toml_file}: "{key}"'
|
||||
|
||||
@ -300,7 +300,7 @@ class InvariantAnalyzer(SecurityAnalyzer):
|
||||
)
|
||||
# we should confirm only on agent actions
|
||||
event_source = event.source if event.source else EventSource.AGENT
|
||||
await call_sync_from_async(self.event_stream.add_event, new_event, event_source)
|
||||
self.event_stream.add_event(new_event, event_source)
|
||||
|
||||
async def security_risk(self, event: Action) -> ActionSecurityRisk:
|
||||
logger.debug('Calling security_risk on InvariantAnalyzer')
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user