mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
[Feat]: Always autogen title (#8292)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
parent
3677c52d2b
commit
aa9a48135e
@ -7,7 +7,6 @@ import React from "react";
|
||||
import { renderWithProviders } from "test-utils";
|
||||
import { ConversationPanel } from "#/components/features/conversation-panel/conversation-panel";
|
||||
import OpenHands from "#/api/open-hands";
|
||||
import { clickOnEditButton } from "./utils";
|
||||
|
||||
describe("ConversationPanel", () => {
|
||||
const onCloseMock = vi.fn();
|
||||
@ -204,64 +203,6 @@ describe("ConversationPanel", () => {
|
||||
});
|
||||
});
|
||||
|
||||
it("should rename a conversation", async () => {
|
||||
const updateUserConversationSpy = vi.spyOn(
|
||||
OpenHands,
|
||||
"updateUserConversation",
|
||||
);
|
||||
|
||||
const user = userEvent.setup();
|
||||
renderConversationPanel();
|
||||
const cards = await screen.findAllByTestId("conversation-card");
|
||||
|
||||
const card = cards[0];
|
||||
await clickOnEditButton(user, card);
|
||||
const title = within(card).getByTestId("conversation-card-title");
|
||||
|
||||
await user.clear(title);
|
||||
await user.type(title, "Conversation 1 Renamed");
|
||||
await user.tab();
|
||||
|
||||
// Ensure the conversation is renamed
|
||||
expect(updateUserConversationSpy).toHaveBeenCalledWith("1", {
|
||||
title: "Conversation 1 Renamed",
|
||||
});
|
||||
});
|
||||
|
||||
it("should not rename a conversation when the name is unchanged", async () => {
|
||||
const updateUserConversationSpy = vi.spyOn(
|
||||
OpenHands,
|
||||
"updateUserConversation",
|
||||
);
|
||||
|
||||
const user = userEvent.setup();
|
||||
renderConversationPanel();
|
||||
const cards = await screen.findAllByTestId("conversation-card");
|
||||
|
||||
const card = cards[0];
|
||||
await clickOnEditButton(user, card);
|
||||
const title = within(card).getByTestId("conversation-card-title");
|
||||
|
||||
await user.click(title);
|
||||
await user.tab();
|
||||
|
||||
// Ensure the conversation is not renamed
|
||||
expect(updateUserConversationSpy).not.toHaveBeenCalled();
|
||||
|
||||
await clickOnEditButton(user, card);
|
||||
|
||||
await user.type(title, "Conversation 1");
|
||||
await user.click(title);
|
||||
await user.tab();
|
||||
|
||||
expect(updateUserConversationSpy).toHaveBeenCalledTimes(1);
|
||||
|
||||
await user.click(title);
|
||||
await user.tab();
|
||||
|
||||
expect(updateUserConversationSpy).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it("should call onClose after clicking a card", async () => {
|
||||
const user = userEvent.setup();
|
||||
renderConversationPanel();
|
||||
|
||||
@ -143,13 +143,6 @@ class OpenHands {
|
||||
await openHands.delete(`/api/conversations/${conversationId}`);
|
||||
}
|
||||
|
||||
static async updateUserConversation(
|
||||
conversationId: string,
|
||||
conversation: Partial<Omit<Conversation, "conversation_id">>,
|
||||
): Promise<void> {
|
||||
await openHands.patch(`/api/conversations/${conversationId}`, conversation);
|
||||
}
|
||||
|
||||
static async createConversation(
|
||||
conversation_trigger: ConversationTrigger = "gui",
|
||||
selectedRepository?: string,
|
||||
|
||||
@ -5,7 +5,6 @@ import { AgentStatusBar } from "./agent-status-bar";
|
||||
import { SecurityLock } from "./security-lock";
|
||||
import { useUserConversation } from "#/hooks/query/use-user-conversation";
|
||||
import { ConversationCard } from "../conversation-panel/conversation-card";
|
||||
import { useAutoTitle } from "#/hooks/use-auto-title";
|
||||
|
||||
interface ControlsProps {
|
||||
setSecurityOpen: (isOpen: boolean) => void;
|
||||
@ -17,7 +16,6 @@ export function Controls({ setSecurityOpen, showSecurityLock }: ControlsProps) {
|
||||
const { data: conversation } = useUserConversation(
|
||||
params.conversationId ?? null,
|
||||
);
|
||||
useAutoTitle();
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-2 md:items-center md:justify-between md:flex-row">
|
||||
|
||||
@ -7,7 +7,6 @@ import { useUserConversations } from "#/hooks/query/use-user-conversations";
|
||||
import { useDeleteConversation } from "#/hooks/mutation/use-delete-conversation";
|
||||
import { ConfirmDeleteModal } from "./confirm-delete-modal";
|
||||
import { LoadingSpinner } from "#/components/shared/loading-spinner";
|
||||
import { useUpdateConversation } from "#/hooks/mutation/use-update-conversation";
|
||||
import { ExitConversationModal } from "./exit-conversation-modal";
|
||||
import { useClickOutsideElement } from "#/hooks/use-click-outside-element";
|
||||
|
||||
@ -34,7 +33,6 @@ export function ConversationPanel({ onClose }: ConversationPanelProps) {
|
||||
const { data: conversations, isFetching, error } = useUserConversations();
|
||||
|
||||
const { mutate: deleteConversation } = useDeleteConversation();
|
||||
const { mutate: updateConversation } = useUpdateConversation();
|
||||
|
||||
const handleDeleteProject = (conversationId: string) => {
|
||||
setConfirmDeleteModalVisible(true);
|
||||
@ -56,18 +54,6 @@ export function ConversationPanel({ onClose }: ConversationPanelProps) {
|
||||
}
|
||||
};
|
||||
|
||||
const handleChangeTitle = (
|
||||
conversationId: string,
|
||||
oldTitle: string,
|
||||
newTitle: string,
|
||||
) => {
|
||||
if (oldTitle !== newTitle)
|
||||
updateConversation({
|
||||
id: conversationId,
|
||||
conversation: { title: newTitle },
|
||||
});
|
||||
};
|
||||
|
||||
return (
|
||||
<div
|
||||
ref={ref}
|
||||
@ -101,9 +87,6 @@ export function ConversationPanel({ onClose }: ConversationPanelProps) {
|
||||
<ConversationCard
|
||||
isActive={isActive}
|
||||
onDelete={() => handleDeleteProject(project.conversation_id)}
|
||||
onChangeTitle={(title) =>
|
||||
handleChangeTitle(project.conversation_id, project.title, title)
|
||||
}
|
||||
title={project.title}
|
||||
selectedRepository={project.selected_repository}
|
||||
lastUpdatedAt={project.last_updated_at}
|
||||
|
||||
@ -1,18 +0,0 @@
|
||||
import { useQueryClient, useMutation } from "@tanstack/react-query";
|
||||
import OpenHands from "#/api/open-hands";
|
||||
import { Conversation } from "#/api/open-hands.types";
|
||||
|
||||
export const useUpdateConversation = () => {
|
||||
const queryClient = useQueryClient();
|
||||
|
||||
return useMutation({
|
||||
mutationFn: (variables: {
|
||||
id: string;
|
||||
conversation: Partial<Omit<Conversation, "id">>;
|
||||
}) =>
|
||||
OpenHands.updateUserConversation(variables.id, variables.conversation),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ["user", "conversations"] });
|
||||
},
|
||||
});
|
||||
};
|
||||
@ -1,82 +0,0 @@
|
||||
import { useEffect } from "react";
|
||||
import { useParams } from "react-router";
|
||||
import { useSelector, useDispatch } from "react-redux";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { useUpdateConversation } from "./mutation/use-update-conversation";
|
||||
import { RootState } from "#/store";
|
||||
import OpenHands from "#/api/open-hands";
|
||||
import { useUserConversation } from "#/hooks/query/use-user-conversation";
|
||||
|
||||
const defaultTitlePattern = /^Conversation [a-f0-9]+$/;
|
||||
|
||||
/**
|
||||
* Hook that monitors for the first agent message and triggers title generation.
|
||||
* This approach is more robust as it ensures the user message has been processed
|
||||
* by the backend and the agent has responded before generating the title.
|
||||
*/
|
||||
export function useAutoTitle() {
|
||||
const { conversationId } = useParams<{ conversationId: string }>();
|
||||
const { data: conversation } = useUserConversation(conversationId ?? null);
|
||||
const queryClient = useQueryClient();
|
||||
const dispatch = useDispatch();
|
||||
const { mutate: updateConversation } = useUpdateConversation();
|
||||
|
||||
const messages = useSelector((state: RootState) => state.chat.messages);
|
||||
|
||||
useEffect(() => {
|
||||
if (
|
||||
!conversation ||
|
||||
!conversationId ||
|
||||
!messages ||
|
||||
messages.length === 0
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
const hasAgentMessage = messages.some(
|
||||
(message) => message.sender === "assistant",
|
||||
);
|
||||
const hasUserMessage = messages.some(
|
||||
(message) => message.sender === "user",
|
||||
);
|
||||
|
||||
if (!hasAgentMessage || !hasUserMessage) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (conversation.title && !defaultTitlePattern.test(conversation.title)) {
|
||||
return;
|
||||
}
|
||||
|
||||
updateConversation(
|
||||
{
|
||||
id: conversationId,
|
||||
conversation: { title: "" },
|
||||
},
|
||||
{
|
||||
onSuccess: async () => {
|
||||
try {
|
||||
const updatedConversation =
|
||||
await OpenHands.getConversation(conversationId);
|
||||
|
||||
queryClient.setQueryData(
|
||||
["user", "conversation", conversationId],
|
||||
updatedConversation,
|
||||
);
|
||||
} catch (error) {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: ["user", "conversation", conversationId],
|
||||
});
|
||||
}
|
||||
},
|
||||
},
|
||||
);
|
||||
}, [
|
||||
messages,
|
||||
conversationId,
|
||||
conversation,
|
||||
updateConversation,
|
||||
queryClient,
|
||||
dispatch,
|
||||
]);
|
||||
}
|
||||
110
frontend/src/services/__tests__/actions.test.ts
Normal file
110
frontend/src/services/__tests__/actions.test.ts
Normal file
@ -0,0 +1,110 @@
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
|
||||
import { handleStatusMessage } from "../actions";
|
||||
import { StatusMessage } from "#/types/message";
|
||||
import { queryClient } from "#/query-client-config";
|
||||
import store from "#/store";
|
||||
import { setCurStatusMessage } from "#/state/status-slice";
|
||||
import { addErrorMessage } from "#/state/chat-slice";
|
||||
import { trackError } from "#/utils/error-handler";
|
||||
|
||||
// Mock dependencies
|
||||
vi.mock("#/query-client-config", () => ({
|
||||
queryClient: {
|
||||
invalidateQueries: vi.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
vi.mock("#/store", () => ({
|
||||
default: {
|
||||
dispatch: vi.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
vi.mock("#/state/status-slice", () => ({
|
||||
setCurStatusMessage: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("#/state/chat-slice", () => ({
|
||||
addErrorMessage: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("#/utils/error-handler", () => ({
|
||||
trackError: vi.fn(),
|
||||
}));
|
||||
|
||||
describe("handleStatusMessage", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.resetAllMocks();
|
||||
});
|
||||
|
||||
it("should invalidate queries when receiving a conversation title update", () => {
|
||||
// Create a status message with a conversation title
|
||||
const statusMessage: StatusMessage = {
|
||||
status_update: true,
|
||||
type: "info",
|
||||
message: "conversation-123",
|
||||
conversation_title: "New Conversation Title",
|
||||
};
|
||||
|
||||
// Call the function
|
||||
handleStatusMessage(statusMessage);
|
||||
|
||||
// Verify that queryClient.invalidateQueries was called with the correct parameters
|
||||
expect(queryClient.invalidateQueries).toHaveBeenCalledWith({
|
||||
queryKey: ["user", "conversation", "conversation-123"],
|
||||
});
|
||||
|
||||
// Verify that store.dispatch was not called
|
||||
expect(store.dispatch).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("should dispatch setCurStatusMessage for info messages without conversation_title", () => {
|
||||
// Create a status message without a conversation title
|
||||
const statusMessage: StatusMessage = {
|
||||
status_update: true,
|
||||
type: "info",
|
||||
message: "Some info message",
|
||||
};
|
||||
|
||||
// Call the function
|
||||
handleStatusMessage(statusMessage);
|
||||
|
||||
// Verify that store.dispatch was called with setCurStatusMessage
|
||||
expect(store.dispatch).toHaveBeenCalledWith(
|
||||
setCurStatusMessage(statusMessage),
|
||||
);
|
||||
|
||||
// Verify that queryClient.invalidateQueries was not called
|
||||
expect(queryClient.invalidateQueries).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("should dispatch addErrorMessage for error messages", () => {
|
||||
// Create an error status message
|
||||
const statusMessage: StatusMessage = {
|
||||
status_update: true,
|
||||
type: "error",
|
||||
id: "ERROR_ID",
|
||||
message: "Some error message",
|
||||
};
|
||||
|
||||
// Call the function
|
||||
handleStatusMessage(statusMessage);
|
||||
|
||||
// Verify that trackError was called with the correct parameters
|
||||
expect(trackError).toHaveBeenCalledWith({
|
||||
message: "Some error message",
|
||||
source: "chat",
|
||||
metadata: { msgId: "ERROR_ID" },
|
||||
});
|
||||
|
||||
// Verify that store.dispatch was called with addErrorMessage
|
||||
expect(store.dispatch).toHaveBeenCalledWith(addErrorMessage(statusMessage));
|
||||
|
||||
// Verify that queryClient.invalidateQueries was not called
|
||||
expect(queryClient.invalidateQueries).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
@ -19,6 +19,7 @@ import {
|
||||
} from "#/types/message";
|
||||
import { handleObservationMessage } from "./observations";
|
||||
import { appendInput } from "#/state/command-slice";
|
||||
import { queryClient } from "#/query-client-config";
|
||||
|
||||
const messageActions = {
|
||||
[ActionType.BROWSE]: (message: ActionMessage) => {
|
||||
@ -125,7 +126,15 @@ export function handleActionMessage(message: ActionMessage) {
|
||||
}
|
||||
|
||||
export function handleStatusMessage(message: StatusMessage) {
|
||||
if (message.type === "info") {
|
||||
// Info message with conversation_title indicates new title for conversation
|
||||
if (message.type === "info" && message.conversation_title) {
|
||||
const conversationId = message.message;
|
||||
|
||||
// Invalidate the conversation query to trigger a refetch with the new title
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: ["user", "conversation", conversationId],
|
||||
});
|
||||
} else if (message.type === "info") {
|
||||
store.dispatch(
|
||||
setCurStatusMessage({
|
||||
...message,
|
||||
|
||||
@ -67,4 +67,5 @@ export interface StatusMessage {
|
||||
type: string;
|
||||
id?: string;
|
||||
message: string;
|
||||
conversation_title?: string;
|
||||
}
|
||||
|
||||
@ -7,12 +7,14 @@ from typing import Callable, Iterable
|
||||
import socketio
|
||||
|
||||
from openhands.core.config.app_config import AppConfig
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.core.exceptions import AgentRuntimeUnavailableError
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema.agent import AgentState
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.events.event import EventSource
|
||||
from openhands.events.event_store import EventStore
|
||||
from openhands.events.stream import EventStreamSubscriber, session_exists
|
||||
from openhands.events.stream import EventStream, EventStreamSubscriber, session_exists
|
||||
from openhands.server.config.server_config import ServerConfig
|
||||
from openhands.server.monitoring import MonitoringListener
|
||||
from openhands.server.session.agent_session import WAIT_TIME_BEFORE_CLOSE
|
||||
@ -23,6 +25,7 @@ from openhands.storage.data_models.conversation_metadata import ConversationMeta
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
from openhands.storage.files import FileStore
|
||||
from openhands.utils.async_utils import GENERAL_TIMEOUT, call_async_from_sync, wait_all
|
||||
from openhands.utils.conversation_summary import get_default_conversation_title, auto_generate_title
|
||||
from openhands.utils.import_utils import get_impl
|
||||
from openhands.utils.shutdown_listener import should_continue
|
||||
|
||||
@ -204,6 +207,7 @@ class StandaloneConversationManager(ConversationManager):
|
||||
)
|
||||
store = await conversation_store_class.get_instance(self.config, user_id)
|
||||
return store
|
||||
|
||||
|
||||
async def get_running_agent_loops(
|
||||
self, user_id: str | None = None, filter_to_sids: set[str] | None = None
|
||||
@ -328,7 +332,7 @@ class StandaloneConversationManager(ConversationManager):
|
||||
try:
|
||||
session.agent_session.event_stream.subscribe(
|
||||
EventStreamSubscriber.SERVER,
|
||||
self._create_conversation_update_callback(user_id, github_user_id, sid),
|
||||
self._create_conversation_update_callback(user_id, github_user_id, sid, settings),
|
||||
UPDATED_AT_CALLBACK_ID,
|
||||
)
|
||||
except ValueError:
|
||||
@ -425,7 +429,7 @@ class StandaloneConversationManager(ConversationManager):
|
||||
)
|
||||
|
||||
def _create_conversation_update_callback(
|
||||
self, user_id: str | None, github_user_id: str | None, conversation_id: str
|
||||
self, user_id: str | None, github_user_id: str | None, conversation_id: str, settings: Settings
|
||||
) -> Callable:
|
||||
def callback(event, *args, **kwargs):
|
||||
call_async_from_sync(
|
||||
@ -434,13 +438,15 @@ class StandaloneConversationManager(ConversationManager):
|
||||
user_id,
|
||||
github_user_id,
|
||||
conversation_id,
|
||||
settings,
|
||||
event,
|
||||
)
|
||||
|
||||
return callback
|
||||
|
||||
|
||||
async def _update_conversation_for_event(
|
||||
self, user_id: str, github_user_id: str, conversation_id: str, event=None
|
||||
self, user_id: str, github_user_id: str, conversation_id: str, settings: Settings, event=None
|
||||
):
|
||||
conversation_store = await self._get_conversation_store(user_id, github_user_id)
|
||||
conversation = await conversation_store.get_metadata(conversation_id)
|
||||
@ -462,6 +468,28 @@ class StandaloneConversationManager(ConversationManager):
|
||||
conversation.total_tokens = (
|
||||
token_usage.prompt_tokens + token_usage.completion_tokens
|
||||
)
|
||||
default_title = get_default_conversation_title(conversation_id)
|
||||
if conversation.title == default_title: # attempt to autogenerate if default title is in use
|
||||
title = await auto_generate_title(conversation_id, user_id, self.file_store, settings)
|
||||
if title and not title.isspace():
|
||||
conversation.title = title
|
||||
try:
|
||||
# Emit a status update to the client with the new title
|
||||
status_update_dict = {
|
||||
'status_update': True,
|
||||
'type': 'info',
|
||||
'message': conversation_id,
|
||||
'conversation_title': conversation.title,
|
||||
}
|
||||
await self.sio.emit(
|
||||
'oh_event',
|
||||
status_update_dict,
|
||||
to=ROOM_KEY.format(sid=conversation_id),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'Error emitting title update event: {e}')
|
||||
else:
|
||||
conversation.title = default_title
|
||||
|
||||
await conversation_store.save_metadata(conversation)
|
||||
|
||||
|
||||
@ -1,16 +1,14 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, status
|
||||
from fastapi import APIRouter, Depends, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action.message import MessageAction
|
||||
from openhands.events.event import EventSource
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.integrations.provider import (
|
||||
PROVIDER_TOKEN_TYPE,
|
||||
ProviderHandler,
|
||||
@ -31,7 +29,6 @@ from openhands.server.shared import (
|
||||
SettingsStoreImpl,
|
||||
config,
|
||||
conversation_manager,
|
||||
file_store,
|
||||
)
|
||||
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
|
||||
from openhands.server.user_auth import (
|
||||
@ -48,7 +45,7 @@ from openhands.storage.data_models.conversation_metadata import (
|
||||
)
|
||||
from openhands.storage.data_models.conversation_status import ConversationStatus
|
||||
from openhands.utils.async_utils import wait_all
|
||||
from openhands.utils.conversation_summary import generate_conversation_title
|
||||
from openhands.utils.conversation_summary import get_default_conversation_title
|
||||
|
||||
app = APIRouter(prefix='/api')
|
||||
|
||||
@ -99,7 +96,7 @@ async def _create_new_conversation(
|
||||
not settings.llm_api_key
|
||||
or settings.llm_api_key.get_secret_value().isspace()
|
||||
):
|
||||
logger.warn(f'Missing api key for model {settings.llm_model}')
|
||||
logger.warning(f'Missing api key for model {settings.llm_model}')
|
||||
raise LLMAuthenticationError(
|
||||
'Error authenticating with the LLM provider. Please check your API key'
|
||||
)
|
||||
@ -163,7 +160,6 @@ async def _create_new_conversation(
|
||||
replay_json=replay_json,
|
||||
)
|
||||
logger.info(f'Finished initializing conversation {conversation_id}')
|
||||
|
||||
return conversation_id
|
||||
|
||||
|
||||
@ -299,110 +295,7 @@ async def get_conversation(
|
||||
return conversation_info
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
|
||||
|
||||
def get_default_conversation_title(conversation_id: str) -> str:
|
||||
"""
|
||||
Generate a default title for a conversation based on its ID.
|
||||
|
||||
Args:
|
||||
conversation_id: The ID of the conversation
|
||||
|
||||
Returns:
|
||||
A default title string
|
||||
"""
|
||||
return f'Conversation {conversation_id[:5]}'
|
||||
|
||||
|
||||
async def auto_generate_title(conversation_id: str, user_id: str | None) -> str:
|
||||
"""
|
||||
Auto-generate a title for a conversation based on the first user message.
|
||||
Uses LLM-based title generation if available, otherwise falls back to a simple truncation.
|
||||
|
||||
Args:
|
||||
conversation_id: The ID of the conversation
|
||||
user_id: The ID of the user
|
||||
|
||||
Returns:
|
||||
A generated title string
|
||||
"""
|
||||
logger.info(f'Auto-generating title for conversation {conversation_id}')
|
||||
|
||||
try:
|
||||
# Create an event stream for the conversation
|
||||
event_stream = EventStream(conversation_id, file_store, user_id)
|
||||
|
||||
# Find the first user message
|
||||
first_user_message = None
|
||||
for event in event_stream.get_events():
|
||||
if (
|
||||
event.source == EventSource.USER
|
||||
and isinstance(event, MessageAction)
|
||||
and event.content
|
||||
and event.content.strip()
|
||||
):
|
||||
first_user_message = event.content
|
||||
break
|
||||
|
||||
if first_user_message:
|
||||
# Get LLM config from user settings
|
||||
try:
|
||||
settings_store = await SettingsStoreImpl.get_instance(config, user_id)
|
||||
settings = await settings_store.load()
|
||||
|
||||
if settings and settings.llm_model:
|
||||
# Create LLM config from settings
|
||||
llm_config = LLMConfig(
|
||||
model=settings.llm_model,
|
||||
api_key=settings.llm_api_key,
|
||||
base_url=settings.llm_base_url,
|
||||
)
|
||||
|
||||
# Try to generate title using LLM
|
||||
llm_title = await generate_conversation_title(
|
||||
first_user_message, llm_config
|
||||
)
|
||||
if llm_title:
|
||||
logger.info(f'Generated title using LLM: {llm_title}')
|
||||
return llm_title
|
||||
except Exception as e:
|
||||
logger.error(f'Error using LLM for title generation: {e}')
|
||||
|
||||
# Fall back to simple truncation if LLM generation fails or is unavailable
|
||||
first_user_message = first_user_message.strip()
|
||||
title = first_user_message[:30]
|
||||
if len(first_user_message) > 30:
|
||||
title += '...'
|
||||
logger.info(f'Generated title using truncation: {title}')
|
||||
return title
|
||||
except Exception as e:
|
||||
logger.error(f'Error generating title: {str(e)}')
|
||||
return ''
|
||||
|
||||
|
||||
@app.patch('/conversations/{conversation_id}')
|
||||
async def update_conversation(
|
||||
conversation_id: str,
|
||||
title: str = Body(embed=True),
|
||||
user_id: str | None = Depends(get_user_id),
|
||||
) -> bool:
|
||||
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
|
||||
metadata = await conversation_store.get_metadata(conversation_id)
|
||||
if not metadata:
|
||||
return False
|
||||
|
||||
# If title is empty or unspecified, auto-generate it
|
||||
if not title or title.isspace():
|
||||
title = await auto_generate_title(conversation_id, user_id)
|
||||
|
||||
# If we still don't have a title, use the default
|
||||
if not title or title.isspace():
|
||||
title = get_default_conversation_title(conversation_id)
|
||||
|
||||
metadata.title = title
|
||||
await conversation_store.save_metadata(metadata)
|
||||
return True
|
||||
|
||||
|
||||
|
||||
@app.delete('/conversations/{conversation_id}')
|
||||
async def delete_conversation(
|
||||
|
||||
@ -4,7 +4,12 @@ from typing import Optional
|
||||
|
||||
from openhands.core.config import LLMConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action.message import MessageAction
|
||||
from openhands.events.event import EventSource
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
from openhands.storage.files import FileStore
|
||||
|
||||
|
||||
async def generate_conversation_title(
|
||||
@ -55,3 +60,81 @@ async def generate_conversation_title(
|
||||
except Exception as e:
|
||||
logger.error(f'Error generating conversation title: {e}')
|
||||
return None
|
||||
|
||||
|
||||
def get_default_conversation_title(conversation_id: str) -> str:
|
||||
"""
|
||||
Generate a default title for a conversation based on its ID.
|
||||
|
||||
Args:
|
||||
conversation_id: The ID of the conversation
|
||||
|
||||
Returns:
|
||||
A default title string
|
||||
"""
|
||||
return f'Conversation {conversation_id[:5]}'
|
||||
|
||||
|
||||
async def auto_generate_title(
|
||||
conversation_id: str, user_id: str | None, file_store: FileStore, settings: Settings
|
||||
) -> str:
|
||||
"""
|
||||
Auto-generate a title for a conversation based on the first user message.
|
||||
Uses LLM-based title generation if available, otherwise falls back to a simple truncation.
|
||||
|
||||
Args:
|
||||
conversation_id: The ID of the conversation
|
||||
user_id: The ID of the user
|
||||
|
||||
Returns:
|
||||
A generated title string
|
||||
"""
|
||||
logger.info(f'Auto-generating title for conversation {conversation_id}')
|
||||
|
||||
try:
|
||||
# Create an event stream for the conversation
|
||||
event_stream = EventStream(conversation_id, file_store, user_id)
|
||||
|
||||
# Find the first user message
|
||||
first_user_message = None
|
||||
for event in event_stream.get_events():
|
||||
if (
|
||||
event.source == EventSource.USER
|
||||
and isinstance(event, MessageAction)
|
||||
and event.content
|
||||
and event.content.strip()
|
||||
):
|
||||
first_user_message = event.content
|
||||
break
|
||||
|
||||
if first_user_message:
|
||||
# Get LLM config from user settings
|
||||
try:
|
||||
if settings and settings.llm_model:
|
||||
# Create LLM config from settings
|
||||
llm_config = LLMConfig(
|
||||
model=settings.llm_model,
|
||||
api_key=settings.llm_api_key,
|
||||
base_url=settings.llm_base_url,
|
||||
)
|
||||
|
||||
# Try to generate title using LLM
|
||||
llm_title = await generate_conversation_title(
|
||||
first_user_message, llm_config
|
||||
)
|
||||
if llm_title:
|
||||
logger.info(f'Generated title using LLM: {llm_title}')
|
||||
return llm_title
|
||||
except Exception as e:
|
||||
logger.error(f'Error using LLM for title generation: {e}')
|
||||
|
||||
# Fall back to simple truncation if LLM generation fails or is unavailable
|
||||
first_user_message = first_user_message.strip()
|
||||
title = first_user_message[:30]
|
||||
if len(first_user_message) > 30:
|
||||
title += '...'
|
||||
logger.info(f'Generated title using truncation: {title}')
|
||||
return title
|
||||
except Exception as e:
|
||||
logger.error(f'Error generating title: {str(e)}')
|
||||
return ''
|
||||
|
||||
240
tests/unit/test_auto_generate_title.py
Normal file
240
tests/unit/test_auto_generate_title.py
Normal file
@ -0,0 +1,240 @@
|
||||
"""Tests for the auto-generate title functionality."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.core.config.app_config import AppConfig
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.events.event import EventSource
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.server.conversation_manager.standalone_conversation_manager import (
|
||||
StandaloneConversationManager,
|
||||
)
|
||||
from openhands.server.monitoring import MonitoringListener
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
from openhands.utils.conversation_summary import auto_generate_title
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_generate_title_with_llm():
|
||||
"""Test auto-generating a title using LLM."""
|
||||
# Mock dependencies
|
||||
file_store = InMemoryFileStore()
|
||||
|
||||
# Create test conversation with a user message
|
||||
conversation_id = 'test-conversation'
|
||||
user_id = 'test-user'
|
||||
|
||||
# Create a mock event
|
||||
user_message = MessageAction(
|
||||
content='Help me write a Python script to analyze data'
|
||||
)
|
||||
user_message._source = EventSource.USER
|
||||
user_message._id = 1
|
||||
user_message._timestamp = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
# Mock the EventStream class
|
||||
with patch(
|
||||
'openhands.utils.conversation_summary.EventStream'
|
||||
) as mock_event_stream_cls:
|
||||
# Configure the mock event stream to return our test message
|
||||
mock_event_stream = MagicMock(spec=EventStream)
|
||||
mock_event_stream.get_events.return_value = [user_message]
|
||||
mock_event_stream_cls.return_value = mock_event_stream
|
||||
|
||||
# Mock the LLM response
|
||||
with patch('openhands.utils.conversation_summary.LLM') as mock_llm_cls:
|
||||
mock_llm = mock_llm_cls.return_value
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = 'Python Data Analysis Script'
|
||||
mock_llm.completion.return_value = mock_response
|
||||
|
||||
# Create test settings with LLM config
|
||||
settings = Settings(
|
||||
llm_model='test-model',
|
||||
llm_api_key='test-key',
|
||||
llm_base_url='test-url',
|
||||
)
|
||||
|
||||
# Call the auto_generate_title function directly
|
||||
title = await auto_generate_title(
|
||||
conversation_id, user_id, file_store, settings
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert title == 'Python Data Analysis Script'
|
||||
|
||||
# Verify EventStream was created with the correct parameters
|
||||
mock_event_stream_cls.assert_called_once_with(
|
||||
conversation_id, file_store, user_id
|
||||
)
|
||||
|
||||
# Verify LLM was called with appropriate parameters
|
||||
mock_llm_cls.assert_called_once_with(
|
||||
LLMConfig(
|
||||
model='test-model',
|
||||
api_key='test-key',
|
||||
base_url='test-url',
|
||||
)
|
||||
)
|
||||
mock_llm.completion.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_generate_title_fallback():
|
||||
"""Test auto-generating a title with fallback to truncation when LLM fails."""
|
||||
# Mock dependencies
|
||||
file_store = InMemoryFileStore()
|
||||
|
||||
# Create test conversation with a user message
|
||||
conversation_id = 'test-conversation'
|
||||
user_id = 'test-user'
|
||||
|
||||
# Create a mock event with a long message
|
||||
long_message = 'This is a very long message that should be truncated when used as a title because it exceeds the maximum length allowed for titles'
|
||||
user_message = MessageAction(content=long_message)
|
||||
user_message._source = EventSource.USER
|
||||
user_message._id = 1
|
||||
user_message._timestamp = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
# Mock the EventStream class
|
||||
with patch(
|
||||
'openhands.utils.conversation_summary.EventStream'
|
||||
) as mock_event_stream_cls:
|
||||
# Configure the mock event stream to return our test message
|
||||
mock_event_stream = MagicMock(spec=EventStream)
|
||||
mock_event_stream.get_events.return_value = [user_message]
|
||||
mock_event_stream_cls.return_value = mock_event_stream
|
||||
|
||||
# Mock the LLM to raise an exception
|
||||
with patch('openhands.utils.conversation_summary.LLM') as mock_llm_cls:
|
||||
mock_llm = mock_llm_cls.return_value
|
||||
mock_llm.completion.side_effect = Exception('Test error')
|
||||
|
||||
# Create test settings with LLM config
|
||||
settings = Settings(
|
||||
llm_model='test-model',
|
||||
llm_api_key='test-key',
|
||||
llm_base_url='test-url',
|
||||
)
|
||||
|
||||
# Call the auto_generate_title function directly
|
||||
title = await auto_generate_title(
|
||||
conversation_id, user_id, file_store, settings
|
||||
)
|
||||
|
||||
# Verify the result is a truncated version of the message
|
||||
assert title == 'This is a very long message th...'
|
||||
assert len(title) <= 35
|
||||
|
||||
# Verify EventStream was created with the correct parameters
|
||||
mock_event_stream_cls.assert_called_once_with(
|
||||
conversation_id, file_store, user_id
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_generate_title_no_messages():
|
||||
"""Test auto-generating a title when there are no user messages."""
|
||||
# Mock dependencies
|
||||
file_store = InMemoryFileStore()
|
||||
|
||||
# Create test conversation with no messages
|
||||
conversation_id = 'test-conversation'
|
||||
user_id = 'test-user'
|
||||
|
||||
# Mock the EventStream class
|
||||
with patch(
|
||||
'openhands.utils.conversation_summary.EventStream'
|
||||
) as mock_event_stream_cls:
|
||||
# Configure the mock event stream to return no events
|
||||
mock_event_stream = MagicMock(spec=EventStream)
|
||||
mock_event_stream.get_events.return_value = []
|
||||
mock_event_stream_cls.return_value = mock_event_stream
|
||||
|
||||
# Create test settings
|
||||
settings = Settings(
|
||||
llm_model='test-model',
|
||||
llm_api_key='test-key',
|
||||
llm_base_url='test-url',
|
||||
)
|
||||
|
||||
# Call the auto_generate_title function directly
|
||||
title = await auto_generate_title(
|
||||
conversation_id, user_id, file_store, settings
|
||||
)
|
||||
|
||||
# Verify the result is empty
|
||||
assert title == ''
|
||||
|
||||
# Verify EventStream was created with the correct parameters
|
||||
mock_event_stream_cls.assert_called_once_with(
|
||||
conversation_id, file_store, user_id
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_conversation_with_title():
|
||||
"""Test that _update_conversation_for_event updates the title when needed."""
|
||||
# Mock dependencies
|
||||
sio = MagicMock()
|
||||
sio.emit = AsyncMock()
|
||||
file_store = InMemoryFileStore()
|
||||
server_config = MagicMock()
|
||||
|
||||
# Create test conversation
|
||||
conversation_id = 'test-conversation'
|
||||
user_id = 'test-user'
|
||||
github_user_id = 'test-github-user'
|
||||
|
||||
# Create test settings
|
||||
settings = Settings(
|
||||
llm_model='test-model',
|
||||
llm_api_key='test-key',
|
||||
llm_base_url='test-url',
|
||||
)
|
||||
|
||||
# Mock the conversation store and metadata
|
||||
mock_conversation_store = AsyncMock()
|
||||
mock_metadata = MagicMock()
|
||||
mock_metadata.title = f'Conversation {conversation_id[:5]}' # Default title
|
||||
mock_conversation_store.get_metadata.return_value = mock_metadata
|
||||
|
||||
# Create the conversation manager
|
||||
manager = StandaloneConversationManager(
|
||||
sio=sio,
|
||||
config=AppConfig(),
|
||||
file_store=file_store,
|
||||
server_config=server_config,
|
||||
monitoring_listener=MonitoringListener(),
|
||||
)
|
||||
|
||||
# Mock the _get_conversation_store method
|
||||
manager._get_conversation_store = AsyncMock(return_value=mock_conversation_store)
|
||||
|
||||
# Mock the auto_generate_title function
|
||||
with patch(
|
||||
'openhands.server.conversation_manager.standalone_conversation_manager.auto_generate_title',
|
||||
AsyncMock(return_value='Generated Title'),
|
||||
):
|
||||
# Call the method
|
||||
await manager._update_conversation_for_event(
|
||||
user_id, github_user_id, conversation_id, settings
|
||||
)
|
||||
|
||||
# Verify the title was updated
|
||||
assert mock_metadata.title == 'Generated Title'
|
||||
|
||||
# Verify the socket.io emit was called with the correct parameters
|
||||
sio.emit.assert_called_once()
|
||||
call_args = sio.emit.call_args[0]
|
||||
assert call_args[0] == 'oh_event'
|
||||
assert call_args[1]['status_update'] is True
|
||||
assert call_args[1]['type'] == 'info'
|
||||
assert call_args[1]['message'] == conversation_id
|
||||
assert call_args[1]['conversation_title'] == 'Generated Title'
|
||||
@ -25,7 +25,6 @@ from openhands.server.routes.manage_conversations import (
|
||||
get_conversation,
|
||||
new_conversation,
|
||||
search_conversations,
|
||||
update_conversation,
|
||||
)
|
||||
from openhands.server.routes.manage_conversations import app as conversation_app
|
||||
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
|
||||
@ -222,50 +221,6 @@ async def test_get_missing_conversation():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_conversation():
|
||||
with _patch_store():
|
||||
# Mock the ConversationStoreImpl.get_instance
|
||||
with patch(
|
||||
'openhands.server.routes.manage_conversations.ConversationStoreImpl.get_instance'
|
||||
) as mock_get_instance:
|
||||
# Create a mock conversation store
|
||||
mock_store = MagicMock()
|
||||
|
||||
# Mock metadata
|
||||
metadata = ConversationMetadata(
|
||||
conversation_id='some_conversation_id',
|
||||
title='Some Conversation',
|
||||
created_at=datetime.fromisoformat('2025-01-01T00:00:00+00:00'),
|
||||
last_updated_at=datetime.fromisoformat('2025-01-01T00:01:00+00:00'),
|
||||
selected_repository='foobar',
|
||||
github_user_id='12345',
|
||||
user_id='12345',
|
||||
)
|
||||
|
||||
# Set up the mock to return metadata and then save it
|
||||
mock_store.get_metadata = AsyncMock(return_value=metadata)
|
||||
mock_store.save_metadata = AsyncMock()
|
||||
|
||||
# Return the mock store from get_instance
|
||||
mock_get_instance.return_value = mock_store
|
||||
|
||||
# Call update_conversation
|
||||
result = await update_conversation(
|
||||
'some_conversation_id',
|
||||
'New Title',
|
||||
user_id='12345',
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result is True
|
||||
|
||||
# Verify that save_metadata was called with updated metadata
|
||||
mock_store.save_metadata.assert_called_once()
|
||||
saved_metadata = mock_store.save_metadata.call_args[0][0]
|
||||
assert saved_metadata.title == 'New Title'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_conversation_success(provider_handler_mock):
|
||||
"""Test successful creation of a new conversation."""
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user