mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
[Feat] Custom MicroAgents. (#4983)
Co-authored-by: diwu-sf <di.wu@shadowfaxdata.com>
This commit is contained in:
parent
cf157c86b3
commit
2b06e4e5d0
@ -38,6 +38,7 @@ interface WsClientProviderProps {
|
||||
enabled: boolean;
|
||||
token: string | null;
|
||||
ghToken: string | null;
|
||||
selectedRepository: string | null;
|
||||
settings: Settings | null;
|
||||
}
|
||||
|
||||
@ -45,12 +46,14 @@ export function WsClientProvider({
|
||||
enabled,
|
||||
token,
|
||||
ghToken,
|
||||
selectedRepository,
|
||||
settings,
|
||||
children,
|
||||
}: React.PropsWithChildren<WsClientProviderProps>) {
|
||||
const sioRef = React.useRef<Socket | null>(null);
|
||||
const tokenRef = React.useRef<string | null>(token);
|
||||
const ghTokenRef = React.useRef<string | null>(ghToken);
|
||||
const selectedRepositoryRef = React.useRef<string | null>(selectedRepository);
|
||||
const disconnectRef = React.useRef<ReturnType<typeof setTimeout> | null>(
|
||||
null,
|
||||
);
|
||||
@ -81,6 +84,9 @@ export function WsClientProvider({
|
||||
if (ghToken) {
|
||||
initEvent.github_token = ghToken;
|
||||
}
|
||||
if (selectedRepository) {
|
||||
initEvent.selected_repository = selectedRepository;
|
||||
}
|
||||
const lastEvent = lastEventRef.current;
|
||||
if (lastEvent) {
|
||||
initEvent.latest_event_id = lastEvent.id;
|
||||
@ -158,6 +164,7 @@ export function WsClientProvider({
|
||||
sioRef.current = sio;
|
||||
tokenRef.current = token;
|
||||
ghTokenRef.current = ghToken;
|
||||
selectedRepositoryRef.current = selectedRepository;
|
||||
|
||||
return () => {
|
||||
sio.off("connect", handleConnect);
|
||||
@ -166,7 +173,7 @@ export function WsClientProvider({
|
||||
sio.off("connect_failed", handleError);
|
||||
sio.off("disconnect", handleDisconnect);
|
||||
};
|
||||
}, [enabled, token, ghToken]);
|
||||
}, [enabled, token, ghToken, selectedRepository]);
|
||||
|
||||
// Strict mode mounts and unmounts each component twice, so we have to wait in the destructor
|
||||
// before actually disconnecting the socket and cancel the operation if the component gets remounted.
|
||||
|
||||
@ -6,7 +6,6 @@ import {
|
||||
WsClientProviderStatus,
|
||||
} from "#/context/ws-client-provider";
|
||||
import { createChatMessage } from "#/services/chat-service";
|
||||
import { getCloneRepoCommand } from "#/services/terminal-service";
|
||||
import { setCurrentAgentState } from "#/state/agent-slice";
|
||||
import { addUserMessage } from "#/state/chat-slice";
|
||||
import {
|
||||
@ -37,11 +36,6 @@ export const useWSStatusChange = () => {
|
||||
send(createChatMessage(query, base64Files, timestamp));
|
||||
};
|
||||
|
||||
const dispatchCloneRepoCommand = (ghToken: string, repository: string) => {
|
||||
send(getCloneRepoCommand(ghToken, repository));
|
||||
dispatch(clearSelectedRepository());
|
||||
};
|
||||
|
||||
const dispatchInitialQuery = (query: string, additionalInfo: string) => {
|
||||
if (additionalInfo) {
|
||||
sendInitialQuery(`${query}\n\n[${additionalInfo}]`, files);
|
||||
@ -57,8 +51,7 @@ export const useWSStatusChange = () => {
|
||||
let additionalInfo = "";
|
||||
|
||||
if (gitHubToken && selectedRepository) {
|
||||
dispatchCloneRepoCommand(gitHubToken, selectedRepository);
|
||||
additionalInfo = `Repository ${selectedRepository} has been cloned to /workspace. Please check the /workspace for files.`;
|
||||
dispatch(clearSelectedRepository());
|
||||
} else if (importedProjectZip) {
|
||||
// if there's an uploaded project zip, add it to the chat
|
||||
additionalInfo =
|
||||
|
||||
@ -64,6 +64,7 @@ function App() {
|
||||
enabled
|
||||
token={token}
|
||||
ghToken={gitHubToken}
|
||||
selectedRepository={selectedRepository}
|
||||
settings={settings}
|
||||
>
|
||||
<EventHandler>
|
||||
|
||||
@ -10,11 +10,3 @@ export function getGitHubTokenCommand(gitHubToken: string) {
|
||||
const event = getTerminalCommand(command, true);
|
||||
return event;
|
||||
}
|
||||
|
||||
export function getCloneRepoCommand(gitHubToken: string, repository: string) {
|
||||
const url = `https://${gitHubToken}@github.com/${repository}.git`;
|
||||
const dirName = repository.split("/")[1];
|
||||
const command = `git clone ${url} ${dirName} ; cd ${dirName} ; git checkout -b openhands-workspace`;
|
||||
const event = getTerminalCommand(command, true);
|
||||
return event;
|
||||
}
|
||||
|
||||
@ -398,6 +398,9 @@ class CodeActAgent(Agent):
|
||||
- Messages from the same role are combined to prevent consecutive same-role messages
|
||||
- For Anthropic models, specific messages are cached according to their documentation
|
||||
"""
|
||||
if not self.prompt_manager:
|
||||
raise Exception('Prompt Manager not instantiated.')
|
||||
|
||||
messages: list[Message] = [
|
||||
Message(
|
||||
role='system',
|
||||
|
||||
@ -11,6 +11,7 @@ from openhands.core.exceptions import (
|
||||
)
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.runtime.plugins import PluginRequirement
|
||||
from openhands.utils.prompt import PromptManager
|
||||
|
||||
|
||||
class Agent(ABC):
|
||||
@ -33,6 +34,7 @@ class Agent(ABC):
|
||||
self.llm = llm
|
||||
self.config = config
|
||||
self._complete = False
|
||||
self.prompt_manager: PromptManager | None = None
|
||||
|
||||
@property
|
||||
def complete(self) -> bool:
|
||||
|
||||
@ -213,6 +213,47 @@ class Runtime(FileEditRuntimeMixin):
|
||||
source = event.source if event.source else EventSource.AGENT
|
||||
self.event_stream.add_event(observation, source) # type: ignore[arg-type]
|
||||
|
||||
def clone_repo(self, github_token: str | None, selected_repository: str | None):
|
||||
if not github_token or not selected_repository:
|
||||
return
|
||||
url = f'https://{github_token}@github.com/{selected_repository}.git'
|
||||
dir_name = selected_repository.split('/')[1]
|
||||
action = CmdRunAction(
|
||||
command=f'git clone {url} {dir_name} ; cd {dir_name} ; git checkout -b openhands-workspace'
|
||||
)
|
||||
self.log('info', 'Cloning repo: {selected_repository}')
|
||||
self.run_action(action)
|
||||
|
||||
def get_custom_microagents(self, selected_repository: str | None) -> list[str]:
|
||||
custom_microagents_content = []
|
||||
custom_microagents_dir = Path('.openhands') / 'microagents'
|
||||
|
||||
dir_name = str(custom_microagents_dir)
|
||||
if selected_repository:
|
||||
dir_name = str(
|
||||
Path(selected_repository.split('/')[1]) / custom_microagents_dir
|
||||
)
|
||||
oh_instructions_header = '---\nname: openhands_instructions\nagent: CodeActAgent\ntriggers:\n- ""\n---\n'
|
||||
obs = self.read(FileReadAction(path='.openhands_instructions'))
|
||||
if isinstance(obs, ErrorObservation):
|
||||
self.log('error', 'Failed to read openhands_instructions')
|
||||
else:
|
||||
openhands_instructions = oh_instructions_header + obs.content
|
||||
self.log('info', f'openhands_instructions: {openhands_instructions}')
|
||||
custom_microagents_content.append(openhands_instructions)
|
||||
|
||||
files = self.list_files(dir_name)
|
||||
|
||||
self.log('info', f'Found {len(files)} custom microagents.')
|
||||
|
||||
for fname in files:
|
||||
content = self.read(
|
||||
FileReadAction(path=str(custom_microagents_dir / fname))
|
||||
).content
|
||||
custom_microagents_content.append(content)
|
||||
|
||||
return custom_microagents_content
|
||||
|
||||
def run_action(self, action: Action) -> Observation:
|
||||
"""Run an action and return the resulting observation.
|
||||
If the action is not runnable in any runtime, a NullObservation is returned.
|
||||
|
||||
@ -32,6 +32,8 @@ async def oh_action(connection_id: str, data: dict):
|
||||
latest_event_id = int(data.pop('latest_event_id', -1))
|
||||
kwargs = {k.lower(): v for k, v in (data.get('args') or {}).items()}
|
||||
session_init_data = SessionInitData(**kwargs)
|
||||
session_init_data.github_token = github_token
|
||||
session_init_data.selected_repository = data.get('selected_repository', None)
|
||||
await init_connection(
|
||||
connection_id, token, github_token, session_init_data, latest_event_id
|
||||
)
|
||||
|
||||
@ -7,7 +7,7 @@ from openhands.controller.state.state import State
|
||||
from openhands.core.config import AgentConfig, AppConfig, LLMConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema.agent import AgentState
|
||||
from openhands.events.action.agent import ChangeAgentStateAction
|
||||
from openhands.events.action import ChangeAgentStateAction
|
||||
from openhands.events.event import EventSource
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.runtime import get_runtime_cls
|
||||
@ -60,6 +60,8 @@ class AgentSession:
|
||||
max_budget_per_task: float | None = None,
|
||||
agent_to_llm_config: dict[str, LLMConfig] | None = None,
|
||||
agent_configs: dict[str, AgentConfig] | None = None,
|
||||
github_token: str | None = None,
|
||||
selected_repository: str | None = None,
|
||||
):
|
||||
"""Starts the Agent session
|
||||
Parameters:
|
||||
@ -86,6 +88,8 @@ class AgentSession:
|
||||
max_budget_per_task,
|
||||
agent_to_llm_config,
|
||||
agent_configs,
|
||||
github_token,
|
||||
selected_repository,
|
||||
)
|
||||
|
||||
def _start_thread(self, *args):
|
||||
@ -104,13 +108,18 @@ class AgentSession:
|
||||
max_budget_per_task: float | None = None,
|
||||
agent_to_llm_config: dict[str, LLMConfig] | None = None,
|
||||
agent_configs: dict[str, AgentConfig] | None = None,
|
||||
github_token: str | None = None,
|
||||
selected_repository: str | None = None,
|
||||
):
|
||||
self._create_security_analyzer(config.security.security_analyzer)
|
||||
await self._create_runtime(
|
||||
runtime_name=runtime_name,
|
||||
config=config,
|
||||
agent=agent,
|
||||
github_token=github_token,
|
||||
selected_repository=selected_repository,
|
||||
)
|
||||
|
||||
self._create_controller(
|
||||
agent,
|
||||
config.security.confirmation_mode,
|
||||
@ -165,6 +174,8 @@ class AgentSession:
|
||||
runtime_name: str,
|
||||
config: AppConfig,
|
||||
agent: Agent,
|
||||
github_token: str | None = None,
|
||||
selected_repository: str | None = None,
|
||||
):
|
||||
"""Creates a runtime instance
|
||||
|
||||
@ -199,6 +210,12 @@ class AgentSession:
|
||||
return
|
||||
|
||||
if self.runtime is not None:
|
||||
self.runtime.clone_repo(github_token, selected_repository)
|
||||
if agent.prompt_manager:
|
||||
agent.prompt_manager.load_microagent_files(
|
||||
self.runtime.get_custom_microagents(selected_repository)
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f'Runtime initialized with plugins: {[plugin.name for plugin in self.runtime.plugins]}'
|
||||
)
|
||||
|
||||
@ -72,7 +72,6 @@ class Session:
|
||||
self.config.security.security_analyzer = session_init_data.security_analyzer or self.config.security.security_analyzer
|
||||
max_iterations = session_init_data.max_iterations or self.config.max_iterations
|
||||
# override default LLM config
|
||||
|
||||
|
||||
default_llm_config = self.config.get_llm_config()
|
||||
default_llm_config.model = session_init_data.llm_model or default_llm_config.model
|
||||
@ -94,6 +93,8 @@ class Session:
|
||||
max_budget_per_task=self.config.max_budget_per_task,
|
||||
agent_to_llm_config=self.config.get_agent_to_llm_config_map(),
|
||||
agent_configs=self.config.get_agent_configs(),
|
||||
github_token=session_init_data.github_token,
|
||||
selected_repository=session_init_data.selected_repository,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f'Error creating controller: {e}')
|
||||
|
||||
@ -16,3 +16,5 @@ class SessionInitData:
|
||||
llm_model: str | None = None
|
||||
llm_api_key: str | None = None
|
||||
llm_base_url: str | None = None
|
||||
github_token: str | None = None
|
||||
selected_repository: str | None = None
|
||||
|
||||
@ -11,14 +11,20 @@ class MicroAgentMetadata(pydantic.BaseModel):
|
||||
|
||||
|
||||
class MicroAgent:
|
||||
def __init__(self, path: str):
|
||||
self.path = path
|
||||
if not os.path.exists(path):
|
||||
raise FileNotFoundError(f'Micro agent file {path} is not found')
|
||||
with open(path, 'r') as file:
|
||||
self._loaded = frontmatter.load(file)
|
||||
self._content = self._loaded.content
|
||||
self._metadata = MicroAgentMetadata(**self._loaded.metadata)
|
||||
def __init__(self, path: str | None = None, content: str | None = None):
|
||||
if path and not content:
|
||||
self.path = path
|
||||
if not os.path.exists(path):
|
||||
raise FileNotFoundError(f'Micro agent file {path} is not found')
|
||||
with open(path, 'r') as file:
|
||||
self._loaded = frontmatter.load(file)
|
||||
self._content = self._loaded.content
|
||||
self._metadata = MicroAgentMetadata(**self._loaded.metadata)
|
||||
elif content and not path:
|
||||
self._metadata, self._content = frontmatter.parse(content)
|
||||
self._metadata = MicroAgentMetadata(**self._metadata)
|
||||
else:
|
||||
raise Exception('You must pass either path or file content, but not both.')
|
||||
|
||||
def get_trigger(self, message: str) -> str | None:
|
||||
message = message.lower()
|
||||
|
||||
@ -42,13 +42,18 @@ class PromptManager:
|
||||
if f.endswith('.md')
|
||||
]
|
||||
for microagent_file in microagent_files:
|
||||
microagent = MicroAgent(microagent_file)
|
||||
microagent = MicroAgent(path=microagent_file)
|
||||
if (
|
||||
disabled_microagents is None
|
||||
or microagent.name not in disabled_microagents
|
||||
):
|
||||
self.microagents[microagent.name] = microagent
|
||||
|
||||
def load_microagent_files(self, microagent_files: list[str]):
|
||||
for microagent_file in microagent_files:
|
||||
microagent = MicroAgent(content=microagent_file)
|
||||
self.microagents[microagent.name] = microagent
|
||||
|
||||
def _load_template(self, template_name: str) -> Template:
|
||||
if self.prompt_dir is None:
|
||||
raise ValueError('Prompt directory is not set')
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user