mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Implement frontend visualization for RecallObservation & Stop issueing recall action for agent message (#7566)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
parent
a828318494
commit
292217b8b4
@ -331,6 +331,7 @@ export enum I18nKey {
|
||||
OBSERVATION_MESSAGE$EDIT = "OBSERVATION_MESSAGE$EDIT",
|
||||
OBSERVATION_MESSAGE$WRITE = "OBSERVATION_MESSAGE$WRITE",
|
||||
OBSERVATION_MESSAGE$BROWSE = "OBSERVATION_MESSAGE$BROWSE",
|
||||
OBSERVATION_MESSAGE$RECALL = "OBSERVATION_MESSAGE$RECALL",
|
||||
EXPANDABLE_MESSAGE$SHOW_DETAILS = "EXPANDABLE_MESSAGE$SHOW_DETAILS",
|
||||
EXPANDABLE_MESSAGE$HIDE_DETAILS = "EXPANDABLE_MESSAGE$HIDE_DETAILS",
|
||||
AI_SETTINGS$TITLE = "AI_SETTINGS$TITLE",
|
||||
|
||||
@ -4945,6 +4945,21 @@
|
||||
"es": "Navegación completada",
|
||||
"tr": "Gezinme tamamlandı"
|
||||
},
|
||||
"OBSERVATION_MESSAGE$RECALL": {
|
||||
"en": "MicroAgent Activated",
|
||||
"ja": "マイクロエージェントが有効化されました",
|
||||
"zh-CN": "微代理已激活",
|
||||
"zh-TW": "微代理已啟動",
|
||||
"ko-KR": "마이크로에이전트 활성화됨",
|
||||
"no": "MikroAgent aktivert",
|
||||
"it": "MicroAgent attivato",
|
||||
"pt": "MicroAgent ativado",
|
||||
"es": "MicroAgent activado",
|
||||
"ar": "تم تنشيط الوكيل المصغر",
|
||||
"fr": "MicroAgent activé",
|
||||
"tr": "MikroAjan Etkinleştirildi",
|
||||
"de": "MicroAgent aktiviert"
|
||||
},
|
||||
"EXPANDABLE_MESSAGE$SHOW_DETAILS": {
|
||||
"en": "Show details",
|
||||
"zh-CN": "显示详情",
|
||||
|
||||
@ -51,6 +51,7 @@ export function handleObservationMessage(message: ObservationMessage) {
|
||||
case ObservationType.EDIT:
|
||||
case ObservationType.THINK:
|
||||
case ObservationType.NULL:
|
||||
case ObservationType.RECALL:
|
||||
break; // We don't display the default message for these observations
|
||||
default:
|
||||
store.dispatch(addAssistantMessage(message.message));
|
||||
@ -76,6 +77,21 @@ export function handleObservationMessage(message: ObservationMessage) {
|
||||
}),
|
||||
);
|
||||
break;
|
||||
case "recall":
|
||||
store.dispatch(
|
||||
addAssistantObservation({
|
||||
...baseObservation,
|
||||
observation: "recall" as const,
|
||||
extras: {
|
||||
...(message.extras || {}),
|
||||
recall_type:
|
||||
(message.extras?.recall_type as
|
||||
| "workspace_context"
|
||||
| "knowledge") || "knowledge",
|
||||
},
|
||||
}),
|
||||
);
|
||||
break;
|
||||
case "run":
|
||||
store.dispatch(
|
||||
addAssistantObservation({
|
||||
|
||||
@ -6,6 +6,7 @@ import {
|
||||
OpenHandsObservation,
|
||||
CommandObservation,
|
||||
IPythonObservation,
|
||||
RecallObservation,
|
||||
} from "#/types/core/observations";
|
||||
import { OpenHandsAction } from "#/types/core/actions";
|
||||
import { OpenHandsEventType } from "#/types/core/base";
|
||||
@ -22,6 +23,7 @@ const HANDLED_ACTIONS: OpenHandsEventType[] = [
|
||||
"browse",
|
||||
"browse_interactive",
|
||||
"edit",
|
||||
"recall",
|
||||
];
|
||||
|
||||
function getRiskText(risk: ActionSecurityRisk) {
|
||||
@ -112,6 +114,9 @@ export const chatSlice = createSlice({
|
||||
} else if (actionID === "browse_interactive") {
|
||||
// Include the browser_actions in the content
|
||||
text = `**Action:**\n\n\`\`\`python\n${action.payload.args.browser_actions}\n\`\`\``;
|
||||
} else if (actionID === "recall") {
|
||||
// skip recall actions
|
||||
return;
|
||||
}
|
||||
if (actionID === "run" || actionID === "run_ipython") {
|
||||
if (
|
||||
@ -143,6 +148,73 @@ export const chatSlice = createSlice({
|
||||
if (!HANDLED_ACTIONS.includes(observationID)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Special handling for RecallObservation - create a new message instead of updating an existing one
|
||||
if (observationID === "recall") {
|
||||
const recallObs = observation.payload as RecallObservation;
|
||||
let content = ``;
|
||||
|
||||
// Handle workspace context
|
||||
if (recallObs.extras.recall_type === "workspace_context") {
|
||||
if (recallObs.extras.repo_name) {
|
||||
content += `\n\n**Repository:** ${recallObs.extras.repo_name}`;
|
||||
}
|
||||
if (recallObs.extras.repo_directory) {
|
||||
content += `\n\n**Directory:** ${recallObs.extras.repo_directory}`;
|
||||
}
|
||||
if (recallObs.extras.date) {
|
||||
content += `\n\n**Date:** ${recallObs.extras.date}`;
|
||||
}
|
||||
if (
|
||||
recallObs.extras.runtime_hosts &&
|
||||
Object.keys(recallObs.extras.runtime_hosts).length > 0
|
||||
) {
|
||||
content += `\n\n**Available Hosts**`;
|
||||
for (const [host, port] of Object.entries(
|
||||
recallObs.extras.runtime_hosts,
|
||||
)) {
|
||||
content += `\n\n- ${host} (port ${port})`;
|
||||
}
|
||||
}
|
||||
if (recallObs.extras.repo_instructions) {
|
||||
content += `\n\n**Repository Instructions:**\n\n${recallObs.extras.repo_instructions}`;
|
||||
}
|
||||
if (recallObs.extras.additional_agent_instructions) {
|
||||
content += `\n\n**Additional Instructions:**\n\n${recallObs.extras.additional_agent_instructions}`;
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new message for the observation
|
||||
// Use the correct translation ID format that matches what's in the i18n file
|
||||
const translationID = `OBSERVATION_MESSAGE$${observationID.toUpperCase()}`;
|
||||
|
||||
// Handle microagent knowledge
|
||||
if (
|
||||
recallObs.extras.microagent_knowledge &&
|
||||
recallObs.extras.microagent_knowledge.length > 0
|
||||
) {
|
||||
content += `\n\n**Triggered Microagent Knowledge:**`;
|
||||
for (const knowledge of recallObs.extras.microagent_knowledge) {
|
||||
content += `\n\n- **${knowledge.name}** (triggered by keyword: ${knowledge.trigger})\n\n\`\`\`\n${knowledge.content}\n\`\`\``;
|
||||
}
|
||||
}
|
||||
|
||||
const message: Message = {
|
||||
type: "action",
|
||||
sender: "assistant",
|
||||
translationID,
|
||||
eventID: observation.payload.id,
|
||||
content,
|
||||
imageUrls: [],
|
||||
timestamp: new Date().toISOString(),
|
||||
success: true,
|
||||
};
|
||||
|
||||
state.messages.push(message);
|
||||
return; // Skip the normal observation handling below
|
||||
}
|
||||
|
||||
// Normal handling for other observation types
|
||||
const translationID = `OBSERVATION_MESSAGE$${observationID.toUpperCase()}`;
|
||||
const causeID = observation.payload.cause;
|
||||
const causeMessage = state.messages.find(
|
||||
|
||||
@ -133,6 +133,15 @@ export interface RejectAction extends OpenHandsActionEvent<"reject"> {
|
||||
};
|
||||
}
|
||||
|
||||
export interface RecallAction extends OpenHandsActionEvent<"recall"> {
|
||||
source: "agent";
|
||||
args: {
|
||||
recall_type: "workspace_context" | "knowledge";
|
||||
query: string;
|
||||
thought: string;
|
||||
};
|
||||
}
|
||||
|
||||
export type OpenHandsAction =
|
||||
| UserMessageAction
|
||||
| AssistantMessageAction
|
||||
@ -146,4 +155,5 @@ export type OpenHandsAction =
|
||||
| FileReadAction
|
||||
| FileEditAction
|
||||
| FileWriteAction
|
||||
| RejectAction;
|
||||
| RejectAction
|
||||
| RecallAction;
|
||||
|
||||
@ -12,7 +12,8 @@ export type OpenHandsEventType =
|
||||
| "reject"
|
||||
| "think"
|
||||
| "finish"
|
||||
| "error";
|
||||
| "error"
|
||||
| "recall";
|
||||
|
||||
interface OpenHandsBaseEvent {
|
||||
id: number;
|
||||
|
||||
@ -109,6 +109,26 @@ export interface AgentThinkObservation
|
||||
};
|
||||
}
|
||||
|
||||
export interface MicroagentKnowledge {
|
||||
name: string;
|
||||
trigger: string;
|
||||
content: string;
|
||||
}
|
||||
|
||||
export interface RecallObservation extends OpenHandsObservationEvent<"recall"> {
|
||||
source: "agent";
|
||||
extras: {
|
||||
recall_type?: "workspace_context" | "knowledge";
|
||||
repo_name?: string;
|
||||
repo_directory?: string;
|
||||
repo_instructions?: string;
|
||||
runtime_hosts?: Record<string, number>;
|
||||
additional_agent_instructions?: string;
|
||||
date?: string;
|
||||
microagent_knowledge?: MicroagentKnowledge[];
|
||||
};
|
||||
}
|
||||
|
||||
export type OpenHandsObservation =
|
||||
| AgentStateChangeObservation
|
||||
| AgentThinkObservation
|
||||
@ -120,4 +140,5 @@ export type OpenHandsObservation =
|
||||
| WriteObservation
|
||||
| ReadObservation
|
||||
| EditObservation
|
||||
| ErrorObservation;
|
||||
| ErrorObservation
|
||||
| RecallObservation;
|
||||
|
||||
@ -29,6 +29,9 @@ enum ObservationType {
|
||||
// A response to the agent's thought (usually a static message)
|
||||
THINK = "think",
|
||||
|
||||
// An observation that shows agent's context extension
|
||||
RECALL = "recall",
|
||||
|
||||
// A no-op observation
|
||||
NULL = "null",
|
||||
}
|
||||
|
||||
@ -490,15 +490,8 @@ class AgentController:
|
||||
|
||||
if self.get_agent_state() != AgentState.RUNNING:
|
||||
await self.set_agent_state_to(AgentState.RUNNING)
|
||||
elif action.source == EventSource.AGENT:
|
||||
# Check if we need to trigger microagents based on agent message content
|
||||
recall_action = RecallAction(
|
||||
query=action.content, recall_type=RecallType.KNOWLEDGE
|
||||
)
|
||||
self._pending_action = recall_action
|
||||
# This is source=AGENT because the agent message is the trigger for the microagent retrieval
|
||||
self.event_stream.add_event(recall_action, EventSource.AGENT)
|
||||
|
||||
elif action.source == EventSource.AGENT:
|
||||
# If the agent is waiting for a response, set the appropriate state
|
||||
if action.wait_for_response:
|
||||
await self.set_agent_state_to(AgentState.AWAITING_USER_INPUT)
|
||||
@ -1084,44 +1077,8 @@ class AgentController:
|
||||
# cut in half
|
||||
mid_point = max(1, len(events) // 2)
|
||||
kept_events = events[mid_point:]
|
||||
|
||||
# Handle first event in truncated history
|
||||
if kept_events:
|
||||
i = 0
|
||||
while i < len(kept_events):
|
||||
first_event = kept_events[i]
|
||||
if isinstance(first_event, Observation) and first_event.cause:
|
||||
# Find its action and include it
|
||||
matching_action = next(
|
||||
(
|
||||
e
|
||||
for e in reversed(events[:mid_point])
|
||||
if isinstance(e, Action) and e.id == first_event.cause
|
||||
),
|
||||
None,
|
||||
)
|
||||
if matching_action:
|
||||
kept_events = [matching_action] + kept_events
|
||||
else:
|
||||
self.log(
|
||||
'warning',
|
||||
f'Found Observation without matching Action at id={first_event.id}',
|
||||
)
|
||||
# drop this observation
|
||||
kept_events = kept_events[1:]
|
||||
break
|
||||
|
||||
elif isinstance(first_event, MessageAction) or (
|
||||
isinstance(first_event, Action)
|
||||
and first_event.source == EventSource.USER
|
||||
):
|
||||
# if it's a message action or a user action, keep it and continue to find the next event
|
||||
i += 1
|
||||
continue
|
||||
|
||||
else:
|
||||
# if it's an action with source == EventSource.AGENT, we're good
|
||||
break
|
||||
if len(kept_events) > 0 and isinstance(kept_events[0], Observation):
|
||||
kept_events = kept_events[1:]
|
||||
|
||||
# Ensure first user message is included
|
||||
if first_user_msg and first_user_msg not in kept_events:
|
||||
|
||||
@ -14,7 +14,6 @@ from openhands.events.observation import (
|
||||
)
|
||||
from openhands.events.observation.agent import (
|
||||
AgentStateChangedObservation,
|
||||
RecallObservation,
|
||||
)
|
||||
from openhands.events.serialization import event_to_dict
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderToken
|
||||
@ -91,7 +90,7 @@ async def connect(connection_id: str, environ):
|
||||
logger.debug(f'oh_event: {event.__class__.__name__}')
|
||||
if isinstance(
|
||||
event,
|
||||
(NullAction, NullObservation, RecallAction, RecallObservation),
|
||||
(NullAction, NullObservation, RecallAction),
|
||||
):
|
||||
continue
|
||||
elif isinstance(event, AgentStateChangedObservation):
|
||||
|
||||
@ -19,6 +19,7 @@ from openhands.events.observation import (
|
||||
CmdOutputObservation,
|
||||
NullObservation,
|
||||
)
|
||||
from openhands.events.observation.agent import RecallObservation
|
||||
from openhands.events.observation.error import ErrorObservation
|
||||
from openhands.events.serialization import event_from_dict, event_to_dict
|
||||
from openhands.events.stream import EventStreamSubscriber
|
||||
@ -199,7 +200,7 @@ class Session:
|
||||
await self.send(event_to_dict(event))
|
||||
# NOTE: ipython observations are not sent here currently
|
||||
elif event.source == EventSource.ENVIRONMENT and isinstance(
|
||||
event, (CmdOutputObservation, AgentStateChangedObservation)
|
||||
event, (CmdOutputObservation, AgentStateChangedObservation, RecallObservation)
|
||||
):
|
||||
# feedback from the environment to agent actions is understood as agent events by the UI
|
||||
event_dict = event_to_dict(event)
|
||||
|
||||
@ -14,7 +14,7 @@ from openhands.core.main import run_controller
|
||||
from openhands.core.schema import AgentState
|
||||
from openhands.events import Event, EventSource, EventStream, EventStreamSubscriber
|
||||
from openhands.events.action import ChangeAgentStateAction, CmdRunAction, MessageAction
|
||||
from openhands.events.action.agent import RecallAction
|
||||
from openhands.events.action.agent import CondensationAction, RecallAction
|
||||
from openhands.events.event import RecallType
|
||||
from openhands.events.observation import (
|
||||
AgentStateChangedObservation,
|
||||
@ -739,6 +739,7 @@ async def test_context_window_exceeded_error_handling(
|
||||
# called the agent's `step` function the right number of times.
|
||||
assert step_state.has_errored
|
||||
assert len(step_state.views) == max_iterations
|
||||
print('step_state.views: ', step_state.views)
|
||||
|
||||
# Look at pre/post-step views. Normally, these should always increase in
|
||||
# size (because we return a message action, which triggers a recall, which
|
||||
@ -755,14 +756,55 @@ async def test_context_window_exceeded_error_handling(
|
||||
|
||||
# The final state's history should contain:
|
||||
# - max_iterations number of message actions,
|
||||
# - max_iterations number of recall actions,
|
||||
# - max_iterations number of recall observations,
|
||||
# - and exactly one condensation action.
|
||||
assert len(final_state.history) == max_iterations * 3 + 1
|
||||
# - 1 recall actions,
|
||||
# - 1 recall observations,
|
||||
# - 1 condensation action.
|
||||
assert (
|
||||
len(
|
||||
[event for event in final_state.history if isinstance(event, MessageAction)]
|
||||
)
|
||||
== max_iterations
|
||||
)
|
||||
assert (
|
||||
len(
|
||||
[
|
||||
event
|
||||
for event in final_state.history
|
||||
if isinstance(event, MessageAction)
|
||||
and event.source == EventSource.AGENT
|
||||
]
|
||||
)
|
||||
== max_iterations - 1
|
||||
)
|
||||
assert (
|
||||
len([event for event in final_state.history if isinstance(event, RecallAction)])
|
||||
== 1
|
||||
)
|
||||
assert (
|
||||
len(
|
||||
[
|
||||
event
|
||||
for event in final_state.history
|
||||
if isinstance(event, RecallObservation)
|
||||
]
|
||||
)
|
||||
== 1
|
||||
)
|
||||
assert (
|
||||
len(
|
||||
[
|
||||
event
|
||||
for event in final_state.history
|
||||
if isinstance(event, CondensationAction)
|
||||
]
|
||||
)
|
||||
== 1
|
||||
)
|
||||
assert (
|
||||
len(final_state.history) == max_iterations + 3
|
||||
) # 1 condensation action, 1 recall action, 1 recall observation
|
||||
|
||||
# ...but the final state's view should be identical to the last view (plus
|
||||
# the final message action and associated recall action/observation).
|
||||
assert len(final_state.view) == len(step_state.views[-1]) + 3
|
||||
assert len(final_state.view) == len(step_state.views[-1]) + 1
|
||||
|
||||
# And these two representations of the state are _not_ the same.
|
||||
assert len(final_state.history) != len(final_state.view)
|
||||
@ -781,7 +823,7 @@ async def test_run_controller_with_context_window_exceeded_with_truncation(
|
||||
def step(self, state: State):
|
||||
# If the state has more than one message and we haven't errored yet,
|
||||
# throw the context window exceeded error
|
||||
if len(state.history) > 1 and not self.has_errored:
|
||||
if len(state.history) > 3 and not self.has_errored:
|
||||
error = ContextWindowExceededError(
|
||||
message='prompt is too long: 233885 tokens > 200000 maximum',
|
||||
model='',
|
||||
@ -813,7 +855,7 @@ async def test_run_controller_with_context_window_exceeded_with_truncation(
|
||||
try:
|
||||
state = await asyncio.wait_for(
|
||||
run_controller(
|
||||
config=AppConfig(max_iterations=3),
|
||||
config=AppConfig(max_iterations=5),
|
||||
initial_user_action=MessageAction(content='INITIAL'),
|
||||
runtime=mock_runtime,
|
||||
sid='test',
|
||||
@ -833,11 +875,11 @@ async def test_run_controller_with_context_window_exceeded_with_truncation(
|
||||
|
||||
# Hitting the iteration limit indicates the controller is failing for the
|
||||
# expected reason
|
||||
assert state.iteration == 3
|
||||
assert state.iteration == 5
|
||||
assert state.agent_state == AgentState.ERROR
|
||||
assert (
|
||||
state.last_error
|
||||
== 'RuntimeError: Agent reached maximum iteration in headless mode. Current iteration: 3, max iteration: 3'
|
||||
== 'RuntimeError: Agent reached maximum iteration in headless mode. Current iteration: 5, max iteration: 5'
|
||||
)
|
||||
|
||||
# Check that the context window exceeded error was raised during the run
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user