Plumb custom secrets to runtime (#8330)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Rohit Malhotra
2025-05-15 20:06:30 -04:00
committed by GitHub
parent 1f827170f4
commit feb04dc65f
17 changed files with 246 additions and 13 deletions

View File

@@ -212,6 +212,17 @@ export const chatSlice = createSlice({
content += `\n\n- ${host} (port ${port})`;
}
}
if (
recallObs.extras.custom_secrets_descriptions &&
Object.keys(recallObs.extras.custom_secrets_descriptions).length > 0
) {
content += `\n\n**Custom Secrets**`;
for (const [name, description] of Object.entries(
recallObs.extras.custom_secrets_descriptions,
)) {
content += `\n\n- $${name}: ${description}`;
}
}
if (recallObs.extras.repo_instructions) {
content += `\n\n**Repository Instructions:**\n\n${recallObs.extras.repo_instructions}`;
}

View File

@@ -123,6 +123,7 @@ export interface RecallObservation extends OpenHandsObservationEvent<"recall"> {
repo_directory?: string;
repo_instructions?: string;
runtime_hosts?: Record<string, number>;
custom_secrets_descriptions?: Record<string, string>;
additional_agent_instructions?: string;
date?: string;
microagent_knowledge?: MicroagentKnowledge[];

View File

@@ -8,7 +8,7 @@ At the user's request, repository {{ repository_info.repo_name }} has been clone
{{ repository_instructions }}
</REPOSITORY_INSTRUCTIONS>
{% endif %}
{% if runtime_info and (runtime_info.available_hosts or runtime_info.additional_agent_instructions) -%}
{% if runtime_info -%}
<RUNTIME_INFORMATION>
{% if runtime_info.available_hosts %}
The user has access to the following hosts for accessing a web application,
@@ -24,6 +24,14 @@ For example, if you are using vite.config.js, you should set server.host and ser
{% if runtime_info.additional_agent_instructions %}
{{ runtime_info.additional_agent_instructions }}
{% endif %}
{% if runtime_info.custom_secrets_descriptions %}
<CUSTOM_SECRETS>
You are have access to the following environment variables
{% for secret_name, secret_description in runtime_info.custom_secrets_descriptions.items() %}
* $**{{ secret_name }}**: {{ secret_description }}
{% endfor %}
</CUSTOM_SECRETS>
{% endif %}
{% if runtime_info.date %}
Today's date is {{ runtime_info.date }} (UTC).
{% endif %}

View File

@@ -154,7 +154,7 @@ def create_memory(
if runtime:
# sets available hosts
memory.set_runtime_info(runtime)
memory.set_runtime_info(runtime, {})
# loads microagents from repo/.openhands/microagents
microagents: list[BaseMicroagent] = runtime.get_microagents_from_selected_repo(

View File

@@ -74,6 +74,7 @@ class RecallObservation(Observation):
runtime_hosts: dict[str, int] = field(default_factory=dict)
additional_agent_instructions: str = ''
date: str = ''
custom_secrets_descriptions: dict[str, str] = field(default_factory=dict)
# knowledge
microagent_knowledge: list[MicroagentKnowledge] = field(default_factory=list)
@@ -114,7 +115,8 @@ class RecallObservation(Observation):
f'repo_instructions={self.repo_instructions[:20]}...',
f'runtime_hosts={self.runtime_hosts}',
f'additional_agent_instructions={self.additional_agent_instructions[:20]}...',
f'date={self.date}',
f'date={self.date}'
f'custom_secrets_descriptions={self.custom_secrets_descriptions}',
]
)
else:

View File

@@ -451,9 +451,13 @@ class ConversationMemory:
available_hosts=obs.runtime_hosts,
additional_agent_instructions=obs.additional_agent_instructions,
date=date,
custom_secrets_descriptions=obs.custom_secrets_descriptions,
)
else:
runtime_info = RuntimeInfo(date=date)
runtime_info = RuntimeInfo(
date=date,
custom_secrets_descriptions=obs.custom_secrets_descriptions,
)
repo_instructions = (
obs.repo_instructions if obs.repo_instructions else ''

View File

@@ -176,6 +176,9 @@ class Memory:
microagent_knowledge=microagent_knowledge,
content='Added workspace context',
date=self.runtime_info.date if self.runtime_info is not None else '',
custom_secrets_descriptions=self.runtime_info.custom_secrets_descriptions
if self.runtime_info is not None
else {},
)
return obs
return None
@@ -266,7 +269,9 @@ class Memory:
else:
self.repository_info = None
def set_runtime_info(self, runtime: Runtime) -> None:
def set_runtime_info(
self, runtime: Runtime, custom_secrets_descriptions: dict[str, str]
) -> None:
"""Store runtime info (web hosts, ports, etc.)."""
# e.g. { '127.0.0.1': 8080 }
utc_now = datetime.now(timezone.utc)
@@ -277,9 +282,12 @@ class Memory:
available_hosts=runtime.web_hosts,
additional_agent_instructions=runtime.additional_agent_instructions,
date=date,
custom_secrets_descriptions=custom_secrets_descriptions,
)
else:
self.runtime_info = RuntimeInfo(date=date)
self.runtime_info = RuntimeInfo(
date=date, custom_secrets_descriptions=custom_secrets_descriptions
)
def send_error_message(self, message_id: str, message: str):
"""Sends an error message if the callback function was provided."""

