Refactor LLM config (#2953)

* Add max_message_chars to LLM

* Refactor LLM config

* Fix tests

* Made some functions class functions

* Fix regression

* Fixed comments
This commit is contained in:
Graham Neubig
2024-07-17 09:16:04 -04:00
committed by GitHub
parent 01ce1e35b5
commit c897791024
14 changed files with 236 additions and 318 deletions

View File

@@ -8,7 +8,6 @@ from agenthub.codeact_agent.prompt import (
)
from opendevin.controller.agent import Agent
from opendevin.controller.state.state import State
from opendevin.core.config import config
from opendevin.events.action import (
Action,
AgentDelegateAction,
@@ -22,6 +21,7 @@ from opendevin.events.observation import (
CmdOutputObservation,
IPythonRunCellObservation,
)
from opendevin.events.observation.observation import Observation
from opendevin.events.serialization.event import truncate_content
from opendevin.llm.llm import LLM
from opendevin.runtime.plugins import (
@@ -34,62 +34,6 @@ from opendevin.runtime.tools import RuntimeTool
ENABLE_GITHUB = True
def action_to_str(action: Action) -> str:
if isinstance(action, CmdRunAction):
return f'{action.thought}\n<execute_bash>\n{action.command}\n</execute_bash>'
elif isinstance(action, IPythonRunCellAction):
return f'{action.thought}\n<execute_ipython>\n{action.code}\n</execute_ipython>'
elif isinstance(action, AgentDelegateAction):
return f'{action.thought}\n<execute_browse>\n{action.inputs["task"]}\n</execute_browse>'
elif isinstance(action, MessageAction):
return action.content
return ''
def get_action_message(action: Action) -> dict[str, str] | None:
if (
isinstance(action, AgentDelegateAction)
or isinstance(action, CmdRunAction)
or isinstance(action, IPythonRunCellAction)
or isinstance(action, MessageAction)
):
return {
'role': 'user' if action.source == 'user' else 'assistant',
'content': action_to_str(action),
}
return None
def get_observation_message(obs) -> dict[str, str] | None:
max_message_chars = config.get_llm_config_from_agent(
'CodeActAgent'
).max_message_chars
if isinstance(obs, CmdOutputObservation):
content = 'OBSERVATION:\n' + truncate_content(obs.content, max_message_chars)
content += (
f'\n[Command {obs.command_id} finished with exit code {obs.exit_code}]'
)
return {'role': 'user', 'content': content}
elif isinstance(obs, IPythonRunCellObservation):
content = 'OBSERVATION:\n' + obs.content
# replace base64 images with a placeholder
splitted = content.split('\n')
for i, line in enumerate(splitted):
if '![image](data:image/png;base64,' in line:
splitted[i] = (
'![image](data:image/png;base64, ...) already displayed to user'
)
content = '\n'.join(splitted)
content = truncate_content(content, max_message_chars)
return {'role': 'user', 'content': content}
elif isinstance(obs, AgentDelegateObservation):
content = 'OBSERVATION:\n' + truncate_content(
str(obs.outputs), max_message_chars
)
return {'role': 'user', 'content': content}
return None
# FIXME: We can tweak these two settings to create MicroAgents specialized toward different area
def get_system_message() -> str:
if ENABLE_GITHUB:
@@ -166,6 +110,61 @@ class CodeActAgent(Agent):
super().__init__(llm)
self.reset()
def action_to_str(self, action: Action) -> str:
if isinstance(action, CmdRunAction):
return (
f'{action.thought}\n<execute_bash>\n{action.command}\n</execute_bash>'
)
elif isinstance(action, IPythonRunCellAction):
return f'{action.thought}\n<execute_ipython>\n{action.code}\n</execute_ipython>'
elif isinstance(action, AgentDelegateAction):
return f'{action.thought}\n<execute_browse>\n{action.inputs["task"]}\n</execute_browse>'
elif isinstance(action, MessageAction):
return action.content
return ''
def get_action_message(self, action: Action) -> dict[str, str] | None:
if (
isinstance(action, AgentDelegateAction)
or isinstance(action, CmdRunAction)
or isinstance(action, IPythonRunCellAction)
or isinstance(action, MessageAction)
):
return {
'role': 'user' if action.source == 'user' else 'assistant',
'content': self.action_to_str(action),
}
return None
def get_observation_message(self, obs: Observation) -> dict[str, str] | None:
max_message_chars = self.llm.config.max_message_chars
if isinstance(obs, CmdOutputObservation):
content = 'OBSERVATION:\n' + truncate_content(
obs.content, max_message_chars
)
content += (
f'\n[Command {obs.command_id} finished with exit code {obs.exit_code}]'
)
return {'role': 'user', 'content': content}
elif isinstance(obs, IPythonRunCellObservation):
content = 'OBSERVATION:\n' + obs.content
# replace base64 images with a placeholder
splitted = content.split('\n')
for i, line in enumerate(splitted):
if '![image](data:image/png;base64,' in line:
splitted[i] = (
'![image](data:image/png;base64, ...) already displayed to user'
)
content = '\n'.join(splitted)
content = truncate_content(content, max_message_chars)
return {'role': 'user', 'content': content}
elif isinstance(obs, AgentDelegateObservation):
content = 'OBSERVATION:\n' + truncate_content(
str(obs.outputs), max_message_chars
)
return {'role': 'user', 'content': content}
return None
def reset(self) -> None:
"""Resets the CodeAct Agent."""
super().reset()
@@ -211,11 +210,12 @@ class CodeActAgent(Agent):
for event in state.history.get_events():
# create a regular message from an event
message = (
get_action_message(event)
if isinstance(event, Action)
else get_observation_message(event)
)
if isinstance(event, Action):
message = self.get_action_message(event)
elif isinstance(event, Observation):
message = self.get_observation_message(event)
else:
raise ValueError(f'Unknown event type: {type(event)}')
# add regular message
if message:

View File

@@ -7,7 +7,6 @@ from agenthub.codeact_swe_agent.prompt import (
from agenthub.codeact_swe_agent.response_parser import CodeActSWEResponseParser
from opendevin.controller.agent import Agent
from opendevin.controller.state.state import State
from opendevin.core.config import config
from opendevin.events.action import (
Action,
AgentFinishAction,
@@ -19,6 +18,7 @@ from opendevin.events.observation import (
CmdOutputObservation,
IPythonRunCellObservation,
)
from opendevin.events.observation.observation import Observation
from opendevin.events.serialization.event import truncate_content
from opendevin.llm.llm import LLM
from opendevin.runtime.plugins import (
@@ -29,54 +29,6 @@ from opendevin.runtime.plugins import (
from opendevin.runtime.tools import RuntimeTool
def action_to_str(action: Action) -> str:
if isinstance(action, CmdRunAction):
return f'{action.thought}\n<execute_bash>\n{action.command}\n</execute_bash>'
elif isinstance(action, IPythonRunCellAction):
return f'{action.thought}\n<execute_ipython>\n{action.code}\n</execute_ipython>'
elif isinstance(action, MessageAction):
return action.content
return ''
def get_action_message(action: Action) -> dict[str, str] | None:
if (
isinstance(action, CmdRunAction)
or isinstance(action, IPythonRunCellAction)
or isinstance(action, MessageAction)
):
return {
'role': 'user' if action.source == 'user' else 'assistant',
'content': action_to_str(action),
}
return None
def get_observation_message(obs) -> dict[str, str] | None:
max_message_chars = config.get_llm_config_from_agent(
'CodeActSWEAgent'
).max_message_chars
if isinstance(obs, CmdOutputObservation):
content = 'OBSERVATION:\n' + truncate_content(obs.content, max_message_chars)
content += (
f'\n[Command {obs.command_id} finished with exit code {obs.exit_code}]'
)
return {'role': 'user', 'content': content}
elif isinstance(obs, IPythonRunCellObservation):
content = 'OBSERVATION:\n' + obs.content
# replace base64 images with a placeholder
splitted = content.split('\n')
for i, line in enumerate(splitted):
if '![image](data:image/png;base64,' in line:
splitted[i] = (
'![image](data:image/png;base64, ...) already displayed to user'
)
content = '\n'.join(splitted)
content = truncate_content(content, max_message_chars)
return {'role': 'user', 'content': content}
return None
def get_system_message() -> str:
return f'{SYSTEM_PREFIX}\n\n{COMMAND_DOCS}\n\n{SYSTEM_SUFFIX}'
@@ -121,6 +73,53 @@ class CodeActSWEAgent(Agent):
super().__init__(llm)
self.reset()
def action_to_str(self, action: Action) -> str:
if isinstance(action, CmdRunAction):
return (
f'{action.thought}\n<execute_bash>\n{action.command}\n</execute_bash>'
)
elif isinstance(action, IPythonRunCellAction):
return f'{action.thought}\n<execute_ipython>\n{action.code}\n</execute_ipython>'
elif isinstance(action, MessageAction):
return action.content
return ''
def get_action_message(self, action: Action) -> dict[str, str] | None:
if (
isinstance(action, CmdRunAction)
or isinstance(action, IPythonRunCellAction)
or isinstance(action, MessageAction)
):
return {
'role': 'user' if action.source == 'user' else 'assistant',
'content': self.action_to_str(action),
}
return None
def get_observation_message(self, obs: Observation) -> dict[str, str] | None:
max_message_chars = self.llm.config.max_message_chars
if isinstance(obs, CmdOutputObservation):
content = 'OBSERVATION:\n' + truncate_content(
obs.content, max_message_chars
)
content += (
f'\n[Command {obs.command_id} finished with exit code {obs.exit_code}]'
)
return {'role': 'user', 'content': content}
elif isinstance(obs, IPythonRunCellObservation):
content = 'OBSERVATION:\n' + obs.content
# replace base64 images with a placeholder
splitted = content.split('\n')
for i, line in enumerate(splitted):
if '![image](data:image/png;base64,' in line:
splitted[i] = (
'![image](data:image/png;base64, ...) already displayed to user'
)
content = '\n'.join(splitted)
content = truncate_content(content, max_message_chars)
return {'role': 'user', 'content': content}
return None
def reset(self) -> None:
"""Resets the CodeAct Agent."""
super().reset()
@@ -165,11 +164,12 @@ class CodeActSWEAgent(Agent):
for event in state.history.get_events():
# create a regular message from an event
message = (
get_action_message(event)
if isinstance(event, Action)
else get_observation_message(event)
)
if isinstance(event, Action):
message = self.get_action_message(event)
elif isinstance(event, Observation):
message = self.get_observation_message(event)
else:
raise ValueError(f'Unknown event type: {type(event)}')
# add regular message
if message:

View File

@@ -2,7 +2,6 @@ from jinja2 import BaseLoader, Environment
from opendevin.controller.agent import Agent
from opendevin.controller.state.state import State
from opendevin.core.config import config
from opendevin.core.utils import json
from opendevin.events.action import Action
from opendevin.events.serialization.action import action_from_dict
@@ -27,32 +26,33 @@ def to_json(obj, **kwargs):
return json.dumps(obj, **kwargs)
def history_to_json(history: ShortTermHistory, max_events=20, **kwargs):
"""Serialize and simplify history to str format"""
# TODO: get agent specific llm config
llm_config = config.get_llm_config()
max_message_chars = llm_config.max_message_chars
processed_history = []
event_count = 0
for event in history.get_events(reverse=True):
if event_count >= max_events:
break
processed_history.append(event_to_memory(event, max_message_chars))
event_count += 1
# history is in reverse order, let's fix it
processed_history.reverse()
return json.dumps(processed_history, **kwargs)
class MicroAgent(Agent):
VERSION = '1.0'
prompt = ''
agent_definition: dict = {}
def history_to_json(
self, history: ShortTermHistory, max_events: int = 20, **kwargs
):
"""
Serialize and simplify history to str format
"""
processed_history = []
event_count = 0
for event in history.get_events(reverse=True):
if event_count >= max_events:
break
processed_history.append(
event_to_memory(event, self.llm.config.max_message_chars)
)
event_count += 1
# history is in reverse order, let's fix it
processed_history.reverse()
return json.dumps(processed_history, **kwargs)
def __init__(self, llm: LLM):
super().__init__(llm)
if 'name' not in self.agent_definition:
@@ -66,7 +66,7 @@ class MicroAgent(Agent):
state=state,
instructions=instructions,
to_json=to_json,
history_to_json=history_to_json,
history_to_json=self.history_to_json,
delegates=self.delegates,
latest_user_message=state.get_current_user_intent(),
)

View File

@@ -83,10 +83,7 @@ class MonologueAgent(Agent):
self._add_initial_thoughts(task)
self._initialized = True
def _add_initial_thoughts(self, task):
max_message_chars = config.get_llm_config_from_agent(
'MonologueAgent'
).max_message_chars
def _add_initial_thoughts(self, task: str):
previous_action = ''
for thought in INITIAL_THOUGHTS:
thought = thought.replace('$TASK', task)
@@ -103,7 +100,7 @@ class MonologueAgent(Agent):
content=thought, url='', screenshot=''
)
self.initial_thoughts.append(
event_to_memory(observation, max_message_chars)
event_to_memory(observation, self.llm.config.max_message_chars)
)
previous_action = ''
else:
@@ -127,7 +124,9 @@ class MonologueAgent(Agent):
previous_action = ActionType.BROWSE
else:
action = MessageAction(thought)
self.initial_thoughts.append(event_to_memory(action, max_message_chars))
self.initial_thoughts.append(
event_to_memory(action, self.llm.config.max_message_chars)
)
def step(self, state: State) -> Action:
"""Modifies the current state by adding the most recent actions and observations, then prompts the model to think about it's next action to take using monologue, memory, and hint.
@@ -138,9 +137,6 @@ class MonologueAgent(Agent):
Returns:
- Action: The next action to take based on LLM response
"""
max_message_chars = config.get_llm_config_from_agent(
'MonologueAgent'
).max_message_chars
goal = state.get_current_user_intent()
self._initialize(goal)
@@ -148,7 +144,9 @@ class MonologueAgent(Agent):
# add the events from state.history
for event in state.history.get_events():
recent_events.append(event_to_memory(event, max_message_chars))
recent_events.append(
event_to_memory(event, self.llm.config.max_message_chars)
)
# add the last messages to long term memory
if self.memory is not None:
@@ -158,10 +156,12 @@ class MonologueAgent(Agent):
# this should still work
# we will need to do this differently: find out if there really is an action or an observation in this step
if last_action:
self.memory.add_event(event_to_memory(last_action, max_message_chars))
self.memory.add_event(
event_to_memory(last_action, self.llm.config.max_message_chars)
)
if last_observation:
self.memory.add_event(
event_to_memory(last_observation, max_message_chars)
event_to_memory(last_observation, self.llm.config.max_message_chars)
)
# the action prompt with initial thoughts and recent events

View File

@@ -42,7 +42,7 @@ class PlannerAgent(Agent):
'abandoned',
]:
return AgentFinishAction()
prompt = get_prompt(state)
prompt = get_prompt(state, self.llm.config.max_message_chars)
messages = [{'content': prompt, 'role': 'user'}]
resp = self.llm.completion(messages=messages)
return self.response_parser.parse(resp)

View File

@@ -1,5 +1,4 @@
from opendevin.controller.state.state import State
from opendevin.core.config import config
from opendevin.core.logger import opendevin_logger as logger
from opendevin.core.schema import ActionType
from opendevin.core.utils import json
@@ -116,8 +115,9 @@ def get_hint(latest_action_id: str) -> str:
return hints.get(latest_action_id, '')
def get_prompt(state: State) -> str:
def get_prompt(state: State, max_message_chars: int) -> str:
"""Gets the prompt for the planner agent.
Formatted with the most recent action-observation pairs, current task, and hint based on last action
Parameters:
@@ -126,10 +126,6 @@ def get_prompt(state: State) -> str:
Returns:
- str: The formatted string prompt with historical values
"""
max_message_chars = config.get_llm_config_from_agent(
'PlannerAgent'
).max_message_chars
# the plan
plan_str = json.dumps(state.root_task.to_dict(), indent=2)

View File

@@ -248,7 +248,7 @@ class AgentController:
async def start_delegate(self, action: AgentDelegateAction):
agent_cls: Type[Agent] = Agent.get_cls(action.agent)
llm_config = config.get_llm_config_from_agent(action.agent)
llm = LLM(llm_config=llm_config)
llm = LLM(config=llm_config)
delegate_agent = agent_cls(llm=llm)
state = State(
inputs=action.inputs or {},

View File

@@ -198,7 +198,7 @@ class AppConfig(metaclass=Singleton):
file_uploads_allowed_extensions: List of allowed file extensions for uploads. ['.*'] means all extensions are allowed.
"""
llms: dict = field(default_factory=dict)
llms: dict[str, LLMConfig] = field(default_factory=dict)
agents: dict = field(default_factory=dict)
default_agent: str = 'CodeActAgent'
sandbox: SandboxConfig = field(default_factory=SandboxConfig)

View File

@@ -52,7 +52,7 @@ async def run_agent_controller(
"""
# Logging
logger.info(
f'Running agent {agent.name}, model {agent.llm.model_name}, with task: "{task_str}"'
f'Running agent {agent.name}, model {agent.llm.config.model}, with task: "{task_str}"'
)
# set up the event stream
@@ -163,7 +163,7 @@ if __name__ == '__main__':
if llm_config is None:
raise ValueError(f'Invalid toml file, cannot read {args.llm_config}')
config.set_llm_config(llm_config)
llm = LLM(llm_config=config.get_llm_config_from_agent(args.agent_cls))
llm = LLM(config=config.get_llm_config_from_agent(args.agent_cls))
# Create the agent
AgentCls: Type[Agent] = Agent.get_cls(args.agent_cls)

View File

@@ -1,6 +1,9 @@
import copy
import warnings
from functools import partial
from opendevin.core.config import LLMConfig
with warnings.catch_warnings():
warnings.simplefilter('ignore')
import litellm
@@ -21,7 +24,6 @@ from tenacity import (
wait_random_exponential,
)
from opendevin.core.config import config
from opendevin.core.logger import llm_prompt_logger, llm_response_logger
from opendevin.core.logger import opendevin_logger as logger
from opendevin.core.metrics import Metrics
@@ -35,155 +37,71 @@ class LLM:
"""The LLM class represents a Language Model instance.
Attributes:
model_name (str): The name of the language model.
api_key (str): The API key for accessing the language model.
base_url (str): The base URL for the language model API.
api_version (str): The version of the API to use.
max_input_tokens (int): The maximum number of tokens to send to the LLM per task.
max_output_tokens (int): The maximum number of tokens to receive from the LLM per task.
llm_timeout (int): The maximum time to wait for a response in seconds.
custom_llm_provider (str): A custom LLM provider.
config: an LLMConfig object specifying the configuration of the LLM.
"""
def __init__(
self,
model=None,
api_key=None,
base_url=None,
api_version=None,
num_retries=None,
retry_min_wait=None,
retry_max_wait=None,
llm_timeout=None,
llm_temperature=None,
llm_top_p=None,
custom_llm_provider=None,
max_input_tokens=None,
max_output_tokens=None,
llm_config=None,
metrics=None,
cost_metric_supported=True,
input_cost_per_token=None,
output_cost_per_token=None,
config: LLMConfig,
metrics: Metrics | None = None,
):
"""Initializes the LLM. If LLMConfig is passed, its values will be the fallback.
Passing simple parameters always overrides config.
Args:
model (str, optional): The name of the language model. Defaults to LLM_MODEL.
api_key (str, optional): The API key for accessing the language model. Defaults to LLM_API_KEY.
base_url (str, optional): The base URL for the language model API. Defaults to LLM_BASE_URL. Not necessary for OpenAI.
api_version (str, optional): The version of the API to use. Defaults to LLM_API_VERSION. Not necessary for OpenAI.
num_retries (int, optional): The number of retries for API calls. Defaults to LLM_NUM_RETRIES.
retry_min_wait (int, optional): The minimum time to wait between retries in seconds. Defaults to LLM_RETRY_MIN_TIME.
retry_max_wait (int, optional): The maximum time to wait between retries in seconds. Defaults to LLM_RETRY_MAX_TIME.
max_input_tokens (int, optional): The maximum number of tokens to send to the LLM per task. Defaults to LLM_MAX_INPUT_TOKENS.
max_output_tokens (int, optional): The maximum number of tokens to receive from the LLM per task. Defaults to LLM_MAX_OUTPUT_TOKENS.
custom_llm_provider (str, optional): A custom LLM provider. Defaults to LLM_CUSTOM_LLM_PROVIDER.
llm_timeout (int, optional): The maximum time to wait for a response in seconds. Defaults to LLM_TIMEOUT.
llm_temperature (float, optional): The temperature for LLM sampling. Defaults to LLM_TEMPERATURE.
metrics (Metrics, optional): The metrics object to use. Defaults to None.
cost_metric_supported (bool, optional): Whether the cost metric is supported. Defaults to True.
input_cost_per_token (float, optional): The cost per input token.
output_cost_per_token (float, optional): The cost per output token.
config: The LLM configuration
"""
if llm_config is None:
llm_config = config.get_llm_config()
model = model if model is not None else llm_config.model
api_key = api_key if api_key is not None else llm_config.api_key
base_url = base_url if base_url is not None else llm_config.base_url
api_version = api_version if api_version is not None else llm_config.api_version
num_retries = num_retries if num_retries is not None else llm_config.num_retries
retry_min_wait = (
retry_min_wait if retry_min_wait is not None else llm_config.retry_min_wait
)
retry_max_wait = (
retry_max_wait if retry_max_wait is not None else llm_config.retry_max_wait
)
llm_timeout = llm_timeout if llm_timeout is not None else llm_config.timeout
llm_temperature = (
llm_temperature if llm_temperature is not None else llm_config.temperature
)
llm_top_p = llm_top_p if llm_top_p is not None else llm_config.top_p
custom_llm_provider = (
custom_llm_provider
if custom_llm_provider is not None
else llm_config.custom_llm_provider
)
max_input_tokens = (
max_input_tokens
if max_input_tokens is not None
else llm_config.max_input_tokens
)
max_output_tokens = (
max_output_tokens
if max_output_tokens is not None
else llm_config.max_output_tokens
)
input_cost_per_token = (
input_cost_per_token
if input_cost_per_token is not None
else llm_config.input_cost_per_token
)
output_cost_per_token = (
output_cost_per_token
if output_cost_per_token is not None
else llm_config.output_cost_per_token
)
metrics = metrics if metrics is not None else Metrics()
logger.info(f'Initializing LLM with model: {model}')
self.model_name = model
self.api_key = api_key
self.base_url = base_url
self.api_version = api_version
self.max_input_tokens = max_input_tokens
self.max_output_tokens = max_output_tokens
self.input_cost_per_token = input_cost_per_token
self.output_cost_per_token = output_cost_per_token
self.llm_timeout = llm_timeout
self.custom_llm_provider = custom_llm_provider
self.metrics = metrics
self.cost_metric_supported = cost_metric_supported
self.config = copy.deepcopy(config)
self.metrics = metrics if metrics is not None else Metrics()
self.cost_metric_supported = True
# litellm actually uses base Exception here for unknown model
self.model_info = None
try:
if not self.model_name.startswith('openrouter'):
self.model_info = litellm.get_model_info(self.model_name.split(':')[0])
if not config.model.startswith('openrouter'):
self.model_info = litellm.get_model_info(config.model.split(':')[0])
else:
self.model_info = litellm.get_model_info(self.model_name)
self.model_info = litellm.get_model_info(config.model)
# noinspection PyBroadException
except Exception:
logger.warning(f'Could not get model info for {self.model_name}')
logger.warning(f'Could not get model info for {config.model}')
if self.max_input_tokens is None:
if self.model_info is not None and 'max_input_tokens' in self.model_info:
self.max_input_tokens = self.model_info['max_input_tokens']
# Set the max tokens in an LM-specific way if not set
if config.max_input_tokens is None:
if (
self.model_info is not None
and 'max_input_tokens' in self.model_info
and isinstance(self.model_info['max_input_tokens'], int)
):
self.config.max_input_tokens = self.model_info['max_input_tokens']
else:
# Max input tokens for gpt3.5, so this is a safe fallback for any potentially viable model
self.max_input_tokens = 4096
self.config.max_input_tokens = 4096
if self.max_output_tokens is None:
if self.model_info is not None and 'max_output_tokens' in self.model_info:
self.max_output_tokens = self.model_info['max_output_tokens']
if config.max_output_tokens is None:
if (
self.model_info is not None
and 'max_output_tokens' in self.model_info
and isinstance(self.model_info['max_output_tokens'], int)
):
self.config.max_output_tokens = self.model_info['max_output_tokens']
else:
# Enough tokens for most output actions, and not too many for a bad llm to get carried away responding
# with thousands of unwanted tokens
self.max_output_tokens = 1024
# Max output tokens for gpt3.5, so this is a safe fallback for any potentially viable model
self.config.max_output_tokens = 1024
self._completion = partial(
litellm_completion,
model=self.model_name,
api_key=self.api_key,
base_url=self.base_url,
api_version=self.api_version,
custom_llm_provider=custom_llm_provider,
max_tokens=self.max_output_tokens,
timeout=self.llm_timeout,
temperature=llm_temperature,
top_p=llm_top_p,
model=self.config.model,
api_key=self.config.api_key,
base_url=self.config.base_url,
api_version=self.config.api_version,
custom_llm_provider=self.config.custom_llm_provider,
max_tokens=self.config.max_output_tokens,
timeout=self.config.timeout,
temperature=self.config.temperature,
top_p=self.config.top_p,
)
completion_unwrapped = self._completion
@@ -197,8 +115,10 @@ class LLM:
@retry(
reraise=True,
stop=stop_after_attempt(num_retries),
wait=wait_random_exponential(min=retry_min_wait, max=retry_max_wait),
stop=stop_after_attempt(config.num_retries),
wait=wait_random_exponential(
min=config.retry_min_wait, max=config.retry_max_wait
),
retry=retry_if_exception_type(
(
RateLimitError,
@@ -267,7 +187,7 @@ class LLM:
Returns:
int: The number of tokens.
"""
return litellm.token_counter(model=self.model_name, messages=messages)
return litellm.token_counter(model=self.config.model, messages=messages)
def is_local(self):
"""Determines if the system is using a locally running LLM.
@@ -275,12 +195,12 @@ class LLM:
Returns:
boolean: True if executing a local model.
"""
if self.base_url is not None:
if self.config.base_url is not None:
for substring in ['localhost', '127.0.0.1' '0.0.0.0']:
if substring in self.base_url:
if substring in self.config.base_url:
return True
elif self.model_name is not None:
if self.model_name.startswith('ollama'):
elif self.config.model is not None:
if self.config.model.startswith('ollama'):
return True
return False
@@ -299,12 +219,12 @@ class LLM:
extra_kwargs = {}
if (
self.input_cost_per_token is not None
and self.output_cost_per_token is not None
self.config.input_cost_per_token is not None
and self.config.output_cost_per_token is not None
):
cost_per_token = CostPerToken(
input_cost_per_token=self.input_cost_per_token,
output_cost_per_token=self.output_cost_per_token,
input_cost_per_token=self.config.input_cost_per_token,
output_cost_per_token=self.config.output_cost_per_token,
)
logger.info(f'Using custom cost per token: {cost_per_token}')
extra_kwargs['custom_cost_per_token'] = cost_per_token
@@ -322,11 +242,11 @@ class LLM:
return 0.0
def __str__(self):
if self.api_version:
return f'LLM(model={self.model_name}, api_version={self.api_version}, base_url={self.base_url})'
elif self.base_url:
return f'LLM(model={self.model_name}, base_url={self.base_url})'
return f'LLM(model={self.model_name})'
if self.config.api_version:
return f'LLM(model={self.config.model}, api_version={self.config.api_version}, base_url={self.config.base_url})'
elif self.config.base_url:
return f'LLM(model={self.config.model}, base_url={self.config.base_url})'
return f'LLM(model={self.config.model})'
def __repr__(self):
return str(self)

View File

@@ -97,7 +97,7 @@ class AgentSession:
# TODO: override other LLM config & agent config groups (#2075)
llm = LLM(llm_config=config.get_llm_config_from_agent(agent_cls))
llm = LLM(config=config.get_llm_config_from_agent(agent_cls))
agent = Agent.get_cls(agent_cls)(llm)
logger.info(f'Creating agent {agent.name} using LLM {llm}')
if isinstance(agent, CodeActAgent):

View File

@@ -7,7 +7,7 @@ import pytest
from opendevin.controller.agent import Agent
from opendevin.controller.state.state import State
from opendevin.core.config import parse_arguments
from opendevin.core.config import LLMConfig, parse_arguments
from opendevin.core.main import run_agent_controller
from opendevin.core.schema import AgentState
from opendevin.events.action import (
@@ -44,20 +44,22 @@ print(f'workspace_mount_path_in_sandbox: {workspace_mount_path_in_sandbox}')
os.getenv('DEFAULT_AGENT') == 'ManagerAgent',
reason='Manager agent is not capable of finishing this in reasonable steps yet',
)
def test_write_simple_script():
def test_write_simple_script() -> None:
task = "Write a shell script 'hello.sh' that prints 'hello'. Do not ask me for confirmation at any point."
args = parse_arguments()
# Create the agent
agent = Agent.get_cls(args.agent_cls)(llm=LLM())
agent = Agent.get_cls(args.agent_cls)(llm=LLM(LLMConfig()))
final_state: State | None = asyncio.run(
run_agent_controller(agent, task, exit_on_message=True)
)
assert final_state is not None
assert final_state.agent_state == AgentState.STOPPED
assert final_state.last_error is None
# Verify the script file exists
assert workspace_base is not None
script_path = os.path.join(workspace_base, 'hello.sh')
assert os.path.exists(script_path), 'The file "hello.sh" does not exist'
@@ -103,7 +105,7 @@ def test_edits():
shutil.copy(os.path.join(source_dir, file), dest_file)
# Create the agent
agent = Agent.get_cls(args.agent_cls)(llm=LLM())
agent = Agent.get_cls(args.agent_cls)(llm=LLM(LLMConfig()))
# Execute the task
task = 'Fix typos in bad.txt. Do not ask me for confirmation at any point.'
@@ -137,7 +139,7 @@ def test_ipython():
args = parse_arguments()
# Create the agent
agent = Agent.get_cls(args.agent_cls)(llm=LLM())
agent = Agent.get_cls(args.agent_cls)(llm=LLM(LLMConfig()))
# Execute the task
task = "Use Jupyter IPython to write a text file containing 'hello world' to '/workspace/test.txt'. Do not ask me for confirmation at any point."
@@ -171,7 +173,7 @@ def test_simple_task_rejection():
args = parse_arguments()
# Create the agent
agent = Agent.get_cls(args.agent_cls)(llm=LLM())
agent = Agent.get_cls(args.agent_cls)(llm=LLM(LLMConfig()))
# Give an impossible task to do: cannot write a commit message because
# the workspace is not a git repo
@@ -195,7 +197,7 @@ def test_ipython_module():
args = parse_arguments()
# Create the agent
agent = Agent.get_cls(args.agent_cls)(llm=LLM())
agent = Agent.get_cls(args.agent_cls)(llm=LLM(LLMConfig()))
# Execute the task
task = "Install and import pymsgbox==1.0.9 and print it's version in /workspace/test.txt. Do not ask me for confirmation at any point."
@@ -235,7 +237,7 @@ def test_browse_internet(http_server):
args = parse_arguments()
# Create the agent
agent = Agent.get_cls(args.agent_cls)(llm=LLM())
agent = Agent.get_cls(args.agent_cls)(llm=LLM(LLMConfig()))
# Execute the task
task = 'Browse localhost:8000, and tell me the ultimate answer to life. Do not ask me for confirmation at any point.'

View File

@@ -1,4 +1,3 @@
from opendevin.core.config import config
from opendevin.events.action import (
Action,
AddTaskAction,
@@ -20,7 +19,9 @@ from opendevin.events.serialization import (
)
def serialization_deserialization(original_action_dict, cls):
def serialization_deserialization(
original_action_dict, cls, max_message_chars: int = 10000
):
action_instance = event_from_dict(original_action_dict)
assert isinstance(
action_instance, Action
@@ -29,9 +30,7 @@ def serialization_deserialization(original_action_dict, cls):
action_instance, cls
), f'The action instance should be an instance of {cls.__name__}.'
serialized_action_dict = event_to_dict(action_instance)
serialized_action_memory = event_to_memory(
action_instance, config.get_llm_config().max_message_chars
)
serialized_action_memory = event_to_memory(action_instance, max_message_chars)
serialized_action_dict.pop('message')
assert (
serialized_action_dict == original_action_dict

View File

@@ -1,4 +1,3 @@
from opendevin.core.config import config
from opendevin.events.observation import (
CmdOutputObservation,
Observation,
@@ -10,7 +9,9 @@ from opendevin.events.serialization import (
)
def serialization_deserialization(original_observation_dict, cls):
def serialization_deserialization(
original_observation_dict, cls, max_message_chars: int = 10000
):
observation_instance = event_from_dict(original_observation_dict)
assert isinstance(
observation_instance, Observation
@@ -20,7 +21,7 @@ def serialization_deserialization(original_observation_dict, cls):
), 'The observation instance should be an instance of CmdOutputObservation.'
serialized_observation_dict = event_to_dict(observation_instance)
serialized_observation_memory = event_to_memory(
observation_instance, config.get_llm_config().max_message_chars
observation_instance, max_message_chars
)
assert (
serialized_observation_dict == original_observation_dict