mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
fix(backend): show name of created branch in conversation list. (#10208)
This commit is contained in:
parent
4849369ede
commit
18b5139237
@ -65,6 +65,24 @@ class BitBucketService(BaseGitService, GitService, InstallationsService):
|
||||
def provider(self) -> str:
|
||||
return ProviderType.BITBUCKET.value
|
||||
|
||||
def _extract_owner_and_repo(self, repository: str) -> tuple[str, str]:
|
||||
"""Extract owner and repo from repository string.
|
||||
|
||||
Args:
|
||||
repository: Repository name in format 'workspace/repo_slug'
|
||||
|
||||
Returns:
|
||||
Tuple of (owner, repo)
|
||||
|
||||
Raises:
|
||||
ValueError: If repository format is invalid
|
||||
"""
|
||||
parts = repository.split('/')
|
||||
if len(parts) < 2:
|
||||
raise ValueError(f'Invalid repository name: {repository}')
|
||||
|
||||
return parts[-2], parts[-1]
|
||||
|
||||
async def get_latest_token(self) -> SecretStr | None:
|
||||
"""Get latest working token of the user."""
|
||||
return self.token
|
||||
@ -495,13 +513,7 @@ class BitBucketService(BaseGitService, GitService, InstallationsService):
|
||||
self, repository: str
|
||||
) -> Repository:
|
||||
"""Gets all repository details from repository name."""
|
||||
# Extract owner and repo from the repository string (e.g., "owner/repo")
|
||||
parts = repository.split('/')
|
||||
if len(parts) < 2:
|
||||
raise ValueError(f'Invalid repository name: {repository}')
|
||||
|
||||
owner = parts[-2]
|
||||
repo = parts[-1]
|
||||
owner, repo = self._extract_owner_and_repo(repository)
|
||||
|
||||
url = f'{self.BASE_URL}/repositories/{owner}/{repo}'
|
||||
data, _ = await self._make_request(url)
|
||||
@ -510,13 +522,7 @@ class BitBucketService(BaseGitService, GitService, InstallationsService):
|
||||
|
||||
async def get_branches(self, repository: str) -> list[Branch]:
|
||||
"""Get branches for a repository."""
|
||||
# Extract owner and repo from the repository string (e.g., "owner/repo")
|
||||
parts = repository.split('/')
|
||||
if len(parts) < 2:
|
||||
raise ValueError(f'Invalid repository name: {repository}')
|
||||
|
||||
owner = parts[-2]
|
||||
repo = parts[-1]
|
||||
owner, repo = self._extract_owner_and_repo(repository)
|
||||
|
||||
url = f'{self.BASE_URL}/repositories/{owner}/{repo}/refs/branches'
|
||||
|
||||
@ -567,13 +573,7 @@ class BitBucketService(BaseGitService, GitService, InstallationsService):
|
||||
Returns:
|
||||
The URL of the created pull request
|
||||
"""
|
||||
# Extract owner and repo from the repository string (e.g., "owner/repo")
|
||||
parts = repo_name.split('/')
|
||||
if len(parts) < 2:
|
||||
raise ValueError(f'Invalid repository name: {repo_name}')
|
||||
|
||||
owner = parts[-2]
|
||||
repo = parts[-1]
|
||||
owner, repo = self._extract_owner_and_repo(repo_name)
|
||||
|
||||
url = f'{self.BASE_URL}/repositories/{owner}/{repo}/pullrequests'
|
||||
|
||||
|
||||
@ -1142,6 +1142,27 @@ fi
|
||||
self.git_handler.set_cwd(cwd)
|
||||
return self.git_handler.get_git_diff(file_path)
|
||||
|
||||
def get_workspace_branch(self, primary_repo_path: str | None = None) -> str | None:
|
||||
"""
|
||||
Get the current branch of the workspace.
|
||||
|
||||
Args:
|
||||
primary_repo_path: Path to the primary repository within the workspace.
|
||||
If None, uses the workspace root.
|
||||
|
||||
Returns:
|
||||
str | None: The current branch name, or None if not a git repository or error occurs.
|
||||
"""
|
||||
if primary_repo_path:
|
||||
# Use the primary repository path
|
||||
git_cwd = str(self.workspace_root / primary_repo_path)
|
||||
else:
|
||||
# Use the workspace root
|
||||
git_cwd = str(self.workspace_root)
|
||||
|
||||
self.git_handler.set_cwd(git_cwd)
|
||||
return self.git_handler.get_current_branch()
|
||||
|
||||
@property
|
||||
def additional_agent_instructions(self) -> str:
|
||||
return ''
|
||||
|
||||
@ -10,6 +10,7 @@ GIT_CHANGES_CMD = 'python3 /openhands/code/openhands/runtime/utils/git_changes.p
|
||||
GIT_DIFF_CMD = (
|
||||
'python3 /openhands/code/openhands/runtime/utils/git_diff.py "{file_path}"'
|
||||
)
|
||||
GIT_BRANCH_CMD = 'git branch --show-current'
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -38,6 +39,7 @@ class GitHandler:
|
||||
self.cwd: str | None = None
|
||||
self.git_changes_cmd = GIT_CHANGES_CMD
|
||||
self.git_diff_cmd = GIT_DIFF_CMD
|
||||
self.git_branch_cmd = GIT_BRANCH_CMD
|
||||
|
||||
def set_cwd(self, cwd: str) -> None:
|
||||
"""Sets the current working directory for Git operations.
|
||||
@ -55,6 +57,28 @@ class GitHandler:
|
||||
result = self.execute(f'chmod +x "{script_file}"', self.cwd)
|
||||
return script_file
|
||||
|
||||
def get_current_branch(self) -> str | None:
|
||||
"""
|
||||
Retrieves the current branch name of the git repository.
|
||||
|
||||
Returns:
|
||||
str | None: The current branch name, or None if not a git repository or error occurs.
|
||||
"""
|
||||
# If cwd is not set, return None
|
||||
if not self.cwd:
|
||||
return None
|
||||
|
||||
result = self.execute(self.git_branch_cmd, self.cwd)
|
||||
if result.exit_code == 0:
|
||||
branch = result.content.strip()
|
||||
# git branch --show-current returns empty string if not on any branch (detached HEAD)
|
||||
if branch:
|
||||
return branch
|
||||
return None
|
||||
|
||||
# If not a git repository or other error, return None
|
||||
return None
|
||||
|
||||
def get_git_changes(self) -> list[dict[str, str]] | None:
|
||||
"""Retrieves the list of changed files in Git repositories.
|
||||
Examines each direct subdirectory of the workspace directory looking for git repositories
|
||||
|
||||
@ -2,7 +2,7 @@ import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Callable, Iterable
|
||||
from typing import Any, Callable, Iterable
|
||||
|
||||
import socketio
|
||||
|
||||
@ -11,7 +11,9 @@ from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.core.exceptions import AgentRuntimeUnavailableError
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema.agent import AgentState
|
||||
from openhands.core.schema.observation import ObservationType
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.events.observation.commands import CmdOutputObservation
|
||||
from openhands.events.stream import EventStreamSubscriber, session_exists
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.runtime import get_runtime_cls
|
||||
@ -516,6 +518,18 @@ class StandaloneConversationManager(ConversationManager):
|
||||
conversation.total_tokens = (
|
||||
token_usage.prompt_tokens + token_usage.completion_tokens
|
||||
)
|
||||
|
||||
# Check for branch changes if this is a git-related event
|
||||
if event and self._is_git_related_event(event):
|
||||
logger.info(
|
||||
f'Git-related event detected, updating conversation branch for {conversation_id}',
|
||||
extra={
|
||||
'session_id': conversation_id,
|
||||
'command': getattr(event, 'command', 'unknown'),
|
||||
},
|
||||
)
|
||||
await self._update_conversation_branch(conversation)
|
||||
|
||||
default_title = get_default_conversation_title(conversation_id)
|
||||
if (
|
||||
conversation.title == default_title
|
||||
@ -548,6 +562,154 @@ class StandaloneConversationManager(ConversationManager):
|
||||
|
||||
await conversation_store.save_metadata(conversation)
|
||||
|
||||
def _is_git_related_event(self, event) -> bool:
|
||||
"""
|
||||
Determine if an event is related to git operations that could change the branch.
|
||||
|
||||
Args:
|
||||
event: The event to check
|
||||
|
||||
Returns:
|
||||
True if the event is git-related and could change the branch, False otherwise
|
||||
"""
|
||||
# Early return if event is None or not the correct type
|
||||
if not event or not isinstance(event, CmdOutputObservation):
|
||||
return False
|
||||
|
||||
# Check CmdOutputObservation for git commands that change branches
|
||||
# We check the observation result, not the action request, to ensure the command actually succeeded
|
||||
if (
|
||||
event.observation == ObservationType.RUN
|
||||
and event.metadata.exit_code == 0 # Only consider successful commands
|
||||
):
|
||||
command = event.command.lower()
|
||||
|
||||
# Check if any git command that changes branches is present anywhere in the command
|
||||
# This handles compound commands like "cd workspace && git checkout feature-branch"
|
||||
git_commands = [
|
||||
'git checkout',
|
||||
'git switch',
|
||||
'git merge',
|
||||
'git rebase',
|
||||
'git reset',
|
||||
'git branch',
|
||||
]
|
||||
|
||||
is_git_related = any(git_cmd in command for git_cmd in git_commands)
|
||||
|
||||
if is_git_related:
|
||||
logger.debug(
|
||||
f'Detected git-related command: {command} with exit code {event.metadata.exit_code}',
|
||||
extra={'command': command, 'exit_code': event.metadata.exit_code},
|
||||
)
|
||||
|
||||
return is_git_related
|
||||
|
||||
return False
|
||||
|
||||
async def _update_conversation_branch(self, conversation: ConversationMetadata):
|
||||
"""
|
||||
Update the conversation's current branch if it has changed.
|
||||
|
||||
Args:
|
||||
conversation: The conversation metadata to update
|
||||
"""
|
||||
try:
|
||||
# Get the session and runtime for this conversation
|
||||
session, runtime = self._get_session_and_runtime(
|
||||
conversation.conversation_id
|
||||
)
|
||||
if not session or not runtime:
|
||||
return
|
||||
|
||||
# Get the current branch from the workspace
|
||||
current_branch = self._get_current_workspace_branch(
|
||||
runtime, conversation.selected_repository
|
||||
)
|
||||
|
||||
# Update branch if it has changed
|
||||
if self._should_update_branch(conversation.selected_branch, current_branch):
|
||||
self._update_branch_in_conversation(conversation, current_branch)
|
||||
|
||||
except Exception as e:
|
||||
# Log an error that occurred during branch update
|
||||
logger.warning(
|
||||
f'Failed to update conversation branch: {e}',
|
||||
extra={'session_id': conversation.conversation_id},
|
||||
)
|
||||
|
||||
def _get_session_and_runtime(
|
||||
self, conversation_id: str
|
||||
) -> tuple[Session | None, Any | None]:
|
||||
"""
|
||||
Get the session and runtime for a conversation.
|
||||
|
||||
Args:
|
||||
conversation_id: The conversation ID
|
||||
|
||||
Returns:
|
||||
Tuple of (session, runtime) or (None, None) if not found
|
||||
"""
|
||||
session = self._local_agent_loops_by_sid.get(conversation_id)
|
||||
if not session or not session.agent_session.runtime:
|
||||
return None, None
|
||||
return session, session.agent_session.runtime
|
||||
|
||||
def _get_current_workspace_branch(
|
||||
self, runtime: Any, selected_repository: str | None
|
||||
) -> str | None:
|
||||
"""
|
||||
Get the current branch from the workspace.
|
||||
|
||||
Args:
|
||||
runtime: The runtime instance
|
||||
selected_repository: The selected repository path or None
|
||||
|
||||
Returns:
|
||||
The current branch name or None if not found
|
||||
"""
|
||||
# Extract the repository name from the full repository path
|
||||
if not selected_repository:
|
||||
primary_repo_path = None
|
||||
else:
|
||||
# Extract the repository name from the full path (e.g., "org/repo" -> "repo")
|
||||
primary_repo_path = selected_repository.split('/')[-1]
|
||||
|
||||
return runtime.get_workspace_branch(primary_repo_path)
|
||||
|
||||
def _should_update_branch(
|
||||
self, current_branch: str | None, new_branch: str | None
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if the branch should be updated.
|
||||
|
||||
Args:
|
||||
current_branch: The current branch in conversation metadata
|
||||
new_branch: The new branch from the workspace
|
||||
|
||||
Returns:
|
||||
True if the branch should be updated, False otherwise
|
||||
"""
|
||||
return new_branch is not None and new_branch != current_branch
|
||||
|
||||
def _update_branch_in_conversation(
|
||||
self, conversation: ConversationMetadata, new_branch: str | None
|
||||
):
|
||||
"""
|
||||
Update the branch in the conversation metadata.
|
||||
|
||||
Args:
|
||||
conversation: The conversation metadata to update
|
||||
new_branch: The new branch name
|
||||
"""
|
||||
old_branch = conversation.selected_branch
|
||||
conversation.selected_branch = new_branch
|
||||
|
||||
logger.info(
|
||||
f'Branch changed from {old_branch} to {new_branch}',
|
||||
extra={'session_id': conversation.conversation_id},
|
||||
)
|
||||
|
||||
async def get_agent_loop_info(
|
||||
self, user_id: str | None = None, filter_to_sids: set[str] | None = None
|
||||
):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user