pass llm to delegation action so that sub-agent shares the same llm for cost accum purpose

This commit is contained in:
Xingyao Wang
2024-07-15 13:34:21 -05:00
parent 47e26cd753
commit 81034c486e
5 changed files with 63 additions and 39 deletions

View File

@@ -206,7 +206,12 @@ class CodeActAgent(Agent):
],
temperature=0.0,
)
return self.action_parser.parse(response)
action = self.action_parser.parse(response)
# post-processing for agent delegation to share the same llm instance
if isinstance(action, AgentDelegateAction):
action.llm = self.llm
return action
def _get_messages(self, state: State) -> list[dict[str, str]]:
messages = [

View File

@@ -38,7 +38,9 @@ class DelegatorAgent(Agent):
self.current_delegate = 'study'
task = state.get_current_user_intent()
return AgentDelegateAction(
agent='StudyRepoForTaskAgent', inputs={'task': task}
agent='StudyRepoForTaskAgent',
inputs={'task': task},
llm=self.llm,
)
# last observation in history should be from the delegate
@@ -56,6 +58,7 @@ class DelegatorAgent(Agent):
'task': goal,
'summary': last_observation.outputs['summary'],
},
llm=self.llm,
)
elif self.current_delegate == 'coder':
self.current_delegate = 'verifier'
@@ -64,6 +67,7 @@ class DelegatorAgent(Agent):
inputs={
'task': goal,
},
llm=self.llm,
)
elif self.current_delegate == 'verifier':
if (
@@ -79,6 +83,7 @@ class DelegatorAgent(Agent):
'task': goal,
'summary': last_observation.outputs['summary'],
},
llm=self.llm,
)
else:
raise Exception('Invalid delegate state')

View File

@@ -248,8 +248,16 @@ class AgentController:
async def start_delegate(self, action: AgentDelegateAction):
agent_cls: Type[Agent] = Agent.get_cls(action.agent)
llm_config = config.get_llm_config_from_agent(action.agent)
llm = LLM(llm_config=llm_config)
if action.llm is not None:
# share the same llm instance as parent
# so the cost is accumulated together
llm = action.llm
else:
llm_config = config.get_llm_config_from_agent(action.agent)
logger.warn(
f'Using default LLM for agent {action.agent}. You should specify the LLM in the delegate action. Current config: {llm_config}'
)
llm = LLM(llm_config=llm_config)
delegate_agent = agent_cls(llm=llm)
state = State(
inputs=action.inputs or {},
@@ -285,6 +293,42 @@ class AgentController:
await asyncio.sleep(1)
return
logger.info(
f'{self.agent.name} LEVEL {self.state.delegate_level} STEP {self.state.iteration}',
extra={'msg_type': 'STEP'},
)
if self.state.iteration >= self.state.max_iterations:
if self.state.traffic_control_state == TrafficControlState.PAUSED:
logger.info(
'Hitting traffic control, temporarily resume upon user request'
)
self.state.traffic_control_state = TrafficControlState.NORMAL
else:
self.state.traffic_control_state = TrafficControlState.THROTTLING
await self.report_error(
f'Agent reached maximum number of iterations, task paused. {TRAFFIC_CONTROL_REMINDER}'
)
await self.set_agent_state_to(AgentState.PAUSED)
return
elif self.max_budget_per_task is not None:
current_cost = self.state.metrics.accumulated_cost
if current_cost > self.max_budget_per_task:
if self.state.traffic_control_state == TrafficControlState.PAUSED:
logger.info(
'Hitting traffic control, temporarily resume upon user request'
)
self.state.traffic_control_state = TrafficControlState.NORMAL
else:
self.state.traffic_control_state = TrafficControlState.THROTTLING
await self.report_error(
f'Task budget exceeded. Current cost: {current_cost:.2f}, Max budget: {self.max_budget_per_task:.2f}, task paused. {TRAFFIC_CONTROL_REMINDER}'
)
await self.set_agent_state_to(AgentState.PAUSED)
return
self.update_state_before_step()
if self.delegate is not None:
logger.debug(f'[Agent Controller {self.id}] Delegate not none, awaiting...')
assert self.delegate != self
@@ -328,41 +372,6 @@ class AgentController:
self.event_stream.add_event(obs, EventSource.AGENT)
return
logger.info(
f'{self.agent.name} LEVEL {self.state.delegate_level} STEP {self.state.iteration}',
extra={'msg_type': 'STEP'},
)
if self.state.iteration >= self.state.max_iterations:
if self.state.traffic_control_state == TrafficControlState.PAUSED:
logger.info(
'Hitting traffic control, temporarily resume upon user request'
)
self.state.traffic_control_state = TrafficControlState.NORMAL
else:
self.state.traffic_control_state = TrafficControlState.THROTTLING
await self.report_error(
f'Agent reached maximum number of iterations, task paused. {TRAFFIC_CONTROL_REMINDER}'
)
await self.set_agent_state_to(AgentState.PAUSED)
return
elif self.max_budget_per_task is not None:
current_cost = self.state.metrics.accumulated_cost
if current_cost > self.max_budget_per_task:
if self.state.traffic_control_state == TrafficControlState.PAUSED:
logger.info(
'Hitting traffic control, temporarily resume upon user request'
)
self.state.traffic_control_state = TrafficControlState.NORMAL
else:
self.state.traffic_control_state = TrafficControlState.THROTTLING
await self.report_error(
f'Task budget exceeded. Current cost: {current_cost:.2f}, Max budget: {self.max_budget_per_task:.2f}, task paused. {TRAFFIC_CONTROL_REMINDER}'
)
await self.set_agent_state_to(AgentState.PAUSED)
return
self.update_state_before_step()
action: Action = NullAction()
try:
action = self.agent.step(self.state)

View File

@@ -1,6 +1,7 @@
from dataclasses import dataclass, field
from opendevin.core.schema import ActionType
from opendevin.llm.llm import LLM
from .action import Action
@@ -65,6 +66,7 @@ class AgentDelegateAction(Action):
agent: str
inputs: dict
thought: str = ''
llm: LLM | None = None
action: str = ActionType.DELEGATE
@property

View File

@@ -60,6 +60,9 @@ def event_to_dict(event: 'Event') -> dict:
d['source'] = d['source'].value
props.pop(key, None)
if 'action' in d:
if 'llm' in props:
# pop LLM used in AgentDelegateAction (not serializable and not needed)
props.pop('llm')
d['args'] = props
elif 'observation' in d:
d['content'] = props.pop('content', '')