mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Add branch information to repository context to prevent unwanted branch switching (#9833)
This commit is contained in:
parent
1cdc38eafb
commit
287c34b3f3
@ -38,7 +38,7 @@ export function ActionSuggestions({
|
||||
pr,
|
||||
prShort,
|
||||
pushToBranch: `Please push the changes to a remote branch on ${getProviderName()}, but do NOT create a ${pr}. Check your current branch name first - if it's main, master, deploy, or another common default branch name, create a new branch with a descriptive name related to your changes. Otherwise, use the exact SAME branch name as the one you are currently on.`,
|
||||
createPR: `Please push the changes to ${getProviderName()} and open a ${pr}. Please create a meaningful branch name that describes the changes. If a ${pr} template exists in the repository, please follow it when creating the ${prShort} description.`,
|
||||
createPR: `Please push the changes to ${getProviderName()} and open a ${pr}. If you're on a default branch (e.g., main, master, deploy), create a new branch with a descriptive name otherwise use the current branch. If a ${pr} template exists in the repository, please follow it when creating the ${prShort} description.`,
|
||||
pushToPR: `Please push the latest changes to the existing ${pr}.`,
|
||||
};
|
||||
|
||||
|
||||
@ -1,6 +1,12 @@
|
||||
{% if repository_info %}
|
||||
<REPOSITORY_INFO>
|
||||
At the user's request, repository {{ repository_info.repo_name }} has been cloned to {{ repository_info.repo_directory }} in the current working directory.
|
||||
{% if repository_info.branch_name %}The repository has been checked out to branch "{{ repository_info.branch_name }}".
|
||||
|
||||
IMPORTANT: You should work within the current branch "{{ repository_info.branch_name }}" unless
|
||||
1. the user explicitly instructs otherwise
|
||||
2. if the current branch is "main", "master", or another default branch where direct pushes may be unsafe
|
||||
{% endif %}
|
||||
</REPOSITORY_INFO>
|
||||
{% endif %}
|
||||
{% if repository_instructions -%}
|
||||
|
||||
@ -1,6 +1,12 @@
|
||||
{% if repository_info %}
|
||||
<REPOSITORY_INFO>
|
||||
At the user's request, repository {{ repository_info.repo_name }} has been cloned to the current working directory {{ repository_info.repo_directory }}.
|
||||
{% if repository_info.branch_name %}The repository has been checked out to branch "{{ repository_info.branch_name }}".
|
||||
|
||||
IMPORTANT: You should work within the current branch "{{ repository_info.branch_name }}" unless
|
||||
1. the user explicitly instructs otherwise
|
||||
2. if the current branch is "main", "master", or another default branch where direct pushes may be unsafe
|
||||
{% endif %}
|
||||
</REPOSITORY_INFO>
|
||||
{% endif %}
|
||||
{% if repository_instructions -%}
|
||||
|
||||
@ -70,6 +70,7 @@ class RecallObservation(Observation):
|
||||
# workspace context
|
||||
repo_name: str = ''
|
||||
repo_directory: str = ''
|
||||
repo_branch: str = ''
|
||||
repo_instructions: str = ''
|
||||
runtime_hosts: dict[str, int] = field(default_factory=dict)
|
||||
additional_agent_instructions: str = ''
|
||||
|
||||
@ -512,6 +512,7 @@ class ConversationMemory:
|
||||
repo_info = RepositoryInfo(
|
||||
repo_name=obs.repo_name or '',
|
||||
repo_directory=obs.repo_directory or '',
|
||||
branch_name=obs.repo_branch or None,
|
||||
)
|
||||
else:
|
||||
repo_info = None
|
||||
|
||||
@ -181,6 +181,9 @@ class Memory:
|
||||
if self.repository_info
|
||||
and self.repository_info.repo_directory is not None
|
||||
else '',
|
||||
repo_branch=self.repository_info.branch_name
|
||||
if self.repository_info and self.repository_info.branch_name is not None
|
||||
else '',
|
||||
repo_instructions=repo_instructions if repo_instructions else '',
|
||||
runtime_hosts=self.runtime_info.available_hosts
|
||||
if self.runtime_info and self.runtime_info.available_hosts is not None
|
||||
@ -322,10 +325,14 @@ class Memory:
|
||||
|
||||
return mcp_configs
|
||||
|
||||
def set_repository_info(self, repo_name: str, repo_directory: str) -> None:
|
||||
def set_repository_info(
|
||||
self, repo_name: str, repo_directory: str, branch_name: str | None = None
|
||||
) -> None:
|
||||
"""Store repository info so we can reference it in an observation."""
|
||||
if repo_name or repo_directory:
|
||||
self.repository_info = RepositoryInfo(repo_name, repo_directory)
|
||||
self.repository_info = RepositoryInfo(
|
||||
repo_name, repo_directory, branch_name
|
||||
)
|
||||
else:
|
||||
self.repository_info = None
|
||||
|
||||
|
||||
@ -152,6 +152,7 @@ class AgentSession:
|
||||
self.memory = await self._create_memory(
|
||||
selected_repository=selected_repository,
|
||||
repo_directory=repo_directory,
|
||||
selected_branch=selected_branch,
|
||||
conversation_instructions=conversation_instructions,
|
||||
custom_secrets_descriptions=custom_secrets_handler.get_custom_secrets_descriptions(),
|
||||
working_dir=config.workspace_mount_path_in_sandbox,
|
||||
@ -463,6 +464,7 @@ class AgentSession:
|
||||
self,
|
||||
selected_repository: str | None,
|
||||
repo_directory: str | None,
|
||||
selected_branch: str | None,
|
||||
conversation_instructions: str | None,
|
||||
custom_secrets_descriptions: dict[str, str],
|
||||
working_dir: str,
|
||||
@ -488,7 +490,9 @@ class AgentSession:
|
||||
memory.load_user_workspace_microagents(microagents)
|
||||
|
||||
if selected_repository and repo_directory:
|
||||
memory.set_repository_info(selected_repository, repo_directory)
|
||||
memory.set_repository_info(
|
||||
selected_repository, repo_directory, selected_branch
|
||||
)
|
||||
return memory
|
||||
|
||||
def _maybe_restore_state(self) -> State | None:
|
||||
|
||||
@ -24,6 +24,7 @@ class RepositoryInfo:
|
||||
|
||||
repo_name: str | None = None
|
||||
repo_directory: str | None = None
|
||||
branch_name: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@ -453,7 +453,9 @@ def test_custom_secrets_descriptions_serialization(prompt_dir):
|
||||
|
||||
# Create a RepositoryInfo
|
||||
repository_info = RepositoryInfo(
|
||||
repo_name='test-owner/test-repo', repo_directory='/workspace/test-repo'
|
||||
repo_name='test-owner/test-repo',
|
||||
repo_directory='/workspace/test-repo',
|
||||
branch_name='main',
|
||||
)
|
||||
|
||||
conversation_instructions = ConversationInstructions(
|
||||
|
||||
@ -296,6 +296,7 @@ def test_microagent_observation_serialization():
|
||||
'recall_type': 'workspace_context',
|
||||
'repo_name': 'some_repo_name',
|
||||
'repo_directory': 'some_repo_directory',
|
||||
'repo_branch': '',
|
||||
'working_dir': '',
|
||||
'runtime_hosts': {'host1': 8080, 'host2': 8081},
|
||||
'repo_instructions': 'complex_repo_instructions',
|
||||
@ -318,6 +319,7 @@ def test_microagent_observation_microagent_knowledge_serialization():
|
||||
'recall_type': 'knowledge',
|
||||
'repo_name': '',
|
||||
'repo_directory': '',
|
||||
'repo_branch': '',
|
||||
'repo_instructions': '',
|
||||
'runtime_hosts': {},
|
||||
'working_dir': '',
|
||||
@ -348,6 +350,7 @@ def test_microagent_observation_knowledge_microagent_serialization():
|
||||
original = RecallObservation(
|
||||
content='Knowledge microagent information',
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
repo_branch='',
|
||||
microagent_knowledge=[
|
||||
MicroagentKnowledge(
|
||||
name='python_best_practices',
|
||||
@ -395,6 +398,7 @@ def test_microagent_observation_environment_serialization():
|
||||
recall_type=RecallType.WORKSPACE_CONTEXT,
|
||||
repo_name='OpenHands',
|
||||
repo_directory='/workspace/openhands',
|
||||
repo_branch='main',
|
||||
repo_instructions="Follow the project's coding style guide.",
|
||||
runtime_hosts={'127.0.0.1': 8080, 'localhost': 5000},
|
||||
additional_agent_instructions='You know it all about this runtime',
|
||||
@ -444,6 +448,7 @@ def test_microagent_observation_combined_serialization():
|
||||
# Environment info
|
||||
repo_name='OpenHands',
|
||||
repo_directory='/workspace/openhands',
|
||||
repo_branch='main',
|
||||
repo_instructions="Follow the project's coding style guide.",
|
||||
runtime_hosts={'127.0.0.1': 8080},
|
||||
additional_agent_instructions='You know it all about this runtime',
|
||||
|
||||
@ -50,7 +50,9 @@ At the user's request, repository {{ repository_info.repo_name }} has been clone
|
||||
|
||||
# Test with GitHub repo
|
||||
manager = PromptManager(prompt_dir=prompt_dir)
|
||||
repo_info = RepositoryInfo(repo_name='owner/repo', repo_directory='/workspace/repo')
|
||||
repo_info = RepositoryInfo(
|
||||
repo_name='owner/repo', repo_directory='/workspace/repo', branch_name='main'
|
||||
)
|
||||
|
||||
# verify its parts are rendered
|
||||
system_msg = manager.get_system_message()
|
||||
@ -231,7 +233,9 @@ Today's date is {{ runtime_info.date }}
|
||||
manager = PromptManager(prompt_dir=prompt_dir)
|
||||
|
||||
# Create repository and runtime information
|
||||
repo_info = RepositoryInfo(repo_name='owner/repo', repo_directory='/workspace/repo')
|
||||
repo_info = RepositoryInfo(
|
||||
repo_name='owner/repo', repo_directory='/workspace/repo', branch_name='main'
|
||||
)
|
||||
runtime_info = RuntimeInfo(
|
||||
date='02/12/1232',
|
||||
available_hosts={'example.com': 8080},
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user