[Feat] Custom MicroAgents. (#4983)

Co-authored-by: diwu-sf <di.wu@shadowfaxdata.com>
This commit is contained in:
Raj Maheshwari 2024-12-07 03:41:06 +05:30 committed by GitHub
parent cf157c86b3
commit 2b06e4e5d0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 100 additions and 28 deletions

View File

@ -38,6 +38,7 @@ interface WsClientProviderProps {
enabled: boolean; enabled: boolean;
token: string | null; token: string | null;
ghToken: string | null; ghToken: string | null;
selectedRepository: string | null;
settings: Settings | null; settings: Settings | null;
} }
@ -45,12 +46,14 @@ export function WsClientProvider({
enabled, enabled,
token, token,
ghToken, ghToken,
selectedRepository,
settings, settings,
children, children,
}: React.PropsWithChildren<WsClientProviderProps>) { }: React.PropsWithChildren<WsClientProviderProps>) {
const sioRef = React.useRef<Socket | null>(null); const sioRef = React.useRef<Socket | null>(null);
const tokenRef = React.useRef<string | null>(token); const tokenRef = React.useRef<string | null>(token);
const ghTokenRef = React.useRef<string | null>(ghToken); const ghTokenRef = React.useRef<string | null>(ghToken);
const selectedRepositoryRef = React.useRef<string | null>(selectedRepository);
const disconnectRef = React.useRef<ReturnType<typeof setTimeout> | null>( const disconnectRef = React.useRef<ReturnType<typeof setTimeout> | null>(
null, null,
); );
@ -81,6 +84,9 @@ export function WsClientProvider({
if (ghToken) { if (ghToken) {
initEvent.github_token = ghToken; initEvent.github_token = ghToken;
} }
if (selectedRepository) {
initEvent.selected_repository = selectedRepository;
}
const lastEvent = lastEventRef.current; const lastEvent = lastEventRef.current;
if (lastEvent) { if (lastEvent) {
initEvent.latest_event_id = lastEvent.id; initEvent.latest_event_id = lastEvent.id;
@ -158,6 +164,7 @@ export function WsClientProvider({
sioRef.current = sio; sioRef.current = sio;
tokenRef.current = token; tokenRef.current = token;
ghTokenRef.current = ghToken; ghTokenRef.current = ghToken;
selectedRepositoryRef.current = selectedRepository;
return () => { return () => {
sio.off("connect", handleConnect); sio.off("connect", handleConnect);
@ -166,7 +173,7 @@ export function WsClientProvider({
sio.off("connect_failed", handleError); sio.off("connect_failed", handleError);
sio.off("disconnect", handleDisconnect); sio.off("disconnect", handleDisconnect);
}; };
}, [enabled, token, ghToken]); }, [enabled, token, ghToken, selectedRepository]);
// Strict mode mounts and unmounts each component twice, so we have to wait in the destructor // Strict mode mounts and unmounts each component twice, so we have to wait in the destructor
// before actually disconnecting the socket and cancel the operation if the component gets remounted. // before actually disconnecting the socket and cancel the operation if the component gets remounted.

View File

@ -6,7 +6,6 @@ import {
WsClientProviderStatus, WsClientProviderStatus,
} from "#/context/ws-client-provider"; } from "#/context/ws-client-provider";
import { createChatMessage } from "#/services/chat-service"; import { createChatMessage } from "#/services/chat-service";
import { getCloneRepoCommand } from "#/services/terminal-service";
import { setCurrentAgentState } from "#/state/agent-slice"; import { setCurrentAgentState } from "#/state/agent-slice";
import { addUserMessage } from "#/state/chat-slice"; import { addUserMessage } from "#/state/chat-slice";
import { import {
@ -37,11 +36,6 @@ export const useWSStatusChange = () => {
send(createChatMessage(query, base64Files, timestamp)); send(createChatMessage(query, base64Files, timestamp));
}; };
const dispatchCloneRepoCommand = (ghToken: string, repository: string) => {
send(getCloneRepoCommand(ghToken, repository));
dispatch(clearSelectedRepository());
};
const dispatchInitialQuery = (query: string, additionalInfo: string) => { const dispatchInitialQuery = (query: string, additionalInfo: string) => {
if (additionalInfo) { if (additionalInfo) {
sendInitialQuery(`${query}\n\n[${additionalInfo}]`, files); sendInitialQuery(`${query}\n\n[${additionalInfo}]`, files);
@ -57,8 +51,7 @@ export const useWSStatusChange = () => {
let additionalInfo = ""; let additionalInfo = "";
if (gitHubToken && selectedRepository) { if (gitHubToken && selectedRepository) {
dispatchCloneRepoCommand(gitHubToken, selectedRepository); dispatch(clearSelectedRepository());
additionalInfo = `Repository ${selectedRepository} has been cloned to /workspace. Please check the /workspace for files.`;
} else if (importedProjectZip) { } else if (importedProjectZip) {
// if there's an uploaded project zip, add it to the chat // if there's an uploaded project zip, add it to the chat
additionalInfo = additionalInfo =

View File

@ -64,6 +64,7 @@ function App() {
enabled enabled
token={token} token={token}
ghToken={gitHubToken} ghToken={gitHubToken}
selectedRepository={selectedRepository}
settings={settings} settings={settings}
> >
<EventHandler> <EventHandler>

View File

@ -10,11 +10,3 @@ export function getGitHubTokenCommand(gitHubToken: string) {
const event = getTerminalCommand(command, true); const event = getTerminalCommand(command, true);
return event; return event;
} }
export function getCloneRepoCommand(gitHubToken: string, repository: string) {
const url = `https://${gitHubToken}@github.com/${repository}.git`;
const dirName = repository.split("/")[1];
const command = `git clone ${url} ${dirName} ; cd ${dirName} ; git checkout -b openhands-workspace`;
const event = getTerminalCommand(command, true);
return event;
}

View File

@ -398,6 +398,9 @@ class CodeActAgent(Agent):
- Messages from the same role are combined to prevent consecutive same-role messages - Messages from the same role are combined to prevent consecutive same-role messages
- For Anthropic models, specific messages are cached according to their documentation - For Anthropic models, specific messages are cached according to their documentation
""" """
if not self.prompt_manager:
raise Exception('Prompt Manager not instantiated.')
messages: list[Message] = [ messages: list[Message] = [
Message( Message(
role='system', role='system',

View File

@ -11,6 +11,7 @@ from openhands.core.exceptions import (
) )
from openhands.llm.llm import LLM from openhands.llm.llm import LLM
from openhands.runtime.plugins import PluginRequirement from openhands.runtime.plugins import PluginRequirement
from openhands.utils.prompt import PromptManager
class Agent(ABC): class Agent(ABC):
@ -33,6 +34,7 @@ class Agent(ABC):
self.llm = llm self.llm = llm
self.config = config self.config = config
self._complete = False self._complete = False
self.prompt_manager: PromptManager | None = None
@property @property
def complete(self) -> bool: def complete(self) -> bool:

View File

@ -213,6 +213,47 @@ class Runtime(FileEditRuntimeMixin):
source = event.source if event.source else EventSource.AGENT source = event.source if event.source else EventSource.AGENT
self.event_stream.add_event(observation, source) # type: ignore[arg-type] self.event_stream.add_event(observation, source) # type: ignore[arg-type]
def clone_repo(self, github_token: str | None, selected_repository: str | None):
if not github_token or not selected_repository:
return
url = f'https://{github_token}@github.com/{selected_repository}.git'
dir_name = selected_repository.split('/')[1]
action = CmdRunAction(
command=f'git clone {url} {dir_name} ; cd {dir_name} ; git checkout -b openhands-workspace'
)
self.log('info', 'Cloning repo: {selected_repository}')
self.run_action(action)
def get_custom_microagents(self, selected_repository: str | None) -> list[str]:
custom_microagents_content = []
custom_microagents_dir = Path('.openhands') / 'microagents'
dir_name = str(custom_microagents_dir)
if selected_repository:
dir_name = str(
Path(selected_repository.split('/')[1]) / custom_microagents_dir
)
oh_instructions_header = '---\nname: openhands_instructions\nagent: CodeActAgent\ntriggers:\n- ""\n---\n'
obs = self.read(FileReadAction(path='.openhands_instructions'))
if isinstance(obs, ErrorObservation):
self.log('error', 'Failed to read openhands_instructions')
else:
openhands_instructions = oh_instructions_header + obs.content
self.log('info', f'openhands_instructions: {openhands_instructions}')
custom_microagents_content.append(openhands_instructions)
files = self.list_files(dir_name)
self.log('info', f'Found {len(files)} custom microagents.')
for fname in files:
content = self.read(
FileReadAction(path=str(custom_microagents_dir / fname))
).content
custom_microagents_content.append(content)
return custom_microagents_content
def run_action(self, action: Action) -> Observation: def run_action(self, action: Action) -> Observation:
"""Run an action and return the resulting observation. """Run an action and return the resulting observation.
If the action is not runnable in any runtime, a NullObservation is returned. If the action is not runnable in any runtime, a NullObservation is returned.

View File

@ -32,6 +32,8 @@ async def oh_action(connection_id: str, data: dict):
latest_event_id = int(data.pop('latest_event_id', -1)) latest_event_id = int(data.pop('latest_event_id', -1))
kwargs = {k.lower(): v for k, v in (data.get('args') or {}).items()} kwargs = {k.lower(): v for k, v in (data.get('args') or {}).items()}
session_init_data = SessionInitData(**kwargs) session_init_data = SessionInitData(**kwargs)
session_init_data.github_token = github_token
session_init_data.selected_repository = data.get('selected_repository', None)
await init_connection( await init_connection(
connection_id, token, github_token, session_init_data, latest_event_id connection_id, token, github_token, session_init_data, latest_event_id
) )

View File

@ -7,7 +7,7 @@ from openhands.controller.state.state import State
from openhands.core.config import AgentConfig, AppConfig, LLMConfig from openhands.core.config import AgentConfig, AppConfig, LLMConfig
from openhands.core.logger import openhands_logger as logger from openhands.core.logger import openhands_logger as logger
from openhands.core.schema.agent import AgentState from openhands.core.schema.agent import AgentState
from openhands.events.action.agent import ChangeAgentStateAction from openhands.events.action import ChangeAgentStateAction
from openhands.events.event import EventSource from openhands.events.event import EventSource
from openhands.events.stream import EventStream from openhands.events.stream import EventStream
from openhands.runtime import get_runtime_cls from openhands.runtime import get_runtime_cls
@ -60,6 +60,8 @@ class AgentSession:
max_budget_per_task: float | None = None, max_budget_per_task: float | None = None,
agent_to_llm_config: dict[str, LLMConfig] | None = None, agent_to_llm_config: dict[str, LLMConfig] | None = None,
agent_configs: dict[str, AgentConfig] | None = None, agent_configs: dict[str, AgentConfig] | None = None,
github_token: str | None = None,
selected_repository: str | None = None,
): ):
"""Starts the Agent session """Starts the Agent session
Parameters: Parameters:
@ -86,6 +88,8 @@ class AgentSession:
max_budget_per_task, max_budget_per_task,
agent_to_llm_config, agent_to_llm_config,
agent_configs, agent_configs,
github_token,
selected_repository,
) )
def _start_thread(self, *args): def _start_thread(self, *args):
@ -104,13 +108,18 @@ class AgentSession:
max_budget_per_task: float | None = None, max_budget_per_task: float | None = None,
agent_to_llm_config: dict[str, LLMConfig] | None = None, agent_to_llm_config: dict[str, LLMConfig] | None = None,
agent_configs: dict[str, AgentConfig] | None = None, agent_configs: dict[str, AgentConfig] | None = None,
github_token: str | None = None,
selected_repository: str | None = None,
): ):
self._create_security_analyzer(config.security.security_analyzer) self._create_security_analyzer(config.security.security_analyzer)
await self._create_runtime( await self._create_runtime(
runtime_name=runtime_name, runtime_name=runtime_name,
config=config, config=config,
agent=agent, agent=agent,
github_token=github_token,
selected_repository=selected_repository,
) )
self._create_controller( self._create_controller(
agent, agent,
config.security.confirmation_mode, config.security.confirmation_mode,
@ -165,6 +174,8 @@ class AgentSession:
runtime_name: str, runtime_name: str,
config: AppConfig, config: AppConfig,
agent: Agent, agent: Agent,
github_token: str | None = None,
selected_repository: str | None = None,
): ):
"""Creates a runtime instance """Creates a runtime instance
@ -199,6 +210,12 @@ class AgentSession:
return return
if self.runtime is not None: if self.runtime is not None:
self.runtime.clone_repo(github_token, selected_repository)
if agent.prompt_manager:
agent.prompt_manager.load_microagent_files(
self.runtime.get_custom_microagents(selected_repository)
)
logger.debug( logger.debug(
f'Runtime initialized with plugins: {[plugin.name for plugin in self.runtime.plugins]}' f'Runtime initialized with plugins: {[plugin.name for plugin in self.runtime.plugins]}'
) )

View File

@ -72,7 +72,6 @@ class Session:
self.config.security.security_analyzer = session_init_data.security_analyzer or self.config.security.security_analyzer self.config.security.security_analyzer = session_init_data.security_analyzer or self.config.security.security_analyzer
max_iterations = session_init_data.max_iterations or self.config.max_iterations max_iterations = session_init_data.max_iterations or self.config.max_iterations
# override default LLM config # override default LLM config
default_llm_config = self.config.get_llm_config() default_llm_config = self.config.get_llm_config()
default_llm_config.model = session_init_data.llm_model or default_llm_config.model default_llm_config.model = session_init_data.llm_model or default_llm_config.model
@ -94,6 +93,8 @@ class Session:
max_budget_per_task=self.config.max_budget_per_task, max_budget_per_task=self.config.max_budget_per_task,
agent_to_llm_config=self.config.get_agent_to_llm_config_map(), agent_to_llm_config=self.config.get_agent_to_llm_config_map(),
agent_configs=self.config.get_agent_configs(), agent_configs=self.config.get_agent_configs(),
github_token=session_init_data.github_token,
selected_repository=session_init_data.selected_repository,
) )
except Exception as e: except Exception as e:
logger.exception(f'Error creating controller: {e}') logger.exception(f'Error creating controller: {e}')

View File

@ -16,3 +16,5 @@ class SessionInitData:
llm_model: str | None = None llm_model: str | None = None
llm_api_key: str | None = None llm_api_key: str | None = None
llm_base_url: str | None = None llm_base_url: str | None = None
github_token: str | None = None
selected_repository: str | None = None

View File

@ -11,14 +11,20 @@ class MicroAgentMetadata(pydantic.BaseModel):
class MicroAgent: class MicroAgent:
def __init__(self, path: str): def __init__(self, path: str | None = None, content: str | None = None):
self.path = path if path and not content:
if not os.path.exists(path): self.path = path
raise FileNotFoundError(f'Micro agent file {path} is not found') if not os.path.exists(path):
with open(path, 'r') as file: raise FileNotFoundError(f'Micro agent file {path} is not found')
self._loaded = frontmatter.load(file) with open(path, 'r') as file:
self._content = self._loaded.content self._loaded = frontmatter.load(file)
self._metadata = MicroAgentMetadata(**self._loaded.metadata) self._content = self._loaded.content
self._metadata = MicroAgentMetadata(**self._loaded.metadata)
elif content and not path:
self._metadata, self._content = frontmatter.parse(content)
self._metadata = MicroAgentMetadata(**self._metadata)
else:
raise Exception('You must pass either path or file content, but not both.')
def get_trigger(self, message: str) -> str | None: def get_trigger(self, message: str) -> str | None:
message = message.lower() message = message.lower()

View File

@ -42,13 +42,18 @@ class PromptManager:
if f.endswith('.md') if f.endswith('.md')
] ]
for microagent_file in microagent_files: for microagent_file in microagent_files:
microagent = MicroAgent(microagent_file) microagent = MicroAgent(path=microagent_file)
if ( if (
disabled_microagents is None disabled_microagents is None
or microagent.name not in disabled_microagents or microagent.name not in disabled_microagents
): ):
self.microagents[microagent.name] = microagent self.microagents[microagent.name] = microagent
def load_microagent_files(self, microagent_files: list[str]):
for microagent_file in microagent_files:
microagent = MicroAgent(content=microagent_file)
self.microagents[microagent.name] = microagent
def _load_template(self, template_name: str) -> Template: def _load_template(self, template_name: str) -> Template:
if self.prompt_dir is None: if self.prompt_dir is None:
raise ValueError('Prompt directory is not set') raise ValueError('Prompt directory is not set')