Add branch information to repository context to prevent unwanted branch switching (#9833)

This commit is contained in:
Rohit Malhotra 2025-08-01 00:25:36 -04:00 committed by GitHub
parent 1cdc38eafb
commit 287c34b3f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 44 additions and 7 deletions

View File

@ -38,7 +38,7 @@ export function ActionSuggestions({
pr,
prShort,
pushToBranch: `Please push the changes to a remote branch on ${getProviderName()}, but do NOT create a ${pr}. Check your current branch name first - if it's main, master, deploy, or another common default branch name, create a new branch with a descriptive name related to your changes. Otherwise, use the exact SAME branch name as the one you are currently on.`,
createPR: `Please push the changes to ${getProviderName()} and open a ${pr}. Please create a meaningful branch name that describes the changes. If a ${pr} template exists in the repository, please follow it when creating the ${prShort} description.`,
createPR: `Please push the changes to ${getProviderName()} and open a ${pr}. If you're on a default branch (e.g., main, master, deploy), create a new branch with a descriptive name otherwise use the current branch. If a ${pr} template exists in the repository, please follow it when creating the ${prShort} description.`,
pushToPR: `Please push the latest changes to the existing ${pr}.`,
};

View File

@ -1,6 +1,12 @@
{% if repository_info %}
<REPOSITORY_INFO>
At the user's request, repository {{ repository_info.repo_name }} has been cloned to {{ repository_info.repo_directory }} in the current working directory.
{% if repository_info.branch_name %}The repository has been checked out to branch "{{ repository_info.branch_name }}".
IMPORTANT: You should work within the current branch "{{ repository_info.branch_name }}" unless
1. the user explicitly instructs otherwise
2. if the current branch is "main", "master", or another default branch where direct pushes may be unsafe
{% endif %}
</REPOSITORY_INFO>
{% endif %}
{% if repository_instructions -%}

View File

@ -1,6 +1,12 @@
{% if repository_info %}
<REPOSITORY_INFO>
At the user's request, repository {{ repository_info.repo_name }} has been cloned to the current working directory {{ repository_info.repo_directory }}.
{% if repository_info.branch_name %}The repository has been checked out to branch "{{ repository_info.branch_name }}".
IMPORTANT: You should work within the current branch "{{ repository_info.branch_name }}" unless
1. the user explicitly instructs otherwise
2. if the current branch is "main", "master", or another default branch where direct pushes may be unsafe
{% endif %}
</REPOSITORY_INFO>
{% endif %}
{% if repository_instructions -%}

View File

@ -70,6 +70,7 @@ class RecallObservation(Observation):
# workspace context
repo_name: str = ''
repo_directory: str = ''
repo_branch: str = ''
repo_instructions: str = ''
runtime_hosts: dict[str, int] = field(default_factory=dict)
additional_agent_instructions: str = ''

View File

@ -512,6 +512,7 @@ class ConversationMemory:
repo_info = RepositoryInfo(
repo_name=obs.repo_name or '',
repo_directory=obs.repo_directory or '',
branch_name=obs.repo_branch or None,
)
else:
repo_info = None

View File

@ -181,6 +181,9 @@ class Memory:
if self.repository_info
and self.repository_info.repo_directory is not None
else '',
repo_branch=self.repository_info.branch_name
if self.repository_info and self.repository_info.branch_name is not None
else '',
repo_instructions=repo_instructions if repo_instructions else '',
runtime_hosts=self.runtime_info.available_hosts
if self.runtime_info and self.runtime_info.available_hosts is not None
@ -322,10 +325,14 @@ class Memory:
return mcp_configs
def set_repository_info(self, repo_name: str, repo_directory: str) -> None:
def set_repository_info(
self, repo_name: str, repo_directory: str, branch_name: str | None = None
) -> None:
"""Store repository info so we can reference it in an observation."""
if repo_name or repo_directory:
self.repository_info = RepositoryInfo(repo_name, repo_directory)
self.repository_info = RepositoryInfo(
repo_name, repo_directory, branch_name
)
else:
self.repository_info = None

View File

@ -152,6 +152,7 @@ class AgentSession:
self.memory = await self._create_memory(
selected_repository=selected_repository,
repo_directory=repo_directory,
selected_branch=selected_branch,
conversation_instructions=conversation_instructions,
custom_secrets_descriptions=custom_secrets_handler.get_custom_secrets_descriptions(),
working_dir=config.workspace_mount_path_in_sandbox,
@ -463,6 +464,7 @@ class AgentSession:
self,
selected_repository: str | None,
repo_directory: str | None,
selected_branch: str | None,
conversation_instructions: str | None,
custom_secrets_descriptions: dict[str, str],
working_dir: str,
@ -488,7 +490,9 @@ class AgentSession:
memory.load_user_workspace_microagents(microagents)
if selected_repository and repo_directory:
memory.set_repository_info(selected_repository, repo_directory)
memory.set_repository_info(
selected_repository, repo_directory, selected_branch
)
return memory
def _maybe_restore_state(self) -> State | None:

View File

@ -24,6 +24,7 @@ class RepositoryInfo:
repo_name: str | None = None
repo_directory: str | None = None
branch_name: str | None = None
@dataclass

View File

@ -453,7 +453,9 @@ def test_custom_secrets_descriptions_serialization(prompt_dir):
# Create a RepositoryInfo
repository_info = RepositoryInfo(
repo_name='test-owner/test-repo', repo_directory='/workspace/test-repo'
repo_name='test-owner/test-repo',
repo_directory='/workspace/test-repo',
branch_name='main',
)
conversation_instructions = ConversationInstructions(

View File

@ -296,6 +296,7 @@ def test_microagent_observation_serialization():
'recall_type': 'workspace_context',
'repo_name': 'some_repo_name',
'repo_directory': 'some_repo_directory',
'repo_branch': '',
'working_dir': '',
'runtime_hosts': {'host1': 8080, 'host2': 8081},
'repo_instructions': 'complex_repo_instructions',
@ -318,6 +319,7 @@ def test_microagent_observation_microagent_knowledge_serialization():
'recall_type': 'knowledge',
'repo_name': '',
'repo_directory': '',
'repo_branch': '',
'repo_instructions': '',
'runtime_hosts': {},
'working_dir': '',
@ -348,6 +350,7 @@ def test_microagent_observation_knowledge_microagent_serialization():
original = RecallObservation(
content='Knowledge microagent information',
recall_type=RecallType.KNOWLEDGE,
repo_branch='',
microagent_knowledge=[
MicroagentKnowledge(
name='python_best_practices',
@ -395,6 +398,7 @@ def test_microagent_observation_environment_serialization():
recall_type=RecallType.WORKSPACE_CONTEXT,
repo_name='OpenHands',
repo_directory='/workspace/openhands',
repo_branch='main',
repo_instructions="Follow the project's coding style guide.",
runtime_hosts={'127.0.0.1': 8080, 'localhost': 5000},
additional_agent_instructions='You know it all about this runtime',
@ -444,6 +448,7 @@ def test_microagent_observation_combined_serialization():
# Environment info
repo_name='OpenHands',
repo_directory='/workspace/openhands',
repo_branch='main',
repo_instructions="Follow the project's coding style guide.",
runtime_hosts={'127.0.0.1': 8080},
additional_agent_instructions='You know it all about this runtime',

View File

@ -50,7 +50,9 @@ At the user's request, repository {{ repository_info.repo_name }} has been clone
# Test with GitHub repo
manager = PromptManager(prompt_dir=prompt_dir)
repo_info = RepositoryInfo(repo_name='owner/repo', repo_directory='/workspace/repo')
repo_info = RepositoryInfo(
repo_name='owner/repo', repo_directory='/workspace/repo', branch_name='main'
)
# verify its parts are rendered
system_msg = manager.get_system_message()
@ -231,7 +233,9 @@ Today's date is {{ runtime_info.date }}
manager = PromptManager(prompt_dir=prompt_dir)
# Create repository and runtime information
repo_info = RepositoryInfo(repo_name='owner/repo', repo_directory='/workspace/repo')
repo_info = RepositoryInfo(
repo_name='owner/repo', repo_directory='/workspace/repo', branch_name='main'
)
runtime_info = RuntimeInfo(
date='02/12/1232',
available_hosts={'example.com': 8080},