mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 13:47:19 +08:00
Refactor LLM config (#2953)
* Add max_message_chars to LLM * Refactor LLM config * Fix tests * Made some functions class functions * Fix regression * Fixed comments
This commit is contained in:
@@ -8,7 +8,6 @@ from agenthub.codeact_agent.prompt import (
|
||||
)
|
||||
from opendevin.controller.agent import Agent
|
||||
from opendevin.controller.state.state import State
|
||||
from opendevin.core.config import config
|
||||
from opendevin.events.action import (
|
||||
Action,
|
||||
AgentDelegateAction,
|
||||
@@ -22,6 +21,7 @@ from opendevin.events.observation import (
|
||||
CmdOutputObservation,
|
||||
IPythonRunCellObservation,
|
||||
)
|
||||
from opendevin.events.observation.observation import Observation
|
||||
from opendevin.events.serialization.event import truncate_content
|
||||
from opendevin.llm.llm import LLM
|
||||
from opendevin.runtime.plugins import (
|
||||
@@ -34,62 +34,6 @@ from opendevin.runtime.tools import RuntimeTool
|
||||
ENABLE_GITHUB = True
|
||||
|
||||
|
||||
def action_to_str(action: Action) -> str:
|
||||
if isinstance(action, CmdRunAction):
|
||||
return f'{action.thought}\n<execute_bash>\n{action.command}\n</execute_bash>'
|
||||
elif isinstance(action, IPythonRunCellAction):
|
||||
return f'{action.thought}\n<execute_ipython>\n{action.code}\n</execute_ipython>'
|
||||
elif isinstance(action, AgentDelegateAction):
|
||||
return f'{action.thought}\n<execute_browse>\n{action.inputs["task"]}\n</execute_browse>'
|
||||
elif isinstance(action, MessageAction):
|
||||
return action.content
|
||||
return ''
|
||||
|
||||
|
||||
def get_action_message(action: Action) -> dict[str, str] | None:
|
||||
if (
|
||||
isinstance(action, AgentDelegateAction)
|
||||
or isinstance(action, CmdRunAction)
|
||||
or isinstance(action, IPythonRunCellAction)
|
||||
or isinstance(action, MessageAction)
|
||||
):
|
||||
return {
|
||||
'role': 'user' if action.source == 'user' else 'assistant',
|
||||
'content': action_to_str(action),
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
def get_observation_message(obs) -> dict[str, str] | None:
|
||||
max_message_chars = config.get_llm_config_from_agent(
|
||||
'CodeActAgent'
|
||||
).max_message_chars
|
||||
if isinstance(obs, CmdOutputObservation):
|
||||
content = 'OBSERVATION:\n' + truncate_content(obs.content, max_message_chars)
|
||||
content += (
|
||||
f'\n[Command {obs.command_id} finished with exit code {obs.exit_code}]'
|
||||
)
|
||||
return {'role': 'user', 'content': content}
|
||||
elif isinstance(obs, IPythonRunCellObservation):
|
||||
content = 'OBSERVATION:\n' + obs.content
|
||||
# replace base64 images with a placeholder
|
||||
splitted = content.split('\n')
|
||||
for i, line in enumerate(splitted):
|
||||
if ' already displayed to user'
|
||||
)
|
||||
content = '\n'.join(splitted)
|
||||
content = truncate_content(content, max_message_chars)
|
||||
return {'role': 'user', 'content': content}
|
||||
elif isinstance(obs, AgentDelegateObservation):
|
||||
content = 'OBSERVATION:\n' + truncate_content(
|
||||
str(obs.outputs), max_message_chars
|
||||
)
|
||||
return {'role': 'user', 'content': content}
|
||||
return None
|
||||
|
||||
|
||||
# FIXME: We can tweak these two settings to create MicroAgents specialized toward different area
|
||||
def get_system_message() -> str:
|
||||
if ENABLE_GITHUB:
|
||||
@@ -166,6 +110,61 @@ class CodeActAgent(Agent):
|
||||
super().__init__(llm)
|
||||
self.reset()
|
||||
|
||||
def action_to_str(self, action: Action) -> str:
|
||||
if isinstance(action, CmdRunAction):
|
||||
return (
|
||||
f'{action.thought}\n<execute_bash>\n{action.command}\n</execute_bash>'
|
||||
)
|
||||
elif isinstance(action, IPythonRunCellAction):
|
||||
return f'{action.thought}\n<execute_ipython>\n{action.code}\n</execute_ipython>'
|
||||
elif isinstance(action, AgentDelegateAction):
|
||||
return f'{action.thought}\n<execute_browse>\n{action.inputs["task"]}\n</execute_browse>'
|
||||
elif isinstance(action, MessageAction):
|
||||
return action.content
|
||||
return ''
|
||||
|
||||
def get_action_message(self, action: Action) -> dict[str, str] | None:
|
||||
if (
|
||||
isinstance(action, AgentDelegateAction)
|
||||
or isinstance(action, CmdRunAction)
|
||||
or isinstance(action, IPythonRunCellAction)
|
||||
or isinstance(action, MessageAction)
|
||||
):
|
||||
return {
|
||||
'role': 'user' if action.source == 'user' else 'assistant',
|
||||
'content': self.action_to_str(action),
|
||||
}
|
||||
return None
|
||||
|
||||
def get_observation_message(self, obs: Observation) -> dict[str, str] | None:
|
||||
max_message_chars = self.llm.config.max_message_chars
|
||||
if isinstance(obs, CmdOutputObservation):
|
||||
content = 'OBSERVATION:\n' + truncate_content(
|
||||
obs.content, max_message_chars
|
||||
)
|
||||
content += (
|
||||
f'\n[Command {obs.command_id} finished with exit code {obs.exit_code}]'
|
||||
)
|
||||
return {'role': 'user', 'content': content}
|
||||
elif isinstance(obs, IPythonRunCellObservation):
|
||||
content = 'OBSERVATION:\n' + obs.content
|
||||
# replace base64 images with a placeholder
|
||||
splitted = content.split('\n')
|
||||
for i, line in enumerate(splitted):
|
||||
if ' already displayed to user'
|
||||
)
|
||||
content = '\n'.join(splitted)
|
||||
content = truncate_content(content, max_message_chars)
|
||||
return {'role': 'user', 'content': content}
|
||||
elif isinstance(obs, AgentDelegateObservation):
|
||||
content = 'OBSERVATION:\n' + truncate_content(
|
||||
str(obs.outputs), max_message_chars
|
||||
)
|
||||
return {'role': 'user', 'content': content}
|
||||
return None
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Resets the CodeAct Agent."""
|
||||
super().reset()
|
||||
@@ -211,11 +210,12 @@ class CodeActAgent(Agent):
|
||||
|
||||
for event in state.history.get_events():
|
||||
# create a regular message from an event
|
||||
message = (
|
||||
get_action_message(event)
|
||||
if isinstance(event, Action)
|
||||
else get_observation_message(event)
|
||||
)
|
||||
if isinstance(event, Action):
|
||||
message = self.get_action_message(event)
|
||||
elif isinstance(event, Observation):
|
||||
message = self.get_observation_message(event)
|
||||
else:
|
||||
raise ValueError(f'Unknown event type: {type(event)}')
|
||||
|
||||
# add regular message
|
||||
if message:
|
||||
|
||||
@@ -7,7 +7,6 @@ from agenthub.codeact_swe_agent.prompt import (
|
||||
from agenthub.codeact_swe_agent.response_parser import CodeActSWEResponseParser
|
||||
from opendevin.controller.agent import Agent
|
||||
from opendevin.controller.state.state import State
|
||||
from opendevin.core.config import config
|
||||
from opendevin.events.action import (
|
||||
Action,
|
||||
AgentFinishAction,
|
||||
@@ -19,6 +18,7 @@ from opendevin.events.observation import (
|
||||
CmdOutputObservation,
|
||||
IPythonRunCellObservation,
|
||||
)
|
||||
from opendevin.events.observation.observation import Observation
|
||||
from opendevin.events.serialization.event import truncate_content
|
||||
from opendevin.llm.llm import LLM
|
||||
from opendevin.runtime.plugins import (
|
||||
@@ -29,54 +29,6 @@ from opendevin.runtime.plugins import (
|
||||
from opendevin.runtime.tools import RuntimeTool
|
||||
|
||||
|
||||
def action_to_str(action: Action) -> str:
|
||||
if isinstance(action, CmdRunAction):
|
||||
return f'{action.thought}\n<execute_bash>\n{action.command}\n</execute_bash>'
|
||||
elif isinstance(action, IPythonRunCellAction):
|
||||
return f'{action.thought}\n<execute_ipython>\n{action.code}\n</execute_ipython>'
|
||||
elif isinstance(action, MessageAction):
|
||||
return action.content
|
||||
return ''
|
||||
|
||||
|
||||
def get_action_message(action: Action) -> dict[str, str] | None:
|
||||
if (
|
||||
isinstance(action, CmdRunAction)
|
||||
or isinstance(action, IPythonRunCellAction)
|
||||
or isinstance(action, MessageAction)
|
||||
):
|
||||
return {
|
||||
'role': 'user' if action.source == 'user' else 'assistant',
|
||||
'content': action_to_str(action),
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
def get_observation_message(obs) -> dict[str, str] | None:
|
||||
max_message_chars = config.get_llm_config_from_agent(
|
||||
'CodeActSWEAgent'
|
||||
).max_message_chars
|
||||
if isinstance(obs, CmdOutputObservation):
|
||||
content = 'OBSERVATION:\n' + truncate_content(obs.content, max_message_chars)
|
||||
content += (
|
||||
f'\n[Command {obs.command_id} finished with exit code {obs.exit_code}]'
|
||||
)
|
||||
return {'role': 'user', 'content': content}
|
||||
elif isinstance(obs, IPythonRunCellObservation):
|
||||
content = 'OBSERVATION:\n' + obs.content
|
||||
# replace base64 images with a placeholder
|
||||
splitted = content.split('\n')
|
||||
for i, line in enumerate(splitted):
|
||||
if ' already displayed to user'
|
||||
)
|
||||
content = '\n'.join(splitted)
|
||||
content = truncate_content(content, max_message_chars)
|
||||
return {'role': 'user', 'content': content}
|
||||
return None
|
||||
|
||||
|
||||
def get_system_message() -> str:
|
||||
return f'{SYSTEM_PREFIX}\n\n{COMMAND_DOCS}\n\n{SYSTEM_SUFFIX}'
|
||||
|
||||
@@ -121,6 +73,53 @@ class CodeActSWEAgent(Agent):
|
||||
super().__init__(llm)
|
||||
self.reset()
|
||||
|
||||
def action_to_str(self, action: Action) -> str:
|
||||
if isinstance(action, CmdRunAction):
|
||||
return (
|
||||
f'{action.thought}\n<execute_bash>\n{action.command}\n</execute_bash>'
|
||||
)
|
||||
elif isinstance(action, IPythonRunCellAction):
|
||||
return f'{action.thought}\n<execute_ipython>\n{action.code}\n</execute_ipython>'
|
||||
elif isinstance(action, MessageAction):
|
||||
return action.content
|
||||
return ''
|
||||
|
||||
def get_action_message(self, action: Action) -> dict[str, str] | None:
|
||||
if (
|
||||
isinstance(action, CmdRunAction)
|
||||
or isinstance(action, IPythonRunCellAction)
|
||||
or isinstance(action, MessageAction)
|
||||
):
|
||||
return {
|
||||
'role': 'user' if action.source == 'user' else 'assistant',
|
||||
'content': self.action_to_str(action),
|
||||
}
|
||||
return None
|
||||
|
||||
def get_observation_message(self, obs: Observation) -> dict[str, str] | None:
|
||||
max_message_chars = self.llm.config.max_message_chars
|
||||
if isinstance(obs, CmdOutputObservation):
|
||||
content = 'OBSERVATION:\n' + truncate_content(
|
||||
obs.content, max_message_chars
|
||||
)
|
||||
content += (
|
||||
f'\n[Command {obs.command_id} finished with exit code {obs.exit_code}]'
|
||||
)
|
||||
return {'role': 'user', 'content': content}
|
||||
elif isinstance(obs, IPythonRunCellObservation):
|
||||
content = 'OBSERVATION:\n' + obs.content
|
||||
# replace base64 images with a placeholder
|
||||
splitted = content.split('\n')
|
||||
for i, line in enumerate(splitted):
|
||||
if ' already displayed to user'
|
||||
)
|
||||
content = '\n'.join(splitted)
|
||||
content = truncate_content(content, max_message_chars)
|
||||
return {'role': 'user', 'content': content}
|
||||
return None
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Resets the CodeAct Agent."""
|
||||
super().reset()
|
||||
@@ -165,11 +164,12 @@ class CodeActSWEAgent(Agent):
|
||||
|
||||
for event in state.history.get_events():
|
||||
# create a regular message from an event
|
||||
message = (
|
||||
get_action_message(event)
|
||||
if isinstance(event, Action)
|
||||
else get_observation_message(event)
|
||||
)
|
||||
if isinstance(event, Action):
|
||||
message = self.get_action_message(event)
|
||||
elif isinstance(event, Observation):
|
||||
message = self.get_observation_message(event)
|
||||
else:
|
||||
raise ValueError(f'Unknown event type: {type(event)}')
|
||||
|
||||
# add regular message
|
||||
if message:
|
||||
|
||||
@@ -2,7 +2,6 @@ from jinja2 import BaseLoader, Environment
|
||||
|
||||
from opendevin.controller.agent import Agent
|
||||
from opendevin.controller.state.state import State
|
||||
from opendevin.core.config import config
|
||||
from opendevin.core.utils import json
|
||||
from opendevin.events.action import Action
|
||||
from opendevin.events.serialization.action import action_from_dict
|
||||
@@ -27,32 +26,33 @@ def to_json(obj, **kwargs):
|
||||
return json.dumps(obj, **kwargs)
|
||||
|
||||
|
||||
def history_to_json(history: ShortTermHistory, max_events=20, **kwargs):
|
||||
"""Serialize and simplify history to str format"""
|
||||
# TODO: get agent specific llm config
|
||||
llm_config = config.get_llm_config()
|
||||
max_message_chars = llm_config.max_message_chars
|
||||
|
||||
processed_history = []
|
||||
event_count = 0
|
||||
|
||||
for event in history.get_events(reverse=True):
|
||||
if event_count >= max_events:
|
||||
break
|
||||
processed_history.append(event_to_memory(event, max_message_chars))
|
||||
event_count += 1
|
||||
|
||||
# history is in reverse order, let's fix it
|
||||
processed_history.reverse()
|
||||
|
||||
return json.dumps(processed_history, **kwargs)
|
||||
|
||||
|
||||
class MicroAgent(Agent):
|
||||
VERSION = '1.0'
|
||||
prompt = ''
|
||||
agent_definition: dict = {}
|
||||
|
||||
def history_to_json(
|
||||
self, history: ShortTermHistory, max_events: int = 20, **kwargs
|
||||
):
|
||||
"""
|
||||
Serialize and simplify history to str format
|
||||
"""
|
||||
processed_history = []
|
||||
event_count = 0
|
||||
|
||||
for event in history.get_events(reverse=True):
|
||||
if event_count >= max_events:
|
||||
break
|
||||
processed_history.append(
|
||||
event_to_memory(event, self.llm.config.max_message_chars)
|
||||
)
|
||||
event_count += 1
|
||||
|
||||
# history is in reverse order, let's fix it
|
||||
processed_history.reverse()
|
||||
|
||||
return json.dumps(processed_history, **kwargs)
|
||||
|
||||
def __init__(self, llm: LLM):
|
||||
super().__init__(llm)
|
||||
if 'name' not in self.agent_definition:
|
||||
@@ -66,7 +66,7 @@ class MicroAgent(Agent):
|
||||
state=state,
|
||||
instructions=instructions,
|
||||
to_json=to_json,
|
||||
history_to_json=history_to_json,
|
||||
history_to_json=self.history_to_json,
|
||||
delegates=self.delegates,
|
||||
latest_user_message=state.get_current_user_intent(),
|
||||
)
|
||||
|
||||
@@ -83,10 +83,7 @@ class MonologueAgent(Agent):
|
||||
self._add_initial_thoughts(task)
|
||||
self._initialized = True
|
||||
|
||||
def _add_initial_thoughts(self, task):
|
||||
max_message_chars = config.get_llm_config_from_agent(
|
||||
'MonologueAgent'
|
||||
).max_message_chars
|
||||
def _add_initial_thoughts(self, task: str):
|
||||
previous_action = ''
|
||||
for thought in INITIAL_THOUGHTS:
|
||||
thought = thought.replace('$TASK', task)
|
||||
@@ -103,7 +100,7 @@ class MonologueAgent(Agent):
|
||||
content=thought, url='', screenshot=''
|
||||
)
|
||||
self.initial_thoughts.append(
|
||||
event_to_memory(observation, max_message_chars)
|
||||
event_to_memory(observation, self.llm.config.max_message_chars)
|
||||
)
|
||||
previous_action = ''
|
||||
else:
|
||||
@@ -127,7 +124,9 @@ class MonologueAgent(Agent):
|
||||
previous_action = ActionType.BROWSE
|
||||
else:
|
||||
action = MessageAction(thought)
|
||||
self.initial_thoughts.append(event_to_memory(action, max_message_chars))
|
||||
self.initial_thoughts.append(
|
||||
event_to_memory(action, self.llm.config.max_message_chars)
|
||||
)
|
||||
|
||||
def step(self, state: State) -> Action:
|
||||
"""Modifies the current state by adding the most recent actions and observations, then prompts the model to think about it's next action to take using monologue, memory, and hint.
|
||||
@@ -138,9 +137,6 @@ class MonologueAgent(Agent):
|
||||
Returns:
|
||||
- Action: The next action to take based on LLM response
|
||||
"""
|
||||
max_message_chars = config.get_llm_config_from_agent(
|
||||
'MonologueAgent'
|
||||
).max_message_chars
|
||||
goal = state.get_current_user_intent()
|
||||
self._initialize(goal)
|
||||
|
||||
@@ -148,7 +144,9 @@ class MonologueAgent(Agent):
|
||||
|
||||
# add the events from state.history
|
||||
for event in state.history.get_events():
|
||||
recent_events.append(event_to_memory(event, max_message_chars))
|
||||
recent_events.append(
|
||||
event_to_memory(event, self.llm.config.max_message_chars)
|
||||
)
|
||||
|
||||
# add the last messages to long term memory
|
||||
if self.memory is not None:
|
||||
@@ -158,10 +156,12 @@ class MonologueAgent(Agent):
|
||||
# this should still work
|
||||
# we will need to do this differently: find out if there really is an action or an observation in this step
|
||||
if last_action:
|
||||
self.memory.add_event(event_to_memory(last_action, max_message_chars))
|
||||
self.memory.add_event(
|
||||
event_to_memory(last_action, self.llm.config.max_message_chars)
|
||||
)
|
||||
if last_observation:
|
||||
self.memory.add_event(
|
||||
event_to_memory(last_observation, max_message_chars)
|
||||
event_to_memory(last_observation, self.llm.config.max_message_chars)
|
||||
)
|
||||
|
||||
# the action prompt with initial thoughts and recent events
|
||||
|
||||
@@ -42,7 +42,7 @@ class PlannerAgent(Agent):
|
||||
'abandoned',
|
||||
]:
|
||||
return AgentFinishAction()
|
||||
prompt = get_prompt(state)
|
||||
prompt = get_prompt(state, self.llm.config.max_message_chars)
|
||||
messages = [{'content': prompt, 'role': 'user'}]
|
||||
resp = self.llm.completion(messages=messages)
|
||||
return self.response_parser.parse(resp)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from opendevin.controller.state.state import State
|
||||
from opendevin.core.config import config
|
||||
from opendevin.core.logger import opendevin_logger as logger
|
||||
from opendevin.core.schema import ActionType
|
||||
from opendevin.core.utils import json
|
||||
@@ -116,8 +115,9 @@ def get_hint(latest_action_id: str) -> str:
|
||||
return hints.get(latest_action_id, '')
|
||||
|
||||
|
||||
def get_prompt(state: State) -> str:
|
||||
def get_prompt(state: State, max_message_chars: int) -> str:
|
||||
"""Gets the prompt for the planner agent.
|
||||
|
||||
Formatted with the most recent action-observation pairs, current task, and hint based on last action
|
||||
|
||||
Parameters:
|
||||
@@ -126,10 +126,6 @@ def get_prompt(state: State) -> str:
|
||||
Returns:
|
||||
- str: The formatted string prompt with historical values
|
||||
"""
|
||||
max_message_chars = config.get_llm_config_from_agent(
|
||||
'PlannerAgent'
|
||||
).max_message_chars
|
||||
|
||||
# the plan
|
||||
plan_str = json.dumps(state.root_task.to_dict(), indent=2)
|
||||
|
||||
|
||||
@@ -248,7 +248,7 @@ class AgentController:
|
||||
async def start_delegate(self, action: AgentDelegateAction):
|
||||
agent_cls: Type[Agent] = Agent.get_cls(action.agent)
|
||||
llm_config = config.get_llm_config_from_agent(action.agent)
|
||||
llm = LLM(llm_config=llm_config)
|
||||
llm = LLM(config=llm_config)
|
||||
delegate_agent = agent_cls(llm=llm)
|
||||
state = State(
|
||||
inputs=action.inputs or {},
|
||||
|
||||
@@ -198,7 +198,7 @@ class AppConfig(metaclass=Singleton):
|
||||
file_uploads_allowed_extensions: List of allowed file extensions for uploads. ['.*'] means all extensions are allowed.
|
||||
"""
|
||||
|
||||
llms: dict = field(default_factory=dict)
|
||||
llms: dict[str, LLMConfig] = field(default_factory=dict)
|
||||
agents: dict = field(default_factory=dict)
|
||||
default_agent: str = 'CodeActAgent'
|
||||
sandbox: SandboxConfig = field(default_factory=SandboxConfig)
|
||||
|
||||
@@ -52,7 +52,7 @@ async def run_agent_controller(
|
||||
"""
|
||||
# Logging
|
||||
logger.info(
|
||||
f'Running agent {agent.name}, model {agent.llm.model_name}, with task: "{task_str}"'
|
||||
f'Running agent {agent.name}, model {agent.llm.config.model}, with task: "{task_str}"'
|
||||
)
|
||||
|
||||
# set up the event stream
|
||||
@@ -163,7 +163,7 @@ if __name__ == '__main__':
|
||||
if llm_config is None:
|
||||
raise ValueError(f'Invalid toml file, cannot read {args.llm_config}')
|
||||
config.set_llm_config(llm_config)
|
||||
llm = LLM(llm_config=config.get_llm_config_from_agent(args.agent_cls))
|
||||
llm = LLM(config=config.get_llm_config_from_agent(args.agent_cls))
|
||||
|
||||
# Create the agent
|
||||
AgentCls: Type[Agent] = Agent.get_cls(args.agent_cls)
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
import copy
|
||||
import warnings
|
||||
from functools import partial
|
||||
|
||||
from opendevin.core.config import LLMConfig
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore')
|
||||
import litellm
|
||||
@@ -21,7 +24,6 @@ from tenacity import (
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from opendevin.core.config import config
|
||||
from opendevin.core.logger import llm_prompt_logger, llm_response_logger
|
||||
from opendevin.core.logger import opendevin_logger as logger
|
||||
from opendevin.core.metrics import Metrics
|
||||
@@ -35,155 +37,71 @@ class LLM:
|
||||
"""The LLM class represents a Language Model instance.
|
||||
|
||||
Attributes:
|
||||
model_name (str): The name of the language model.
|
||||
api_key (str): The API key for accessing the language model.
|
||||
base_url (str): The base URL for the language model API.
|
||||
api_version (str): The version of the API to use.
|
||||
max_input_tokens (int): The maximum number of tokens to send to the LLM per task.
|
||||
max_output_tokens (int): The maximum number of tokens to receive from the LLM per task.
|
||||
llm_timeout (int): The maximum time to wait for a response in seconds.
|
||||
custom_llm_provider (str): A custom LLM provider.
|
||||
config: an LLMConfig object specifying the configuration of the LLM.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model=None,
|
||||
api_key=None,
|
||||
base_url=None,
|
||||
api_version=None,
|
||||
num_retries=None,
|
||||
retry_min_wait=None,
|
||||
retry_max_wait=None,
|
||||
llm_timeout=None,
|
||||
llm_temperature=None,
|
||||
llm_top_p=None,
|
||||
custom_llm_provider=None,
|
||||
max_input_tokens=None,
|
||||
max_output_tokens=None,
|
||||
llm_config=None,
|
||||
metrics=None,
|
||||
cost_metric_supported=True,
|
||||
input_cost_per_token=None,
|
||||
output_cost_per_token=None,
|
||||
config: LLMConfig,
|
||||
metrics: Metrics | None = None,
|
||||
):
|
||||
"""Initializes the LLM. If LLMConfig is passed, its values will be the fallback.
|
||||
|
||||
Passing simple parameters always overrides config.
|
||||
|
||||
Args:
|
||||
model (str, optional): The name of the language model. Defaults to LLM_MODEL.
|
||||
api_key (str, optional): The API key for accessing the language model. Defaults to LLM_API_KEY.
|
||||
base_url (str, optional): The base URL for the language model API. Defaults to LLM_BASE_URL. Not necessary for OpenAI.
|
||||
api_version (str, optional): The version of the API to use. Defaults to LLM_API_VERSION. Not necessary for OpenAI.
|
||||
num_retries (int, optional): The number of retries for API calls. Defaults to LLM_NUM_RETRIES.
|
||||
retry_min_wait (int, optional): The minimum time to wait between retries in seconds. Defaults to LLM_RETRY_MIN_TIME.
|
||||
retry_max_wait (int, optional): The maximum time to wait between retries in seconds. Defaults to LLM_RETRY_MAX_TIME.
|
||||
max_input_tokens (int, optional): The maximum number of tokens to send to the LLM per task. Defaults to LLM_MAX_INPUT_TOKENS.
|
||||
max_output_tokens (int, optional): The maximum number of tokens to receive from the LLM per task. Defaults to LLM_MAX_OUTPUT_TOKENS.
|
||||
custom_llm_provider (str, optional): A custom LLM provider. Defaults to LLM_CUSTOM_LLM_PROVIDER.
|
||||
llm_timeout (int, optional): The maximum time to wait for a response in seconds. Defaults to LLM_TIMEOUT.
|
||||
llm_temperature (float, optional): The temperature for LLM sampling. Defaults to LLM_TEMPERATURE.
|
||||
metrics (Metrics, optional): The metrics object to use. Defaults to None.
|
||||
cost_metric_supported (bool, optional): Whether the cost metric is supported. Defaults to True.
|
||||
input_cost_per_token (float, optional): The cost per input token.
|
||||
output_cost_per_token (float, optional): The cost per output token.
|
||||
config: The LLM configuration
|
||||
"""
|
||||
if llm_config is None:
|
||||
llm_config = config.get_llm_config()
|
||||
model = model if model is not None else llm_config.model
|
||||
api_key = api_key if api_key is not None else llm_config.api_key
|
||||
base_url = base_url if base_url is not None else llm_config.base_url
|
||||
api_version = api_version if api_version is not None else llm_config.api_version
|
||||
num_retries = num_retries if num_retries is not None else llm_config.num_retries
|
||||
retry_min_wait = (
|
||||
retry_min_wait if retry_min_wait is not None else llm_config.retry_min_wait
|
||||
)
|
||||
retry_max_wait = (
|
||||
retry_max_wait if retry_max_wait is not None else llm_config.retry_max_wait
|
||||
)
|
||||
llm_timeout = llm_timeout if llm_timeout is not None else llm_config.timeout
|
||||
llm_temperature = (
|
||||
llm_temperature if llm_temperature is not None else llm_config.temperature
|
||||
)
|
||||
llm_top_p = llm_top_p if llm_top_p is not None else llm_config.top_p
|
||||
custom_llm_provider = (
|
||||
custom_llm_provider
|
||||
if custom_llm_provider is not None
|
||||
else llm_config.custom_llm_provider
|
||||
)
|
||||
max_input_tokens = (
|
||||
max_input_tokens
|
||||
if max_input_tokens is not None
|
||||
else llm_config.max_input_tokens
|
||||
)
|
||||
max_output_tokens = (
|
||||
max_output_tokens
|
||||
if max_output_tokens is not None
|
||||
else llm_config.max_output_tokens
|
||||
)
|
||||
input_cost_per_token = (
|
||||
input_cost_per_token
|
||||
if input_cost_per_token is not None
|
||||
else llm_config.input_cost_per_token
|
||||
)
|
||||
output_cost_per_token = (
|
||||
output_cost_per_token
|
||||
if output_cost_per_token is not None
|
||||
else llm_config.output_cost_per_token
|
||||
)
|
||||
metrics = metrics if metrics is not None else Metrics()
|
||||
|
||||
logger.info(f'Initializing LLM with model: {model}')
|
||||
self.model_name = model
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
self.api_version = api_version
|
||||
self.max_input_tokens = max_input_tokens
|
||||
self.max_output_tokens = max_output_tokens
|
||||
self.input_cost_per_token = input_cost_per_token
|
||||
self.output_cost_per_token = output_cost_per_token
|
||||
self.llm_timeout = llm_timeout
|
||||
self.custom_llm_provider = custom_llm_provider
|
||||
self.metrics = metrics
|
||||
self.cost_metric_supported = cost_metric_supported
|
||||
self.config = copy.deepcopy(config)
|
||||
self.metrics = metrics if metrics is not None else Metrics()
|
||||
self.cost_metric_supported = True
|
||||
|
||||
# litellm actually uses base Exception here for unknown model
|
||||
self.model_info = None
|
||||
try:
|
||||
if not self.model_name.startswith('openrouter'):
|
||||
self.model_info = litellm.get_model_info(self.model_name.split(':')[0])
|
||||
if not config.model.startswith('openrouter'):
|
||||
self.model_info = litellm.get_model_info(config.model.split(':')[0])
|
||||
else:
|
||||
self.model_info = litellm.get_model_info(self.model_name)
|
||||
self.model_info = litellm.get_model_info(config.model)
|
||||
# noinspection PyBroadException
|
||||
except Exception:
|
||||
logger.warning(f'Could not get model info for {self.model_name}')
|
||||
logger.warning(f'Could not get model info for {config.model}')
|
||||
|
||||
if self.max_input_tokens is None:
|
||||
if self.model_info is not None and 'max_input_tokens' in self.model_info:
|
||||
self.max_input_tokens = self.model_info['max_input_tokens']
|
||||
# Set the max tokens in an LM-specific way if not set
|
||||
if config.max_input_tokens is None:
|
||||
if (
|
||||
self.model_info is not None
|
||||
and 'max_input_tokens' in self.model_info
|
||||
and isinstance(self.model_info['max_input_tokens'], int)
|
||||
):
|
||||
self.config.max_input_tokens = self.model_info['max_input_tokens']
|
||||
else:
|
||||
# Max input tokens for gpt3.5, so this is a safe fallback for any potentially viable model
|
||||
self.max_input_tokens = 4096
|
||||
self.config.max_input_tokens = 4096
|
||||
|
||||
if self.max_output_tokens is None:
|
||||
if self.model_info is not None and 'max_output_tokens' in self.model_info:
|
||||
self.max_output_tokens = self.model_info['max_output_tokens']
|
||||
if config.max_output_tokens is None:
|
||||
if (
|
||||
self.model_info is not None
|
||||
and 'max_output_tokens' in self.model_info
|
||||
and isinstance(self.model_info['max_output_tokens'], int)
|
||||
):
|
||||
self.config.max_output_tokens = self.model_info['max_output_tokens']
|
||||
else:
|
||||
# Enough tokens for most output actions, and not too many for a bad llm to get carried away responding
|
||||
# with thousands of unwanted tokens
|
||||
self.max_output_tokens = 1024
|
||||
# Max output tokens for gpt3.5, so this is a safe fallback for any potentially viable model
|
||||
self.config.max_output_tokens = 1024
|
||||
|
||||
self._completion = partial(
|
||||
litellm_completion,
|
||||
model=self.model_name,
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url,
|
||||
api_version=self.api_version,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
max_tokens=self.max_output_tokens,
|
||||
timeout=self.llm_timeout,
|
||||
temperature=llm_temperature,
|
||||
top_p=llm_top_p,
|
||||
model=self.config.model,
|
||||
api_key=self.config.api_key,
|
||||
base_url=self.config.base_url,
|
||||
api_version=self.config.api_version,
|
||||
custom_llm_provider=self.config.custom_llm_provider,
|
||||
max_tokens=self.config.max_output_tokens,
|
||||
timeout=self.config.timeout,
|
||||
temperature=self.config.temperature,
|
||||
top_p=self.config.top_p,
|
||||
)
|
||||
|
||||
completion_unwrapped = self._completion
|
||||
@@ -197,8 +115,10 @@ class LLM:
|
||||
|
||||
@retry(
|
||||
reraise=True,
|
||||
stop=stop_after_attempt(num_retries),
|
||||
wait=wait_random_exponential(min=retry_min_wait, max=retry_max_wait),
|
||||
stop=stop_after_attempt(config.num_retries),
|
||||
wait=wait_random_exponential(
|
||||
min=config.retry_min_wait, max=config.retry_max_wait
|
||||
),
|
||||
retry=retry_if_exception_type(
|
||||
(
|
||||
RateLimitError,
|
||||
@@ -267,7 +187,7 @@ class LLM:
|
||||
Returns:
|
||||
int: The number of tokens.
|
||||
"""
|
||||
return litellm.token_counter(model=self.model_name, messages=messages)
|
||||
return litellm.token_counter(model=self.config.model, messages=messages)
|
||||
|
||||
def is_local(self):
|
||||
"""Determines if the system is using a locally running LLM.
|
||||
@@ -275,12 +195,12 @@ class LLM:
|
||||
Returns:
|
||||
boolean: True if executing a local model.
|
||||
"""
|
||||
if self.base_url is not None:
|
||||
if self.config.base_url is not None:
|
||||
for substring in ['localhost', '127.0.0.1' '0.0.0.0']:
|
||||
if substring in self.base_url:
|
||||
if substring in self.config.base_url:
|
||||
return True
|
||||
elif self.model_name is not None:
|
||||
if self.model_name.startswith('ollama'):
|
||||
elif self.config.model is not None:
|
||||
if self.config.model.startswith('ollama'):
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -299,12 +219,12 @@ class LLM:
|
||||
|
||||
extra_kwargs = {}
|
||||
if (
|
||||
self.input_cost_per_token is not None
|
||||
and self.output_cost_per_token is not None
|
||||
self.config.input_cost_per_token is not None
|
||||
and self.config.output_cost_per_token is not None
|
||||
):
|
||||
cost_per_token = CostPerToken(
|
||||
input_cost_per_token=self.input_cost_per_token,
|
||||
output_cost_per_token=self.output_cost_per_token,
|
||||
input_cost_per_token=self.config.input_cost_per_token,
|
||||
output_cost_per_token=self.config.output_cost_per_token,
|
||||
)
|
||||
logger.info(f'Using custom cost per token: {cost_per_token}')
|
||||
extra_kwargs['custom_cost_per_token'] = cost_per_token
|
||||
@@ -322,11 +242,11 @@ class LLM:
|
||||
return 0.0
|
||||
|
||||
def __str__(self):
|
||||
if self.api_version:
|
||||
return f'LLM(model={self.model_name}, api_version={self.api_version}, base_url={self.base_url})'
|
||||
elif self.base_url:
|
||||
return f'LLM(model={self.model_name}, base_url={self.base_url})'
|
||||
return f'LLM(model={self.model_name})'
|
||||
if self.config.api_version:
|
||||
return f'LLM(model={self.config.model}, api_version={self.config.api_version}, base_url={self.config.base_url})'
|
||||
elif self.config.base_url:
|
||||
return f'LLM(model={self.config.model}, base_url={self.config.base_url})'
|
||||
return f'LLM(model={self.config.model})'
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
@@ -97,7 +97,7 @@ class AgentSession:
|
||||
|
||||
# TODO: override other LLM config & agent config groups (#2075)
|
||||
|
||||
llm = LLM(llm_config=config.get_llm_config_from_agent(agent_cls))
|
||||
llm = LLM(config=config.get_llm_config_from_agent(agent_cls))
|
||||
agent = Agent.get_cls(agent_cls)(llm)
|
||||
logger.info(f'Creating agent {agent.name} using LLM {llm}')
|
||||
if isinstance(agent, CodeActAgent):
|
||||
|
||||
@@ -7,7 +7,7 @@ import pytest
|
||||
|
||||
from opendevin.controller.agent import Agent
|
||||
from opendevin.controller.state.state import State
|
||||
from opendevin.core.config import parse_arguments
|
||||
from opendevin.core.config import LLMConfig, parse_arguments
|
||||
from opendevin.core.main import run_agent_controller
|
||||
from opendevin.core.schema import AgentState
|
||||
from opendevin.events.action import (
|
||||
@@ -44,20 +44,22 @@ print(f'workspace_mount_path_in_sandbox: {workspace_mount_path_in_sandbox}')
|
||||
os.getenv('DEFAULT_AGENT') == 'ManagerAgent',
|
||||
reason='Manager agent is not capable of finishing this in reasonable steps yet',
|
||||
)
|
||||
def test_write_simple_script():
|
||||
def test_write_simple_script() -> None:
|
||||
task = "Write a shell script 'hello.sh' that prints 'hello'. Do not ask me for confirmation at any point."
|
||||
args = parse_arguments()
|
||||
|
||||
# Create the agent
|
||||
agent = Agent.get_cls(args.agent_cls)(llm=LLM())
|
||||
agent = Agent.get_cls(args.agent_cls)(llm=LLM(LLMConfig()))
|
||||
|
||||
final_state: State | None = asyncio.run(
|
||||
run_agent_controller(agent, task, exit_on_message=True)
|
||||
)
|
||||
assert final_state is not None
|
||||
assert final_state.agent_state == AgentState.STOPPED
|
||||
assert final_state.last_error is None
|
||||
|
||||
# Verify the script file exists
|
||||
assert workspace_base is not None
|
||||
script_path = os.path.join(workspace_base, 'hello.sh')
|
||||
assert os.path.exists(script_path), 'The file "hello.sh" does not exist'
|
||||
|
||||
@@ -103,7 +105,7 @@ def test_edits():
|
||||
shutil.copy(os.path.join(source_dir, file), dest_file)
|
||||
|
||||
# Create the agent
|
||||
agent = Agent.get_cls(args.agent_cls)(llm=LLM())
|
||||
agent = Agent.get_cls(args.agent_cls)(llm=LLM(LLMConfig()))
|
||||
|
||||
# Execute the task
|
||||
task = 'Fix typos in bad.txt. Do not ask me for confirmation at any point.'
|
||||
@@ -137,7 +139,7 @@ def test_ipython():
|
||||
args = parse_arguments()
|
||||
|
||||
# Create the agent
|
||||
agent = Agent.get_cls(args.agent_cls)(llm=LLM())
|
||||
agent = Agent.get_cls(args.agent_cls)(llm=LLM(LLMConfig()))
|
||||
|
||||
# Execute the task
|
||||
task = "Use Jupyter IPython to write a text file containing 'hello world' to '/workspace/test.txt'. Do not ask me for confirmation at any point."
|
||||
@@ -171,7 +173,7 @@ def test_simple_task_rejection():
|
||||
args = parse_arguments()
|
||||
|
||||
# Create the agent
|
||||
agent = Agent.get_cls(args.agent_cls)(llm=LLM())
|
||||
agent = Agent.get_cls(args.agent_cls)(llm=LLM(LLMConfig()))
|
||||
|
||||
# Give an impossible task to do: cannot write a commit message because
|
||||
# the workspace is not a git repo
|
||||
@@ -195,7 +197,7 @@ def test_ipython_module():
|
||||
args = parse_arguments()
|
||||
|
||||
# Create the agent
|
||||
agent = Agent.get_cls(args.agent_cls)(llm=LLM())
|
||||
agent = Agent.get_cls(args.agent_cls)(llm=LLM(LLMConfig()))
|
||||
|
||||
# Execute the task
|
||||
task = "Install and import pymsgbox==1.0.9 and print it's version in /workspace/test.txt. Do not ask me for confirmation at any point."
|
||||
@@ -235,7 +237,7 @@ def test_browse_internet(http_server):
|
||||
args = parse_arguments()
|
||||
|
||||
# Create the agent
|
||||
agent = Agent.get_cls(args.agent_cls)(llm=LLM())
|
||||
agent = Agent.get_cls(args.agent_cls)(llm=LLM(LLMConfig()))
|
||||
|
||||
# Execute the task
|
||||
task = 'Browse localhost:8000, and tell me the ultimate answer to life. Do not ask me for confirmation at any point.'
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from opendevin.core.config import config
|
||||
from opendevin.events.action import (
|
||||
Action,
|
||||
AddTaskAction,
|
||||
@@ -20,7 +19,9 @@ from opendevin.events.serialization import (
|
||||
)
|
||||
|
||||
|
||||
def serialization_deserialization(original_action_dict, cls):
|
||||
def serialization_deserialization(
|
||||
original_action_dict, cls, max_message_chars: int = 10000
|
||||
):
|
||||
action_instance = event_from_dict(original_action_dict)
|
||||
assert isinstance(
|
||||
action_instance, Action
|
||||
@@ -29,9 +30,7 @@ def serialization_deserialization(original_action_dict, cls):
|
||||
action_instance, cls
|
||||
), f'The action instance should be an instance of {cls.__name__}.'
|
||||
serialized_action_dict = event_to_dict(action_instance)
|
||||
serialized_action_memory = event_to_memory(
|
||||
action_instance, config.get_llm_config().max_message_chars
|
||||
)
|
||||
serialized_action_memory = event_to_memory(action_instance, max_message_chars)
|
||||
serialized_action_dict.pop('message')
|
||||
assert (
|
||||
serialized_action_dict == original_action_dict
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from opendevin.core.config import config
|
||||
from opendevin.events.observation import (
|
||||
CmdOutputObservation,
|
||||
Observation,
|
||||
@@ -10,7 +9,9 @@ from opendevin.events.serialization import (
|
||||
)
|
||||
|
||||
|
||||
def serialization_deserialization(original_observation_dict, cls):
|
||||
def serialization_deserialization(
|
||||
original_observation_dict, cls, max_message_chars: int = 10000
|
||||
):
|
||||
observation_instance = event_from_dict(original_observation_dict)
|
||||
assert isinstance(
|
||||
observation_instance, Observation
|
||||
@@ -20,7 +21,7 @@ def serialization_deserialization(original_observation_dict, cls):
|
||||
), 'The observation instance should be an instance of CmdOutputObservation.'
|
||||
serialized_observation_dict = event_to_dict(observation_instance)
|
||||
serialized_observation_memory = event_to_memory(
|
||||
observation_instance, config.get_llm_config().max_message_chars
|
||||
observation_instance, max_message_chars
|
||||
)
|
||||
assert (
|
||||
serialized_observation_dict == original_observation_dict
|
||||
|
||||
Reference in New Issue
Block a user