mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
feat(agent): Add configurable system_prompt_filename to AgentConfig (#9265)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
parent
99fd3f7bb2
commit
a1479adfd3
@ -95,6 +95,7 @@ class CodeActAgent(Agent):
|
||||
if self._prompt_manager is None:
|
||||
self._prompt_manager = PromptManager(
|
||||
prompt_dir=os.path.join(os.path.dirname(__file__), 'prompts'),
|
||||
system_prompt_filename=self.config.system_prompt_filename,
|
||||
)
|
||||
|
||||
return self._prompt_manager
|
||||
|
||||
@ -12,6 +12,8 @@ class AgentConfig(BaseModel):
|
||||
"""The name of the llm config to use. If specified, this will override global llm config."""
|
||||
classpath: str | None = Field(default=None)
|
||||
"""The classpath of the agent to use. To be used for custom agents that are not defined in the openhands.agenthub package."""
|
||||
system_prompt_filename: str = Field(default='system_prompt.j2')
|
||||
"""Filename of the system prompt template file within the agent's prompt directory. Defaults to 'system_prompt.j2'."""
|
||||
enable_browsing: bool = Field(default=True)
|
||||
"""Whether to enable browsing tool.
|
||||
Note: If using CLIRuntime, browsing is not implemented and should be disabled."""
|
||||
|
||||
@ -52,13 +52,33 @@ class PromptManager:
|
||||
def __init__(
|
||||
self,
|
||||
prompt_dir: str,
|
||||
system_prompt_filename: str = 'system_prompt.j2',
|
||||
):
|
||||
self.prompt_dir: str = prompt_dir
|
||||
self.system_template: Template = self._load_template('system_prompt')
|
||||
self.system_template: Template = self._load_system_template(
|
||||
system_prompt_filename
|
||||
)
|
||||
self.user_template: Template = self._load_template('user_prompt')
|
||||
self.additional_info_template: Template = self._load_template('additional_info')
|
||||
self.microagent_info_template: Template = self._load_template('microagent_info')
|
||||
|
||||
def _load_system_template(self, system_prompt_filename: str) -> Template:
|
||||
"""Load the system prompt template using the specified filename."""
|
||||
# Remove .j2 extension if present to use with _load_template
|
||||
template_name = system_prompt_filename
|
||||
if template_name.endswith('.j2'):
|
||||
template_name = template_name[:-3]
|
||||
|
||||
try:
|
||||
return self._load_template(template_name)
|
||||
except FileNotFoundError:
|
||||
# Provide a more specific error message for system prompt files
|
||||
template_path = os.path.join(self.prompt_dir, f'{template_name}.j2')
|
||||
raise FileNotFoundError(
|
||||
f'System prompt file "{system_prompt_filename}" not found at {template_path}. '
|
||||
f'Please ensure the file exists in the prompt directory: {self.prompt_dir}'
|
||||
)
|
||||
|
||||
def _load_template(self, template_name: str) -> Template:
|
||||
if self.prompt_dir is None:
|
||||
raise ValueError('Prompt directory is not set')
|
||||
|
||||
@ -1211,3 +1211,39 @@ def test_agent_config_from_toml_section_with_invalid_base():
|
||||
assert 'CustomAgent' in result
|
||||
assert result['CustomAgent'].enable_browsing is False
|
||||
assert result['CustomAgent'].enable_jupyter is True
|
||||
|
||||
|
||||
def test_agent_config_system_prompt_filename_default():
|
||||
"""Test that AgentConfig defaults to 'system_prompt.j2' for system_prompt_filename."""
|
||||
config = AgentConfig()
|
||||
assert config.system_prompt_filename == 'system_prompt.j2'
|
||||
|
||||
|
||||
def test_agent_config_system_prompt_filename_toml_integration(
|
||||
default_config, temp_toml_file
|
||||
):
|
||||
"""Test that system_prompt_filename is correctly loaded from TOML configuration."""
|
||||
with open(temp_toml_file, 'w', encoding='utf-8') as toml_file:
|
||||
toml_file.write(
|
||||
"""
|
||||
[agent]
|
||||
enable_browsing = true
|
||||
system_prompt_filename = "custom_prompt.j2"
|
||||
|
||||
[agent.CodeReviewAgent]
|
||||
system_prompt_filename = "code_review_prompt.j2"
|
||||
enable_browsing = false
|
||||
"""
|
||||
)
|
||||
|
||||
load_from_toml(default_config, temp_toml_file)
|
||||
|
||||
# Check default agent config
|
||||
default_agent_config = default_config.get_agent_config()
|
||||
assert default_agent_config.system_prompt_filename == 'custom_prompt.j2'
|
||||
assert default_agent_config.enable_browsing is True
|
||||
|
||||
# Check custom agent config
|
||||
custom_agent_config = default_config.get_agent_config('CodeReviewAgent')
|
||||
assert custom_agent_config.system_prompt_filename == 'code_review_prompt.j2'
|
||||
assert custom_agent_config.enable_browsing is False
|
||||
|
||||
@ -269,3 +269,39 @@ def test_prompt_manager_initialization_error():
|
||||
"""Test that PromptManager raises an error if the prompt directory is not set."""
|
||||
with pytest.raises(ValueError, match='Prompt directory is not set'):
|
||||
PromptManager(None)
|
||||
|
||||
|
||||
def test_prompt_manager_custom_system_prompt_filename(prompt_dir):
|
||||
"""Test that PromptManager can use a custom system prompt filename."""
|
||||
# Create a custom system prompt file
|
||||
with open(os.path.join(prompt_dir, 'custom_system.j2'), 'w') as f:
|
||||
f.write('Custom system prompt: {{ custom_var }}')
|
||||
|
||||
# Create default system prompt
|
||||
with open(os.path.join(prompt_dir, 'system_prompt.j2'), 'w') as f:
|
||||
f.write('Default system prompt')
|
||||
|
||||
# Test with custom system prompt filename
|
||||
manager = PromptManager(
|
||||
prompt_dir=prompt_dir, system_prompt_filename='custom_system.j2'
|
||||
)
|
||||
system_msg = manager.get_system_message()
|
||||
assert 'Custom system prompt:' in system_msg
|
||||
|
||||
# Test without custom system prompt filename (should use default)
|
||||
manager_default = PromptManager(prompt_dir=prompt_dir)
|
||||
default_msg = manager_default.get_system_message()
|
||||
assert 'Default system prompt' in default_msg
|
||||
|
||||
# Clean up
|
||||
os.remove(os.path.join(prompt_dir, 'custom_system.j2'))
|
||||
os.remove(os.path.join(prompt_dir, 'system_prompt.j2'))
|
||||
|
||||
|
||||
def test_prompt_manager_custom_system_prompt_filename_not_found(prompt_dir):
|
||||
"""Test that PromptManager raises an error if custom system prompt file is not found."""
|
||||
with pytest.raises(
|
||||
FileNotFoundError,
|
||||
match=r'System prompt file "non_existent\.j2" not found at .*/non_existent\.j2\. Please ensure the file exists in the prompt directory:',
|
||||
):
|
||||
PromptManager(prompt_dir=prompt_dir, system_prompt_filename='non_existent.j2')
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user