Feat: Add selected branch param to backend (#6508)

This commit is contained in:
Rohit Malhotra 2025-02-12 15:39:10 -05:00 committed by GitHub
parent ba599c7dd6
commit 312b9fbfb1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 38 additions and 5 deletions

View File

@ -229,6 +229,7 @@ class OpenHands {
): Promise<Conversation> {
const body = {
selected_repository: selectedRepository,
selected_branch: undefined,
initial_user_msg: initialUserMsg,
image_urls: imageUrls,
};

View File

@ -249,20 +249,37 @@ class Runtime(FileEditRuntimeMixin):
source = event.source if event.source else EventSource.AGENT
self.event_stream.add_event(observation, source) # type: ignore[arg-type]
def clone_repo(self, github_token: SecretStr, selected_repository: str) -> str:
def clone_repo(
self,
github_token: SecretStr,
selected_repository: str,
selected_branch: str | None,
) -> str:
if not github_token or not selected_repository:
raise ValueError(
'github_token and selected_repository must be provided to clone a repository'
)
url = f'https://{github_token.get_secret_value()}@github.com/{selected_repository}.git'
dir_name = selected_repository.split('/')[1]
# add random branch name to avoid conflicts
# Generate a random branch name to avoid conflicts
random_str = ''.join(
random.choices(string.ascii_lowercase + string.digits, k=8)
)
branch_name = f'openhands-workspace-{random_str}'
openhands_workspace_branch = f'openhands-workspace-{random_str}'
# Clone repository command
clone_command = f'git clone {url} {dir_name}'
# Checkout to appropriate branch
checkout_command = (
f'git checkout {selected_branch}'
if selected_branch
else f'git checkout -b {openhands_workspace_branch}'
)
action = CmdRunAction(
command=f'git clone {url} {dir_name} ; cd {dir_name} ; git checkout -b {branch_name}',
command=f'{clone_command} ; cd {dir_name} ; {checkout_command}',
)
self.log('info', f'Cloning repo: {selected_repository}')
self.run_action(action)

View File

@ -38,6 +38,7 @@ UPDATED_AT_CALLBACK_ID = 'updated_at_callback_id'
class InitSessionRequest(BaseModel):
selected_repository: str | None = None
selected_branch: str | None = None
initial_user_msg: str | None = None
image_urls: list[str] | None = None
@ -46,6 +47,7 @@ async def _create_new_conversation(
user_id: str | None,
token: SecretStr | None,
selected_repository: str | None,
selected_branch: str | None,
initial_user_msg: str | None,
image_urls: list[str] | None,
):
@ -74,6 +76,7 @@ async def _create_new_conversation(
session_init_args['github_token'] = token or SecretStr('')
session_init_args['selected_repository'] = selected_repository
session_init_args['selected_branch'] = selected_branch
conversation_init_data = ConversationInitData(**session_init_args)
logger.info('Loading conversation store')
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
@ -135,6 +138,7 @@ async def new_conversation(request: Request, data: InitSessionRequest):
github_token = await gh_client.get_latest_token()
selected_repository = data.selected_repository
selected_branch = data.selected_branch
initial_user_msg = data.initial_user_msg
image_urls = data.image_urls or []
@ -144,6 +148,7 @@ async def new_conversation(request: Request, data: InitSessionRequest):
user_id,
github_token,
selected_repository,
selected_branch,
initial_user_msg,
image_urls,
)

View File

@ -76,6 +76,7 @@ class AgentSession:
agent_configs: dict[str, AgentConfig] | None = None,
github_token: SecretStr | None = None,
selected_repository: str | None = None,
selected_branch: str | None = None,
initial_message: MessageAction | None = None,
):
"""Starts the Agent session
@ -105,6 +106,7 @@ class AgentSession:
agent=agent,
github_token=github_token,
selected_repository=selected_repository,
selected_branch=selected_branch,
)
self.controller = self._create_controller(
@ -184,6 +186,7 @@ class AgentSession:
agent: Agent,
github_token: SecretStr | None = None,
selected_repository: str | None = None,
selected_branch: str | None = None,
):
"""Creates a runtime instance
@ -239,7 +242,10 @@ class AgentSession:
repo_directory = None
if selected_repository:
repo_directory = await call_sync_from_async(
self.runtime.clone_repo, github_token, selected_repository
self.runtime.clone_repo,
github_token,
selected_repository,
selected_branch,
)
if agent.prompt_manager:

View File

@ -10,3 +10,4 @@ class ConversationInitData(Settings):
github_token: SecretStr | None = Field(default=None)
selected_repository: str | None = Field(default=None)
selected_branch: str | None = Field(default=None)

View File

@ -123,9 +123,11 @@ class Session:
github_token = None
selected_repository = None
selected_branch = None
if isinstance(settings, ConversationInitData):
github_token = settings.github_token
selected_repository = settings.selected_repository
selected_branch = settings.selected_branch
try:
await self.agent_session.start(
@ -138,6 +140,7 @@ class Session:
agent_configs=self.config.get_agent_configs(),
github_token=github_token,
selected_repository=selected_repository,
selected_branch=selected_branch,
initial_message=initial_message,
)
except Exception as e: