mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
[Feat] Custom MicroAgents. (#4983)
Co-authored-by: diwu-sf <di.wu@shadowfaxdata.com>
This commit is contained in:
parent
cf157c86b3
commit
2b06e4e5d0
@ -38,6 +38,7 @@ interface WsClientProviderProps {
|
|||||||
enabled: boolean;
|
enabled: boolean;
|
||||||
token: string | null;
|
token: string | null;
|
||||||
ghToken: string | null;
|
ghToken: string | null;
|
||||||
|
selectedRepository: string | null;
|
||||||
settings: Settings | null;
|
settings: Settings | null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -45,12 +46,14 @@ export function WsClientProvider({
|
|||||||
enabled,
|
enabled,
|
||||||
token,
|
token,
|
||||||
ghToken,
|
ghToken,
|
||||||
|
selectedRepository,
|
||||||
settings,
|
settings,
|
||||||
children,
|
children,
|
||||||
}: React.PropsWithChildren<WsClientProviderProps>) {
|
}: React.PropsWithChildren<WsClientProviderProps>) {
|
||||||
const sioRef = React.useRef<Socket | null>(null);
|
const sioRef = React.useRef<Socket | null>(null);
|
||||||
const tokenRef = React.useRef<string | null>(token);
|
const tokenRef = React.useRef<string | null>(token);
|
||||||
const ghTokenRef = React.useRef<string | null>(ghToken);
|
const ghTokenRef = React.useRef<string | null>(ghToken);
|
||||||
|
const selectedRepositoryRef = React.useRef<string | null>(selectedRepository);
|
||||||
const disconnectRef = React.useRef<ReturnType<typeof setTimeout> | null>(
|
const disconnectRef = React.useRef<ReturnType<typeof setTimeout> | null>(
|
||||||
null,
|
null,
|
||||||
);
|
);
|
||||||
@ -81,6 +84,9 @@ export function WsClientProvider({
|
|||||||
if (ghToken) {
|
if (ghToken) {
|
||||||
initEvent.github_token = ghToken;
|
initEvent.github_token = ghToken;
|
||||||
}
|
}
|
||||||
|
if (selectedRepository) {
|
||||||
|
initEvent.selected_repository = selectedRepository;
|
||||||
|
}
|
||||||
const lastEvent = lastEventRef.current;
|
const lastEvent = lastEventRef.current;
|
||||||
if (lastEvent) {
|
if (lastEvent) {
|
||||||
initEvent.latest_event_id = lastEvent.id;
|
initEvent.latest_event_id = lastEvent.id;
|
||||||
@ -158,6 +164,7 @@ export function WsClientProvider({
|
|||||||
sioRef.current = sio;
|
sioRef.current = sio;
|
||||||
tokenRef.current = token;
|
tokenRef.current = token;
|
||||||
ghTokenRef.current = ghToken;
|
ghTokenRef.current = ghToken;
|
||||||
|
selectedRepositoryRef.current = selectedRepository;
|
||||||
|
|
||||||
return () => {
|
return () => {
|
||||||
sio.off("connect", handleConnect);
|
sio.off("connect", handleConnect);
|
||||||
@ -166,7 +173,7 @@ export function WsClientProvider({
|
|||||||
sio.off("connect_failed", handleError);
|
sio.off("connect_failed", handleError);
|
||||||
sio.off("disconnect", handleDisconnect);
|
sio.off("disconnect", handleDisconnect);
|
||||||
};
|
};
|
||||||
}, [enabled, token, ghToken]);
|
}, [enabled, token, ghToken, selectedRepository]);
|
||||||
|
|
||||||
// Strict mode mounts and unmounts each component twice, so we have to wait in the destructor
|
// Strict mode mounts and unmounts each component twice, so we have to wait in the destructor
|
||||||
// before actually disconnecting the socket and cancel the operation if the component gets remounted.
|
// before actually disconnecting the socket and cancel the operation if the component gets remounted.
|
||||||
|
|||||||
@ -6,7 +6,6 @@ import {
|
|||||||
WsClientProviderStatus,
|
WsClientProviderStatus,
|
||||||
} from "#/context/ws-client-provider";
|
} from "#/context/ws-client-provider";
|
||||||
import { createChatMessage } from "#/services/chat-service";
|
import { createChatMessage } from "#/services/chat-service";
|
||||||
import { getCloneRepoCommand } from "#/services/terminal-service";
|
|
||||||
import { setCurrentAgentState } from "#/state/agent-slice";
|
import { setCurrentAgentState } from "#/state/agent-slice";
|
||||||
import { addUserMessage } from "#/state/chat-slice";
|
import { addUserMessage } from "#/state/chat-slice";
|
||||||
import {
|
import {
|
||||||
@ -37,11 +36,6 @@ export const useWSStatusChange = () => {
|
|||||||
send(createChatMessage(query, base64Files, timestamp));
|
send(createChatMessage(query, base64Files, timestamp));
|
||||||
};
|
};
|
||||||
|
|
||||||
const dispatchCloneRepoCommand = (ghToken: string, repository: string) => {
|
|
||||||
send(getCloneRepoCommand(ghToken, repository));
|
|
||||||
dispatch(clearSelectedRepository());
|
|
||||||
};
|
|
||||||
|
|
||||||
const dispatchInitialQuery = (query: string, additionalInfo: string) => {
|
const dispatchInitialQuery = (query: string, additionalInfo: string) => {
|
||||||
if (additionalInfo) {
|
if (additionalInfo) {
|
||||||
sendInitialQuery(`${query}\n\n[${additionalInfo}]`, files);
|
sendInitialQuery(`${query}\n\n[${additionalInfo}]`, files);
|
||||||
@ -57,8 +51,7 @@ export const useWSStatusChange = () => {
|
|||||||
let additionalInfo = "";
|
let additionalInfo = "";
|
||||||
|
|
||||||
if (gitHubToken && selectedRepository) {
|
if (gitHubToken && selectedRepository) {
|
||||||
dispatchCloneRepoCommand(gitHubToken, selectedRepository);
|
dispatch(clearSelectedRepository());
|
||||||
additionalInfo = `Repository ${selectedRepository} has been cloned to /workspace. Please check the /workspace for files.`;
|
|
||||||
} else if (importedProjectZip) {
|
} else if (importedProjectZip) {
|
||||||
// if there's an uploaded project zip, add it to the chat
|
// if there's an uploaded project zip, add it to the chat
|
||||||
additionalInfo =
|
additionalInfo =
|
||||||
|
|||||||
@ -64,6 +64,7 @@ function App() {
|
|||||||
enabled
|
enabled
|
||||||
token={token}
|
token={token}
|
||||||
ghToken={gitHubToken}
|
ghToken={gitHubToken}
|
||||||
|
selectedRepository={selectedRepository}
|
||||||
settings={settings}
|
settings={settings}
|
||||||
>
|
>
|
||||||
<EventHandler>
|
<EventHandler>
|
||||||
|
|||||||
@ -10,11 +10,3 @@ export function getGitHubTokenCommand(gitHubToken: string) {
|
|||||||
const event = getTerminalCommand(command, true);
|
const event = getTerminalCommand(command, true);
|
||||||
return event;
|
return event;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function getCloneRepoCommand(gitHubToken: string, repository: string) {
|
|
||||||
const url = `https://${gitHubToken}@github.com/${repository}.git`;
|
|
||||||
const dirName = repository.split("/")[1];
|
|
||||||
const command = `git clone ${url} ${dirName} ; cd ${dirName} ; git checkout -b openhands-workspace`;
|
|
||||||
const event = getTerminalCommand(command, true);
|
|
||||||
return event;
|
|
||||||
}
|
|
||||||
|
|||||||
@ -398,6 +398,9 @@ class CodeActAgent(Agent):
|
|||||||
- Messages from the same role are combined to prevent consecutive same-role messages
|
- Messages from the same role are combined to prevent consecutive same-role messages
|
||||||
- For Anthropic models, specific messages are cached according to their documentation
|
- For Anthropic models, specific messages are cached according to their documentation
|
||||||
"""
|
"""
|
||||||
|
if not self.prompt_manager:
|
||||||
|
raise Exception('Prompt Manager not instantiated.')
|
||||||
|
|
||||||
messages: list[Message] = [
|
messages: list[Message] = [
|
||||||
Message(
|
Message(
|
||||||
role='system',
|
role='system',
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from openhands.core.exceptions import (
|
|||||||
)
|
)
|
||||||
from openhands.llm.llm import LLM
|
from openhands.llm.llm import LLM
|
||||||
from openhands.runtime.plugins import PluginRequirement
|
from openhands.runtime.plugins import PluginRequirement
|
||||||
|
from openhands.utils.prompt import PromptManager
|
||||||
|
|
||||||
|
|
||||||
class Agent(ABC):
|
class Agent(ABC):
|
||||||
@ -33,6 +34,7 @@ class Agent(ABC):
|
|||||||
self.llm = llm
|
self.llm = llm
|
||||||
self.config = config
|
self.config = config
|
||||||
self._complete = False
|
self._complete = False
|
||||||
|
self.prompt_manager: PromptManager | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def complete(self) -> bool:
|
def complete(self) -> bool:
|
||||||
|
|||||||
@ -213,6 +213,47 @@ class Runtime(FileEditRuntimeMixin):
|
|||||||
source = event.source if event.source else EventSource.AGENT
|
source = event.source if event.source else EventSource.AGENT
|
||||||
self.event_stream.add_event(observation, source) # type: ignore[arg-type]
|
self.event_stream.add_event(observation, source) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
def clone_repo(self, github_token: str | None, selected_repository: str | None):
|
||||||
|
if not github_token or not selected_repository:
|
||||||
|
return
|
||||||
|
url = f'https://{github_token}@github.com/{selected_repository}.git'
|
||||||
|
dir_name = selected_repository.split('/')[1]
|
||||||
|
action = CmdRunAction(
|
||||||
|
command=f'git clone {url} {dir_name} ; cd {dir_name} ; git checkout -b openhands-workspace'
|
||||||
|
)
|
||||||
|
self.log('info', 'Cloning repo: {selected_repository}')
|
||||||
|
self.run_action(action)
|
||||||
|
|
||||||
|
def get_custom_microagents(self, selected_repository: str | None) -> list[str]:
|
||||||
|
custom_microagents_content = []
|
||||||
|
custom_microagents_dir = Path('.openhands') / 'microagents'
|
||||||
|
|
||||||
|
dir_name = str(custom_microagents_dir)
|
||||||
|
if selected_repository:
|
||||||
|
dir_name = str(
|
||||||
|
Path(selected_repository.split('/')[1]) / custom_microagents_dir
|
||||||
|
)
|
||||||
|
oh_instructions_header = '---\nname: openhands_instructions\nagent: CodeActAgent\ntriggers:\n- ""\n---\n'
|
||||||
|
obs = self.read(FileReadAction(path='.openhands_instructions'))
|
||||||
|
if isinstance(obs, ErrorObservation):
|
||||||
|
self.log('error', 'Failed to read openhands_instructions')
|
||||||
|
else:
|
||||||
|
openhands_instructions = oh_instructions_header + obs.content
|
||||||
|
self.log('info', f'openhands_instructions: {openhands_instructions}')
|
||||||
|
custom_microagents_content.append(openhands_instructions)
|
||||||
|
|
||||||
|
files = self.list_files(dir_name)
|
||||||
|
|
||||||
|
self.log('info', f'Found {len(files)} custom microagents.')
|
||||||
|
|
||||||
|
for fname in files:
|
||||||
|
content = self.read(
|
||||||
|
FileReadAction(path=str(custom_microagents_dir / fname))
|
||||||
|
).content
|
||||||
|
custom_microagents_content.append(content)
|
||||||
|
|
||||||
|
return custom_microagents_content
|
||||||
|
|
||||||
def run_action(self, action: Action) -> Observation:
|
def run_action(self, action: Action) -> Observation:
|
||||||
"""Run an action and return the resulting observation.
|
"""Run an action and return the resulting observation.
|
||||||
If the action is not runnable in any runtime, a NullObservation is returned.
|
If the action is not runnable in any runtime, a NullObservation is returned.
|
||||||
|
|||||||
@ -32,6 +32,8 @@ async def oh_action(connection_id: str, data: dict):
|
|||||||
latest_event_id = int(data.pop('latest_event_id', -1))
|
latest_event_id = int(data.pop('latest_event_id', -1))
|
||||||
kwargs = {k.lower(): v for k, v in (data.get('args') or {}).items()}
|
kwargs = {k.lower(): v for k, v in (data.get('args') or {}).items()}
|
||||||
session_init_data = SessionInitData(**kwargs)
|
session_init_data = SessionInitData(**kwargs)
|
||||||
|
session_init_data.github_token = github_token
|
||||||
|
session_init_data.selected_repository = data.get('selected_repository', None)
|
||||||
await init_connection(
|
await init_connection(
|
||||||
connection_id, token, github_token, session_init_data, latest_event_id
|
connection_id, token, github_token, session_init_data, latest_event_id
|
||||||
)
|
)
|
||||||
|
|||||||
@ -7,7 +7,7 @@ from openhands.controller.state.state import State
|
|||||||
from openhands.core.config import AgentConfig, AppConfig, LLMConfig
|
from openhands.core.config import AgentConfig, AppConfig, LLMConfig
|
||||||
from openhands.core.logger import openhands_logger as logger
|
from openhands.core.logger import openhands_logger as logger
|
||||||
from openhands.core.schema.agent import AgentState
|
from openhands.core.schema.agent import AgentState
|
||||||
from openhands.events.action.agent import ChangeAgentStateAction
|
from openhands.events.action import ChangeAgentStateAction
|
||||||
from openhands.events.event import EventSource
|
from openhands.events.event import EventSource
|
||||||
from openhands.events.stream import EventStream
|
from openhands.events.stream import EventStream
|
||||||
from openhands.runtime import get_runtime_cls
|
from openhands.runtime import get_runtime_cls
|
||||||
@ -60,6 +60,8 @@ class AgentSession:
|
|||||||
max_budget_per_task: float | None = None,
|
max_budget_per_task: float | None = None,
|
||||||
agent_to_llm_config: dict[str, LLMConfig] | None = None,
|
agent_to_llm_config: dict[str, LLMConfig] | None = None,
|
||||||
agent_configs: dict[str, AgentConfig] | None = None,
|
agent_configs: dict[str, AgentConfig] | None = None,
|
||||||
|
github_token: str | None = None,
|
||||||
|
selected_repository: str | None = None,
|
||||||
):
|
):
|
||||||
"""Starts the Agent session
|
"""Starts the Agent session
|
||||||
Parameters:
|
Parameters:
|
||||||
@ -86,6 +88,8 @@ class AgentSession:
|
|||||||
max_budget_per_task,
|
max_budget_per_task,
|
||||||
agent_to_llm_config,
|
agent_to_llm_config,
|
||||||
agent_configs,
|
agent_configs,
|
||||||
|
github_token,
|
||||||
|
selected_repository,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _start_thread(self, *args):
|
def _start_thread(self, *args):
|
||||||
@ -104,13 +108,18 @@ class AgentSession:
|
|||||||
max_budget_per_task: float | None = None,
|
max_budget_per_task: float | None = None,
|
||||||
agent_to_llm_config: dict[str, LLMConfig] | None = None,
|
agent_to_llm_config: dict[str, LLMConfig] | None = None,
|
||||||
agent_configs: dict[str, AgentConfig] | None = None,
|
agent_configs: dict[str, AgentConfig] | None = None,
|
||||||
|
github_token: str | None = None,
|
||||||
|
selected_repository: str | None = None,
|
||||||
):
|
):
|
||||||
self._create_security_analyzer(config.security.security_analyzer)
|
self._create_security_analyzer(config.security.security_analyzer)
|
||||||
await self._create_runtime(
|
await self._create_runtime(
|
||||||
runtime_name=runtime_name,
|
runtime_name=runtime_name,
|
||||||
config=config,
|
config=config,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
github_token=github_token,
|
||||||
|
selected_repository=selected_repository,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._create_controller(
|
self._create_controller(
|
||||||
agent,
|
agent,
|
||||||
config.security.confirmation_mode,
|
config.security.confirmation_mode,
|
||||||
@ -165,6 +174,8 @@ class AgentSession:
|
|||||||
runtime_name: str,
|
runtime_name: str,
|
||||||
config: AppConfig,
|
config: AppConfig,
|
||||||
agent: Agent,
|
agent: Agent,
|
||||||
|
github_token: str | None = None,
|
||||||
|
selected_repository: str | None = None,
|
||||||
):
|
):
|
||||||
"""Creates a runtime instance
|
"""Creates a runtime instance
|
||||||
|
|
||||||
@ -199,6 +210,12 @@ class AgentSession:
|
|||||||
return
|
return
|
||||||
|
|
||||||
if self.runtime is not None:
|
if self.runtime is not None:
|
||||||
|
self.runtime.clone_repo(github_token, selected_repository)
|
||||||
|
if agent.prompt_manager:
|
||||||
|
agent.prompt_manager.load_microagent_files(
|
||||||
|
self.runtime.get_custom_microagents(selected_repository)
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f'Runtime initialized with plugins: {[plugin.name for plugin in self.runtime.plugins]}'
|
f'Runtime initialized with plugins: {[plugin.name for plugin in self.runtime.plugins]}'
|
||||||
)
|
)
|
||||||
|
|||||||
@ -72,7 +72,6 @@ class Session:
|
|||||||
self.config.security.security_analyzer = session_init_data.security_analyzer or self.config.security.security_analyzer
|
self.config.security.security_analyzer = session_init_data.security_analyzer or self.config.security.security_analyzer
|
||||||
max_iterations = session_init_data.max_iterations or self.config.max_iterations
|
max_iterations = session_init_data.max_iterations or self.config.max_iterations
|
||||||
# override default LLM config
|
# override default LLM config
|
||||||
|
|
||||||
|
|
||||||
default_llm_config = self.config.get_llm_config()
|
default_llm_config = self.config.get_llm_config()
|
||||||
default_llm_config.model = session_init_data.llm_model or default_llm_config.model
|
default_llm_config.model = session_init_data.llm_model or default_llm_config.model
|
||||||
@ -94,6 +93,8 @@ class Session:
|
|||||||
max_budget_per_task=self.config.max_budget_per_task,
|
max_budget_per_task=self.config.max_budget_per_task,
|
||||||
agent_to_llm_config=self.config.get_agent_to_llm_config_map(),
|
agent_to_llm_config=self.config.get_agent_to_llm_config_map(),
|
||||||
agent_configs=self.config.get_agent_configs(),
|
agent_configs=self.config.get_agent_configs(),
|
||||||
|
github_token=session_init_data.github_token,
|
||||||
|
selected_repository=session_init_data.selected_repository,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f'Error creating controller: {e}')
|
logger.exception(f'Error creating controller: {e}')
|
||||||
|
|||||||
@ -16,3 +16,5 @@ class SessionInitData:
|
|||||||
llm_model: str | None = None
|
llm_model: str | None = None
|
||||||
llm_api_key: str | None = None
|
llm_api_key: str | None = None
|
||||||
llm_base_url: str | None = None
|
llm_base_url: str | None = None
|
||||||
|
github_token: str | None = None
|
||||||
|
selected_repository: str | None = None
|
||||||
|
|||||||
@ -11,14 +11,20 @@ class MicroAgentMetadata(pydantic.BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class MicroAgent:
|
class MicroAgent:
|
||||||
def __init__(self, path: str):
|
def __init__(self, path: str | None = None, content: str | None = None):
|
||||||
self.path = path
|
if path and not content:
|
||||||
if not os.path.exists(path):
|
self.path = path
|
||||||
raise FileNotFoundError(f'Micro agent file {path} is not found')
|
if not os.path.exists(path):
|
||||||
with open(path, 'r') as file:
|
raise FileNotFoundError(f'Micro agent file {path} is not found')
|
||||||
self._loaded = frontmatter.load(file)
|
with open(path, 'r') as file:
|
||||||
self._content = self._loaded.content
|
self._loaded = frontmatter.load(file)
|
||||||
self._metadata = MicroAgentMetadata(**self._loaded.metadata)
|
self._content = self._loaded.content
|
||||||
|
self._metadata = MicroAgentMetadata(**self._loaded.metadata)
|
||||||
|
elif content and not path:
|
||||||
|
self._metadata, self._content = frontmatter.parse(content)
|
||||||
|
self._metadata = MicroAgentMetadata(**self._metadata)
|
||||||
|
else:
|
||||||
|
raise Exception('You must pass either path or file content, but not both.')
|
||||||
|
|
||||||
def get_trigger(self, message: str) -> str | None:
|
def get_trigger(self, message: str) -> str | None:
|
||||||
message = message.lower()
|
message = message.lower()
|
||||||
|
|||||||
@ -42,13 +42,18 @@ class PromptManager:
|
|||||||
if f.endswith('.md')
|
if f.endswith('.md')
|
||||||
]
|
]
|
||||||
for microagent_file in microagent_files:
|
for microagent_file in microagent_files:
|
||||||
microagent = MicroAgent(microagent_file)
|
microagent = MicroAgent(path=microagent_file)
|
||||||
if (
|
if (
|
||||||
disabled_microagents is None
|
disabled_microagents is None
|
||||||
or microagent.name not in disabled_microagents
|
or microagent.name not in disabled_microagents
|
||||||
):
|
):
|
||||||
self.microagents[microagent.name] = microagent
|
self.microagents[microagent.name] = microagent
|
||||||
|
|
||||||
|
def load_microagent_files(self, microagent_files: list[str]):
|
||||||
|
for microagent_file in microagent_files:
|
||||||
|
microagent = MicroAgent(content=microagent_file)
|
||||||
|
self.microagents[microagent.name] = microagent
|
||||||
|
|
||||||
def _load_template(self, template_name: str) -> Template:
|
def _load_template(self, template_name: str) -> Template:
|
||||||
if self.prompt_dir is None:
|
if self.prompt_dir is None:
|
||||||
raise ValueError('Prompt directory is not set')
|
raise ValueError('Prompt directory is not set')
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user