View File

@@ -100,6 +100,8 @@ async def connect(connection_id: str, environ: dict) -> None:
git_provider_tokens = user_secrets.provider_tokens
session_init_args['git_provider_tokens'] = git_provider_tokens
if user_secrets:
session_init_args['custom_secrets'] = user_secrets.custom_secrets
conversation_init_data = ConversationInitData(**session_init_args)

View File

@@ -9,6 +9,7 @@ from pydantic import BaseModel
from openhands.core.logger import openhands_logger as logger
from openhands.events.action.message import MessageAction
from openhands.integrations.provider import (
CUSTOM_SECRETS_TYPE_WITH_JSON_SCHEMA,
PROVIDER_TOKEN_TYPE,
ProviderHandler,
)
@@ -35,6 +36,7 @@ from openhands.server.user_auth import (
get_auth_type,
get_provider_tokens,
get_user_id,
get_user_secrets,
)
from openhands.server.user_auth.user_auth import AuthType
from openhands.server.utils import get_conversation_store
@@ -44,6 +46,7 @@ from openhands.storage.data_models.conversation_metadata import (
ConversationTrigger,
)
from openhands.storage.data_models.conversation_status import ConversationStatus
from openhands.storage.data_models.user_secrets import UserSecrets
from openhands.utils.async_utils import wait_all
from openhands.utils.conversation_summary import get_default_conversation_title
@@ -73,6 +76,7 @@ class InitSessionResponse(BaseModel):
async def _create_new_conversation(
user_id: str | None,
git_provider_tokens: PROVIDER_TOKEN_TYPE | None,
custom_secrets: CUSTOM_SECRETS_TYPE_WITH_JSON_SCHEMA | None,
selected_repository: str | None,
selected_branch: str | None,
initial_user_msg: str | None,
@@ -114,6 +118,7 @@ async def _create_new_conversation(
session_init_args['git_provider_tokens'] = git_provider_tokens
session_init_args['selected_repository'] = selected_repository
session_init_args['custom_secrets'] = custom_secrets
session_init_args['selected_branch'] = selected_branch
conversation_init_data = ConversationInitData(**session_init_args)
logger.info('Loading conversation store')
@@ -174,6 +179,7 @@ async def new_conversation(
data: InitSessionRequest,
user_id: str = Depends(get_user_id),
provider_tokens: PROVIDER_TOKEN_TYPE = Depends(get_provider_tokens),
user_secrets: UserSecrets = Depends(get_user_secrets),
auth_type: AuthType | None = Depends(get_auth_type),
) -> InitSessionResponse:
"""Initialize a new session or join an existing one.
@@ -209,6 +215,7 @@ async def new_conversation(
agent_loop_info = await _create_new_conversation(
user_id=user_id,
git_provider_tokens=provider_tokens,
custom_secrets=user_secrets.custom_secrets,
selected_repository=repository,
selected_branch=selected_branch,
initial_user_msg=initial_user_msg,

View File

@@ -16,7 +16,7 @@ from openhands.core.schema.agent import AgentState
from openhands.events.action import ChangeAgentStateAction, MessageAction
from openhands.events.event import Event, EventSource
from openhands.events.stream import EventStream
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderHandler
from openhands.integrations.provider import CUSTOM_SECRETS_TYPE, PROVIDER_TOKEN_TYPE, ProviderHandler
from openhands.mcp import add_mcp_tools_to_agent
from openhands.memory.memory import Memory
from openhands.microagent.microagent import BaseMicroagent
@@ -24,6 +24,7 @@ from openhands.runtime import get_runtime_cls
from openhands.runtime.base import Runtime
from openhands.runtime.impl.remote.remote_runtime import RemoteRuntime
from openhands.security import SecurityAnalyzer, options
from openhands.storage.data_models.user_secrets import UserSecrets
from openhands.storage.files import FileStore
from openhands.utils.async_utils import EXECUTOR, call_sync_from_async
from openhands.utils.shutdown_listener import should_continue
@@ -82,6 +83,7 @@ class AgentSession:
agent: Agent,
max_iterations: int,
git_provider_tokens: PROVIDER_TOKEN_TYPE | None = None,
custom_secrets: CUSTOM_SECRETS_TYPE | None = None,
max_budget_per_task: float | None = None,
agent_to_llm_config: dict[str, LLMConfig] | None = None,
agent_configs: dict[str, AgentConfig] | None = None,
@@ -113,6 +115,9 @@ class AgentSession:
self._started_at = started_at
finished = False # For monitoring
runtime_connected = False
custom_secrets_handler = UserSecrets(custom_secrets=custom_secrets if custom_secrets else {})
try:
self._create_security_analyzer(config.security.security_analyzer)
runtime_connected = await self._create_runtime(
@@ -120,6 +125,7 @@ class AgentSession:
config=config,
agent=agent,
git_provider_tokens=git_provider_tokens,
custom_secrets=custom_secrets,
selected_repository=selected_repository,
selected_branch=selected_branch,
)
@@ -157,12 +163,16 @@ class AgentSession:
self.memory = await self._create_memory(
selected_repository=selected_repository,
repo_directory=repo_directory,
custom_secrets_descriptions=custom_secrets_handler.get_custom_secrets_descriptions()
)
if git_provider_tokens:
provider_handler = ProviderHandler(provider_tokens=git_provider_tokens)
await provider_handler.set_event_stream_secrets(self.event_stream)
if custom_secrets:
custom_secrets_handler.set_event_stream_secrets(self.event_stream)
if not self._closed:
if initial_message:
self.event_stream.add_event(initial_message, EventSource.USER)
@@ -264,6 +274,7 @@ class AgentSession:
config: AppConfig,
agent: Agent,
git_provider_tokens: PROVIDER_TOKEN_TYPE | None = None,
custom_secrets: CUSTOM_SECRETS_TYPE | None = None,
selected_repository: str | None = None,
selected_branch: str | None = None,
) -> bool:
@@ -281,9 +292,11 @@ class AgentSession:
if self.runtime is not None:
raise RuntimeError('Runtime already created')
custom_secrets_handler = UserSecrets(custom_secrets=custom_secrets or {})
env_vars = custom_secrets_handler.get_env_vars()
self.logger.debug(f'Initializing runtime `{runtime_name}` now...')
runtime_cls = get_runtime_cls(runtime_name)
if runtime_cls == RemoteRuntime:
self.runtime = runtime_cls(
config=config,
@@ -294,6 +307,7 @@ class AgentSession:
headless_mode=False,
attach_to_existing=False,
git_provider_tokens=git_provider_tokens,
env_vars=env_vars,
user_id=self.user_id,
)
else:
@@ -301,8 +315,9 @@ class AgentSession:
provider_tokens=git_provider_tokens
or cast(PROVIDER_TOKEN_TYPE, MappingProxyType({}))
)
env_vars = await provider_handler.get_env_vars(expose_secrets=True)
# Merge git provider tokens with custom secrets before passing over to runtime
env_vars.update(await provider_handler.get_env_vars(expose_secrets=True))
self.runtime = runtime_cls(
config=config,
event_stream=self.event_stream,
@@ -400,7 +415,7 @@ class AgentSession:
return controller
async def _create_memory(
self, selected_repository: str | None, repo_directory: str | None
self, selected_repository: str | None, repo_directory: str | None, custom_secrets_descriptions: dict[str, str]
) -> Memory:
memory = Memory(
event_stream=self.event_stream,
@@ -410,7 +425,7 @@ class AgentSession:
if self.runtime:
# sets available hosts and other runtime info
memory.set_runtime_info(self.runtime)
memory.set_runtime_info(self.runtime, custom_secrets_descriptions)
# loads microagents from repo/.openhands/microagents
microagents: list[BaseMicroagent] = await call_sync_from_async(

View File

@@ -1,6 +1,6 @@
from pydantic import Field
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
from openhands.integrations.provider import CUSTOM_SECRETS_TYPE, PROVIDER_TOKEN_TYPE
from openhands.storage.data_models.settings import Settings
@@ -10,6 +10,7 @@ class ConversationInitData(Settings):
"""
git_provider_tokens: PROVIDER_TOKEN_TYPE | None = Field(default=None, frozen=True)
custom_secrets: CUSTOM_SECRETS_TYPE | None = Field(default=None, frozen=True)
selected_repository: str | None = Field(default=None)
replay_json: str | None = Field(default=None)
selected_branch: str | None = Field(default=None)

View File

@@ -153,10 +153,12 @@ class Session:
git_provider_tokens = None
selected_repository = None
selected_branch = None
custom_secrets = None
if isinstance(settings, ConversationInitData):
git_provider_tokens = settings.git_provider_tokens
selected_repository = settings.selected_repository
selected_branch = settings.selected_branch
custom_secrets = settings.custom_secrets
try:
await self.agent_session.start(
@@ -168,6 +170,7 @@ class Session:
agent_to_llm_config=self.config.get_agent_to_llm_config_map(),
agent_configs=self.config.get_agent_configs(),
git_provider_tokens=git_provider_tokens,
custom_secrets=custom_secrets,
selected_repository=selected_repository,
selected_branch=selected_branch,
initial_message=initial_message,

View File

@@ -10,6 +10,7 @@ from pydantic import (
)
from pydantic.json import pydantic_encoder
from openhands.events.stream import EventStream
from openhands.integrations.provider import (
CUSTOM_SECRETS_TYPE,
CUSTOM_SECRETS_TYPE_WITH_JSON_SCHEMA,
@@ -136,3 +137,31 @@ class UserSecrets(BaseModel):
new_data['custom_secrets'] = secrets
return new_data
def set_event_stream_secrets(self, event_stream: EventStream) -> None:
"""
This ensures that provider tokens and custom secrets masked from the event stream
Args:
event_stream: Agent session's event stream
"""
secrets = self.get_env_vars()
event_stream.set_secrets(secrets)
def get_env_vars(self) -> dict[str, str]:
secret_store = self.model_dump(context={'expose_secrets': True})
custom_secrets = secret_store.get('custom_secrets', {})
secrets = {}
for secret_name, value in custom_secrets.items():
secrets[secret_name] = value['secret']
return secrets
def get_custom_secrets_descriptions(self) -> dict[str, str]:
secrets = {}
for secret_name, secret in self.custom_secrets.items():
secrets[secret_name] = secret.description
return secrets

View File

@@ -14,6 +14,7 @@ class RuntimeInfo:
date: str
available_hosts: dict[str, int] = field(default_factory=dict)
additional_agent_instructions: str = ''
custom_secrets_descriptions: dict[str, str] = field(default_factory=dict)
@dataclass

View File

@@ -77,10 +77,15 @@ def test_client():
def create_new_test_conversation(
test_request: InitSessionRequest, auth_type: AuthType | None = None
):
# Create a mock UserSecrets object with the required custom_secrets attribute
mock_user_secrets = MagicMock()
mock_user_secrets.custom_secrets = MappingProxyType({})
return new_conversation(
data=test_request,
user_id='test_user',
provider_tokens=MappingProxyType({'github': 'token123'}),
user_secrets=mock_user_secrets,
auth_type=auth_type,
)

View File

@@ -17,6 +17,7 @@ from openhands.events.observation.agent import (
RecallObservation,
RecallType,
)
from openhands.events.serialization.observation import observation_from_dict
from openhands.events.stream import EventStream
from openhands.llm import LLM
from openhands.llm.metrics import Metrics
@@ -25,6 +26,7 @@ from openhands.runtime.impl.action_execution.action_execution_client import (
ActionExecutionClient,
)
from openhands.storage.memory import InMemoryFileStore
from openhands.utils.prompt import PromptManager, RepositoryInfo, RuntimeInfo
@pytest.fixture
@@ -326,6 +328,138 @@ async def test_memory_with_agent_microagents():
assert 'magic word' in observation.microagent_knowledge[0].content
@pytest.mark.asyncio
async def test_custom_secrets_descriptions():
"""Test that custom_secrets_descriptions are properly stored in memory and included in RecallObservation."""
# Create a mock event stream
event_stream = MagicMock(spec=EventStream)
# Initialize Memory
memory = Memory(
event_stream=event_stream,
sid='test-session',
)
# Create a mock runtime with custom secrets descriptions
mock_runtime = MagicMock()
mock_runtime.web_hosts = {'test-host.example.com': 8080}
mock_runtime.additional_agent_instructions = 'Test instructions'
# Define custom secrets descriptions
custom_secrets = {
'API_KEY': 'API key for external service',
'DATABASE_URL': 'Connection string for the database',
'SECRET_TOKEN': 'Authentication token for secure operations',
}
# Set runtime info with custom secrets
memory.set_runtime_info(mock_runtime, custom_secrets)
# Set repository info
memory.set_repository_info('test-owner/test-repo', '/workspace/test-repo')
# Create a workspace context recall action
recall_action = RecallAction(
query='Initial message', recall_type=RecallType.WORKSPACE_CONTEXT
)
recall_action._source = EventSource.USER # type: ignore[attr-defined]
# Mock the event_stream.add_event method
added_events = []
def mock_add_event(event, source):
added_events.append((event, source))
event_stream.add_event = mock_add_event
# Process the recall action
await memory._on_event(recall_action)
# Verify a RecallObservation was added to the event stream
assert len(added_events) == 1
observation, source = added_events[0]
# Verify the observation is a RecallObservation
assert isinstance(observation, RecallObservation)
assert source == EventSource.ENVIRONMENT
assert observation.recall_type == RecallType.WORKSPACE_CONTEXT
# Verify custom_secrets_descriptions are included in the observation
assert observation.custom_secrets_descriptions == custom_secrets
# Verify repository info is included
assert observation.repo_name == 'test-owner/test-repo'
assert observation.repo_directory == '/workspace/test-repo'
# Verify runtime info is included
assert observation.runtime_hosts == {'test-host.example.com': 8080}
assert observation.additional_agent_instructions == 'Test instructions'
def test_custom_secrets_descriptions_serialization(prompt_dir):
"""Test that custom_secrets_descriptions are properly serialized in the message for the LLM."""
# Create a PromptManager with the test prompt directory
prompt_manager = PromptManager(prompt_dir)
# Create a RuntimeInfo with custom_secrets_descriptions
custom_secrets = {
'API_KEY': 'API key for external service',
'DATABASE_URL': 'Connection string for the database',
'SECRET_TOKEN': 'Authentication token for secure operations',
}
runtime_info = RuntimeInfo(
date='2025-05-15',
available_hosts={'test-host.example.com': 8080},
additional_agent_instructions='Test instructions',
custom_secrets_descriptions=custom_secrets,
)
# Create a RepositoryInfo
repository_info = RepositoryInfo(
repo_name='test-owner/test-repo', repo_directory='/workspace/test-repo'
)
# Build the workspace context message
workspace_context = prompt_manager.build_workspace_context(
repository_info=repository_info,
runtime_info=runtime_info,
repo_instructions='Test repository instructions',
)
# Verify that the workspace context includes the custom_secrets_descriptions
assert '<CUSTOM_SECRETS>' in workspace_context
for secret_name, secret_description in custom_secrets.items():
assert f'$**{secret_name}**' in workspace_context
assert secret_description in workspace_context
def test_serialization_deserialization_with_custom_secrets():
"""Test that RecallObservation can be serialized and deserialized with custom_secrets_descriptions."""
# This simulates an older version of the RecallObservation
legacy_observation = {
'message': 'Added workspace context',
'observation': 'recall',
'content': 'Test content',
'extras': {
'recall_type': 'workspace_context',
'repo_name': 'test-owner/test-repo',
'repo_directory': '/workspace/test-repo',
'repo_instructions': 'Test repository instructions',
'runtime_hosts': {'test-host.example.com': 8080},
'additional_agent_instructions': 'Test instructions',
'date': '2025-05-15',
'microagent_knowledge': [], # Intentionally omitting custom_secrets_descriptions
},
}
legacy_observation = observation_from_dict(legacy_observation)
# Verify that the observation was created successfully
assert legacy_observation.recall_type == RecallType.WORKSPACE_CONTEXT
assert legacy_observation.repo_name == 'test-owner/test-repo'
assert legacy_observation.repo_directory == '/workspace/test-repo'
def test_memory_multiple_repo_microagents(prompt_dir, file_store):
"""Test that Memory loads and concatenates multiple repo microagents correctly."""
# Create real event stream

View File

@@ -245,6 +245,7 @@ def test_microagent_observation_serialization():
'runtime_hosts': {'host1': 8080, 'host2': 8081},
'repo_instructions': 'complex_repo_instructions',
'additional_agent_instructions': 'You know it all about this runtime',
'custom_secrets_descriptions': {'SECRET': 'CUSTOM'},
'date': '04/12/1023',
'microagent_knowledge': [],
},
@@ -264,6 +265,7 @@ def test_microagent_observation_microagent_knowledge_serialization():
'repo_instructions': '',
'runtime_hosts': {},
'additional_agent_instructions': '',
'custom_secrets_descriptions': {},
'date': '',
'microagent_knowledge': [
{