feat(backend): enable sub-conversation creation using a different agent (#11715)

This commit is contained in:
Hiep Le 2025-11-13 23:06:44 +07:00 committed by GitHub
parent d5b2d2ebc5
commit bc86796a67
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 241 additions and 3 deletions

View File

@ -0,0 +1,41 @@
"""add parent_conversation_id to conversation_metadata
Revision ID: 081
Revises: 080
Create Date: 2025-11-06 00:00:00.000000
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '081'
down_revision: Union[str, None] = '080'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
op.add_column(
'conversation_metadata',
sa.Column('parent_conversation_id', sa.String(), nullable=True),
)
op.create_index(
op.f('ix_conversation_metadata_parent_conversation_id'),
'conversation_metadata',
['parent_conversation_id'],
unique=False,
)
def downgrade() -> None:
"""Downgrade schema."""
op.drop_index(
op.f('ix_conversation_metadata_parent_conversation_id'),
table_name='conversation_metadata',
)
op.drop_column('conversation_metadata', 'parent_conversation_id')

View File

@ -60,6 +60,7 @@ class SaasConversationStore(ConversationStore):
kwargs.pop('reasoning_tokens', None)
kwargs.pop('context_window', None)
kwargs.pop('per_turn_token', None)
kwargs.pop('parent_conversation_id', None)
return ConversationMetadata(**kwargs)

View File

@ -16,6 +16,13 @@ from openhands.sdk.llm import MetricsSnapshot
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
class AgentType(Enum):
"""Agent type for conversation."""
DEFAULT = 'default'
PLAN = 'plan'
class AppConversationInfo(BaseModel):
"""Conversation info which does not contain status."""
@ -34,6 +41,8 @@ class AppConversationInfo(BaseModel):
metrics: MetricsSnapshot | None = None
parent_conversation_id: OpenHandsUUID | None = None
created_at: datetime = Field(default_factory=utc_now)
updated_at: datetime = Field(default_factory=utc_now)
@ -98,6 +107,8 @@ class AppConversationStartRequest(BaseModel):
title: str | None = None
trigger: ConversationTrigger | None = None
pr_number: list[int] = Field(default_factory=list)
parent_conversation_id: OpenHandsUUID | None = None
agent_type: AgentType = Field(default=AgentType.DEFAULT)
class AppConversationStartTaskStatus(Enum):

View File

@ -21,6 +21,7 @@ from openhands.app_server.app_conversation.app_conversation_info_service import
AppConversationInfoService,
)
from openhands.app_server.app_conversation.app_conversation_models import (
AgentType,
AppConversation,
AppConversationInfo,
AppConversationPage,
@ -70,6 +71,7 @@ from openhands.sdk.llm import LLM
from openhands.sdk.security.confirmation_policy import AlwaysConfirm
from openhands.sdk.workspace.remote.async_remote_workspace import AsyncRemoteWorkspace
from openhands.tools.preset.default import get_default_agent
from openhands.tools.preset.planning import get_planning_agent
_conversation_info_type_adapter = TypeAdapter(list[ConversationInfo | None])
_logger = logging.getLogger(__name__)
@ -168,6 +170,20 @@ class LiveStatusAppConversationService(GitAppConversationService):
) -> AsyncGenerator[AppConversationStartTask, None]:
# Create and yield the start task
user_id = await self.user_context.get_user_id()
# Validate and inherit from parent conversation if provided
if request.parent_conversation_id:
parent_info = (
await self.app_conversation_info_service.get_app_conversation_info(
request.parent_conversation_id
)
)
if parent_info is None:
raise ValueError(
f'Parent conversation not found: {request.parent_conversation_id}'
)
self._inherit_configuration_from_parent(request, parent_info)
task = AppConversationStartTask(
created_by_user_id=user_id,
request=request,
@ -206,6 +222,8 @@ class LiveStatusAppConversationService(GitAppConversationService):
request.initial_message,
request.git_provider,
sandbox_spec.working_dir,
request.agent_type,
request.llm_model,
)
)
@ -224,6 +242,7 @@ class LiveStatusAppConversationService(GitAppConversationService):
headers={'X-Session-API-Key': sandbox.session_api_key},
timeout=self.sandbox_startup_timeout,
)
response.raise_for_status()
info = ConversationInfo.model_validate(response.json())
@ -241,6 +260,7 @@ class LiveStatusAppConversationService(GitAppConversationService):
git_provider=request.git_provider,
trigger=request.trigger,
pr_number=request.pr_number,
parent_conversation_id=request.parent_conversation_id,
)
await self.app_conversation_info_service.save_app_conversation_info(
app_conversation_info
@ -452,11 +472,43 @@ class LiveStatusAppConversationService(GitAppConversationService):
)
return agent_server_url
def _inherit_configuration_from_parent(
self, request: AppConversationStartRequest, parent_info: AppConversationInfo
) -> None:
"""Inherit configuration from parent conversation if not explicitly provided.
This ensures sub-conversations automatically inherit:
- Sandbox ID (to share the same workspace/environment)
- Git parameters (repository, branch, provider)
- LLM model
Args:
request: The conversation start request to modify
parent_info: The parent conversation info to inherit from
"""
# Inherit sandbox_id from parent to share the same workspace/environment
if not request.sandbox_id:
request.sandbox_id = parent_info.sandbox_id
# Inherit git parameters from parent if not provided
if not request.selected_repository:
request.selected_repository = parent_info.selected_repository
if not request.selected_branch:
request.selected_branch = parent_info.selected_branch
if not request.git_provider:
request.git_provider = parent_info.git_provider
# Inherit LLM model from parent if not provided
if not request.llm_model and parent_info.llm_model:
request.llm_model = parent_info.llm_model
async def _build_start_conversation_request_for_user(
self,
initial_message: SendMessageRequest | None,
git_provider: ProviderType | None,
working_dir: str,
agent_type: AgentType = AgentType.DEFAULT,
llm_model: str | None = None,
) -> StartConversationRequest:
user = await self.user_context.get_user_info()
@ -488,13 +540,19 @@ class LiveStatusAppConversationService(GitAppConversationService):
workspace = LocalWorkspace(working_dir=working_dir)
# Use provided llm_model if available, otherwise fall back to user's default
model = llm_model or user.llm_model
llm = LLM(
model=user.llm_model,
model=model,
base_url=user.llm_base_url,
api_key=user.llm_api_key,
usage_id='agent',
)
agent = get_default_agent(llm=llm)
# Select agent based on agent_type
if agent_type == AgentType.PLAN:
agent = get_planning_agent(llm=llm)
else:
agent = get_default_agent(llm=llm)
conversation_id = uuid4()
agent = ExperimentManagerImpl.run_agent_variant_tests__v1(

View File

@ -88,6 +88,7 @@ class StoredConversationMetadata(Base): # type: ignore
conversation_version = Column(String, nullable=False, default='V0', index=True)
sandbox_id = Column(String, nullable=True, index=True)
parent_conversation_id = Column(String, nullable=True, index=True)
@dataclass
@ -307,6 +308,11 @@ class SQLAppConversationInfoService(AppConversationInfoService):
llm_model=info.llm_model,
conversation_version='V1',
sandbox_id=info.sandbox_id,
parent_conversation_id=(
str(info.parent_conversation_id)
if info.parent_conversation_id
else None
),
)
await self.db_session.merge(stored)
@ -364,6 +370,11 @@ class SQLAppConversationInfoService(AppConversationInfoService):
pr_number=stored.pr_number,
llm_model=stored.llm_model,
metrics=metrics,
parent_conversation_id=(
UUID(stored.parent_conversation_id)
if stored.parent_conversation_id
else None
),
created_at=created_at,
updated_at=updated_at,
)

View File

@ -0,0 +1,41 @@
"""add parent_conversation_id to conversation_metadata
Revision ID: 003
Revises: 002
Create Date: 2025-11-06 00:00:00.000000
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '003'
down_revision: Union[str, None] = '002'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
op.add_column(
'conversation_metadata',
sa.Column('parent_conversation_id', sa.String(), nullable=True),
)
op.create_index(
op.f('ix_conversation_metadata_parent_conversation_id'),
'conversation_metadata',
['parent_conversation_id'],
unique=False,
)
def downgrade() -> None:
"""Downgrade schema."""
op.drop_index(
op.f('ix_conversation_metadata_parent_conversation_id'),
table_name='conversation_metadata',
)
op.drop_column('conversation_metadata', 'parent_conversation_id')

View File

@ -11,7 +11,7 @@ from openhands.sdk.utils.models import DiscriminatedUnionMixin
# The version of the agent server to use for deployments.
# Typically this will be the same as the values from the pyproject.toml
AGENT_SERVER_IMAGE = 'ghcr.io/openhands/agent-server:f3c0c19-python'
AGENT_SERVER_IMAGE = 'ghcr.io/openhands/agent-server:4e2ecd8-python'
class SandboxSpecService(ABC):

View File

@ -11,7 +11,14 @@ from openhands.app_server.app_conversation.app_conversation_info_service import
AppConversationInfoService,
)
from openhands.app_server.app_conversation.app_conversation_models import (
AgentType,
AppConversationInfo,
AppConversationStartRequest,
AppConversationStartTask,
AppConversationStartTaskStatus,
)
from openhands.app_server.app_conversation.app_conversation_service import (
AppConversationService,
)
from openhands.microagent.microagent import KnowledgeMicroagent, RepoMicroagent
from openhands.microagent.types import MicroagentMetadata, MicroagentType
@ -1125,3 +1132,71 @@ async def test_add_message_empty_message():
call_args = mock_manager.send_event_to_conversation.call_args
message_data = call_args[0][1]
assert message_data['args']['content'] == ''
@pytest.mark.sub_conversation
@pytest.mark.asyncio
async def test_create_sub_conversation_with_planning_agent():
"""Test creating a sub-conversation from a parent conversation with planning agent."""
from uuid import uuid4
parent_conversation_id = uuid4()
user_id = 'test_user_456'
sandbox_id = 'test_sandbox_123'
# Create mock parent conversation info
parent_info = AppConversationInfo(
id=parent_conversation_id,
created_by_user_id=user_id,
sandbox_id=sandbox_id,
selected_repository='test/repo',
selected_branch='main',
git_provider=None,
title='Parent Conversation',
llm_model='anthropic/claude-3-5-sonnet-20241022',
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
# Create sub-conversation request with planning agent
sub_conversation_request = AppConversationStartRequest(
parent_conversation_id=parent_conversation_id,
agent_type=AgentType.PLAN,
initial_message=None,
)
# Create mock app conversation service
mock_app_conversation_service = MagicMock(spec=AppConversationService)
mock_app_conversation_info_service = MagicMock(spec=AppConversationInfoService)
# Mock the service to return parent info
mock_app_conversation_info_service.get_app_conversation_info = AsyncMock(
return_value=parent_info
)
# Mock the start_app_conversation method to return a task
async def mock_start_generator(request):
task = AppConversationStartTask(
id=uuid4(),
created_by_user_id=user_id,
status=AppConversationStartTaskStatus.READY,
app_conversation_id=uuid4(),
sandbox_id=sandbox_id,
agent_server_url='http://agent-server:8000',
request=request,
)
yield task
mock_app_conversation_service.start_app_conversation = mock_start_generator
# Test the service method directly
async for task in mock_app_conversation_service.start_app_conversation(
sub_conversation_request
):
# Verify the task was created with planning agent
assert task is not None
assert task.status == AppConversationStartTaskStatus.READY
assert task.request.agent_type == AgentType.PLAN
assert task.request.parent_conversation_id == parent_conversation_id
assert task.sandbox_id == sandbox_id
break