mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 13:47:19 +08:00
pass llm to delegation action so that sub-agent shares the same llm for cost accum purpose
This commit is contained in:
@@ -206,7 +206,12 @@ class CodeActAgent(Agent):
|
||||
],
|
||||
temperature=0.0,
|
||||
)
|
||||
return self.action_parser.parse(response)
|
||||
|
||||
action = self.action_parser.parse(response)
|
||||
# post-processing for agent delegation to share the same llm instance
|
||||
if isinstance(action, AgentDelegateAction):
|
||||
action.llm = self.llm
|
||||
return action
|
||||
|
||||
def _get_messages(self, state: State) -> list[dict[str, str]]:
|
||||
messages = [
|
||||
|
||||
@@ -38,7 +38,9 @@ class DelegatorAgent(Agent):
|
||||
self.current_delegate = 'study'
|
||||
task = state.get_current_user_intent()
|
||||
return AgentDelegateAction(
|
||||
agent='StudyRepoForTaskAgent', inputs={'task': task}
|
||||
agent='StudyRepoForTaskAgent',
|
||||
inputs={'task': task},
|
||||
llm=self.llm,
|
||||
)
|
||||
|
||||
# last observation in history should be from the delegate
|
||||
@@ -56,6 +58,7 @@ class DelegatorAgent(Agent):
|
||||
'task': goal,
|
||||
'summary': last_observation.outputs['summary'],
|
||||
},
|
||||
llm=self.llm,
|
||||
)
|
||||
elif self.current_delegate == 'coder':
|
||||
self.current_delegate = 'verifier'
|
||||
@@ -64,6 +67,7 @@ class DelegatorAgent(Agent):
|
||||
inputs={
|
||||
'task': goal,
|
||||
},
|
||||
llm=self.llm,
|
||||
)
|
||||
elif self.current_delegate == 'verifier':
|
||||
if (
|
||||
@@ -79,6 +83,7 @@ class DelegatorAgent(Agent):
|
||||
'task': goal,
|
||||
'summary': last_observation.outputs['summary'],
|
||||
},
|
||||
llm=self.llm,
|
||||
)
|
||||
else:
|
||||
raise Exception('Invalid delegate state')
|
||||
|
||||
@@ -248,8 +248,16 @@ class AgentController:
|
||||
|
||||
async def start_delegate(self, action: AgentDelegateAction):
|
||||
agent_cls: Type[Agent] = Agent.get_cls(action.agent)
|
||||
llm_config = config.get_llm_config_from_agent(action.agent)
|
||||
llm = LLM(llm_config=llm_config)
|
||||
if action.llm is not None:
|
||||
# share the same llm instance as parent
|
||||
# so the cost is accumulated together
|
||||
llm = action.llm
|
||||
else:
|
||||
llm_config = config.get_llm_config_from_agent(action.agent)
|
||||
logger.warn(
|
||||
f'Using default LLM for agent {action.agent}. You should specify the LLM in the delegate action. Current config: {llm_config}'
|
||||
)
|
||||
llm = LLM(llm_config=llm_config)
|
||||
delegate_agent = agent_cls(llm=llm)
|
||||
state = State(
|
||||
inputs=action.inputs or {},
|
||||
@@ -285,6 +293,42 @@ class AgentController:
|
||||
await asyncio.sleep(1)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f'{self.agent.name} LEVEL {self.state.delegate_level} STEP {self.state.iteration}',
|
||||
extra={'msg_type': 'STEP'},
|
||||
)
|
||||
|
||||
if self.state.iteration >= self.state.max_iterations:
|
||||
if self.state.traffic_control_state == TrafficControlState.PAUSED:
|
||||
logger.info(
|
||||
'Hitting traffic control, temporarily resume upon user request'
|
||||
)
|
||||
self.state.traffic_control_state = TrafficControlState.NORMAL
|
||||
else:
|
||||
self.state.traffic_control_state = TrafficControlState.THROTTLING
|
||||
await self.report_error(
|
||||
f'Agent reached maximum number of iterations, task paused. {TRAFFIC_CONTROL_REMINDER}'
|
||||
)
|
||||
await self.set_agent_state_to(AgentState.PAUSED)
|
||||
return
|
||||
elif self.max_budget_per_task is not None:
|
||||
current_cost = self.state.metrics.accumulated_cost
|
||||
if current_cost > self.max_budget_per_task:
|
||||
if self.state.traffic_control_state == TrafficControlState.PAUSED:
|
||||
logger.info(
|
||||
'Hitting traffic control, temporarily resume upon user request'
|
||||
)
|
||||
self.state.traffic_control_state = TrafficControlState.NORMAL
|
||||
else:
|
||||
self.state.traffic_control_state = TrafficControlState.THROTTLING
|
||||
await self.report_error(
|
||||
f'Task budget exceeded. Current cost: {current_cost:.2f}, Max budget: {self.max_budget_per_task:.2f}, task paused. {TRAFFIC_CONTROL_REMINDER}'
|
||||
)
|
||||
await self.set_agent_state_to(AgentState.PAUSED)
|
||||
return
|
||||
|
||||
self.update_state_before_step()
|
||||
|
||||
if self.delegate is not None:
|
||||
logger.debug(f'[Agent Controller {self.id}] Delegate not none, awaiting...')
|
||||
assert self.delegate != self
|
||||
@@ -328,41 +372,6 @@ class AgentController:
|
||||
self.event_stream.add_event(obs, EventSource.AGENT)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f'{self.agent.name} LEVEL {self.state.delegate_level} STEP {self.state.iteration}',
|
||||
extra={'msg_type': 'STEP'},
|
||||
)
|
||||
|
||||
if self.state.iteration >= self.state.max_iterations:
|
||||
if self.state.traffic_control_state == TrafficControlState.PAUSED:
|
||||
logger.info(
|
||||
'Hitting traffic control, temporarily resume upon user request'
|
||||
)
|
||||
self.state.traffic_control_state = TrafficControlState.NORMAL
|
||||
else:
|
||||
self.state.traffic_control_state = TrafficControlState.THROTTLING
|
||||
await self.report_error(
|
||||
f'Agent reached maximum number of iterations, task paused. {TRAFFIC_CONTROL_REMINDER}'
|
||||
)
|
||||
await self.set_agent_state_to(AgentState.PAUSED)
|
||||
return
|
||||
elif self.max_budget_per_task is not None:
|
||||
current_cost = self.state.metrics.accumulated_cost
|
||||
if current_cost > self.max_budget_per_task:
|
||||
if self.state.traffic_control_state == TrafficControlState.PAUSED:
|
||||
logger.info(
|
||||
'Hitting traffic control, temporarily resume upon user request'
|
||||
)
|
||||
self.state.traffic_control_state = TrafficControlState.NORMAL
|
||||
else:
|
||||
self.state.traffic_control_state = TrafficControlState.THROTTLING
|
||||
await self.report_error(
|
||||
f'Task budget exceeded. Current cost: {current_cost:.2f}, Max budget: {self.max_budget_per_task:.2f}, task paused. {TRAFFIC_CONTROL_REMINDER}'
|
||||
)
|
||||
await self.set_agent_state_to(AgentState.PAUSED)
|
||||
return
|
||||
|
||||
self.update_state_before_step()
|
||||
action: Action = NullAction()
|
||||
try:
|
||||
action = self.agent.step(self.state)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from opendevin.core.schema import ActionType
|
||||
from opendevin.llm.llm import LLM
|
||||
|
||||
from .action import Action
|
||||
|
||||
@@ -65,6 +66,7 @@ class AgentDelegateAction(Action):
|
||||
agent: str
|
||||
inputs: dict
|
||||
thought: str = ''
|
||||
llm: LLM | None = None
|
||||
action: str = ActionType.DELEGATE
|
||||
|
||||
@property
|
||||
|
||||
@@ -60,6 +60,9 @@ def event_to_dict(event: 'Event') -> dict:
|
||||
d['source'] = d['source'].value
|
||||
props.pop(key, None)
|
||||
if 'action' in d:
|
||||
if 'llm' in props:
|
||||
# pop LLM used in AgentDelegateAction (not serializable and not needed)
|
||||
props.pop('llm')
|
||||
d['args'] = props
|
||||
elif 'observation' in d:
|
||||
d['content'] = props.pop('content', '')
|
||||
|
||||
Reference in New Issue
Block a user