initial update

This commit is contained in:
Yuhang Zhou 2025-03-05 12:26:04 +08:00
commit 62da328e7b
482 changed files with 43122 additions and 0 deletions

5
.gitignore vendored Normal file
View File

@ -0,0 +1,5 @@
.dist
deep-swarm/data
deep-swarm/tmp
deep-swarm/.env
deep-swarm/utils/__pycache__/

45
README.md Normal file
View File

@ -0,0 +1,45 @@
# DeepSwarm
## Overview
DeepSwarm is a multi-agent framework based from [camel](https://github.com/camel-ai/camel/). It achieved open-source state-of-the-art performance on the [GAIA](https://huggingface.co/datasets/gaia-benchmark/GAIA) benchmark.
## Quickstart
It is recommended to run the code in linux environment.
To get started, follow these steps:
1. **Clone the Github repository:**
```bash
$ git clone xxx
```
2. **Set up Python Environment:**
```bash
$ conda create -n deepswarm python=3.11
$ conda activate deepswarm
```
3. **Install Dependencies:**
```bash
$ pip install -r requirements.txt
```
4. **Set API Keys:** We use `dotenv` to manage API keys. Please copy and check the `.env.example` file to `.env` and fill in the necessary API keys.
5. **Run the Demo Code:**
```bash
$ python run.py
```
## Reproduce the Results in GAIA
We have provided a script to reproduce the results in GAIA. You can check the `run_gaia_roleplaying.py` file and run the following command:
```bash
$ python run_gaia_roleplaying.py
```

25
deep-swarm/.env_template Normal file
View File

@ -0,0 +1,25 @@
# OPENAI API
OPENAI_API_KEY = ""
# Hugging Face API (https://huggingface.co/join)
HF_TOKEN=""
# Qwen API (https://help.aliyun.com/document_detail/611472.html)
QWEN_API_KEY=""
#===========================================
# Tools & Services API
#===========================================
# Google Search API (https://developers.google.com/custom-search/v1/overview)
GOOGLE_API_KEY=""
SEARCH_ENGINE_ID=""
# Chunkr API (https://chunkr.ai/)
CHUNKR_API_KEY=""
# Firecrawl API (https://www.firecrawl.dev/)
FIRECRAWL_API_KEY=""

View File

@ -0,0 +1,25 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from camel.logger import disable_logging, enable_logging, set_log_level
__version__ = '0.2.11'
__all__ = [
'__version__',
'camel',
'disable_logging',
'enable_logging',
'set_log_level',
]

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -0,0 +1,44 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from .base import BaseAgent
from .chat_agent import ChatAgent
from .critic_agent import CriticAgent
from .embodied_agent import EmbodiedAgent
from .knowledge_graph_agent import KnowledgeGraphAgent
from .role_assignment_agent import RoleAssignmentAgent
from .search_agent import SearchAgent
from .task_agent import (
TaskCreationAgent,
TaskPlannerAgent,
TaskPrioritizationAgent,
TaskSpecifyAgent,
)
from .tool_agents.base import BaseToolAgent
from .tool_agents.hugging_face_tool_agent import HuggingFaceToolAgent
__all__ = [
'BaseAgent',
'ChatAgent',
'TaskSpecifyAgent',
'TaskPlannerAgent',
'TaskCreationAgent',
'TaskPrioritizationAgent',
'CriticAgent',
'BaseToolAgent',
'HuggingFaceToolAgent',
'EmbodiedAgent',
'RoleAssignmentAgent',
'SearchAgent',
'KnowledgeGraphAgent',
]

View File

@ -0,0 +1,29 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from abc import ABC, abstractmethod
from typing import Any
class BaseAgent(ABC):
r"""An abstract base class for all CAMEL agents."""
@abstractmethod
def reset(self, *args: Any, **kwargs: Any) -> Any:
r"""Resets the agent to its initial state."""
pass
@abstractmethod
def step(self, *args: Any, **kwargs: Any) -> Any:
r"""Performs a single step of the agent."""
pass

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,202 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import random
import warnings
from typing import Any, Dict, Optional, Sequence
from colorama import Fore
from camel.agents.chat_agent import ChatAgent
from camel.memories import AgentMemory
from camel.messages import BaseMessage
from camel.models import BaseModelBackend
from camel.responses import ChatAgentResponse
from camel.utils import get_first_int, print_text_animated
# AgentOps decorator setting
try:
import os
if os.getenv("AGENTOPS_API_KEY") is not None:
from agentops import track_agent
else:
raise ImportError
except (ImportError, AttributeError):
from camel.utils import track_agent
@track_agent(name="CriticAgent")
class CriticAgent(ChatAgent):
r"""A class for the critic agent that assists in selecting an option.
Args:
system_message (BaseMessage): The system message for the critic
agent.
model (BaseModelBackend, optional): The model backend to use for
generating responses. (default: :obj:`OpenAIModel` with
`GPT_4O_MINI`)
message_window_size (int, optional): The maximum number of previous
messages to include in the context window. If `None`, no windowing
is performed. (default: :obj:`6`)
retry_attempts (int, optional): The number of retry attempts if the
critic fails to return a valid option. (default: :obj:`2`)
verbose (bool, optional): Whether to print the critic's messages.
logger_color (Any): The color of the menu options displayed to the
user. (default: :obj:`Fore.MAGENTA`)
"""
def __init__(
self,
system_message: BaseMessage,
model: Optional[BaseModelBackend] = None,
memory: Optional[AgentMemory] = None,
message_window_size: int = 6,
retry_attempts: int = 2,
verbose: bool = False,
logger_color: Any = Fore.MAGENTA,
) -> None:
super().__init__(
system_message,
model=model,
memory=memory,
message_window_size=message_window_size,
)
self.options_dict: Dict[str, str] = dict()
self.retry_attempts = retry_attempts
self.verbose = verbose
self.logger_color = logger_color
def flatten_options(self, messages: Sequence[BaseMessage]) -> str:
r"""Flattens the options to the critic.
Args:
messages (Sequence[BaseMessage]): A list of `BaseMessage` objects.
Returns:
str: A string containing the flattened options to the critic.
"""
options = [message.content for message in messages]
flatten_options = (
f"> Proposals from "
f"{messages[0].role_name} ({messages[0].role_type}). "
"Please choose an option:\n"
)
for index, option in enumerate(options):
flatten_options += f"Option {index + 1}:\n{option}\n\n"
self.options_dict[str(index + 1)] = option
format = (
f"Please first enter your choice ([1-{len(self.options_dict)}]) "
"and then your explanation and comparison: "
)
return flatten_options + format
def get_option(self, input_message: BaseMessage) -> str:
r"""Gets the option selected by the critic.
Args:
input_message (BaseMessage): A `BaseMessage` object representing
the input message.
Returns:
str: The option selected by the critic.
"""
# TODO: Add support for editing options by the critic.
msg_content = input_message.content
i = 0
while i < self.retry_attempts:
critic_response = self.step(input_message)
if critic_response.msgs is None or len(critic_response.msgs) == 0:
raise RuntimeError("Got None critic messages.")
if critic_response.terminated:
raise RuntimeError("Critic step failed.")
critic_msg = critic_response.msg
if self.verbose:
print_text_animated(
self.logger_color + "\n> Critic response: "
f"\x1b[3m{critic_msg.content}\x1b[0m\n"
)
choice = self.parse_critic(critic_msg)
if choice in self.options_dict:
return self.options_dict[choice]
else:
input_message = BaseMessage(
role_name=input_message.role_name,
role_type=input_message.role_type,
meta_dict=input_message.meta_dict,
content="> Invalid choice. Please choose again.\n"
+ msg_content,
)
i += 1
warnings.warn(
"Critic failed to get a valid option. "
f"After {self.retry_attempts} attempts. "
"Returning a random option."
)
return random.choice(list(self.options_dict.values()))
def parse_critic(self, critic_msg: BaseMessage) -> Optional[str]:
r"""Parses the critic's message and extracts the choice.
Args:
critic_msg (BaseMessage): A `BaseMessage` object representing the
critic's response.
Returns:
Optional[str]: The critic's choice as a string, or None if the
message could not be parsed.
"""
choice = str(get_first_int(critic_msg.content))
return choice
def reduce_step(
self,
input_messages: Sequence[BaseMessage],
) -> ChatAgentResponse:
r"""Performs one step of the conversation by flattening options to the
critic, getting the option, and parsing the choice.
Args:
input_messages (Sequence[BaseMessage]): A list of BaseMessage
objects.
Returns:
ChatAgentResponse: A `ChatAgentResponse` object includes the
critic's choice.
"""
meta_chat_message = BaseMessage(
role_name=input_messages[0].role_name,
role_type=input_messages[0].role_type,
meta_dict=input_messages[0].meta_dict,
content="",
)
flatten_options = self.flatten_options(input_messages)
if self.verbose:
print_text_animated(
self.logger_color + f"\x1b[3m{flatten_options}\x1b[0m\n"
)
input_msg = meta_chat_message.create_new_instance(flatten_options)
option = self.get_option(input_msg)
output_msg = meta_chat_message.create_new_instance(option)
# TODO: The return `info` can be improved.
return ChatAgentResponse(
msgs=[output_msg],
terminated=False,
info={},
)

View File

@ -0,0 +1,303 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import re
from typing import Dict, List, Optional, Union
from camel.agents.chat_agent import ChatAgent
from camel.logger import get_logger
from camel.messages import BaseMessage
from camel.models import BaseModelBackend
from camel.prompts import TextPrompt
from camel.types import RoleType
logger = get_logger(__name__)
# AgentOps decorator setting
try:
import os
if os.getenv("AGENTOPS_API_KEY") is not None:
from agentops import track_agent
else:
raise ImportError
except (ImportError, AttributeError):
from camel.utils import track_agent
@track_agent(name="DeductiveReasonerAgent")
class DeductiveReasonerAgent(ChatAgent):
r"""An agent responsible for deductive reasoning. Model of deductive
reasoning:
- L: A C -> q * B
- A represents the known starting state.
- B represents the known target state.
- C represents the conditions required to transition from A to B.
- Q represents the quality or effectiveness of the transition from
A to B.
- L represents the path or process from A to B.
Args:
model (BaseModelBackend, optional): The model backend to use for
generating responses. (default: :obj:`OpenAIModel` with
`GPT_4O_MINI`)
"""
def __init__(
self,
model: Optional[BaseModelBackend] = None,
) -> None:
system_message = BaseMessage(
role_name="Insight Agent",
role_type=RoleType.ASSISTANT,
meta_dict=None,
content="You assign roles based on tasks.",
)
super().__init__(system_message, model=model)
def deduce_conditions_and_quality(
self,
starting_state: str,
target_state: str,
role_descriptions_dict: Optional[Dict[str, str]] = None,
) -> Dict[str, Union[List[str], Dict[str, str]]]:
r"""Derives the conditions and quality from the starting state and the
target state based on the model of the deductive reasoning and the
knowledge base. It can optionally consider the roles involved in the
scenario, which allows tailoring the output more closely to the AI
agent's environment.
Args:
starting_state (str): The initial or starting state from which
conditions are deduced.
target_state (str): The target state of the task.
role_descriptions_dict (Optional[Dict[str, str]], optional): The
descriptions of the roles. (default: :obj:`None`)
role_descriptions_dict (Optional[Dict[str, str]], optional): A
dictionary describing the roles involved in the scenario. This
is optional and can be used to provide a context for the
CAMEL's role-playing, enabling the generation of more relevant
and tailored conditions and quality assessments. This could be
generated using a `RoleAssignmentAgent()` or defined manually
by the user.
Returns:
Dict[str, Union[List[str], Dict[str, str]]]: A dictionary with the
extracted data from the message. The dictionary contains three
keys:
- 'conditions': A list where each key is a condition ID and
each value is the corresponding condition text.
- 'labels': A list of label strings extracted from the message.
- 'quality': A string of quality assessment strings extracted
from the message.
"""
self.reset()
deduce_prompt = """You are a deductive reasoner. You are tasked to
complete the TASK based on the THOUGHT OF DEDUCTIVE REASONING, the
STARTING STATE A and the TARGET STATE B. You are given the CONTEXT
CONTENT to help you complete the TASK.
Your answer MUST strictly adhere to the structure of ANSWER TEMPLATE, ONLY
fill in the BLANKs, and DO NOT alter or modify any other part of the template
===== MODELING OF DEDUCTIVE REASONING =====
You are tasked with understanding a mathematical model based on the components
${A, B, C, Q, L}$. In this model: ``L: A C -> q * B``.
- $A$ represents the known starting state.
- $B$ represents the known target state.
- $C$ represents the conditions required to transition from $A$ to $B$.
- $Q$ represents the quality or effectiveness of the transition from $A$ to
$B$.
- $L$ represents the path or process from $A$ to $B$.
===== THOUGHT OF DEDUCTIVE REASONING =====
1. Define the Parameters of A and B:
- Characterization: Before delving into transitions, thoroughly understand
the nature and boundaries of both $A$ and $B$. This includes the type,
properties, constraints, and possible interactions between the two.
- Contrast and Compare: Highlight the similarities and differences between
$A$ and $B$. This comparative analysis will give an insight into what
needs changing and what remains constant.
2. Historical & Empirical Analysis:
- Previous Transitions according to the Knowledge Base of GPT: (if
applicable) Extract conditions and patterns from the historical instances
where a similar transition from a state comparable to $A$ moved towards
$B$.
- Scientific Principles: (if applicable) Consider the underlying
scientific principles governing or related to the states and their
transition. For example, if $A$ and $B$ are physical states, laws of
physics might apply.
3. Logical Deduction of Conditions ($C$):
- Direct Path Analysis: What are the immediate and direct conditions
required to move from $A$ to $B$?
- Intermediate States: Are there states between $A$ and $B$ that must be
traversed or can be used to make the transition smoother or more
efficient? If yes, what is the content?
- Constraints & Limitations: Identify potential barriers or restrictions
in moving from $A$ to $B$. These can be external (e.g., environmental
factors) or internal (properties of $A$ or $B$).
- Resource and Information Analysis: What resources and information are
required for the transition? This could be time, entity, factor, code
language, software platform, unknowns, etc.
- External Influences: Consider socio-economic, political, or
environmental factors (if applicable) that could influence the transition
conditions.
- Creative/Heuristic Reasoning: Open your mind to multiple possible $C$'s,
no matter how unconventional they might seem. Utilize analogies,
metaphors, or brainstorming techniques to envision possible conditions or
paths from $A$ to $B$.
- The conditions $C$ should be multiple but in one sentence. And each
condition should be concerned with one aspect/entity.
4. Entity/Label Recognition of Conditions ($C$):
- Identify and categorize entities of Conditions ($C$) such as the names,
locations, dates, specific technical terms or contextual parameters that
might be associated with events, innovations post-2022.
- The output of the entities/labels will be used as tags or labels for
semantic similarity searches. The entities/labels may be the words, or
phrases, each of them should contain valuable, high information entropy
information, and should be independent.
- Ensure that the identified entities are formatted in a manner suitable
for database indexing and retrieval. Organize the entities into
categories, and combine the category with its instance into a continuous
phrase, without using colons or other separators.
- Format these entities for database indexing: output the category rather
than its instance/content into a continuous phrase. For example, instead
of "Jan. 02", identify it as "Event time".
5. Quality Assessment ($Q$):
- Efficiency: How efficient is the transition from $A$ to $B$, which
measures the resources used versus the desired outcome?
- Effectiveness: Did the transition achieve the desired outcome or was the
target state achieved as intended?
- Safety & Risks: Assess any risks associated with the transition and the
measures to mitigate them.
- Feedback Mechanisms: Incorporate feedback loops to continuously monitor
and adjust the quality of transition, making it more adaptive.
6. Iterative Evaluation:
- Test & Refine: Based on the initially deduced conditions and assessed
quality, iterate the process to refine and optimize the transition. This
might involve tweaking conditions, employing different paths, or changing
resources.
- Feedback Integration: Use feedback to make improvements and increase the
quality of the transition.
7. Real-world scenarios often present challenges that may not be captured by
models and frameworks. While using the model, maintain an adaptive mindset:
- Scenario Exploration: Continuously imagine various possible scenarios,
both positive and negative, to prepare for unexpected events.
- Flexibility: Be prepared to modify conditions ($C$) or alter the path/
process ($L$) if unforeseen challenges arise.
- Feedback Integration: Rapidly integrate feedback from actual
implementations to adjust the model's application, ensuring relevancy and
effectiveness.
===== TASK =====
Given the starting state $A$ and the target state $B$, assuming that a path
$L$ always exists between $A$ and $B$, how can one deduce or identify the
necessary conditions $C$ and the quality $Q$ of the transition?
===== STARTING STATE $A$ =====
{starting_state}
===== TARGET STATE $B$ =====
{target_state}
{role_with_description_prompt}
===== ANSWER TEMPLATE =====
- Characterization and comparison of $A$ and $B$:\n<BLANK>
- Historical & Empirical Analysis:\n<BLANK>/None
- Logical Deduction of Conditions ($C$) (multiple conditions can be deduced):
condition <NUM>:
<BLANK>.
- Entity/Label Recognition of Conditions:\n[<BLANK>, <BLANK>, ...] (include
square brackets)
- Quality Assessment ($Q$) (do not use symbols):
<BLANK>.
- Iterative Evaluation:\n<BLANK>/None"""
if role_descriptions_dict is not None:
role_names = role_descriptions_dict.keys()
role_with_description_prompt = (
"===== ROLES WITH DESCRIPTIONS =====\n"
+ "\n".join(
f"{role_name}:\n{role_descriptions_dict[role_name]}\n"
for role_name in role_names
)
+ "\n\n"
)
else:
role_with_description_prompt = ""
deduce_prompt = TextPrompt(deduce_prompt)
deduce = deduce_prompt.format(
starting_state=starting_state,
target_state=target_state,
role_with_description_prompt=role_with_description_prompt,
)
conditions_and_quality_generation_msg = BaseMessage.make_user_message(
role_name="Deductive Reasoner", content=deduce
)
response = self.step(
input_message=conditions_and_quality_generation_msg
)
if response.terminated:
raise RuntimeError(
"Deduction failed. Error:\n" + f"{response.info}"
)
msg: BaseMessage = response.msg
logger.info(f"Message content:\n{msg.content}")
# Extract the conditions from the message
conditions_dict = {
f"condition {i}": cdt.replace("<", "")
.replace(">", "")
.strip()
.strip('\n')
for i, cdt in re.findall(
r"condition (\d+):\s*(.+?)(?=condition \d+|- Entity)",
msg.content,
re.DOTALL,
)
}
# Extract the labels from the message
labels = [
label.strip().strip('\n').strip("\"'")
for label in re.findall(
r"Entity/Label Recognition of Conditions:\n\[(.+?)\]",
msg.content,
re.DOTALL,
)[0].split(",")
]
# Extract the quality from the message
quality = next(
q.strip().strip('\n')
for q in re.findall(
r"Quality Assessment \(\$Q\$\) \(do not use symbols\):"
r"\n(.+?)- Iterative",
msg.content,
re.DOTALL,
)
)
# Convert them into JSON format
conditions_and_quality_json: Dict[
str, Union[List[str], Dict[str, str]]
] = {}
conditions_and_quality_json["conditions"] = conditions_dict
conditions_and_quality_json["labels"] = labels
conditions_and_quality_json["evaluate_quality"] = quality
return conditions_and_quality_json

View File

@ -0,0 +1,201 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from typing import Any, List, Optional
from colorama import Fore
from camel.agents.chat_agent import ChatAgent
from camel.agents.tool_agents.base import BaseToolAgent
from camel.interpreters import (
BaseInterpreter,
InternalPythonInterpreter,
SubprocessInterpreter,
)
from camel.messages import BaseMessage
from camel.models import BaseModelBackend
from camel.responses import ChatAgentResponse
from camel.utils import print_text_animated
# AgentOps decorator setting
try:
import os
if os.getenv("AGENTOPS_API_KEY") is not None:
from agentops import track_agent
else:
raise ImportError
except (ImportError, AttributeError):
from camel.utils import track_agent
@track_agent(name="EmbodiedAgent")
class EmbodiedAgent(ChatAgent):
r"""Class for managing conversations of CAMEL Embodied Agents.
Args:
system_message (BaseMessage): The system message for the chat agent.
model (BaseModelBackend, optional): The model backend to use for
generating responses. (default: :obj:`OpenAIModel` with
`GPT_4O_MINI`)
message_window_size (int, optional): The maximum number of previous
messages to include in the context window. If `None`, no windowing
is performed. (default: :obj:`None`)
tool_agents (List[BaseToolAgent], optional): The tools agents to use in
the embodied agent. (default: :obj:`None`)
code_interpreter (BaseInterpreter, optional): The code interpreter to
execute codes. If `code_interpreter` and `tool_agent` are both
`None`, default to `SubProcessInterpreter`. If `code_interpreter`
is `None` and `tool_agents` is not `None`, default to
`InternalPythonInterpreter`. (default: :obj:`None`)
verbose (bool, optional): Whether to print the critic's messages.
logger_color (Any): The color of the logger displayed to the user.
(default: :obj:`Fore.MAGENTA`)
"""
def __init__(
self,
system_message: BaseMessage,
model: Optional[BaseModelBackend] = None,
message_window_size: Optional[int] = None,
tool_agents: Optional[List[BaseToolAgent]] = None,
code_interpreter: Optional[BaseInterpreter] = None,
verbose: bool = False,
logger_color: Any = Fore.MAGENTA,
) -> None:
self.tool_agents = tool_agents
self.code_interpreter: BaseInterpreter
if code_interpreter is not None:
self.code_interpreter = code_interpreter
elif self.tool_agents:
self.code_interpreter = InternalPythonInterpreter()
else:
self.code_interpreter = SubprocessInterpreter()
if self.tool_agents:
system_message = self._set_tool_agents(system_message)
self.verbose = verbose
self.logger_color = logger_color
super().__init__(
system_message=system_message,
model=model,
message_window_size=message_window_size,
)
def _set_tool_agents(self, system_message: BaseMessage) -> BaseMessage:
action_space_prompt = self._get_tool_agents_prompt()
result_message = system_message.create_new_instance(
content=system_message.content.format(
action_space=action_space_prompt
)
)
if self.tool_agents is not None:
self.code_interpreter.update_action_space(
{tool.name: tool for tool in self.tool_agents}
)
return result_message
def _get_tool_agents_prompt(self) -> str:
r"""Returns the action space prompt.
Returns:
str: The action space prompt.
"""
if self.tool_agents is not None:
return "\n".join(
[
f"*** {tool.name} ***:\n {tool.description}"
for tool in self.tool_agents
]
)
else:
return ""
def get_tool_agent_names(self) -> List[str]:
r"""Returns the names of tool agents.
Returns:
List[str]: The names of tool agents.
"""
if self.tool_agents is not None:
return [tool.name for tool in self.tool_agents]
else:
return []
# ruff: noqa: E501
def step(self, input_message: BaseMessage) -> ChatAgentResponse: # type: ignore[override]
r"""Performs a step in the conversation.
Args:
input_message (BaseMessage): The input message.
Returns:
ChatAgentResponse: A struct containing the output messages,
a boolean indicating whether the chat session has terminated,
and information about the chat session.
"""
response = super().step(input_message)
if response.msgs is None or len(response.msgs) == 0:
raise RuntimeError("Got None output messages.")
if response.terminated:
raise RuntimeError(f"{self.__class__.__name__} step failed.")
# NOTE: Only single output messages are supported
explanations, codes = response.msg.extract_text_and_code_prompts()
if self.verbose:
for explanation, code in zip(explanations, codes):
print_text_animated(
self.logger_color + f"> Explanation:\n{explanation}"
)
print_text_animated(self.logger_color + f"> Code:\n{code}")
if len(explanations) > len(codes):
print_text_animated(
self.logger_color + f"> Explanation:\n{explanations[-1]}"
)
content = response.msg.content
if codes is not None:
try:
content = "\n> Executed Results:\n"
for block_idx, code in enumerate(codes):
executed_output = self.code_interpreter.run(
code, code.code_type
)
content += (
f"Executing code block {block_idx}: {{\n"
+ executed_output
+ "}\n"
)
except InterruptedError as e:
content = (
f"\n> Running code fail: {e}\n"
"Please regenerate the code."
)
# TODO: Handle errors
content = input_message.content + f"\n> Embodied Actions:\n{content}"
message = BaseMessage(
input_message.role_name,
input_message.role_type,
input_message.meta_dict,
content,
)
return ChatAgentResponse(
msgs=[message],
terminated=response.terminated,
info=response.info,
)

View File

@ -0,0 +1,259 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from typing import TYPE_CHECKING, Optional, Union
if TYPE_CHECKING:
from unstructured.documents.elements import Element
from camel.agents import ChatAgent
from camel.messages import BaseMessage
from camel.models import BaseModelBackend
from camel.prompts import TextPrompt
from camel.storages.graph_storages.graph_element import (
GraphElement,
Node,
Relationship,
)
from camel.types import RoleType
# AgentOps decorator setting
try:
import os
if os.getenv("AGENTOPS_API_KEY") is not None:
from agentops import track_agent
else:
raise ImportError
except (ImportError, AttributeError):
from camel.utils import track_agent
text_prompt = """
You are tasked with extracting nodes and relationships from given content and
structures them into Node and Relationship objects. Here's the outline of what
you needs to do:
Content Extraction:
You should be able to process input content and identify entities mentioned
within it.
Entities can be any noun phrases or concepts that represent distinct entities
in the context of the given content.
Node Extraction:
For each identified entity, you should create a Node object.
Each Node object should have a unique identifier (id) and a type (type).
Additional properties associated with the node can also be extracted and
stored.
Relationship Extraction:
You should identify relationships between entities mentioned in the content.
For each relationship, create a Relationship object.
A Relationship object should have a subject (subj) and an object (obj) which
are Node objects representing the entities involved in the relationship.
Each relationship should also have a type (type), and additional properties if
applicable.
Output Formatting:
The extracted nodes and relationships should be formatted as instances of the
provided Node and Relationship classes.
Ensure that the extracted data adheres to the structure defined by the classes.
Output the structured data in a format that can be easily validated against
the provided code.
Instructions for you:
Read the provided content thoroughly.
Identify distinct entities mentioned in the content and categorize them as
nodes.
Determine relationships between these entities and represent them as directed
relationships.
Provide the extracted nodes and relationships in the specified format below.
Example for you:
Example Content:
"John works at XYZ Corporation. He is a software engineer. The company is
located in New York City."
Expected Output:
Nodes:
Node(id='John', type='Person')
Node(id='XYZ Corporation', type='Organization')
Node(id='New York City', type='Location')
Relationships:
Relationship(subj=Node(id='John', type='Person'), obj=Node(id='XYZ
Corporation', type='Organization'), type='WorksAt')
Relationship(subj=Node(id='John', type='Person'), obj=Node(id='New York City',
type='Location'), type='ResidesIn')
===== TASK =====
Please extracts nodes and relationships from given content and structures them
into Node and Relationship objects.
{task}
"""
@track_agent(name="KnowledgeGraphAgent")
class KnowledgeGraphAgent(ChatAgent):
r"""An agent that can extract node and relationship information for
different entities from given `Element` content.
Attributes:
task_prompt (TextPrompt): A prompt for the agent to extract node and
relationship information for different entities.
"""
def __init__(
self,
model: Optional[BaseModelBackend] = None,
) -> None:
r"""Initialize the `KnowledgeGraphAgent`.
Args:
model (BaseModelBackend, optional): The model backend to use for
generating responses. (default: :obj:`OpenAIModel` with
`GPT_4O_MINI`)
"""
system_message = BaseMessage(
role_name="Graphify",
role_type=RoleType.ASSISTANT,
meta_dict=None,
content="Your mission is to transform unstructured content "
"into structured graph data. Extract nodes and relationships with "
"precision, and let the connections unfold. Your graphs will "
"illuminate the hidden connections within the chaos of "
"information.",
)
super().__init__(system_message, model=model)
def run(
self,
element: "Element",
parse_graph_elements: bool = False,
) -> Union[str, GraphElement]:
r"""Run the agent to extract node and relationship information.
Args:
element (Element): The input element.
parse_graph_elements (bool, optional): Whether to parse into
`GraphElement`. Defaults to `False`.
Returns:
Union[str, GraphElement]: The extracted node and relationship
information. If `parse_graph_elements` is `True` then return
`GraphElement`, else return `str`.
"""
self.reset()
self.element = element
knowledge_graph_prompt = TextPrompt(text_prompt)
knowledge_graph_generation = knowledge_graph_prompt.format(
task=str(element)
)
knowledge_graph_generation_msg = BaseMessage.make_user_message(
role_name="Graphify", content=knowledge_graph_generation
)
response = self.step(input_message=knowledge_graph_generation_msg)
content = response.msg.content
if parse_graph_elements:
content = self._parse_graph_elements(content)
return content
def _validate_node(self, node: Node) -> bool:
r"""Validate if the object is a valid Node.
Args:
node (Node): Object to be validated.
Returns:
bool: True if the object is a valid Node, False otherwise.
"""
return (
isinstance(node, Node)
and isinstance(node.id, (str, int))
and isinstance(node.type, str)
)
def _validate_relationship(self, relationship: Relationship) -> bool:
r"""Validate if the object is a valid Relationship.
Args:
relationship (Relationship): Object to be validated.
Returns:
bool: True if the object is a valid Relationship, False otherwise.
"""
return (
isinstance(relationship, Relationship)
and self._validate_node(relationship.subj)
and self._validate_node(relationship.obj)
and isinstance(relationship.type, str)
)
def _parse_graph_elements(self, input_string: str) -> GraphElement:
r"""Parses graph elements from given content.
Args:
input_string (str): The input content.
Returns:
GraphElement: The parsed graph elements.
"""
import re
# Regular expressions to extract nodes and relationships
node_pattern = r"Node\(id='(.*?)', type='(.*?)'\)"
rel_pattern = (
r"Relationship\(subj=Node\(id='(.*?)', type='(.*?)'\), "
r"obj=Node\(id='(.*?)', type='(.*?)'\), type='(.*?)'\)"
)
nodes = {}
relationships = []
# Extract nodes
for match in re.finditer(node_pattern, input_string):
id, type = match.groups()
properties = {'source': 'agent_created'}
if id not in nodes:
node = Node(id=id, type=type, properties=properties)
if self._validate_node(node):
nodes[id] = node
# Extract relationships
for match in re.finditer(rel_pattern, input_string):
subj_id, subj_type, obj_id, obj_type, rel_type = match.groups()
properties = {'source': 'agent_created'}
if subj_id in nodes and obj_id in nodes:
subj = nodes[subj_id]
obj = nodes[obj_id]
relationship = Relationship(
subj=subj, obj=obj, type=rel_type, properties=properties
)
if self._validate_relationship(relationship):
relationships.append(relationship)
return GraphElement(
nodes=list(nodes.values()),
relationships=relationships,
source=self.element,
)

View File

@ -0,0 +1,141 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import re
from typing import Dict, Optional, Union
from camel.agents.chat_agent import ChatAgent
from camel.messages import BaseMessage
from camel.models import BaseModelBackend
from camel.prompts import TextPrompt
from camel.types import RoleType
# AgentOps decorator setting
try:
import os
if os.getenv("AGENTOPS_API_KEY") is not None:
from agentops import track_agent
else:
raise ImportError
except (ImportError, AttributeError):
from camel.utils import track_agent
@track_agent(name="RoleAssignmentAgent")
class RoleAssignmentAgent(ChatAgent):
r"""An agent that generates role names based on the task prompt.
Args:
model (BaseModelBackend, optional): The model backend to use for
generating responses. (default: :obj:`OpenAIModel` with
`GPT_4O_MINI`)
Attributes:
role_assignment_prompt (TextPrompt): A prompt for the agent to generate
role names.
"""
def __init__(
self,
model: Optional[BaseModelBackend] = None,
) -> None:
system_message = BaseMessage(
role_name="Role Assigner",
role_type=RoleType.ASSISTANT,
meta_dict=None,
content="You assign roles based on tasks.",
)
super().__init__(system_message, model=model)
def run(
self,
task_prompt: Union[str, TextPrompt],
num_roles: int = 2,
) -> Dict[str, str]:
r"""Generate role names based on the input task prompt.
Args:
task_prompt (Union[str, TextPrompt]): The prompt
for the task based on which the roles are to be generated.
num_roles (int, optional): The number of roles to generate.
(default: :obj:`2`)
Returns:
Dict[str, str]: A dictionary mapping role names to their
descriptions.
"""
self.reset()
expert_prompt = "===== ANSWER PROMPT =====\n" + "\n".join(
f"Domain expert {i + 1}: <BLANK>\n"
f"Associated competencies, characteristics, duties "
f"and workflows: <BLANK>. End."
for i in range(num_roles or 0)
)
role_assignment_generation_prompt = TextPrompt(
"You are a role assignment agent, and you're in charge of "
+ "recruiting {num_roles} experts for the following task."
+ "\n==== TASK =====\n {task}\n\n"
+ "Identify the domain experts you'd recruit and detail their "
+ "associated competencies, characteristics, duties and workflows "
+ "to complete the task.\n "
+ "Your answer MUST adhere to the format of ANSWER PROMPT, and "
+ "ONLY answer the BLANKs.\n"
+ expert_prompt
)
role_assignment_generation = role_assignment_generation_prompt.format(
num_roles=num_roles, task=task_prompt
)
role_assignment_generation_msg = BaseMessage.make_user_message(
role_name="Role Assigner", content=role_assignment_generation
)
response = self.step(input_message=role_assignment_generation_msg)
msg = response.msg # type: BaseMessage
terminated = response.terminated
# Distribute the output completions into role names and descriptions
role_names = [
desc.replace("<|", "").replace("|>", "")
for desc in re.findall(
r"Domain expert \d: (.+?)\nAssociated competencies,",
msg.content,
re.DOTALL,
)
]
role_descriptions = [
desc.replace("<|", "").replace("|>", "")
for desc in re.findall(
r"Associated competencies, characteristics, "
r"duties and workflows: (.+?) End.",
msg.content,
re.DOTALL,
)
]
if len(role_names) != num_roles or len(role_descriptions) != num_roles:
raise RuntimeError(
"Got None or insufficient information of roles."
)
if terminated:
raise RuntimeError("Role assignment failed.")
role_descriptions_dict = {
role_name: description
for role_name, description in zip(role_names, role_descriptions)
}
return role_descriptions_dict

View File

@ -0,0 +1,133 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from typing import Optional
from camel.agents.chat_agent import ChatAgent
from camel.messages import BaseMessage
from camel.models import BaseModelBackend
from camel.prompts import TextPrompt
from camel.types import RoleType
from camel.utils import create_chunks
# AgentOps decorator setting
try:
import os
if os.getenv("AGENTOPS_API_KEY") is not None:
from agentops import track_agent
else:
raise ImportError
except (ImportError, AttributeError):
from camel.utils import track_agent
@track_agent(name="SearchAgent")
class SearchAgent(ChatAgent):
r"""An agent that summarizes text based on a query and evaluates the
relevance of an answer.
Args:
model (BaseModelBackend, optional): The model backend to use for
generating responses. (default: :obj:`OpenAIModel` with
`GPT_4O_MINI`)
"""
def __init__(
self,
model: Optional[BaseModelBackend] = None,
) -> None:
system_message = BaseMessage(
role_name="Assistant",
role_type=RoleType.ASSISTANT,
meta_dict=None,
content="You are a helpful assistant.",
)
super().__init__(system_message, model=model)
def summarize_text(self, text: str, query: str) -> str:
r"""Summarize the information from the text, base on the query.
Args:
text (str): Text to summarize.
query (str): What information you want.
Returns:
str: Strings with information.
"""
self.reset()
summary_prompt = TextPrompt(
'''Gather information from this text that relative to the
question, but do not directly answer the question.\nquestion:
{query}\ntext '''
)
summary_prompt = summary_prompt.format(query=query)
# Max length of each chunk
max_len = 3000
results = ""
chunks = create_chunks(text, max_len)
# Summarize
for i, chunk in enumerate(chunks, start=1):
prompt = summary_prompt + str(i) + ": " + chunk
user_msg = BaseMessage.make_user_message(
role_name="User",
content=prompt,
)
result = self.step(user_msg).msg.content
results += result + "\n"
# Final summarization
final_prompt = TextPrompt(
'''Here are some summarized texts which split from one text. Using
the information to answer the question. If can't find the answer,
you must answer "I can not find the answer to the query" and
explain why.\n Query:\n{query}.\n\nText:\n'''
)
final_prompt = final_prompt.format(query=query)
prompt = final_prompt + results
user_msg = BaseMessage.make_user_message(
role_name="User",
content=prompt,
)
response = self.step(user_msg).msg.content
return response
def continue_search(self, query: str, answer: str) -> bool:
r"""Ask whether to continue search or not based on the provided answer.
Args:
query (str): The question.
answer (str): The answer to the question.
Returns:
bool: `True` if the user want to continue search, `False`
otherwise.
"""
prompt = TextPrompt(
"Do you think the ANSWER can answer the QUERY? "
"Use only 'yes' or 'no' to answer.\n"
"===== QUERY =====\n{query}\n\n"
"===== ANSWER =====\n{answer}"
)
prompt = prompt.format(query=query, answer=answer)
user_msg = BaseMessage.make_user_message(
role_name="User",
content=prompt,
)
response = self.step(user_msg).msg.content
if "yes" in str(response).lower():
return False
return True

View File

@ -0,0 +1,410 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from typing import Any, Dict, List, Optional, Union
from camel.agents.chat_agent import ChatAgent
from camel.messages import BaseMessage
from camel.models import BaseModelBackend
from camel.prompts import PromptTemplateGenerator, TextPrompt
from camel.types import RoleType, TaskType
from camel.utils import get_task_list
# AgentOps decorator setting
try:
import os
if os.getenv("AGENTOPS_API_KEY") is not None:
from agentops import track_agent
else:
raise ImportError
except (ImportError, AttributeError):
from camel.utils import track_agent
@track_agent(name="TaskSpecifyAgent")
class TaskSpecifyAgent(ChatAgent):
r"""An agent that specifies a given task prompt by prompting the user to
provide more details.
Attributes:
DEFAULT_WORD_LIMIT (int): The default word limit for the task prompt.
task_specify_prompt (TextPrompt): The prompt for specifying the task.
Args:
model (BaseModelBackend, optional): The model backend to use for
generating responses. (default: :obj:`OpenAIModel` with
`GPT_4O_MINI`)
task_type (TaskType, optional): The type of task for which to generate
a prompt. (default: :obj:`TaskType.AI_SOCIETY`)
task_specify_prompt (Union[str, TextPrompt], optional): The prompt for
specifying the task. (default: :obj:`None`)
word_limit (int, optional): The word limit for the task prompt.
(default: :obj:`50`)
output_language (str, optional): The language to be output by the
agent. (default: :obj:`None`)
"""
DEFAULT_WORD_LIMIT = 50
def __init__(
self,
model: Optional[BaseModelBackend] = None,
task_type: TaskType = TaskType.AI_SOCIETY,
task_specify_prompt: Optional[Union[str, TextPrompt]] = None,
word_limit: int = DEFAULT_WORD_LIMIT,
output_language: Optional[str] = None,
) -> None:
self.task_specify_prompt: Union[str, TextPrompt]
if task_specify_prompt is None:
task_specify_prompt_template = (
PromptTemplateGenerator().get_task_specify_prompt(task_type)
)
self.task_specify_prompt = task_specify_prompt_template.format(
word_limit=word_limit
)
else:
self.task_specify_prompt = TextPrompt(task_specify_prompt)
system_message = BaseMessage(
role_name="Task Specifier",
role_type=RoleType.ASSISTANT,
meta_dict=None,
content="You can make a task more specific.",
)
super().__init__(
system_message,
model=model,
output_language=output_language,
)
def run(
self,
task_prompt: Union[str, TextPrompt],
meta_dict: Optional[Dict[str, Any]] = None,
) -> TextPrompt:
r"""Specify the given task prompt by providing more details.
Args:
task_prompt (Union[str, TextPrompt]): The original task
prompt.
meta_dict (Dict[str, Any], optional): A dictionary containing
additional information to include in the prompt.
(default: :obj:`None`)
Returns:
TextPrompt: The specified task prompt.
"""
self.reset()
task_specify_prompt = self.task_specify_prompt.format(task=task_prompt)
if meta_dict is not None:
task_specify_prompt = task_specify_prompt.format(**meta_dict)
task_msg = BaseMessage.make_user_message(
role_name="Task Specifier", content=task_specify_prompt
)
specifier_response = self.step(task_msg)
if specifier_response.terminated:
raise RuntimeError("Task specification failed.")
if len(specifier_response.msgs) == 0:
raise RuntimeError("Got no specification message.")
specified_task_msg = specifier_response.msgs[0]
return TextPrompt(specified_task_msg.content)
@track_agent(name="TaskPlannerAgent")
class TaskPlannerAgent(ChatAgent):
r"""An agent that helps divide a task into subtasks based on the input
task prompt.
Attributes:
task_planner_prompt (TextPrompt): A prompt for the agent to divide
the task into subtasks.
Args:
model (BaseModelBackend, optional): The model backend to use for
generating responses. (default: :obj:`OpenAIModel` with
`GPT_4O_MINI`)
output_language (str, optional): The language to be output by the
agent. (default: :obj:`None`)
"""
def __init__(
self,
model: Optional[BaseModelBackend] = None,
output_language: Optional[str] = None,
) -> None:
self.task_planner_prompt = TextPrompt(
"Divide this task into subtasks: {task}. Be concise."
)
system_message = BaseMessage(
role_name="Task Planner",
role_type=RoleType.ASSISTANT,
meta_dict=None,
content="You are a helpful task planner.",
)
super().__init__(
system_message,
model=model,
output_language=output_language,
)
def run(
self,
task_prompt: Union[str, TextPrompt],
) -> TextPrompt:
r"""Generate subtasks based on the input task prompt.
Args:
task_prompt (Union[str, TextPrompt]): The prompt for the task to
be divided into subtasks.
Returns:
TextPrompt: A prompt for the subtasks generated by the agent.
"""
# TODO: Maybe include roles information.
self.reset()
task_planner_prompt = self.task_planner_prompt.format(task=task_prompt)
task_msg = BaseMessage.make_user_message(
role_name="Task Planner", content=task_planner_prompt
)
task_response = self.step(task_msg)
if task_response.terminated:
raise RuntimeError("Task planning failed.")
if len(task_response.msgs) == 0:
raise RuntimeError("Got no task planning message.")
sub_tasks_msg = task_response.msgs[0]
return TextPrompt(sub_tasks_msg.content)
@track_agent(name="TaskCreationAgent")
class TaskCreationAgent(ChatAgent):
r"""An agent that helps create new tasks based on the objective
and last completed task. Compared to :obj:`TaskPlannerAgent`,
it's still a task planner, but it has more context information
like last task and incomplete task list. Modified from
`BabyAGI <https://github.com/yoheinakajima/babyagi>`_.
Attributes:
task_creation_prompt (TextPrompt): A prompt for the agent to
create new tasks.
Args:
role_name (str): The role name of the Agent to create the task.
objective (Union[str, TextPrompt]): The objective of the Agent to
perform the task.
model (BaseModelBackend, optional): The LLM backend to use for
generating responses. (default: :obj:`OpenAIModel` with
`GPT_4O_MINI`)
output_language (str, optional): The language to be output by the
agent. (default: :obj:`None`)
message_window_size (int, optional): The maximum number of previous
messages to include in the context window. If `None`, no windowing
is performed. (default: :obj:`None`)
max_task_num (int, optional): The maximum number of planned
tasks in one round. (default: :obj:3)
"""
def __init__(
self,
role_name: str,
objective: Union[str, TextPrompt],
model: Optional[BaseModelBackend] = None,
output_language: Optional[str] = None,
message_window_size: Optional[int] = None,
max_task_num: Optional[int] = 3,
) -> None:
task_creation_prompt = TextPrompt(
"""Create new a task with the following objective: {objective}.
Never forget you are a Task Creator of {role_name}.
You must instruct me based on my expertise and your needs to solve the task.
You should consider past solved tasks and in-progress tasks: {task_list}.
The new created tasks must not overlap with these past tasks.
The result must be a numbered list in the format:
#. First Task
#. Second Task
#. Third Task
You can only give me up to {max_task_num} tasks at a time. \
Each task should be concise, concrete and doable for a {role_name}.
You should make task plan and not ask me questions.
If you think no new tasks are needed right now, write "No tasks to add."
Now start to give me new tasks one by one. No more than three tasks.
Be concrete.
"""
)
self.task_creation_prompt = task_creation_prompt.format(
objective=objective, role_name=role_name, max_task_num=max_task_num
)
self.objective = objective
system_message = BaseMessage(
role_name="Task Creator",
role_type=RoleType.ASSISTANT,
meta_dict=None,
content="You are a helpful task creator.",
)
super().__init__(
system_message,
model=model,
output_language=output_language,
message_window_size=message_window_size,
)
def run(
self,
task_list: List[str],
) -> List[str]:
r"""Generate subtasks based on the previous task results and
incomplete task list.
Args:
task_list (List[str]): The completed or in-progress
tasks which should not overlap with new created tasks.
Returns:
List[str]: The new task list generated by the Agent.
"""
if len(task_list) > 0:
task_creation_prompt = self.task_creation_prompt.format(
task_list=task_list
)
else:
task_creation_prompt = self.task_creation_prompt.format(
task_list=""
)
task_msg = BaseMessage.make_user_message(
role_name="Task Creator", content=task_creation_prompt
)
task_response = self.step(task_msg)
if task_response.terminated:
raise RuntimeError("Task creation failed.")
if len(task_response.msgs) == 0:
raise RuntimeError("Got no task creation message.")
sub_tasks_msg = task_response.msgs[0]
return get_task_list(sub_tasks_msg.content)
@track_agent(name="TaskPrioritizationAgent")
class TaskPrioritizationAgent(ChatAgent):
r"""An agent that helps re-prioritize the task list and
returns numbered prioritized list. Modified from
`BabyAGI <https://github.com/yoheinakajima/babyagi>`_.
Attributes:
task_prioritization_prompt (TextPrompt): A prompt for the agent to
prioritize tasks.
Args:
objective (Union[str, TextPrompt]): The objective of the Agent to
perform the task.
model (BaseModelBackend, optional): The LLM backend to use for
generating responses. (default: :obj:`OpenAIModel` with
`GPT_4O_MINI`)
output_language (str, optional): The language to be output by the
agent. (default: :obj:`None`)
message_window_size (int, optional): The maximum number of previous
messages to include in the context window. If `None`, no windowing
is performed. (default: :obj:`None`)
"""
def __init__(
self,
objective: Union[str, TextPrompt],
model: Optional[BaseModelBackend] = None,
output_language: Optional[str] = None,
message_window_size: Optional[int] = None,
) -> None:
task_prioritization_prompt = TextPrompt(
"""Prioritize the following tasks : {task_list}.
Consider the ultimate objective of you: {objective}.
Tasks should be sorted from highest to lowest priority, where higher-priority \
tasks are those that act as pre-requisites or are more essential for meeting \
the objective. Return one task per line in your response.
Do not remove or modify any tasks.
The result must be a numbered list in the format:
#. First task
#. Second task
The entries must be consecutively numbered, starting with 1.
The number of each entry must be followed by a period.
Do not include any headers before your ranked list or follow your list \
with any other output."""
)
self.task_prioritization_prompt = task_prioritization_prompt.format(
objective=objective
)
self.objective = objective
system_message = BaseMessage(
role_name="Task Prioritizer",
role_type=RoleType.ASSISTANT,
meta_dict=None,
content="You are a helpful task prioritizer.",
)
super().__init__(
system_message,
model=model,
output_language=output_language,
message_window_size=message_window_size,
)
def run(
self,
task_list: List[str],
) -> List[str]:
r"""Prioritize the task list given the agent objective.
Args:
task_list (List[str]): The unprioritized tasks of agent.
Returns:
List[str]: The new prioritized task list generated by the Agent.
"""
task_prioritization_prompt = self.task_prioritization_prompt.format(
task_list=task_list
)
task_msg = BaseMessage.make_user_message(
role_name="Task Prioritizer", content=task_prioritization_prompt
)
task_response = self.step(task_msg)
if task_response.terminated:
raise RuntimeError("Task prioritization failed.")
if len(task_response.msgs) == 0:
raise RuntimeError("Got no task prioritization message.")
sub_tasks_msg = task_response.msgs[0]
return get_task_list(sub_tasks_msg.content)

View File

@ -0,0 +1,20 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from .base import BaseToolAgent
from .hugging_face_tool_agent import HuggingFaceToolAgent
__all__ = [
'BaseToolAgent',
'HuggingFaceToolAgent',
]

View File

@ -0,0 +1,39 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from camel.agents import BaseAgent
class BaseToolAgent(BaseAgent):
r"""Creates a :obj:`BaseToolAgent` object with the specified name and
description.
Args:
name (str): The name of the tool agent.
description (str): The description of the tool agent.
"""
def __init__(self, name: str, description: str) -> None:
self.name = name
self.description = description
def reset(self) -> None:
r"""Resets the agent to its initial state."""
pass
def step(self) -> None:
r"""Performs a single step of the agent."""
pass
def __str__(self) -> str:
return f"{self.name}: {self.description}"

View File

@ -0,0 +1,206 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from typing import Any, Optional
from camel.agents.tool_agents.base import BaseToolAgent
# flake8: noqa :E501
class HuggingFaceToolAgent(BaseToolAgent):
r"""Tool agent for calling HuggingFace models. This agent is a wrapper
around agents from the `transformers` library. For more information
about the available models, please see the `transformers` documentation
at https://huggingface.co/docs/transformers/transformers_agents.
Args:
name (str): The name of the agent.
*args (Any): Additional positional arguments to pass to the underlying
Agent class.
remote (bool, optional): Flag indicating whether to run the agent
remotely. (default: :obj:`True`)
**kwargs (Any): Additional keyword arguments to pass to the underlying
Agent class.
"""
def __init__(
self,
name: str,
*args: Any,
remote: bool = True,
**kwargs: Any,
) -> None:
try:
# TODO: Support other tool agents
import transformers
from packaging import version
if version.parse(transformers.__version__) < version.parse(
"4.31.0"
):
raise ValueError(
"The version of \"transformers\" package should >= 4.31.0"
)
from transformers.tools import OpenAiAgent
from transformers.tools.agent_types import AgentImage
except (ImportError, ValueError):
raise ValueError(
"Could not import transformers tool agents. "
"Please setup the environment with "
"pip install huggingface_hub==0.14.1 transformers==4.31.0 diffusers accelerate==0.20.3 datasets torch soundfile sentencepiece opencv-python"
)
self.agent_image_type = AgentImage
self.agent = OpenAiAgent(*args, **kwargs)
description = f"""The `{name}` is a tool agent that can perform a variety of tasks including:
- Document question answering: given a document (such as a PDF) in image format, answer a question on this document
- Text question answering: given a long text and a question, answer the question in the text
- Unconditional image captioning: Caption the image!
- Image question answering: given an image, answer a question on this image
- Image segmentation: given an image and a prompt, output the segmentation mask of that prompt
- Speech to text: given an audio recording of a person talking, transcribe the speech into text
- Text to speech: convert text to speech
- Zero-shot text classification: given a text and a list of labels, identify to which label the text corresponds the most
- Text summarization: summarize a long text in one or a few sentences
- Translation: translate the text into a given language
- Text downloading: to download a text from a web URL
- Text to image: generate an image according to a prompt, leveraging stable diffusion
- Image transformation: modify an image given an initial image and a prompt, leveraging instruct pix2pix stable diffusion
- Text to video: generate a small video according to a prompt
Here are some python code examples of what you can do with this agent:
Single execution (step) mode, the single execution method is when using the step() method of the agent:
```
# Text to image
rivers_and_lakes_image = {name}.step("Draw me a picture of rivers and lakes.")
rivers_and_lakes_image.save("./rivers_and_lakes_image.png")
# Text to image -> Image transformation
sea_add_island_image = {name}.step("Draw me a picture of the sea then transform the picture to add an island")
sea_add_island_image.save("./sea_add_island_image.png")
# If you'd like to keep a state across executions or to pass non-text objects to the agent,
# you can do so by specifying variables that you would like the agent to use. For example,
# you could generate the first image of rivers and lakes, and ask the model to update that picture to add an island by doing the following:
picture = {name}.step("Generate a picture of rivers and lakes.")
picture.save("./picture.png")
updated_picture = {name}.step("Transform the image in `picture` to add an island to it.", picture=picture)
updated_picture.save("./updated_picture.png")
capybara_sea_image = {name}.step("Draw me a picture of the `prompt`", prompt="a capybara swimming in the sea")
capybara_sea_image.save("./capybara_sea_image.png")
# Document question answering
answer = {name}.step(
"In the following `document`, where will the TRRF Scientific Advisory Council Meeting take place?",
document=document,
)
print(answer)
# Text to image
boat_image = {name}.step("Generate an image of a boat in the water")
boat_image.save("./boat_image.png")
# Unconditional image captioning
boat_image_caption = {name}.step("Can you caption the `boat_image`?", boat_image=boat_image)
print(boat_image_caption)
# Text to image -> Unconditional image captioning -> Text to speech
boat_audio = {name}.step("Can you generate an image of a boat? Please read out loud the contents of the image afterwards")
# Text downloading
document = {name}.step("Download the text from http://hf.co")
print(document)
# Text summarization
summary = {name}.step("Summarize the following text: `document`", document=document)
print(summary)
# Text downloading -> Text summarization -> Text to speech
audio = {name}.step("Read out loud the summary of http://hf.co")
```
Chat-based execution (chat), the agent also has a chat-based approach, using the chat() method:
```
# Clean the chat history
{name}.reset()
# Text to image
capybara_image = {name}.chat("Show me an an image of a capybara")
capybara_image.save("./capybara_image.png")
# Image transformation
transformed_capybara_image = {name}.chat("Transform the image so that it snows")
transformed_capybara_image.save("./transformed_capybara_image.png")
# Image segmentation
segmented_transformed_capybara_image = {name}.chat("Show me a mask of the snowy capybaras")
segmented_transformed_capybara_image.save("./segmented_transformed_capybara_image.png")
```
"""
super(HuggingFaceToolAgent, self).__init__(name, description)
self.remote = remote
def reset(self) -> None:
r"""Resets the chat history of the agent."""
self.agent.prepare_for_new_chat()
def step(
self,
*args: Any,
remote: Optional[bool] = None,
**kwargs: Any,
) -> Any:
r"""Runs the agent in single execution mode.
Args:
*args (Any): Positional arguments to pass to the agent.
remote (bool, optional): Flag indicating whether to run the agent
remotely. Overrides the default setting. (default: :obj:`None`)
**kwargs (Any): Keyword arguments to pass to the agent.
Returns:
str: The response from the agent.
"""
if remote is None:
remote = self.remote
agent_output = self.agent.run(*args, remote=remote, **kwargs)
if isinstance(agent_output, self.agent_image_type):
agent_output = agent_output.to_raw()
return agent_output
def chat(
self,
*args: Any,
remote: Optional[bool] = None,
**kwargs: Any,
) -> Any:
r"""Runs the agent in a chat conversation mode.
Args:
*args (Any): Positional arguments to pass to the agent.
remote (bool, optional): Flag indicating whether to run the agent
remotely. Overrides the default setting. (default: :obj:`None`)
**kwargs (Any): Keyword arguments to pass to the agent.
Returns:
str: The response from the agent.
"""
if remote is None:
remote = self.remote
agent_output = self.agent.chat(*args, remote=remote, **kwargs)
if isinstance(agent_output, self.agent_image_type):
agent_output = agent_output.to_raw()
return agent_output

View File

@ -0,0 +1,17 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from .base import BaseBenchmark
__all__ = ["BaseBenchmark"]

View File

@ -0,0 +1,152 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import logging
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional
from camel.agents import ChatAgent
logger = logging.getLogger(__name__)
class BaseBenchmark(ABC):
r"""Base class for benchmarks.
Attributes:
name (str): Name of the benchmark.
data_dir (str): Path to the data directory.
save_to (str): Path to save the results.
processes (int): Number of processes to use for parallel
processing. :(default: :obj:`1`)
"""
def __init__(
self, name: str, data_dir: str, save_to: str, processes: int = 1
):
r"""Initialize the benchmark.
Args:
name (str): Name of the benchmark.
data_dir (str): Path to the data directory.
save_to (str): Path to save the results.
processes (int): Number of processes to use for parallel
processing. :(default: :obj:`1`)
"""
self.name = name
self.data_dir = Path(data_dir)
self.processes = processes
self.save_to = save_to
if not self.data_dir.exists():
logger.info(
f"Data directory {data_dir} does not exist. Creating it."
)
self.data_dir.mkdir(parents=True, exist_ok=True)
if not self.data_dir.is_dir():
raise NotADirectoryError(
f"Data directory {data_dir} is not a directory"
)
self._data: Dict[str, List[Dict[str, Any]]] = dict()
self._results: List[Dict[str, Any]] = []
@abstractmethod
def download(self) -> "BaseBenchmark":
r"""Download the benchmark data.
Returns:
BaseBenchmark: The benchmark instance.
"""
pass
@abstractmethod
def load(self, force_download: bool = False) -> "BaseBenchmark":
r"""Load the benchmark data.
Args:
force_download (bool): Whether to force download the data.
Returns:
BaseBenchmark: The benchmark instance.
"""
pass
@property
def train(self) -> List[Dict[str, Any]]:
r"""Get the training data.
Returns:
List[Dict[str, Any]]: The training data.
"""
if not self._data:
logger.info("Data not loaded. Loading data.")
self.load()
return self._data["train"]
@property
def valid(self) -> List[Dict[str, Any]]:
r"""Get the validation data.
Returns:
List[Dict[str, Any]]: The validation data.
"""
if not self._data:
logger.info("Data not loaded. Loading data.")
self.load()
return self._data["valid"]
@property
def test(self) -> List[Dict[str, Any]]:
r"""Get the test data.
Returns:
List[Dict[str, Any]]: The test data.
"""
if not self._data:
logger.info("Data not loaded. Loading data.")
self.load()
return self._data["test"]
@abstractmethod
def run(
self,
agent: ChatAgent,
on: Literal["train", "valid", "test"],
randomize: bool = False,
subset: Optional[int] = None,
*args,
**kwargs,
) -> "BaseBenchmark":
r"""Run the benchmark.
Args:
agent (ChatAgent): The chat agent.
on (str): The data split to run the benchmark on.
randomize (bool): Whether to randomize the data.
subset (int): The subset of the data to run the benchmark on.
Returns:
BaseBenchmark: The benchmark instance.
"""
pass
@property
def results(self) -> List[Dict[str, Any]]:
r"""Get the results.
Returns:
List[Dict[str, Any]]: The results.
"""
return self._results

View File

@ -0,0 +1,34 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from .discord_app import DiscordApp
from .slack.models import (
SlackAppMentionEventBody,
SlackAppMentionEventProfile,
SlackAuthProfile,
SlackEventBody,
SlackEventProfile,
)
from .slack.slack_app import SlackApp
from .telegram_bot import TelegramBot
__all__ = [
'DiscordApp',
'SlackApp',
'SlackAppMentionEventBody',
'SlackAppMentionEventProfile',
'SlackAuthProfile',
'SlackEventBody',
'SlackEventProfile',
'TelegramBot',
]

View File

@ -0,0 +1,138 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import logging
import os
from typing import TYPE_CHECKING, List, Optional
from camel.utils import dependencies_required
if TYPE_CHECKING:
from discord import Message
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class DiscordApp:
r"""A class representing a Discord app that uses the `discord.py` library
to interact with Discord servers.
This bot can respond to messages in specific channels and only reacts to
messages that mention the bot.
Attributes:
channel_ids (Optional[List[int]]): A list of allowed channel IDs. If
provided, the bot will only respond to messages in these channels.
token (Optional[str]): The Discord bot token used for authentication.
"""
@dependencies_required('discord')
def __init__(
self,
channel_ids: Optional[List[int]] = None,
token: Optional[str] = None,
) -> None:
r"""Initialize the DiscordApp instance by setting up the Discord client
and event handlers.
Args:
channel_ids (Optional[List[int]]): A list of allowed channel IDs.
The bot will only respond to messages in these channels if
provided.
token (Optional[str]): The Discord bot token for authentication.
If not provided, the token will be retrieved from the
environment variable `DISCORD_TOKEN`.
Raises:
ValueError: If the `DISCORD_TOKEN` is not found in environment
variables.
"""
self.token = token or os.getenv('DISCORD_TOKEN')
self.channel_ids = channel_ids
if not self.token:
raise ValueError(
"`DISCORD_TOKEN` not found in environment variables. Get it"
" here: `https://discord.com/developers/applications`."
)
import discord
intents = discord.Intents.default()
intents.message_content = True
self._client = discord.Client(intents=intents)
# Register event handlers
self._client.event(self.on_ready)
self._client.event(self.on_message)
async def start(self):
r"""Asynchronously start the Discord bot using its token.
This method starts the bot and logs into Discord asynchronously using
the provided token. It should be awaited when used in an async
environment.
"""
await self._client.start(self.token)
def run(self) -> None:
r"""Start the Discord bot using its token.
This method starts the bot and logs into Discord synchronously using
the provided token. It blocks execution and keeps the bot running.
"""
self._client.run(self.token) # type: ignore[arg-type]
async def on_ready(self) -> None:
r"""Event handler that is called when the bot has successfully
connected to the Discord server.
When the bot is ready and logged into Discord, it prints a message
displaying the bot's username.
"""
logger.info(f'We have logged in as {self._client.user}')
async def on_message(self, message: 'Message') -> None:
r"""Event handler for processing incoming messages.
This method is called whenever a new message is received by the bot. It
will ignore messages sent by the bot itself, only respond to messages
in allowed channels (if specified), and only to messages that mention
the bot.
Args:
message (discord.Message): The message object received from
Discord.
"""
# If the message author is the bot itself,
# do not respond to this message
if message.author == self._client.user:
return
# If allowed channel IDs are provided,
# only respond to messages in those channels
if self.channel_ids and message.channel.id not in self.channel_ids:
return
# Only respond to messages that mention the bot
if not self._client.user or not self._client.user.mentioned_in(
message
):
return
logger.info(f"Received message: {message.content}")
@property
def client(self):
return self._client

View File

@ -0,0 +1,30 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from .models import (
SlackAppMentionEventBody,
SlackAppMentionEventProfile,
SlackAuthProfile,
SlackEventBody,
SlackEventProfile,
)
from .slack_app import SlackApp
__all__ = [
'SlackApp',
'SlackAppMentionEventBody',
'SlackAppMentionEventProfile',
'SlackAuthProfile',
'SlackEventBody',
'SlackEventProfile',
]

View File

@ -0,0 +1,158 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from typing import Optional
from pydantic import BaseModel
class SlackAuthProfile(BaseModel):
r"""Represents the authorization profile within a Slack event.
Events will contain a single, compact authorizations field that shows one
installation of your app that the event is visible to.
In other words, lists of authorizations will be truncated to one element.
If there's more than one installing party that your app is keeping track
of, it's best not to rely on the single party listed in authorizations to
be any particular one.
To get a full list of who can see events, call the apps.event.
authorizations.list method after obtaining an app-level token. Read more on
the changes here; they have taken effect for existing apps as of
February 24, 2021.
References:
- https://api.slack.com/apis/events-api#authorizations
- https://api.slack.com/changelog/2020-09-15-events-api-truncate-authed-users#no_context
"""
enterprise_id: Optional[str] = None
"""The ID of the enterprise associated with the authorization."""
team_id: str
"""The ID of the team associated with the authorization."""
user_id: str
"""The ID of the user associated with the authorization."""
is_bot: bool
"""Whether the authorized user is a bot."""
is_enterprise_install: bool
"""Whether the authorization is for an enterprise installation."""
class SlackEventProfile(BaseModel):
r"""Represents the detailed profile of a Slack event, including user,
message, and context data.
"""
user: str
"""The ID of the user associated with the event."""
type: str
"""The type of the event (e.g., 'message')."""
ts: str
"""A timestamp representing when the event was triggered."""
thread_ts: Optional[str] = None
"""The timestamp of the parent message in a thread."""
client_msg_id: str
"""A unique ID generated by the client for the message (if available)."""
text: str
"""The message content text."""
team: str
"""The ID of the team that the event is associated with."""
blocks: list
"""The list of message blocks, providing structured information."""
channel: str
"""The ID of the Slack channel where the event happened."""
event_ts: str
"""The event-specific timestamp when it occurred."""
channel_type: Optional[str]
"""The type of Slack channel (e.g., 'channel', 'im')."""
class SlackEventBody(BaseModel):
r"""Represents the entire body of a Slack event, including the event
profile, authorization, and context.
"""
token: str
"""The token to verify the source of the event."""
team_id: str
"""The ID of the team where the event is happening."""
context_team_id: Optional[str]
"""The team ID for the shared channel context, if applicable."""
context_enterprise_id: Optional[str] = None
"""The enterprise ID for the shared channel context, if applicable."""
api_app_id: str
"""The unique identifier for the Slack app that received the event."""
event: SlackEventProfile
"""A detailed profile of the event"""
type: str
"""The overall type of event received (e.g., 'event_callback')."""
event_id: str
"""A unique identifier assigned to this event by Slack."""
event_time: int
"""The timestamp (in seconds) representing when the event was triggered."""
authorizations: Optional[list[SlackAuthProfile]] = None
"""An optional list of authorizations that describe which installation can
see the event."""
is_ext_shared_channel: bool
"""Indicates if the event is part of a shared channel between different
organizations."""
event_context: str
"""A unique string representing the context of the event."""
class SlackAppMentionEventProfile(SlackEventProfile):
r"""Represents the detailed profile of a Slack event where the app was
mentioned in a message.
"""
channel_type: Optional[str] = None
"""The type of Slack channel. it's None for app mentions."""
class SlackAppMentionEventBody(SlackEventBody):
r"""Represents the entire body of a Slack event where the app was mentioned
in a message.
"""
context_team_id: Optional[str] = None
"""A detailed profile of the event. it's None for app mentions."""
event: SlackAppMentionEventProfile
"""A detailed profile of the event"""

View File

@ -0,0 +1,255 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import logging
import os
from typing import TYPE_CHECKING, Any, Dict, Optional
from slack_sdk.oauth.installation_store.async_installation_store import (
AsyncInstallationStore,
)
from starlette import requests, responses
from camel.bots.slack.models import (
SlackAppMentionEventBody,
SlackAppMentionEventProfile,
SlackEventBody,
SlackEventProfile,
)
from camel.utils import dependencies_required
if TYPE_CHECKING:
from slack_bolt.context.async_context import AsyncBoltContext
from slack_bolt.context.say.async_say import AsyncSay
from slack_sdk.web.async_client import AsyncWebClient
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class SlackApp:
r"""Represents a Slack app that is powered by a Slack Bolt `AsyncApp`.
This class is responsible for initializing and managing the Slack
application by setting up event handlers, running the app server, and
handling events such as messages and mentions from Slack.
Args:
token (Optional[str]): Slack API token for authentication.
scopes (Optional[str]): Slack app scopes for permissions.
signing_secret (Optional[str]): Signing secret for verifying Slack
requests.
client_id (Optional[str]): Slack app client ID.
client_secret (Optional[str]): Slack app client secret.
redirect_uri_path (str): The URI path for OAuth redirect, defaults to
"/slack/oauth_redirect".
installation_store (Optional[AsyncInstallationStore]): The installation
store for handling OAuth installations.
"""
@dependencies_required('slack_bolt')
def __init__(
self,
token: Optional[str] = None,
scopes: Optional[str] = None,
signing_secret: Optional[str] = None,
client_id: Optional[str] = None,
client_secret: Optional[str] = None,
redirect_uri_path: str = "/slack/oauth_redirect",
installation_store: Optional[AsyncInstallationStore] = None,
) -> None:
r"""Initializes the SlackApp instance by setting up the Slack Bolt app
and configuring event handlers and OAuth settings.
Args:
token (Optional[str]): The Slack API token.
scopes (Optional[str]): The scopes for Slack app permissions.
signing_secret (Optional[str]): The signing secret for verifying
requests.
client_id (Optional[str]): The Slack app client ID.
client_secret (Optional[str]): The Slack app client secret.
redirect_uri_path (str): The URI path for handling OAuth redirects
(default is "/slack/oauth_redirect").
installation_store (Optional[AsyncInstallationStore]): An optional
installation store for OAuth installations.
"""
from slack_bolt.adapter.starlette.async_handler import (
AsyncSlackRequestHandler,
)
from slack_bolt.app.async_app import AsyncApp
from slack_bolt.oauth.async_oauth_settings import AsyncOAuthSettings
self.token: Optional[str] = token or os.getenv("SLACK_TOKEN")
self.scopes: Optional[str] = scopes or os.getenv("SLACK_SCOPES")
self.signing_secret: Optional[str] = signing_secret or os.getenv(
"SLACK_SIGNING_SECRET"
)
self.client_id: Optional[str] = client_id or os.getenv(
"SLACK_CLIENT_ID"
)
self.client_secret: Optional[str] = client_secret or os.getenv(
"SLACK_CLIENT_SECRET"
)
if not all([self.token, self.scopes, self.signing_secret]):
raise ValueError(
"`SLACK_TOKEN`, `SLACK_SCOPES`, and `SLACK_SIGNING_SECRET` "
"environment variables must be set. Get it here: "
"`https://api.slack.com/apps`."
)
# Setup OAuth settings if client ID and secret are provided
if self.client_id and self.client_secret:
self._app = AsyncApp(
oauth_settings=AsyncOAuthSettings(
client_id=self.client_id,
client_secret=self.client_secret,
scopes=self.scopes,
redirect_uri_path=redirect_uri_path,
),
logger=logger,
signing_secret=self.signing_secret,
installation_store=installation_store,
token=self.token,
)
else:
# Initialize Slack Bolt AsyncApp with settings
self._app = AsyncApp(
logger=logger,
signing_secret=self.signing_secret,
installation_store=installation_store,
token=self.token,
)
self._handler = AsyncSlackRequestHandler(self._app)
self.setup_handlers()
def setup_handlers(self) -> None:
r"""Sets up the event handlers for Slack events, such as `app_mention`
and `message`.
This method registers the `app_mention` and `on_message` event handlers
with the Slack Bolt app to respond to Slack events.
"""
self._app.event("app_mention")(self.app_mention)
self._app.event("message")(self.on_message)
def run(
self,
port: int = 3000,
path: str = "/slack/events",
host: Optional[str] = None,
) -> None:
r"""Starts the Slack Bolt app server to listen for incoming Slack
events.
Args:
port (int): The port on which the server should run (default is
3000).
path (str): The endpoint path for receiving Slack events (default
is "/slack/events").
host (Optional[str]): The hostname to bind the server (default is
None).
"""
self._app.start(port=port, path=path, host=host)
async def handle_request(
self, request: requests.Request
) -> responses.Response:
r"""Handles incoming requests from Slack through the request handler.
Args:
request (Request): A Starlette request object representing the
incoming request.
Returns:
The response generated by the Slack Bolt handler.
"""
return await self._handler.handle(request)
async def app_mention(
self,
context: "AsyncBoltContext",
client: "AsyncWebClient",
event: Dict[str, Any],
body: Dict[str, Any],
say: "AsyncSay",
) -> None:
r"""Event handler for `app_mention` events.
This method is triggered when someone mentions the app in Slack.
Args:
context (AsyncBoltContext): The Slack Bolt context for the event.
client (AsyncWebClient): The Slack Web API client.
event (Dict[str, Any]): The event data for the app mention.
body (Dict[str, Any]): The full request body from Slack.
say (AsyncSay): A function to send a response back to the channel.
"""
event_profile = SlackAppMentionEventProfile(**event)
event_body = SlackAppMentionEventBody(**body)
logger.info(f"app_mention, context: {context}")
logger.info(f"app_mention, client: {client}")
logger.info(f"app_mention, event_profile: {event_profile}")
logger.info(f"app_mention, event_body: {event_body}")
logger.info(f"app_mention, say: {say}")
async def on_message(
self,
context: "AsyncBoltContext",
client: "AsyncWebClient",
event: Dict[str, Any],
body: Dict[str, Any],
say: "AsyncSay",
) -> None:
r"""Event handler for `message` events.
This method is triggered when the app receives a message in Slack.
Args:
context (AsyncBoltContext): The Slack Bolt context for the event.
client (AsyncWebClient): The Slack Web API client.
event (Dict[str, Any]): The event data for the message.
body (Dict[str, Any]): The full request body from Slack.
say (AsyncSay): A function to send a response back to the channel.
"""
await context.ack()
event_profile = SlackEventProfile(**event)
event_body = SlackEventBody(**body)
logger.info(f"on_message, context: {context}")
logger.info(f"on_message, client: {client}")
logger.info(f"on_message, event_profile: {event_profile}")
logger.info(f"on_message, event_body: {event_body}")
logger.info(f"on_message, say: {say}")
logger.info(f"Received message: {event_profile.text}")
def mention_me(
self, context: "AsyncBoltContext", body: SlackEventBody
) -> bool:
r"""Check if the bot is mentioned in the message.
Args:
context (AsyncBoltContext): The Slack Bolt context for the event.
body (SlackEventBody): The body of the Slack event.
Returns:
bool: True if the bot is mentioned in the message, False otherwise.
"""
message = body.event.text
bot_user_id = context.bot_user_id
mention = f"<@{bot_user_id}>"
return mention in message

View File

@ -0,0 +1,82 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import os
from typing import TYPE_CHECKING, Optional
from camel.agents import ChatAgent
from camel.messages import BaseMessage
from camel.utils import dependencies_required
# Conditionally import telebot types only for type checking
if TYPE_CHECKING:
from telebot.types import ( # type: ignore[import-untyped]
Message,
)
class TelegramBot:
r"""Represents a Telegram bot that is powered by an agent.
Attributes:
chat_agent (ChatAgent): Chat agent that will power the bot.
telegram_token (str, optional): The bot token.
"""
@dependencies_required('telebot')
def __init__(
self,
chat_agent: ChatAgent,
telegram_token: Optional[str] = None,
) -> None:
self.chat_agent = chat_agent
if not telegram_token:
self.token = os.getenv('TELEGRAM_TOKEN')
if not self.token:
raise ValueError(
"`TELEGRAM_TOKEN` not found in environment variables. "
"Get it from t.me/BotFather."
)
else:
self.token = telegram_token
import telebot # type: ignore[import-untyped]
self.bot = telebot.TeleBot(token=self.token)
# Register the message handler within the constructor
self.bot.message_handler(func=lambda message: True)(self.on_message)
def run(self) -> None:
r"""Start the Telegram bot."""
print("Telegram bot is running...")
self.bot.infinity_polling()
def on_message(self, message: 'Message') -> None:
r"""Handles incoming messages from the user.
Args:
message (types.Message): The incoming message object.
"""
self.chat_agent.reset()
if not message.text:
return
user_msg = BaseMessage.make_user_message(
role_name="User", content=message.text
)
assistant_response = self.chat_agent.step(user_msg)
self.bot.reply_to(message, assistant_response.msg.content)

View File

@ -0,0 +1,76 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from .anthropic_config import ANTHROPIC_API_PARAMS, AnthropicConfig
from .base_config import BaseConfig
from .cohere_config import COHERE_API_PARAMS, CohereConfig
from .deepseek_config import DEEPSEEK_API_PARAMS, DeepSeekConfig
from .gemini_config import Gemini_API_PARAMS, GeminiConfig
from .groq_config import GROQ_API_PARAMS, GroqConfig
from .litellm_config import LITELLM_API_PARAMS, LiteLLMConfig
from .mistral_config import MISTRAL_API_PARAMS, MistralConfig
from .nvidia_config import NVIDIA_API_PARAMS, NvidiaConfig
from .ollama_config import OLLAMA_API_PARAMS, OllamaConfig
from .openai_config import OPENAI_API_PARAMS, ChatGPTConfig
from .qwen_config import QWEN_API_PARAMS, QwenConfig
from .reka_config import REKA_API_PARAMS, RekaConfig
from .samba_config import (
SAMBA_CLOUD_API_PARAMS,
SAMBA_VERSE_API_PARAMS,
SambaCloudAPIConfig,
SambaVerseAPIConfig,
)
from .togetherai_config import TOGETHERAI_API_PARAMS, TogetherAIConfig
from .vllm_config import VLLM_API_PARAMS, VLLMConfig
from .yi_config import YI_API_PARAMS, YiConfig
from .zhipuai_config import ZHIPUAI_API_PARAMS, ZhipuAIConfig
__all__ = [
'BaseConfig',
'ChatGPTConfig',
'OPENAI_API_PARAMS',
'AnthropicConfig',
'ANTHROPIC_API_PARAMS',
'GROQ_API_PARAMS',
'GroqConfig',
'LiteLLMConfig',
'LITELLM_API_PARAMS',
'NvidiaConfig',
'NVIDIA_API_PARAMS',
'OllamaConfig',
'OLLAMA_API_PARAMS',
'ZhipuAIConfig',
'ZHIPUAI_API_PARAMS',
'GeminiConfig',
'Gemini_API_PARAMS',
'VLLMConfig',
'VLLM_API_PARAMS',
'MistralConfig',
'MISTRAL_API_PARAMS',
'RekaConfig',
'REKA_API_PARAMS',
'SambaVerseAPIConfig',
'SAMBA_VERSE_API_PARAMS',
'SambaCloudAPIConfig',
'SAMBA_CLOUD_API_PARAMS',
'TogetherAIConfig',
'TOGETHERAI_API_PARAMS',
'CohereConfig',
'COHERE_API_PARAMS',
'YiConfig',
'YI_API_PARAMS',
'QwenConfig',
'QWEN_API_PARAMS',
'DeepSeekConfig',
'DEEPSEEK_API_PARAMS',
]

View File

@ -0,0 +1,69 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import List, Union
from camel.configs.base_config import BaseConfig
from camel.types import NOT_GIVEN, NotGiven
class AnthropicConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
Anthropic API.
See: https://docs.anthropic.com/claude/reference/complete_post
Args:
max_tokens (int, optional): The maximum number of tokens to
generate before stopping. Note that Anthropic models may stop
before reaching this maximum. This parameter only specifies the
absolute maximum number of tokens to generate.
(default: :obj:`256`)
stop_sequences (List[str], optional): Sequences that will cause the
model to stop generating completion text. Anthropic models stop
on "\n\nHuman:", and may include additional built-in stop sequences
in the future. By providing the stop_sequences parameter, you may
include additional strings that will cause the model to stop
generating.
temperature (float, optional): Amount of randomness injected into the
response. Defaults to 1. Ranges from 0 to 1. Use temp closer to 0
for analytical / multiple choice, and closer to 1 for creative
and generative tasks.
(default: :obj:`1`)
top_p (float, optional): Use nucleus sampling. In nucleus sampling, we
compute the cumulative distribution over all the options for each
subsequent token in decreasing probability order and cut it off
once it reaches a particular probability specified by `top_p`.
You should either alter `temperature` or `top_p`,
but not both.
(default: :obj:`0.7`)
top_k (int, optional): Only sample from the top K options for each
subsequent token. Used to remove "long tail" low probability
responses.
(default: :obj:`5`)
metadata: An object describing metadata about the request.
stream (bool, optional): Whether to incrementally stream the response
using server-sent events. (default: :obj:`False`)
"""
max_tokens: int = 256
stop_sequences: Union[List[str], NotGiven] = NOT_GIVEN
temperature: float = 1
top_p: Union[float, NotGiven] = NOT_GIVEN
top_k: Union[int, NotGiven] = NOT_GIVEN
metadata: NotGiven = NOT_GIVEN
stream: bool = False
ANTHROPIC_API_PARAMS = {param for param in AnthropicConfig.model_fields.keys()}

View File

@ -0,0 +1,89 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from abc import ABC
from typing import Any, List, Optional
from pydantic import BaseModel, ConfigDict, field_validator
class BaseConfig(ABC, BaseModel):
r"""Base configuration class for all models.
This class provides a common interface for all models, ensuring that all
models have a consistent set of attributes and methods.
"""
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra="forbid",
frozen=True,
# UserWarning: conflict with protected namespace "model_"
protected_namespaces=(),
)
tools: Optional[List[Any]] = None
"""A list of tools the model may
call. Currently, only functions are supported as a tool. Use this
to provide a list of functions the model may generate JSON inputs
for. A max of 128 functions are supported.
"""
@field_validator("tools", mode="before")
@classmethod
def fields_type_checking(cls, tools):
r"""Validate the type of tools in the configuration.
This method ensures that the tools provided in the configuration are
instances of `FunctionTool`. If any tool is not an instance of
`FunctionTool`, it raises a ValueError.
"""
if tools is not None:
from camel.toolkits import FunctionTool
for tool in tools:
if not isinstance(tool, FunctionTool):
raise ValueError(
f"The tool {tool} should "
"be an instance of `FunctionTool`."
)
return tools
def as_dict(self) -> dict[str, Any]:
r"""Convert the current configuration to a dictionary.
This method converts the current configuration object to a dictionary
representation, which can be used for serialization or other purposes.
Returns:
dict[str, Any]: A dictionary representation of the current
configuration.
"""
config_dict = self.model_dump()
tools_schema = None
if self.tools:
from camel.toolkits import FunctionTool
tools_schema = []
for tool in self.tools:
if not isinstance(tool, FunctionTool):
raise ValueError(
f"The tool {tool} should "
"be an instance of `FunctionTool`."
)
tools_schema.append(tool.get_openai_tool_schema())
config_dict["tools"] = tools_schema
return config_dict

View File

@ -0,0 +1,76 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import List, Optional
from camel.configs.base_config import BaseConfig
class CohereConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
Cohere API.
Args:
temperature (float, optional): Sampling temperature to use, between
:obj:`0` and :obj:`2`. Higher values make the output more random,
while lower values make it more focused and deterministic.
(default: :obj:`0.3`)
documents (list, optional): A list of relevant documents that the
model can cite to generate a more accurate reply. Each document is
either a string or document object with content and metadata.
(default: :obj:`None`)
max_tokens (int, optional): The maximum number of tokens the model
will generate as part of the response. (default: :obj:`None`)
stop_sequences (List(str), optional): A list of up to 5 strings that
the model will use to stop generation. If the model generates a
string that matches any of the strings in the list, it will stop
generating tokens and return the generated text up to that point
not including the stop sequence. (default: :obj:`None`)
seed (int, optional): If specified, the backend will make a best
effort to sample tokens deterministically, such that repeated
requests with the same seed and parameters should return the same
result. However, determinism cannot be totally guaranteed.
(default: :obj:`None`)
frequency_penalty (float, optional): Min value of `0.0`, max value of
`1.0`. Used to reduce repetitiveness of generated tokens. The
higher the value, the stronger a penalty is applied to previously
present tokens, proportional to how many times they have already
appeared in the prompt or prior generation. (default: :obj:`0.0`)
presence_penalty (float, optional): Min value of `0.0`, max value of
`1.0`. Used to reduce repetitiveness of generated tokens. Similar
to `frequency_penalty`, except that this penalty is applied
equally to all tokens that have already appeared, regardless of
their exact frequencies. (default: :obj:`0.0`)
k (int, optional): Ensures only the top k most likely tokens are
considered for generation at each step. Min value of `0`, max
value of `500`. (default: :obj:`0`)
p (float, optional): Ensures that only the most likely tokens, with
total probability mass of `p`, are considered for generation at
each step. If both k and p are enabled, `p` acts after `k`. Min
value of `0.01`, max value of `0.99`. (default: :obj:`0.75`)
"""
temperature: Optional[float] = 0.2
documents: Optional[list] = None
max_tokens: Optional[int] = None
stop_sequences: Optional[List[str]] = None
seed: Optional[int] = None
frequency_penalty: Optional[float] = 0.0
presence_penalty: Optional[float] = 0.0
k: Optional[int] = 0
p: Optional[float] = 0.75
COHERE_API_PARAMS = {param for param in CohereConfig().model_fields.keys()}

View File

@ -0,0 +1,134 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import Any, Optional, Sequence, Type, Union
from pydantic import BaseModel
from camel.configs.base_config import BaseConfig
from camel.types import NOT_GIVEN, NotGiven
class DeepSeekConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
DeepSeek API.
Args:
temperature (float, optional): Sampling temperature to use, between
:obj:`0` and :obj:`2`. Higher values make the output more random,
while lower values make it more focused and deterministic.
(default: :obj:`0.2`)
top_p (float, optional): Controls the diversity and focus of the
generated results. Higher values make the output more diverse,
while lower values make it more focused. (default: :obj:`1.0`)
response_format (object, optional): Specifies the format of the
returned content. The available values are `{"type": "text"}` or
`{"type": "json_object"}`. Setting it to `{"type": "json_object"}`
will output a standard JSON string.
(default: :obj:`{"type": "text"}`)
stream (bool, optional): If set, partial message deltas will be sent.
Tokens will be sent as data-only server-sent events (SSE) as
they become available, with the stream terminated by a
data: [DONE] message. (default: :obj:`False`)
stop (Union[str, list[str]], optional): Up to 16 sequences where
the API will stop generating further tokens. (default: :obj:`None`)
max_tokens (int, optional): The maximum number of tokens that can
be generated in the chat completion. The total length of input
tokens and generated tokens is limited by the model's context
length. (default: :obj:`None`)
presence_penalty (float, optional): Number between -2.0 and 2.0.
Positive values penalize new tokens based on whether they
appear in the text so far, increasing the model's likelihood
to talk about new topics. (default: :obj:`0.0`)
frequency_penalty (float, optional): Number between -2.0 and 2.0.
Positive values penalize new tokens based on their existing
frequency in the text so far, decreasing the model's likelihood
to repeat the same line verbatim. (default: :obj:`0`)
tools (list[FunctionTool], optional): A list of tools the model may
call. Currently, only functions are supported as a tool. Use
this to provide a list of functions the model may generate JSON
inputs for. A max of 128 functions are supported.
(default: :obj:`None`)
tool_choice (Union[dict[str, str], str], optional): Controls which
(if any) tool is called by the model. "none" means the model
will not call any tool and instead generates a message. "auto"
means the model can pick between generating a message or calling
one or more tools. "required" means the model must call one or
more tools. Specifying a particular tool via
{"type": "function", "function": {"name": "my_function"}} forces
the model to call that tool. "none" is the default when no tools
are present. "auto" is the default if tools are present.
(default: :obj:`"auto"`)
logprobs (bool, optional): Whether to return log probabilities of
the output tokens or not. If true, returns the log probabilities
of each output token returned in the content of message.
(default: :obj:`False`)
top_logprobs (int, optional): An integer between 0 and 20 specifying
the number of most likely tokens to return at each token
position, each with an associated log probability. logprobs
must be set to true if this parameter is used.
(default: :obj:`None`)
include_usage (bool, optional): When streaming, specifies whether to
include usage information in `stream_options`. (default:
:obj:`True`)
"""
temperature: float = 0.2 # deepseek default: 1.0
top_p: float = 1.0
stream: bool = False
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
max_tokens: Union[int, NotGiven] = NOT_GIVEN
presence_penalty: float = 0.0
response_format: Union[Type[BaseModel], dict, NotGiven] = NOT_GIVEN
frequency_penalty: float = 0.0
tool_choice: Optional[Union[dict[str, str], str]] = None
logprobs: bool = False
top_logprobs: Optional[int] = None
def __init__(self, include_usage: bool = True, **kwargs):
super().__init__(**kwargs)
# Only set stream_options when stream is True
# Otherwise, it will raise error when calling the API
if self.stream:
self.stream_options = {"include_usage": include_usage}
def as_dict(self) -> dict[str, Any]:
r"""Convert the current configuration to a dictionary.
This method converts the current configuration object to a dictionary
representation, which can be used for serialization or other purposes.
Returns:
dict[str, Any]: A dictionary representation of the current
configuration.
"""
config_dict = self.model_dump()
if self.tools:
from camel.toolkits import FunctionTool
tools_schema = []
for tool in self.tools:
if not isinstance(tool, FunctionTool):
raise ValueError(
f"The tool {tool} should "
"be an instance of `FunctionTool`."
)
tools_schema.append(tool.get_openai_tool_schema())
config_dict["tools"] = NOT_GIVEN
return config_dict
DEEPSEEK_API_PARAMS = {param for param in DeepSeekConfig.model_fields.keys()}

View File

@ -0,0 +1,114 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import Any, Optional, Sequence, Type, Union
from pydantic import BaseModel
from camel.configs.base_config import BaseConfig
from camel.types import NOT_GIVEN, NotGiven
class GeminiConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
Gemini API.
Args:
temperature (float, optional): Sampling temperature to use, between
:obj:`0` and :obj:`2`. Higher values make the output more random,
while lower values make it more focused and deterministic.
(default: :obj:`0.2`)
top_p (float, optional): An alternative to sampling with temperature,
called nucleus sampling, where the model considers the results of
the tokens with top_p probability mass. So :obj:`0.1` means only
the tokens comprising the top 10% probability mass are considered.
(default: :obj:`1.0`)
n (int, optional): How many chat completion choices to generate for
each input message. (default: :obj:`1`)
response_format (object, optional): An object specifying the format
that the model must output. Compatible with GPT-4 Turbo and all
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
{"type": "json_object"} enables JSON mode, which guarantees the
message the model generates is valid JSON. Important: when using
JSON mode, you must also instruct the model to produce JSON
yourself via a system or user message. Without this, the model
may generate an unending stream of whitespace until the generation
reaches the token limit, resulting in a long-running and seemingly
"stuck" request. Also note that the message content may be
partially cut off if finish_reason="length", which indicates the
generation exceeded max_tokens or the conversation exceeded the
max context length.
stream (bool, optional): If True, partial message deltas will be sent
as data-only server-sent events as they become available.
(default: :obj:`False`)
stop (str or list, optional): Up to :obj:`4` sequences where the API
will stop generating further tokens. (default: :obj:`None`)
max_tokens (int, optional): The maximum number of tokens to generate
in the chat completion. The total length of input tokens and
generated tokens is limited by the model's context length.
(default: :obj:`None`)
tools (list[FunctionTool], optional): A list of tools the model may
call. Currently, only functions are supported as a tool. Use this
to provide a list of functions the model may generate JSON inputs
for. A max of 128 functions are supported.
tool_choice (Union[dict[str, str], str], optional): Controls which (if
any) tool is called by the model. :obj:`"none"` means the model
will not call any tool and instead generates a message.
:obj:`"auto"` means the model can pick between generating a
message or calling one or more tools. :obj:`"required"` means the
model must call one or more tools. Specifying a particular tool
via {"type": "function", "function": {"name": "my_function"}}
forces the model to call that tool. :obj:`"none"` is the default
when no tools are present. :obj:`"auto"` is the default if tools
are present.
"""
temperature: float = 0.2 # openai default: 1.0
top_p: float = 1.0
n: int = 1
stream: bool = False
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
max_tokens: Union[int, NotGiven] = NOT_GIVEN
response_format: Union[Type[BaseModel], dict, NotGiven] = NOT_GIVEN
tool_choice: Optional[Union[dict[str, str], str]] = None
def as_dict(self) -> dict[str, Any]:
r"""Convert the current configuration to a dictionary.
This method converts the current configuration object to a dictionary
representation, which can be used for serialization or other purposes.
Returns:
dict[str, Any]: A dictionary representation of the current
configuration.
"""
config_dict = self.model_dump()
if self.tools:
from camel.toolkits import FunctionTool
tools_schema = []
for tool in self.tools:
if not isinstance(tool, FunctionTool):
raise ValueError(
f"The tool {tool} should "
"be an instance of `FunctionTool`."
)
tools_schema.append(tool.get_openai_tool_schema())
config_dict["tools"] = NOT_GIVEN
return config_dict
Gemini_API_PARAMS = {param for param in GeminiConfig.model_fields.keys()}

View File

@ -0,0 +1,104 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import Optional, Sequence, Union
from camel.configs.base_config import BaseConfig
from camel.types import NOT_GIVEN, NotGiven
class GroqConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using OpenAI
compatibility.
Reference: https://console.groq.com/docs/openai
Args:
temperature (float, optional): Sampling temperature to use, between
:obj:`0` and :obj:`2`. Higher values make the output more random,
while lower values make it more focused and deterministic.
(default: :obj:`0.2`)
top_p (float, optional): An alternative to sampling with temperature,
called nucleus sampling, where the model considers the results of
the tokens with top_p probability mass. So :obj:`0.1` means only
the tokens comprising the top 10% probability mass are considered.
(default: :obj:`1.0`)
n (int, optional): How many chat completion choices to generate for
each input message. (default: :obj:`1`)
response_format (object, optional): An object specifying the format
that the model must output. Compatible with GPT-4 Turbo and all
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
{"type": "json_object"} enables JSON mode, which guarantees the
message the model generates is valid JSON. Important: when using
JSON mode, you must also instruct the model to produce JSON
yourself via a system or user message. Without this, the model
may generate an unending stream of whitespace until the generation
reaches the token limit, resulting in a long-running and seemingly
"stuck" request. Also note that the message content may be
partially cut off if finish_reason="length", which indicates the
generation exceeded max_tokens or the conversation exceeded the
max context length.
stream (bool, optional): If True, partial message deltas will be sent
as data-only server-sent events as they become available.
(default: :obj:`False`)
stop (str or list, optional): Up to :obj:`4` sequences where the API
will stop generating further tokens. (default: :obj:`None`)
max_tokens (int, optional): The maximum number of tokens to generate
in the chat completion. The total length of input tokens and
generated tokens is limited by the model's context length.
(default: :obj:`None`)
presence_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on whether
they appear in the text so far, increasing the model's likelihood
to talk about new topics. See more information about frequency and
presence penalties. (default: :obj:`0.0`)
frequency_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on their
existing frequency in the text so far, decreasing the model's
likelihood to repeat the same line verbatim. See more information
about frequency and presence penalties. (default: :obj:`0.0`)
user (str, optional): A unique identifier representing your end-user,
which can help OpenAI to monitor and detect abuse.
(default: :obj:`""`)
tools (list[FunctionTool], optional): A list of tools the model may
call. Currently, only functions are supported as a tool. Use this
to provide a list of functions the model may generate JSON inputs
for. A max of 128 functions are supported.
tool_choice (Union[dict[str, str], str], optional): Controls which (if
any) tool is called by the model. :obj:`"none"` means the model
will not call any tool and instead generates a message.
:obj:`"auto"` means the model can pick between generating a
message or calling one or more tools. :obj:`"required"` means the
model must call one or more tools. Specifying a particular tool
via {"type": "function", "function": {"name": "my_function"}}
forces the model to call that tool. :obj:`"none"` is the default
when no tools are present. :obj:`"auto"` is the default if tools
are present.
"""
temperature: float = 0.2 # openai default: 1.0
top_p: float = 1.0
n: int = 1
stream: bool = False
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
max_tokens: Union[int, NotGiven] = NOT_GIVEN
presence_penalty: float = 0.0
response_format: Union[dict, NotGiven] = NOT_GIVEN
frequency_penalty: float = 0.0
user: str = ""
tool_choice: Optional[Union[dict[str, str], str]] = "auto"
GROQ_API_PARAMS = {param for param in GroqConfig.model_fields.keys()}

View File

@ -0,0 +1,97 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import List, Optional, Union
from camel.configs.base_config import BaseConfig
class LiteLLMConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
LiteLLM API.
Args:
timeout (Optional[Union[float, str]], optional): Request timeout.
(default: None)
temperature (Optional[float], optional): Temperature parameter for
controlling randomness. (default: None)
top_p (Optional[float], optional): Top-p parameter for nucleus
sampling. (default: None)
n (Optional[int], optional): Number of completions to generate.
(default: None)
stream (Optional[bool], optional): Whether to return a streaming
response. (default: None)
stream_options (Optional[dict], optional): Options for the streaming
response. (default: None)
stop (Optional[Union[str, List[str]]], optional): Sequences where the
API will stop generating further tokens. (default: None)
max_tokens (Optional[int], optional): Maximum number of tokens to
generate. (default: None)
presence_penalty (Optional[float], optional): Penalize new tokens
based on their existence in the text so far. (default: None)
frequency_penalty (Optional[float], optional): Penalize new tokens
based on their frequency in the text so far. (default: None)
logit_bias (Optional[dict], optional): Modify the probability of
specific tokens appearing in the completion. (default: None)
user (Optional[str], optional): A unique identifier representing the
end-user. (default: None)
response_format (Optional[dict], optional): Response format
parameters. (default: None)
seed (Optional[int], optional): Random seed. (default: None)
tools (Optional[List], optional): List of tools. (default: None)
tool_choice (Optional[Union[str, dict]], optional): Tool choice
parameters. (default: None)
logprobs (Optional[bool], optional): Whether to return log
probabilities of the output tokens. (default: None)
top_logprobs (Optional[int], optional): Number of most likely tokens
to return at each token position. (default: None)
deployment_id (Optional[str], optional): Deployment ID. (default: None)
extra_headers (Optional[dict], optional): Additional headers for the
request. (default: None)
api_version (Optional[str], optional): API version. (default: None)
mock_response (Optional[str], optional): Mock completion response for
testing or debugging. (default: None)
custom_llm_provider (Optional[str], optional): Non-OpenAI LLM
provider. (default: None)
max_retries (Optional[int], optional): Maximum number of retries.
(default: None)
"""
timeout: Optional[Union[float, str]] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
n: Optional[int] = None
stream: Optional[bool] = None
stream_options: Optional[dict] = None
stop: Optional[Union[str, List[str]]] = None
max_tokens: Optional[int] = None
presence_penalty: Optional[float] = None
frequency_penalty: Optional[float] = None
logit_bias: Optional[dict] = None
user: Optional[str] = None
response_format: Optional[dict] = None
seed: Optional[int] = None
tool_choice: Optional[Union[str, dict]] = None
logprobs: Optional[bool] = None
top_logprobs: Optional[int] = None
deployment_id: Optional[str] = None
extra_headers: Optional[dict] = None
api_version: Optional[str] = None
mock_response: Optional[str] = None
custom_llm_provider: Optional[str] = None
max_retries: Optional[int] = None
LITELLM_API_PARAMS = {param for param in LiteLLMConfig.model_fields.keys()}

View File

@ -0,0 +1,79 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import Any, Dict, Optional, Union
from pydantic import field_validator
from camel.configs.base_config import BaseConfig
class MistralConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
Mistral API.
reference: https://github.com/mistralai/client-python/blob/9d238f88c41689821d7b08570f13b43426f97fd6/src/mistralai/client.py#L195
#TODO: Support stream mode
Args:
temperature (Optional[float], optional): temperature the temperature
to use for sampling, e.g. 0.5.
top_p (Optional[float], optional): the cumulative probability of
tokens to generate, e.g. 0.9. Defaults to None.
max_tokens (Optional[int], optional): the maximum number of tokens to
generate, e.g. 100. Defaults to None.
stop (Optional[Union[str,list[str]]]): Stop generation if this token
is detected. Or if one of these tokens is detected when providing
a string list.
random_seed (Optional[int], optional): the random seed to use for
sampling, e.g. 42. Defaults to None.
safe_prompt (bool, optional): whether to use safe prompt, e.g. true.
Defaults to False.
response_format (Union[Dict[str, str], ResponseFormat): format of the
response.
tool_choice (str, optional): Controls which (if
any) tool is called by the model. :obj:`"none"` means the model
will not call any tool and instead generates a message.
:obj:`"auto"` means the model can pick between generating a
message or calling one or more tools. :obj:`"any"` means the
model must call one or more tools. :obj:`"auto"` is the default
value.
"""
temperature: Optional[float] = None
top_p: Optional[float] = None
max_tokens: Optional[int] = None
stop: Optional[Union[str, list[str]]] = None
random_seed: Optional[int] = None
safe_prompt: bool = False
response_format: Optional[Union[Dict[str, str], Any]] = None
tool_choice: Optional[str] = "auto"
@field_validator("response_format", mode="before")
@classmethod
def fields_type_checking(cls, response_format):
if response_format and not isinstance(response_format, dict):
from mistralai.models import ResponseFormat
if not isinstance(response_format, ResponseFormat):
raise ValueError(
f"The tool {response_format} should be an instance "
"of `mistralai.models.ResponseFormat`."
)
return response_format
MISTRAL_API_PARAMS = {param for param in MistralConfig().model_fields.keys()}

View File

@ -0,0 +1,70 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import List, Optional, Union
from pydantic import Field
from camel.configs.base_config import BaseConfig
from camel.types import NOT_GIVEN, NotGiven
class NvidiaConfig(BaseConfig):
r"""Configuration class for NVIDIA API models.
This class defines the configuration parameters for NVIDIA's language
models, including temperature, sampling parameters, and response format
settings.
Args:
stream (bool, optional): Whether to stream the response.
(default: :obj:`False`)
temperature (float, optional): Controls randomness in the response.
Higher values make output more random, lower values make it more
deterministic. Range: [0.0, 2.0]. (default: :obj:`0.7`)
top_p (float, optional): Controls diversity via nucleus sampling.
Range: [0.0, 1.0]. (default: :obj:`0.95`)
presence_penalty (float, optional): Penalizes new tokens based on
whether they appear in the text so far. Range: [-2.0, 2.0].
(default: :obj:`0.0`)
frequency_penalty (float, optional): Penalizes new tokens based on
their frequency in the text so far. Range: [-2.0, 2.0].
(default: :obj:`0.0`)
max_tokens (Union[int, NotGiven], optional): Maximum number of tokens
to generate. If not provided, model will use its default maximum.
(default: :obj:`NOT_GIVEN`)
seed (Optional[int], optional): Random seed for deterministic sampling.
(default: :obj:`None`)
tools (Optional[List[Dict]], optional): List of tools available to the
model. This includes tools such as a text editor, a calculator, or
a search engine. (default: :obj:`None`)
tool_choice (Optional[str], optional): Tool choice configuration.
(default: :obj:`None`)
stop (Optional[List[str]], optional): List of stop sequences.
(default: :obj:`None`)
"""
stream: bool = Field(default=False)
temperature: float = Field(default=0.7)
top_p: float = Field(default=0.95)
presence_penalty: float = Field(default=0.0)
frequency_penalty: float = Field(default=0.0)
max_tokens: Union[int, NotGiven] = Field(default=NOT_GIVEN)
seed: Optional[int] = Field(default=None)
tool_choice: Optional[str] = Field(default=None)
stop: Optional[List[str]] = Field(default=None)
NVIDIA_API_PARAMS = {param for param in NvidiaConfig.model_fields.keys()}

View File

@ -0,0 +1,82 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import Sequence, Union
from camel.configs.base_config import BaseConfig
from camel.types import NOT_GIVEN, NotGiven
class OllamaConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using OpenAI
compatibility
Reference: https://github.com/ollama/ollama/blob/main/docs/openai.md
Args:
temperature (float, optional): Sampling temperature to use, between
:obj:`0` and :obj:`2`. Higher values make the output more random,
while lower values make it more focused and deterministic.
(default: :obj:`0.2`)
top_p (float, optional): An alternative to sampling with temperature,
called nucleus sampling, where the model considers the results of
the tokens with top_p probability mass. So :obj:`0.1` means only
the tokens comprising the top 10% probability mass are considered.
(default: :obj:`1.0`)
response_format (object, optional): An object specifying the format
that the model must output. Compatible with GPT-4 Turbo and all
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
{"type": "json_object"} enables JSON mode, which guarantees the
message the model generates is valid JSON. Important: when using
JSON mode, you must also instruct the model to produce JSON
yourself via a system or user message. Without this, the model
may generate an unending stream of whitespace until the generation
reaches the token limit, resulting in a long-running and seemingly
"stuck" request. Also note that the message content may be
partially cut off if finish_reason="length", which indicates the
generation exceeded max_tokens or the conversation exceeded the
max context length.
stream (bool, optional): If True, partial message deltas will be sent
as data-only server-sent events as they become available.
(default: :obj:`False`)
stop (str or list, optional): Up to :obj:`4` sequences where the API
will stop generating further tokens. (default: :obj:`None`)
max_tokens (int, optional): The maximum number of tokens to generate
in the chat completion. The total length of input tokens and
generated tokens is limited by the model's context length.
(default: :obj:`None`)
presence_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on whether
they appear in the text so far, increasing the model's likelihood
to talk about new topics. See more information about frequency and
presence penalties. (default: :obj:`0.0`)
frequency_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on their
existing frequency in the text so far, decreasing the model's
likelihood to repeat the same line verbatim. See more information
about frequency and presence penalties. (default: :obj:`0.0`)
"""
temperature: float = 0.2
top_p: float = 1.0
stream: bool = False
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
max_tokens: Union[int, NotGiven] = NOT_GIVEN
presence_penalty: float = 0.0
response_format: Union[dict, NotGiven] = NOT_GIVEN
frequency_penalty: float = 0.0
OLLAMA_API_PARAMS = {param for param in OllamaConfig.model_fields.keys()}

View File

@ -0,0 +1,139 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import Any, Optional, Sequence, Type, Union
from pydantic import BaseModel, Field
from camel.configs.base_config import BaseConfig
from camel.types import NOT_GIVEN, NotGiven
class ChatGPTConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
OpenAI API.
Args:
temperature (float, optional): Sampling temperature to use, between
:obj:`0` and :obj:`2`. Higher values make the output more random,
while lower values make it more focused and deterministic.
(default: :obj:`0.2`)
top_p (float, optional): An alternative to sampling with temperature,
called nucleus sampling, where the model considers the results of
the tokens with top_p probability mass. So :obj:`0.1` means only
the tokens comprising the top 10% probability mass are considered.
(default: :obj:`1.0`)
n (int, optional): How many chat completion choices to generate for
each input message. (default: :obj:`1`)
response_format (object, optional): An object specifying the format
that the model must output. Compatible with GPT-4 Turbo and all
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
{"type": "json_object"} enables JSON mode, which guarantees the
message the model generates is valid JSON. Important: when using
JSON mode, you must also instruct the model to produce JSON
yourself via a system or user message. Without this, the model
may generate an unending stream of whitespace until the generation
reaches the token limit, resulting in a long-running and seemingly
"stuck" request. Also note that the message content may be
partially cut off if finish_reason="length", which indicates the
generation exceeded max_tokens or the conversation exceeded the
max context length.
stream (bool, optional): If True, partial message deltas will be sent
as data-only server-sent events as they become available.
(default: :obj:`False`)
stop (str or list, optional): Up to :obj:`4` sequences where the API
will stop generating further tokens. (default: :obj:`None`)
max_tokens (int, optional): The maximum number of tokens to generate
in the chat completion. The total length of input tokens and
generated tokens is limited by the model's context length.
(default: :obj:`None`)
presence_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on whether
they appear in the text so far, increasing the model's likelihood
to talk about new topics. See more information about frequency and
presence penalties. (default: :obj:`0.0`)
frequency_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on their
existing frequency in the text so far, decreasing the model's
likelihood to repeat the same line verbatim. See more information
about frequency and presence penalties. (default: :obj:`0.0`)
logit_bias (dict, optional): Modify the likelihood of specified tokens
appearing in the completion. Accepts a json object that maps tokens
(specified by their token ID in the tokenizer) to an associated
bias value from :obj:`-100` to :obj:`100`. Mathematically, the bias
is added to the logits generated by the model prior to sampling.
The exact effect will vary per model, but values between:obj:` -1`
and :obj:`1` should decrease or increase likelihood of selection;
values like :obj:`-100` or :obj:`100` should result in a ban or
exclusive selection of the relevant token. (default: :obj:`{}`)
user (str, optional): A unique identifier representing your end-user,
which can help OpenAI to monitor and detect abuse.
(default: :obj:`""`)
tools (list[FunctionTool], optional): A list of tools the model may
call. Currently, only functions are supported as a tool. Use this
to provide a list of functions the model may generate JSON inputs
for. A max of 128 functions are supported.
tool_choice (Union[dict[str, str], str], optional): Controls which (if
any) tool is called by the model. :obj:`"none"` means the model
will not call any tool and instead generates a message.
:obj:`"auto"` means the model can pick between generating a
message or calling one or more tools. :obj:`"required"` means the
model must call one or more tools. Specifying a particular tool
via {"type": "function", "function": {"name": "my_function"}}
forces the model to call that tool. :obj:`"none"` is the default
when no tools are present. :obj:`"auto"` is the default if tools
are present.
"""
temperature: float = 0.2 # openai default: 1.0
top_p: float = 1.0
n: int = 1
stream: bool = False
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
max_tokens: Union[int, NotGiven] = NOT_GIVEN
presence_penalty: float = 0.0
response_format: Union[Type[BaseModel], dict, NotGiven] = NOT_GIVEN
frequency_penalty: float = 0.0
logit_bias: dict = Field(default_factory=dict)
user: str = ""
tool_choice: Optional[Union[dict[str, str], str]] = None
def as_dict(self) -> dict[str, Any]:
r"""Convert the current configuration to a dictionary.
This method converts the current configuration object to a dictionary
representation, which can be used for serialization or other purposes.
Returns:
dict[str, Any]: A dictionary representation of the current
configuration.
"""
config_dict = self.model_dump()
if self.tools:
from camel.toolkits import FunctionTool
tools_schema = []
for tool in self.tools:
if not isinstance(tool, FunctionTool):
raise ValueError(
f"The tool {tool} should "
"be an instance of `FunctionTool`."
)
tools_schema.append(tool.get_openai_tool_schema())
config_dict["tools"] = NOT_GIVEN
return config_dict
OPENAI_API_PARAMS = {param for param in ChatGPTConfig.model_fields.keys()}

View File

@ -0,0 +1,91 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import ClassVar, Optional, Union
from camel.configs.base_config import BaseConfig
from camel.types import NOT_GIVEN, NotGiven
class QwenConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
Qwen API. You can refer to the following link for more details:
https://help.aliyun.com/zh/model-studio/developer-reference/use-qwen-by-calling-api
Args:
stream (bool, optional): Whether to stream the response.
(default: :obj:`False`)
temperature (float, optional): Controls the diversity and focus of
the generated results. Lower values make the output more focused,
while higher values make it more diverse. (default: :obj:`0.3`)
top_p (float, optional): Controls the diversity and focus of the
generated results. Higher values make the output more diverse,
while lower values make it more focused. (default: :obj:`0.9`)
presence_penalty (float, optional): Controls the repetition of
content in the generated results. Positive values reduce the
repetition of content, while negative values increase it.
(default: :obj:`0.0`)
response_format (object, optional): Specifies the format of the
returned content. The available values are `{"type": "text"}` or
`{"type": "json_object"}`. Setting it to `{"type": "json_object"}`
will output a standard JSON string.
(default: :obj:`{"type": "text"}`)
max_tokens (Union[int, NotGiven], optional): Allows the model to
generate the maximum number of tokens.
(default: :obj:`NOT_GIVEN`)
seed (int, optional): Sets the seed parameter to make the text
generation process more deterministic, typically used to ensure
that the results are consistent across model runs. By passing the
same seed value (specified by you) in each model call while
keeping other parameters unchanged, the model is likely to return
the same result.
(default: :obj:`None`)
stop (str or list, optional): Using the stop parameter, the model will
automatically stop generating text when it is about to include the
specified string or token_id. You can use the stop parameter to
control the output of the model by passing sensitive words.
(default: :obj:`None`)
tools (list, optional): Specifies an array of tools that the model can
call. It can contain one or more tool objects. During a function
call process, the model will select one tool from the array.
(default: :obj:`None`)
extra_body (dict, optional): Additional parameters to be sent to the
Qwen API. If you want to enable internet search, you can set this
parameter to `{"enable_search": True}`.
(default: :obj:`{"enable_search": False}`)
include_usage (bool, optional): When streaming, specifies whether to
include usage information in `stream_options`. (default:
:obj:`True`)
"""
stream: bool = False
temperature: float = 0.3
top_p: float = 0.9
presence_penalty: float = 0.0
response_format: ClassVar[dict] = {"type": "text"}
max_tokens: Union[int, NotGiven] = NOT_GIVEN
seed: Optional[int] = None
stop: Optional[Union[str, list]] = None
extra_body: ClassVar[dict] = {"enable_search": False}
def __init__(self, include_usage: bool = True, **kwargs):
super().__init__(**kwargs)
# Only set stream_options when stream is True
# Otherwise, it will raise error when calling the API
if self.stream:
self.stream_options = {"include_usage": include_usage}
QWEN_API_PARAMS = {param for param in QwenConfig.model_fields.keys()}

View File

@ -0,0 +1,74 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import Any, Optional, Union
from camel.configs.base_config import BaseConfig
class RekaConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
Reka API.
Reference: https://docs.reka.ai/api-reference/chat/create
Args:
temperature (Optional[float], optional): temperature the temperature
to use for sampling, e.g. 0.5.
top_p (Optional[float], optional): the cumulative probability of
tokens to generate, e.g. 0.9. Defaults to None.
top_k (Optional[int], optional): Parameter which forces the model to
only consider the tokens with the `top_k` highest probabilities at
the next step. Defaults to 1024.
max_tokens (Optional[int], optional): the maximum number of tokens to
generate, e.g. 100. Defaults to None.
stop (Optional[Union[str,list[str]]]): Stop generation if this token
is detected. Or if one of these tokens is detected when providing
a string list.
seed (Optional[int], optional): the random seed to use for sampling, e.
g. 42. Defaults to None.
presence_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on whether
they appear in the text so far, increasing the model's likelihood
to talk about new topics. See more information about frequency and
presence penalties. (default: :obj:`0.0`)
frequency_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on their
existing frequency in the text so far, decreasing the model's
likelihood to repeat the same line verbatim. See more information
about frequency and presence penalties. (default: :obj:`0.0`)
use_search_engine (Optional[bool]): Whether to consider using search
engine to complete the request. Note that even if this is set to
`True`, the model might decide to not use search.
"""
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
max_tokens: Optional[int] = None
stop: Optional[Union[str, list[str]]] = None
seed: Optional[int] = None
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
use_search_engine: Optional[bool] = False
def as_dict(self) -> dict[str, Any]:
config_dict = super().as_dict()
if "tools" in config_dict:
del config_dict["tools"] # Reka does not support tool calling
return config_dict
REKA_API_PARAMS = {param for param in RekaConfig().model_fields.keys()}

View File

@ -0,0 +1,170 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import Any, Optional, Sequence, Union
from pydantic import Field
from camel.configs.base_config import BaseConfig
from camel.types import NOT_GIVEN, NotGiven
class SambaVerseAPIConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
SambaVerse API.
Args:
temperature (float, optional): Sampling temperature to use, between
:obj:`0` and :obj:`2`. Higher values make the output more random,
while lower values make it more focused and deterministic.
(default: :obj:`0.7`)
top_p (float, optional): An alternative to sampling with temperature,
called nucleus sampling, where the model considers the results of
the tokens with top_p probability mass. So :obj:`0.1` means only
the tokens comprising the top 10% probability mass are considered.
(default: :obj:`0.95`)
top_k (int, optional): Only sample from the top K options for each
subsequent token. Used to remove "long tail" low probability
responses.
(default: :obj:`50`)
max_tokens (Optional[int], optional): The maximum number of tokens to
generate, e.g. 100.
(default: :obj:`2048`)
repetition_penalty (Optional[float], optional): The parameter for
repetition penalty. 1.0 means no penalty.
(default: :obj:`1.0`)
stop (Optional[Union[str,list[str]]]): Stop generation if this token
is detected. Or if one of these tokens is detected when providing
a string list.
(default: :obj:`""`)
stream (Optional[bool]): If True, partial message deltas will be sent
as data-only server-sent events as they become available.
Currently SambaVerse API doesn't support stream mode.
(default: :obj:`False`)
"""
temperature: Optional[float] = 0.7
top_p: Optional[float] = 0.95
top_k: Optional[int] = 50
max_tokens: Optional[int] = 2048
repetition_penalty: Optional[float] = 1.0
stop: Optional[Union[str, list[str]]] = ""
stream: Optional[bool] = False
def as_dict(self) -> dict[str, Any]:
config_dict = super().as_dict()
if "tools" in config_dict:
del config_dict["tools"] # SambaNova does not support tool calling
return config_dict
SAMBA_VERSE_API_PARAMS = {
param for param in SambaVerseAPIConfig().model_fields.keys()
}
class SambaCloudAPIConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
OpenAI API.
Args:
temperature (float, optional): Sampling temperature to use, between
:obj:`0` and :obj:`2`. Higher values make the output more random,
while lower values make it more focused and deterministic.
(default: :obj:`0.2`)
top_p (float, optional): An alternative to sampling with temperature,
called nucleus sampling, where the model considers the results of
the tokens with top_p probability mass. So :obj:`0.1` means only
the tokens comprising the top 10% probability mass are considered.
(default: :obj:`1.0`)
n (int, optional): How many chat completion choices to generate for
each input message. (default: :obj:`1`)
response_format (object, optional): An object specifying the format
that the model must output. Compatible with GPT-4 Turbo and all
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
{"type": "json_object"} enables JSON mode, which guarantees the
message the model generates is valid JSON. Important: when using
JSON mode, you must also instruct the model to produce JSON
yourself via a system or user message. Without this, the model
may generate an unending stream of whitespace until the generation
reaches the token limit, resulting in a long-running and seemingly
"stuck" request. Also note that the message content may be
partially cut off if finish_reason="length", which indicates the
generation exceeded max_tokens or the conversation exceeded the
max context length.
stream (bool, optional): If True, partial message deltas will be sent
as data-only server-sent events as they become available.
(default: :obj:`False`)
stop (str or list, optional): Up to :obj:`4` sequences where the API
will stop generating further tokens. (default: :obj:`None`)
max_tokens (int, optional): The maximum number of tokens to generate
in the chat completion. The total length of input tokens and
generated tokens is limited by the model's context length.
(default: :obj:`None`)
presence_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on whether
they appear in the text so far, increasing the model's likelihood
to talk about new topics. See more information about frequency and
presence penalties. (default: :obj:`0.0`)
frequency_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on their
existing frequency in the text so far, decreasing the model's
likelihood to repeat the same line verbatim. See more information
about frequency and presence penalties. (default: :obj:`0.0`)
logit_bias (dict, optional): Modify the likelihood of specified tokens
appearing in the completion. Accepts a json object that maps tokens
(specified by their token ID in the tokenizer) to an associated
bias value from :obj:`-100` to :obj:`100`. Mathematically, the bias
is added to the logits generated by the model prior to sampling.
The exact effect will vary per model, but values between:obj:` -1`
and :obj:`1` should decrease or increase likelihood of selection;
values like :obj:`-100` or :obj:`100` should result in a ban or
exclusive selection of the relevant token. (default: :obj:`{}`)
user (str, optional): A unique identifier representing your end-user,
which can help OpenAI to monitor and detect abuse.
(default: :obj:`""`)
tools (list[FunctionTool], optional): A list of tools the model may
call. Currently, only functions are supported as a tool. Use this
to provide a list of functions the model may generate JSON inputs
for. A max of 128 functions are supported.
tool_choice (Union[dict[str, str], str], optional): Controls which (if
any) tool is called by the model. :obj:`"none"` means the model
will not call any tool and instead generates a message.
:obj:`"auto"` means the model can pick between generating a
message or calling one or more tools. :obj:`"required"` means the
model must call one or more tools. Specifying a particular tool
via {"type": "function", "function": {"name": "my_function"}}
forces the model to call that tool. :obj:`"none"` is the default
when no tools are present. :obj:`"auto"` is the default if tools
are present.
"""
temperature: float = 0.2 # openai default: 1.0
top_p: float = 1.0
n: int = 1
stream: bool = False
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
max_tokens: Union[int, NotGiven] = NOT_GIVEN
presence_penalty: float = 0.0
response_format: Union[dict, NotGiven] = NOT_GIVEN
frequency_penalty: float = 0.0
logit_bias: dict = Field(default_factory=dict)
user: str = ""
tool_choice: Optional[Union[dict[str, str], str]] = None
SAMBA_CLOUD_API_PARAMS = {
param for param in SambaCloudAPIConfig().model_fields.keys()
}

View File

@ -0,0 +1,107 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import Any, Sequence, Union
from pydantic import Field
from camel.configs.base_config import BaseConfig
from camel.types import NOT_GIVEN, NotGiven
class TogetherAIConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
OpenAI API.
Args:
temperature (float, optional): Sampling temperature to use, between
:obj:`0` and :obj:`2`. Higher values make the output more random,
while lower values make it more focused and deterministic.
(default: :obj:`0.2`)
top_p (float, optional): An alternative to sampling with temperature,
called nucleus sampling, where the model considers the results of
the tokens with top_p probability mass. So :obj:`0.1` means only
the tokens comprising the top 10% probability mass are considered.
(default: :obj:`1.0`)
n (int, optional): How many chat completion choices to generate for
each input message. (default: :obj:`1`)
response_format (object, optional): An object specifying the format
that the model must output. Compatible with GPT-4 Turbo and all
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
{"type": "json_object"} enables JSON mode, which guarantees the
message the model generates is valid JSON. Important: when using
JSON mode, you must also instruct the model to produce JSON
yourself via a system or user message. Without this, the model
may generate an unending stream of whitespace until the generation
reaches the token limit, resulting in a long-running and seemingly
"stuck" request. Also note that the message content may be
partially cut off if finish_reason="length", which indicates the
generation exceeded max_tokens or the conversation exceeded the
max context length.
stream (bool, optional): If True, partial message deltas will be sent
as data-only server-sent events as they become available.
(default: :obj:`False`)
stop (str or list, optional): Up to :obj:`4` sequences where the API
will stop generating further tokens. (default: :obj:`None`)
max_tokens (int, optional): The maximum number of tokens to generate
in the chat completion. The total length of input tokens and
generated tokens is limited by the model's context length.
(default: :obj:`None`)
presence_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on whether
they appear in the text so far, increasing the model's likelihood
to talk about new topics. See more information about frequency and
presence penalties. (default: :obj:`0.0`)
frequency_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on their
existing frequency in the text so far, decreasing the model's
likelihood to repeat the same line verbatim. See more information
about frequency and presence penalties. (default: :obj:`0.0`)
logit_bias (dict, optional): Modify the likelihood of specified tokens
appearing in the completion. Accepts a json object that maps tokens
(specified by their token ID in the tokenizer) to an associated
bias value from :obj:`-100` to :obj:`100`. Mathematically, the bias
is added to the logits generated by the model prior to sampling.
The exact effect will vary per model, but values between:obj:` -1`
and :obj:`1` should decrease or increase likelihood of selection;
values like :obj:`-100` or :obj:`100` should result in a ban or
exclusive selection of the relevant token. (default: :obj:`{}`)
user (str, optional): A unique identifier representing your end-user,
which can help OpenAI to monitor and detect abuse.
(default: :obj:`""`)
"""
temperature: float = 0.2 # openai default: 1.0
top_p: float = 1.0
n: int = 1
stream: bool = False
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
max_tokens: Union[int, NotGiven] = NOT_GIVEN
presence_penalty: float = 0.0
response_format: Union[dict, NotGiven] = NOT_GIVEN
frequency_penalty: float = 0.0
logit_bias: dict = Field(default_factory=dict)
user: str = ""
def as_dict(self) -> dict[str, Any]:
config_dict = super().as_dict()
if "tools" in config_dict:
del config_dict["tools"] # Currently does not support tool calling
return config_dict
TOGETHERAI_API_PARAMS = {
param for param in TogetherAIConfig.model_fields.keys()
}

View File

@ -0,0 +1,111 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import Optional, Sequence, Union
from pydantic import Field
from camel.configs.base_config import BaseConfig
from camel.types import NOT_GIVEN, NotGiven
# flake8: noqa: E501
class VLLMConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
OpenAI API.
Reference: https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html
Args:
temperature (float, optional): Sampling temperature to use, between
:obj:`0` and :obj:`2`. Higher values make the output more random,
while lower values make it more focused and deterministic.
(default: :obj:`0.2`)
top_p (float, optional): An alternative to sampling with temperature,
called nucleus sampling, where the model considers the results of
the tokens with top_p probability mass. So :obj:`0.1` means only
the tokens comprising the top 10% probability mass are considered.
(default: :obj:`1.0`)
n (int, optional): How many chat completion choices to generate for
each input message. (default: :obj:`1`)
response_format (object, optional): An object specifying the format
that the model must output. Compatible with GPT-4 Turbo and all
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
{"type": "json_object"} enables JSON mode, which guarantees the
message the model generates is valid JSON. Important: when using
JSON mode, you must also instruct the model to produce JSON
yourself via a system or user message. Without this, the model
may generate an unending stream of whitespace until the generation
reaches the token limit, resulting in a long-running and seemingly
"stuck" request. Also note that the message content may be
partially cut off if finish_reason="length", which indicates the
generation exceeded max_tokens or the conversation exceeded the
max context length.
stream (bool, optional): If True, partial message deltas will be sent
as data-only server-sent events as they become available.
(default: :obj:`False`)
stop (str or list, optional): Up to :obj:`4` sequences where the API
will stop generating further tokens. (default: :obj:`None`)
max_tokens (int, optional): The maximum number of tokens to generate
in the chat completion. The total length of input tokens and
generated tokens is limited by the model's context length.
(default: :obj:`None`)
presence_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on whether
they appear in the text so far, increasing the model's likelihood
to talk about new topics. See more information about frequency and
presence penalties. (default: :obj:`0.0`)
frequency_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on their
existing frequency in the text so far, decreasing the model's
likelihood to repeat the same line verbatim. See more information
about frequency and presence penalties. (default: :obj:`0.0`)
logit_bias (dict, optional): Modify the likelihood of specified tokens
appearing in the completion. Accepts a json object that maps tokens
(specified by their token ID in the tokenizer) to an associated
bias value from :obj:`-100` to :obj:`100`. Mathematically, the bias
is added to the logits generated by the model prior to sampling.
The exact effect will vary per model, but values between:obj:` -1`
and :obj:`1` should decrease or increase likelihood of selection;
values like :obj:`-100` or :obj:`100` should result in a ban or
exclusive selection of the relevant token. (default: :obj:`{}`)
user (str, optional): A unique identifier representing your end-user,
which can help OpenAI to monitor and detect abuse.
(default: :obj:`""`)
logprobs: Whether to return log probabilities of the output tokens or
not. If true, returns the log probabilities of each output token
returned in the `logits` of `message`. (default: :obj:`None`)
top_logprobs: An integer between 0 and 20 specifying the number of
most likely tokens to return at each token position, each with an
associated log probability. `logprobs` must be set to `true` if
this parameter is used. (default: :obj:`None`)
"""
temperature: float = 0.2 # openai default: 1.0
top_p: float = 1.0
n: int = 1
stream: bool = False
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
max_tokens: Union[int, NotGiven] = NOT_GIVEN
presence_penalty: float = 0.0
response_format: Union[dict, NotGiven] = NOT_GIVEN
frequency_penalty: float = 0.0
logit_bias: dict = Field(default_factory=dict)
user: str = ""
logprobs: Optional[bool] = None
top_logprobs: Optional[int] = None
VLLM_API_PARAMS = {param for param in VLLMConfig.model_fields.keys()}

View File

@ -0,0 +1,58 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import Optional, Union
from camel.configs.base_config import BaseConfig
from camel.types import NOT_GIVEN, NotGiven
class YiConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
Yi API. You can refer to the following link for more details:
https://platform.lingyiwanwu.com/docs/api-reference
Args:
tool_choice (Union[dict[str, str], str], optional): Controls which (if
any) tool is called by the model. :obj:`"none"` means the model
will not call any tool and instead generates a message.
:obj:`"auto"` means the model can pick between generating a
message or calling one or more tools. :obj:`"required"` or
specifying a particular tool via
{"type": "function", "function": {"name": "some_function"}}
can be used to guide the model to use tools more strongly.
(default: :obj:`None`)
max_tokens (int, optional): Specifies the maximum number of tokens
the model can generate. This sets an upper limit, but does not
guarantee that this number will always be reached.
(default: :obj:`5000`)
top_p (float, optional): Controls the randomness of the generated
results. Lower values lead to less randomness, while higher
values increase randomness. (default: :obj:`0.9`)
temperature (float, optional): Controls the diversity and focus of
the generated results. Lower values make the output more focused,
while higher values make it more diverse. (default: :obj:`0.3`)
stream (bool, optional): If True, enables streaming output.
(default: :obj:`False`)
"""
tool_choice: Optional[Union[dict[str, str], str]] = None
max_tokens: Union[int, NotGiven] = NOT_GIVEN
top_p: float = 0.9
temperature: float = 0.3
stream: bool = False
YI_API_PARAMS = {param for param in YiConfig.model_fields.keys()}

View File

@ -0,0 +1,71 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import Optional, Sequence, Union
from camel.configs.base_config import BaseConfig
from camel.types import NOT_GIVEN, NotGiven
class ZhipuAIConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using OpenAI
compatibility
Reference: https://open.bigmodel.cn/dev/api#glm-4v
Args:
temperature (float, optional): Sampling temperature to use, between
:obj:`0` and :obj:`2`. Higher values make the output more random,
while lower values make it more focused and deterministic.
(default: :obj:`0.2`)
top_p (float, optional): An alternative to sampling with temperature,
called nucleus sampling, where the model considers the results of
the tokens with top_p probability mass. So :obj:`0.1` means only
the tokens comprising the top 10% probability mass are considered.
(default: :obj:`0.6`)
stream (bool, optional): If True, partial message deltas will be sent
as data-only server-sent events as they become available.
(default: :obj:`False`)
stop (str or list, optional): Up to :obj:`4` sequences where the API
will stop generating further tokens. (default: :obj:`None`)
max_tokens (int, optional): The maximum number of tokens to generate
in the chat completion. The total length of input tokens and
generated tokens is limited by the model's context length.
(default: :obj:`None`)
tools (list[FunctionTool], optional): A list of tools the model may
call. Currently, only functions are supported as a tool. Use this
to provide a list of functions the model may generate JSON inputs
for. A max of 128 functions are supported.
tool_choice (Union[dict[str, str], str], optional): Controls which (if
any) tool is called by the model. :obj:`"none"` means the model
will not call any tool and instead generates a message.
:obj:`"auto"` means the model can pick between generating a
message or calling one or more tools. :obj:`"required"` means the
model must call one or more tools. Specifying a particular tool
via {"type": "function", "function": {"name": "my_function"}}
forces the model to call that tool. :obj:`"none"` is the default
when no tools are present. :obj:`"auto"` is the default if tools
are present.
"""
temperature: float = 0.2
top_p: float = 0.6
stream: bool = False
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
max_tokens: Union[int, NotGiven] = NOT_GIVEN
tool_choice: Optional[Union[dict[str, str], str]] = None
ZHIPUAI_API_PARAMS = {param for param in ZhipuAIConfig.model_fields.keys()}

View File

@ -0,0 +1,23 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from .base import BaseDatasetManager
from .huggingface import HuggingFaceDatasetManager
from .models import Record
__all__ = [
"BaseDatasetManager",
"Record",
"HuggingFaceDatasetManager",
]

View File

@ -0,0 +1,136 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from abc import ABC, abstractmethod
from typing import Any, List
from camel.datahubs.models import Record
class BaseDatasetManager(ABC):
r"""Abstract base class for dataset managers."""
@abstractmethod
def create_dataset(self, name: str, **kwargs: Any) -> str:
r"""Creates a new dataset.
Args:
name (str): The name of the dataset.
kwargs (Any): Additional keyword arguments.
Returns:
str: The URL of the created dataset.
"""
pass
@abstractmethod
def list_datasets(
self, username: str, limit: int = 100, **kwargs: Any
) -> List[str]:
r"""Lists all datasets for the current user.
Args:
username (str): The username of the user whose datasets to list.
limit (int): The maximum number of datasets to list.
(default::obj:`100`)
kwargs (Any): Additional keyword arguments.
Returns:
List[str]: A list of dataset ids.
"""
pass
@abstractmethod
def delete_dataset(self, dataset_name: str, **kwargs: Any) -> None:
r"""Deletes a dataset.
Args:
dataset_name (str): The name of the dataset to delete.
kwargs (Any): Additional keyword arguments.
"""
pass
@abstractmethod
def add_records(
self,
dataset_name: str,
records: List[Record],
filepath: str = "records/records.json",
**kwargs: Any,
) -> None:
r"""Adds records to a dataset.
Args:
dataset_name (str): The name of the dataset.
records (List[Record]): A list of records to add to the dataset.
filepath (str): The path to the file containing the records.
(default::obj:`"records/records.json"`)
kwargs (Any): Additional keyword arguments.
"""
pass
@abstractmethod
def update_records(
self,
dataset_name: str,
records: List[Record],
filepath: str = "records/records.json",
**kwargs: Any,
) -> None:
r"""Updates records in a dataset.
Args:
dataset_name (str): The name of the dataset.
records (List[Record]): A list of records to update in the dataset.
filepath (str): The path to the file containing the records.
(default::obj:`"records/records.json"`)
kwargs (Any): Additional keyword arguments.
"""
pass
@abstractmethod
def list_records(
self,
dataset_name: str,
filepath: str = "records/records.json",
**kwargs: Any,
) -> List[Record]:
r"""Lists records in a dataset.
Args:
dataset_name (str): The name of the dataset.
filepath (str): The path to the file containing the records.
(default::obj:`"records/records.json"`)
kwargs (Any): Additional keyword arguments.
"""
pass
# New method for record deletion
@abstractmethod
def delete_record(
self,
dataset_name: str,
record_id: str,
filepath: str = "records/records.json",
**kwargs: Any,
) -> None:
r"""Deletes a record from the dataset.
Args:
dataset_name (str): The name of the dataset.
record_id (str): The ID of the record to delete.
filepath (str): The path to the file containing the records.
(default::obj:`"records/records.json"`)
kwargs (Any): Additional keyword arguments.
"""
pass

View File

@ -0,0 +1,433 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import json
import os
import tempfile
from typing import Any, List, Optional
from camel.datahubs.base import BaseDatasetManager
from camel.datahubs.models import Record
from camel.logger import get_logger
from camel.types import HuggingFaceRepoType
from camel.utils import api_keys_required, dependencies_required
logger = get_logger(__name__)
class HuggingFaceDatasetManager(BaseDatasetManager):
r"""A dataset manager for Hugging Face datasets. This class provides
methods to create, add, update, delete, and list records in a dataset on
the Hugging Face Hub.
Args:
token (str): The Hugging Face API token. If not provided, the token
will be read from the environment variable `HUGGING_FACE_TOKEN`.
"""
@api_keys_required("HUGGING_FACE_TOKEN")
@dependencies_required('huggingface_hub')
def __init__(self, token: Optional[str] = None):
from huggingface_hub import HfApi
self._api_key = token or os.getenv("HUGGING_FACE_TOKEN")
self.api = HfApi(token=self._api_key)
def create_dataset_card(
self,
dataset_name: str,
description: str,
license: Optional[str] = None,
version: Optional[str] = None,
tags: Optional[List[str]] = None,
authors: Optional[List[str]] = None,
size_category: Optional[List[str]] = None,
language: Optional[List[str]] = None,
task_categories: Optional[List[str]] = None,
content: Optional[str] = None,
) -> None:
r"""Creates and uploads a dataset card to the Hugging Face Hub in YAML
format.
Args:
dataset_name (str): The name of the dataset.
description (str): A description of the dataset.
license (str): The license of the dataset. (default: :obj:`None`)
version (str): The version of the dataset. (default: :obj:`None`)
tags (list): A list of tags for the dataset.(default: :obj:`None`)
authors (list): A list of authors of the dataset. (default:
:obj:`None`)
size_category (list): A size category for the dataset. (default:
:obj:`None`)
language (list): A list of languages the dataset is in. (default:
:obj:`None`)
task_categories (list): A list of task categories. (default:
:obj:`None`)
content (str): Custom markdown content that the user wants to add
to the dataset card. (default: :obj:`None`)
"""
import yaml
metadata = {
"license": license,
"authors": authors,
"task_categories": task_categories,
"language": language,
"tags": tags,
"pretty_name": dataset_name,
"size_categories": size_category,
"version": version,
"description": description,
}
# Remove keys with None values
metadata = {k: v for k, v in metadata.items() if v}
card_content = (
"---\n"
+ yaml.dump(metadata, default_flow_style=False, allow_unicode=True)
+ "\n---"
)
if content:
card_content += f"\n\n# Additional Information\n{content}\n"
self._upload_file(
file_content=card_content,
dataset_name=dataset_name,
filepath="README.md",
file_type="md",
)
def create_dataset(
self, name: str, private: bool = False, **kwargs: Any
) -> str:
r"""Creates a new dataset on the Hugging Face Hub.
Args:
name (str): The name of the dataset.
private (bool): Whether the dataset should be private. defaults to
False.
kwargs (Any): Additional keyword arguments.
Returns:
str: The URL of the created dataset.
"""
from huggingface_hub.errors import RepositoryNotFoundError
try:
self.api.repo_info(
repo_id=name,
repo_type=HuggingFaceRepoType.DATASET.value,
**kwargs,
)
except RepositoryNotFoundError:
self.api.create_repo(
repo_id=name,
repo_type=HuggingFaceRepoType.DATASET.value,
private=private,
)
return f"https://huggingface.co/datasets/{name}"
def list_datasets(
self, username: str, limit: int = 100, **kwargs: Any
) -> List[str]:
r"""Lists all datasets for the current user.
Args:
username (str): The username of the user whose datasets to list.
limit (int): The maximum number of datasets to list.
(default: :obj:`100`)
kwargs (Any): Additional keyword arguments.
Returns:
List[str]: A list of dataset ids.
"""
try:
return [
dataset.id
for dataset in self.api.list_datasets(
author=username, limit=limit, **kwargs
)
]
except Exception as e:
logger.error(f"Error listing datasets: {e}")
return []
def delete_dataset(self, dataset_name: str, **kwargs: Any) -> None:
r"""Deletes a dataset from the Hugging Face Hub.
Args:
dataset_name (str): The name of the dataset to delete.
kwargs (Any): Additional keyword arguments.
"""
try:
self.api.delete_repo(
repo_id=dataset_name,
repo_type=HuggingFaceRepoType.DATASET.value,
**kwargs,
)
logger.info(f"Dataset '{dataset_name}' deleted successfully.")
except Exception as e:
logger.error(f"Error deleting dataset '{dataset_name}': {e}")
raise
def add_records(
self,
dataset_name: str,
records: List[Record],
filepath: str = "records/records.json",
**kwargs: Any,
) -> None:
r"""Adds records to a dataset on the Hugging Face Hub.
Args:
dataset_name (str): The name of the dataset.
records (List[Record]): A list of records to add to the dataset.
filepath (str): The path to the file containing the records.
kwargs (Any): Additional keyword arguments.
Raises:
ValueError: If the dataset already has a records file.
"""
existing_records = self._download_records(
dataset_name=dataset_name, filepath=filepath, **kwargs
)
if existing_records:
raise ValueError(
f"Dataset '{filepath}' already exists. "
f"Use `update_records` to modify."
)
self._upload_records(
records=records,
dataset_name=dataset_name,
filepath=filepath,
**kwargs,
)
def update_records(
self,
dataset_name: str,
records: List[Record],
filepath: str = "records/records.json",
**kwargs: Any,
) -> None:
r"""Updates records in a dataset on the Hugging Face Hub.
Args:
dataset_name (str): The name of the dataset.
records (List[Record]): A list of records to update in the dataset.
filepath (str): The path to the file containing the records.
kwargs (Any): Additional keyword arguments.
Raises:
ValueError: If the dataset does not have an existing file to update
records in.
"""
existing_records = self._download_records(
dataset_name=dataset_name, filepath=filepath, **kwargs
)
if not existing_records:
logger.warning(
f"Dataset '{dataset_name}' does not have existing "
"records. Adding new records."
)
self._upload_records(
records=records,
dataset_name=dataset_name,
filepath=filepath,
**kwargs,
)
return
old_dict = {record.id: record for record in existing_records}
new_dict = {record.id: record for record in records}
merged_dict = old_dict.copy()
merged_dict.update(new_dict)
self._upload_records(
records=list(merged_dict.values()),
dataset_name=dataset_name,
filepath=filepath,
**kwargs,
)
def delete_record(
self,
dataset_name: str,
record_id: str,
filepath: str = "records/records.json",
**kwargs: Any,
) -> None:
r"""Deletes a record from the dataset.
Args:
dataset_name (str): The name of the dataset.
record_id (str): The ID of the record to delete.
filepath (str): The path to the file containing the records.
kwargs (Any): Additional keyword arguments.
Raises:
ValueError: If the dataset does not have an existing file to delete
records from.
"""
existing_records = self._download_records(
dataset_name=dataset_name, filepath=filepath, **kwargs
)
if not existing_records:
raise ValueError(
f"Dataset '{dataset_name}' does not have an existing file to "
f"delete records from."
)
filtered_records = [
record for record in existing_records if record.id != record_id
]
self._upload_records(
records=filtered_records,
dataset_name=dataset_name,
filepath=filepath,
**kwargs,
)
def list_records(
self,
dataset_name: str,
filepath: str = "records/records.json",
**kwargs: Any,
) -> List[Record]:
r"""Lists all records in a dataset.
Args:
dataset_name (str): The name of the dataset.
filepath (str): The path to the file containing the records.
kwargs (Any): Additional keyword arguments.
Returns:
List[Record]: A list of records in the dataset.
"""
return self._download_records(
dataset_name=dataset_name, filepath=filepath, **kwargs
)
def _download_records(
self, dataset_name: str, filepath: str, **kwargs: Any
) -> List[Record]:
from huggingface_hub import hf_hub_download
from huggingface_hub.errors import EntryNotFoundError
try:
downloaded_file_path = hf_hub_download(
repo_id=dataset_name,
filename=filepath,
repo_type=HuggingFaceRepoType.DATASET.value,
token=self._api_key,
**kwargs,
)
with open(downloaded_file_path, "r") as f:
records_data = json.load(f)
return [Record(**record) for record in records_data]
except EntryNotFoundError:
logger.info(f"No records found for dataset '{dataset_name}'.")
return []
except Exception as e:
logger.error(f"Error downloading or processing records: {e}")
raise e
def _upload_records(
self,
records: List[Record],
dataset_name: str,
filepath: str,
**kwargs: Any,
):
with tempfile.NamedTemporaryFile(
delete=False, mode="w", newline="", encoding="utf-8"
) as f:
json.dump([record.model_dump() for record in records], f)
temp_file_path = f.name
try:
self.api.upload_file(
path_or_fileobj=temp_file_path,
path_in_repo=filepath,
repo_id=dataset_name,
repo_type=HuggingFaceRepoType.DATASET.value,
**kwargs,
)
except Exception as e:
logger.error(f"Error uploading records file: {e}")
raise
finally:
if os.path.exists(temp_file_path):
os.remove(temp_file_path)
def _upload_file(
self,
file_content: str,
dataset_name: str,
filepath: str,
file_type: str = "json",
**kwargs: Any,
):
with tempfile.NamedTemporaryFile(
mode="w", delete=False, suffix=f".{file_type}"
) as f:
if file_type == "json":
if isinstance(file_content, str):
try:
json_content = json.loads(file_content)
except json.JSONDecodeError:
raise ValueError(
"Invalid JSON string provided for file_content."
)
else:
try:
json.dumps(file_content)
json_content = file_content
except (TypeError, ValueError):
raise ValueError(
"file_content is not JSON serializable."
)
json.dump(json_content, f)
elif file_type == "md" or file_type == "txt":
f.write(file_content)
else:
raise ValueError(f"Unsupported file type: {file_type}")
temp_file_path = f.name
try:
self.api.upload_file(
path_or_fileobj=temp_file_path,
path_in_repo=filepath,
repo_id=dataset_name,
repo_type=HuggingFaceRepoType.DATASET.value,
**kwargs,
)
logger.info(f"File uploaded successfully: {filepath}")
except Exception as e:
logger.error(f"Error uploading file: {e}")
raise
if os.path.exists(temp_file_path):
os.remove(temp_file_path)

View File

@ -0,0 +1,22 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from typing import Any, Dict, Optional
from pydantic import BaseModel
class Record(BaseModel):
id: Optional[str] = None
metadata: Optional[Dict[str, Any]] = None
content: Dict[str, Any]

View File

@ -0,0 +1,28 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from .base import BaseEmbedding
from .mistral_embedding import MistralEmbedding
from .openai_compatible_embedding import OpenAICompatibleEmbedding
from .openai_embedding import OpenAIEmbedding
from .sentence_transformers_embeddings import SentenceTransformerEncoder
from .vlm_embedding import VisionLanguageEmbedding
__all__ = [
"BaseEmbedding",
"OpenAIEmbedding",
"SentenceTransformerEncoder",
"VisionLanguageEmbedding",
"MistralEmbedding",
"OpenAICompatibleEmbedding",
]

View File

@ -0,0 +1,67 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Generic, TypeVar
T = TypeVar('T')
class BaseEmbedding(ABC, Generic[T]):
r"""Abstract base class for text embedding functionalities."""
@abstractmethod
def embed_list(
self,
objs: list[T],
**kwargs: Any,
) -> list[list[float]]:
r"""Generates embeddings for the given texts.
Args:
objs (list[T]): The objects for which to generate the embeddings.
**kwargs (Any): Extra kwargs passed to the embedding API.
Returns:
list[list[float]]: A list that represents the
generated embedding as a list of floating-point numbers.
"""
pass
def embed(
self,
obj: T,
**kwargs: Any,
) -> list[float]:
r"""Generates an embedding for the given text.
Args:
obj (T): The object for which to generate the embedding.
**kwargs (Any): Extra kwargs passed to the embedding API.
Returns:
list[float]: A list of floating-point numbers representing the
generated embedding.
"""
return self.embed_list([obj], **kwargs)[0]
@abstractmethod
def get_output_dim(self) -> int:
r"""Returns the output dimension of the embeddings.
Returns:
int: The dimensionality of the embedding for the current model.
"""
pass

View File

@ -0,0 +1,89 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
import os
from typing import Any
from camel.embeddings.base import BaseEmbedding
from camel.types import EmbeddingModelType
from camel.utils import api_keys_required
class MistralEmbedding(BaseEmbedding[str]):
r"""Provides text embedding functionalities using Mistral's models.
Args:
model_type (EmbeddingModelType, optional): The model type to be
used for text embeddings.
(default: :obj:`MISTRAL_EMBED`)
api_key (str, optional): The API key for authenticating with the
Mistral service. (default: :obj:`None`)
dimensions (int, optional): The text embedding output dimensions.
(default: :obj:`None`)
Raises:
RuntimeError: If an unsupported model type is specified.
"""
def __init__(
self,
model_type: EmbeddingModelType = (EmbeddingModelType.MISTRAL_EMBED),
api_key: str | None = None,
dimensions: int | None = None,
) -> None:
from mistralai import Mistral
if not model_type.is_mistral:
raise ValueError("Invalid Mistral embedding model type.")
self.model_type = model_type
if dimensions is None:
self.output_dim = model_type.output_dim
else:
assert isinstance(dimensions, int)
self.output_dim = dimensions
self._api_key = api_key or os.environ.get("MISTRAL_API_KEY")
self._client = Mistral(api_key=self._api_key)
@api_keys_required("MISTRAL_API_KEY")
def embed_list(
self,
objs: list[str],
**kwargs: Any,
) -> list[list[float]]:
r"""Generates embeddings for the given texts.
Args:
objs (list[str]): The texts for which to generate the embeddings.
**kwargs (Any): Extra kwargs passed to the embedding API.
Returns:
list[list[float]]: A list that represents the generated embedding
as a list of floating-point numbers.
"""
# TODO: count tokens
response = self._client.embeddings.create(
inputs=objs,
model=self.model_type.value,
**kwargs,
)
return [data.embedding for data in response.data] # type: ignore[misc,union-attr]
def get_output_dim(self) -> int:
r"""Returns the output dimension of the embeddings.
Returns:
int: The dimensionality of the embedding for the current model.
"""
return self.output_dim

View File

@ -0,0 +1,91 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
import os
from typing import Any, Optional
from openai import OpenAI
from camel.embeddings.base import BaseEmbedding
from camel.utils import api_keys_required
class OpenAICompatibleEmbedding(BaseEmbedding[str]):
r"""Provides text embedding functionalities supporting OpenAI
compatibility.
Args:
model_type (str): The model type to be used for text embeddings.
api_key (str): The API key for authenticating with the model service.
url (str): The url to the model service.
"""
def __init__(
self,
model_type: str,
api_key: Optional[str] = None,
url: Optional[str] = None,
) -> None:
self.model_type = model_type
self.output_dim: Optional[int] = None
self._api_key = api_key or os.environ.get(
"OPENAI_COMPATIBILIY_API_KEY"
)
self._url = url or os.environ.get("OPENAI_COMPATIBILIY_API_BASE_URL")
self._client = OpenAI(
timeout=60,
max_retries=3,
api_key=self._api_key,
base_url=self._url,
)
@api_keys_required("OPENAI_COMPATIBILIY_API_KEY")
def embed_list(
self,
objs: list[str],
**kwargs: Any,
) -> list[list[float]]:
r"""Generates embeddings for the given texts.
Args:
objs (list[str]): The texts for which to generate the embeddings.
**kwargs (Any): Extra kwargs passed to the embedding API.
Returns:
list[list[float]]: A list that represents the generated embedding
as a list of floating-point numbers.
"""
response = self._client.embeddings.create(
input=objs,
model=self.model_type,
**kwargs,
)
self.output_dim = len(response.data[0].embedding)
return [data.embedding for data in response.data]
def get_output_dim(self) -> int:
r"""Returns the output dimension of the embeddings.
Returns:
int: The dimensionality of the embedding for the current model.
"""
if self.output_dim is None:
raise ValueError(
"Output dimension is not yet determined. Call "
"'embed_list' first."
)
return self.output_dim

View File

@ -0,0 +1,99 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
import os
from typing import Any
from openai import OpenAI
from camel.embeddings.base import BaseEmbedding
from camel.types import NOT_GIVEN, EmbeddingModelType, NotGiven
from camel.utils import api_keys_required
class OpenAIEmbedding(BaseEmbedding[str]):
r"""Provides text embedding functionalities using OpenAI's models.
Args:
model_type (EmbeddingModelType, optional): The model type to be
used for text embeddings.
(default: :obj:`TEXT_EMBEDDING_3_SMALL`)
api_key (str, optional): The API key for authenticating with the
OpenAI service. (default: :obj:`None`)
dimensions (int, optional): The text embedding output dimensions.
(default: :obj:`NOT_GIVEN`)
Raises:
RuntimeError: If an unsupported model type is specified.
"""
def __init__(
self,
model_type: EmbeddingModelType = (
EmbeddingModelType.TEXT_EMBEDDING_3_SMALL
),
api_key: str | None = None,
dimensions: int | NotGiven = NOT_GIVEN,
) -> None:
if not model_type.is_openai:
raise ValueError("Invalid OpenAI embedding model type.")
self.model_type = model_type
if dimensions == NOT_GIVEN:
self.output_dim = model_type.output_dim
else:
assert isinstance(dimensions, int)
self.output_dim = dimensions
self._api_key = api_key or os.environ.get("OPENAI_API_KEY")
self.client = OpenAI(timeout=60, max_retries=3, api_key=self._api_key)
@api_keys_required("OPENAI_API_KEY")
def embed_list(
self,
objs: list[str],
**kwargs: Any,
) -> list[list[float]]:
r"""Generates embeddings for the given texts.
Args:
objs (list[str]): The texts for which to generate the embeddings.
**kwargs (Any): Extra kwargs passed to the embedding API.
Returns:
list[list[float]]: A list that represents the generated embedding
as a list of floating-point numbers.
"""
# TODO: count tokens
if self.model_type == EmbeddingModelType.TEXT_EMBEDDING_ADA_2:
response = self.client.embeddings.create(
input=objs,
model=self.model_type.value,
**kwargs,
)
else:
response = self.client.embeddings.create(
input=objs,
model=self.model_type.value,
dimensions=self.output_dim,
**kwargs,
)
return [data.embedding for data in response.data]
def get_output_dim(self) -> int:
r"""Returns the output dimension of the embeddings.
Returns:
int: The dimensionality of the embedding for the current model.
"""
return self.output_dim

View File

@ -0,0 +1,80 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import Any
from numpy import ndarray
from camel.embeddings.base import BaseEmbedding
class SentenceTransformerEncoder(BaseEmbedding[str]):
r"""This class provides functionalities to generate text
embeddings using `Sentence Transformers`.
References:
https://www.sbert.net/
"""
def __init__(
self,
model_name: str = "intfloat/e5-large-v2",
**kwargs,
):
r"""Initializes the: obj: `SentenceTransformerEmbedding` class
with the specified transformer model.
Args:
model_name (str, optional): The name of the model to use.
(default: :obj:`intfloat/e5-large-v2`)
**kwargs (optional): Additional arguments of
:class:`SentenceTransformer`, such as :obj:`prompts` etc.
"""
from sentence_transformers import SentenceTransformer
self.model = SentenceTransformer(model_name, **kwargs)
def embed_list(
self,
objs: list[str],
**kwargs: Any,
) -> list[list[float]]:
r"""Generates embeddings for the given texts using the model.
Args:
objs (list[str]): The texts for which to generate the
embeddings.
Returns:
list[list[float]]: A list that represents the generated embedding
as a list of floating-point numbers.
"""
if not objs:
raise ValueError("Input text list is empty")
embeddings = self.model.encode(
objs, normalize_embeddings=True, **kwargs
)
assert isinstance(embeddings, ndarray)
return embeddings.tolist()
def get_output_dim(self) -> int:
r"""Returns the output dimension of the embeddings.
Returns:
int: The dimensionality of the embeddings.
"""
output_dim = self.model.get_sentence_embedding_dimension()
assert isinstance(output_dim, int)
return output_dim

View File

@ -0,0 +1,149 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from typing import Any, List, Optional, Union
from PIL import Image
from camel.embeddings import BaseEmbedding
from camel.logger import get_logger
logger = get_logger(__name__)
class VisionLanguageEmbedding(BaseEmbedding[Union[str, Image.Image]]):
r"""Provides image embedding functionalities using multimodal model.
Args:
model_name : The model type to be used for generating embeddings.
And the default value is: obj:`openai/clip-vit-base-patch32`.
Raises:
RuntimeError: If an unsupported model type is specified.
"""
def __init__(
self, model_name: str = "openai/clip-vit-base-patch32"
) -> None:
r"""Initializes the: obj: `VisionLanguageEmbedding` class with a
specified model and return the dimension of embeddings.
Args:
model_name (str, optional): The version name of the model to use.
(default: :obj:`openai/clip-vit-base-patch32`)
"""
from transformers import AutoModel, AutoProcessor
try:
self.model = AutoModel.from_pretrained(model_name)
self.processor = AutoProcessor.from_pretrained(model_name)
except Exception as e:
raise RuntimeError(f"Failed to load model '{model_name}': {e}")
self.valid_processor_kwargs = []
self.valid_model_kwargs = []
try:
self.valid_processor_kwargs = (
self.processor.image_processor._valid_processor_keys
)
self.valid_model_kwargs = [
"pixel_values",
"return_dict",
"interpolate_pos_encoding",
]
except Exception:
logger.warning("not typically processor and model structure")
pass
self.dim: Optional[int] = None
def embed_list(
self, objs: List[Union[Image.Image, str]], **kwargs: Any
) -> List[List[float]]:
"""Generates embeddings for the given images or texts.
Args:
objs (List[Image.Image|str]): The list of images or texts for
which to generate the embeddings.
image_processor_kwargs: Extra kwargs passed to the image processor.
tokenizer_kwargs: Extra kwargs passed to the text tokenizer
(processor).
model_kwargs: Extra kwargs passed to the main model.
Returns:
List[List[float]]: A list that represents the generated embedding
as a list of floating-point numbers.
Raises:
ValueError: If the input type is not `Image.Image` or `str`.
"""
if not objs:
raise ValueError("Input objs list is empty.")
image_processor_kwargs: Optional[dict] = kwargs.get(
'image_processor_kwargs', {}
)
tokenizer_kwargs: Optional[dict] = kwargs.get('tokenizer_kwargs', {})
model_kwargs: Optional[dict] = kwargs.get('model_kwargs', {})
result_list = []
for obj in objs:
if isinstance(obj, Image.Image):
image_input = self.processor(
images=obj,
return_tensors="pt",
padding=True,
**image_processor_kwargs,
)
image_feature = (
self.model.get_image_features(
**image_input, **model_kwargs
)
.squeeze(dim=0)
.tolist()
)
result_list.append(image_feature)
elif isinstance(obj, str):
text_input = self.processor(
text=obj,
return_tensors="pt",
padding=True,
**tokenizer_kwargs,
)
text_feature = (
self.model.get_text_features(**text_input, **model_kwargs)
.squeeze(dim=0)
.tolist()
)
result_list.append(text_feature)
else:
raise ValueError("Input type is not image nor text.")
self.dim = len(result_list[0])
if any(len(result) != self.dim for result in result_list):
raise ValueError("Dimensionality is not consistent.")
return result_list
def get_output_dim(self) -> int:
r"""Returns the output dimension of the embeddings.
Returns:
int: The dimensionality of the embedding for the current model.
"""
if self.dim is None:
text = 'dimension'
inputs = self.processor(text=[text], return_tensors="pt")
self.dim = self.model.get_text_features(**inputs).shape[1]
return self.dim

View File

@ -0,0 +1,375 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from typing import Dict, Generator, List, Optional, Set, Tuple
from camel.messages import BaseMessage
from camel.prompts import PromptTemplateGenerator, TextPrompt
from camel.types import RoleType, TaskType
class SystemMessageGenerator:
r"""System message generator for agents.
Args:
task_type (TaskType, optional): The task type.
(default: :obj:`TaskType.AI_SOCIETY`)
sys_prompts (Optional[Dict[RoleType, str]], optional): The prompts of
the system messages for each role type. (default: :obj:`None`)
sys_msg_meta_dict_keys (Optional[Set[str]], optional): The set of keys
of the meta dictionary used to fill the prompts.
(default: :obj:`None`)
"""
def __init__(
self,
task_type: TaskType = TaskType.AI_SOCIETY,
sys_prompts: Optional[Dict[RoleType, str]] = None,
sys_msg_meta_dict_keys: Optional[Set[str]] = None,
) -> None:
self.sys_prompts: Dict[RoleType, str]
if sys_prompts is not None:
self.sys_prompts = sys_prompts
self.sys_msg_meta_dict_keys = sys_msg_meta_dict_keys or set()
else:
assistant_prompt_template = (
PromptTemplateGenerator().get_system_prompt(
task_type,
RoleType.ASSISTANT,
)
)
user_prompt_template = PromptTemplateGenerator().get_system_prompt(
task_type,
RoleType.USER,
)
critic_prompt_template = (
PromptTemplateGenerator().get_system_prompt(
task_type,
RoleType.CRITIC,
)
)
embodiment_prompt_template = (
PromptTemplateGenerator().get_system_prompt(
task_type,
RoleType.EMBODIMENT,
)
)
self.sys_prompts = dict()
self.sys_prompts[RoleType.ASSISTANT] = assistant_prompt_template
self.sys_prompts[RoleType.USER] = user_prompt_template
self.sys_prompts[RoleType.CRITIC] = critic_prompt_template
self.sys_prompts[RoleType.EMBODIMENT] = embodiment_prompt_template
self.sys_msg_meta_dict_keys = (
assistant_prompt_template.key_words
| user_prompt_template.key_words
| critic_prompt_template.key_words
| embodiment_prompt_template.key_words
)
if RoleType.DEFAULT not in self.sys_prompts:
self.sys_prompts[RoleType.DEFAULT] = "You are a helpful assistant."
def validate_meta_dict_keys(self, meta_dict: Dict[str, str]) -> None:
r"""Validates the keys of the meta_dict.
Args:
meta_dict (Dict[str, str]): The dictionary to validate.
"""
if not set(meta_dict.keys()).issubset(self.sys_msg_meta_dict_keys):
raise ValueError(
"The keys of the meta_dict should be in "
f"{self.sys_msg_meta_dict_keys}. "
f"Got {set(meta_dict.keys())} instead."
)
def from_dict(
self,
meta_dict: Dict[str, str],
role_tuple: Tuple[str, RoleType] = ("", RoleType.DEFAULT),
) -> BaseMessage:
r"""Generates a system message from a dictionary.
Args:
meta_dict (Dict[str, str]): The dictionary containing the
information to generate the system message.
role_tuple (Tuple[str, RoleType], optional): The tuple containing
the role name and role type. (default: ("", RoleType.DEFAULT))
Returns:
BaseMessage: The generated system message.
"""
self.validate_meta_dict_keys(meta_dict)
role_name, role_type = role_tuple
sys_prompt = self.sys_prompts[role_type]
sys_prompt = sys_prompt.format(**meta_dict)
return BaseMessage(
role_name=role_name,
role_type=role_type,
meta_dict=meta_dict,
content=sys_prompt,
)
def from_dicts(
self,
meta_dicts: List[Dict[str, str]],
role_tuples: List[Tuple[str, RoleType]],
) -> List[BaseMessage]:
r"""Generates a list of system messages from a list of dictionaries.
Args:
meta_dicts (List[Dict[str, str]]): A list of dictionaries
containing the information to generate the system messages.
role_tuples (List[Tuple[str, RoleType]]): A list of tuples
containing the role name and role type for each system message.
Returns:
List[BaseMessage]: A list of generated system messages.
Raises:
ValueError: If the number of meta_dicts and role_tuples are
different.
"""
if len(meta_dicts) != len(role_tuples):
raise ValueError(
"The number of meta_dicts and role_types should be the same."
)
return [
self.from_dict(meta_dict, role_tuple)
for meta_dict, role_tuple in zip(meta_dicts, role_tuples)
]
class RoleNameGenerator:
r"""Role name generator for role-playing workers.
Args:
assistant_role_names_path (str, optional): The path to the file
containing the assistant role names.
(default: :obj:`"data/ai_society/assistant_roles.txt"`)
user_role_names_path (str, optional): The path to the file
containing the user role names.
(default: :obj:`"data/ai_society/user_roles.txt"`)
assistant_role_names (Optional[List[str]], optional): The list of
assistant role names. (default: :obj:`None`)
user_role_names (Optional[List[str]], optional): The list of user role
names. (default: :obj:`None`)
"""
def __init__(
self,
assistant_role_names_path: str = "data/ai_society/assistant_roles.txt",
user_role_names_path: str = "data/ai_society/user_roles.txt",
assistant_role_names: Optional[List[str]] = None,
user_role_names: Optional[List[str]] = None,
) -> None:
if assistant_role_names is None:
with open(assistant_role_names_path, "r") as f:
assistant_role_names_: List[str] = f.read().splitlines()
self.assistant_role_names = [
" ".join(name.split(" ")[1:])
for name in assistant_role_names_
]
else:
self.assistant_role_names = assistant_role_names
if user_role_names is None:
with open(user_role_names_path, "r") as f:
user_role_names_: List[str] = f.read().splitlines()
self.user_role_names = [
" ".join(name.split(" ")[1:]) for name in user_role_names_
]
else:
self.user_role_names = user_role_names
def from_role_files(self) -> Generator[Tuple, None, None]:
r"""Generate role names from the file.
Returns:
Generator[Tuple, None, None]: A generator that yields tuples of
assistant role names and user role names.
"""
for assistant_role_name in self.assistant_role_names:
for user_role_name in self.user_role_names:
yield (assistant_role_name, user_role_name)
class AISocietyTaskPromptGenerator:
r"""Task prompt generator for AI society tasks.
Args:
num_tasks (int, optional): The number of tasks to generate.
(default: :obj:`10`)
"""
def __init__(
self,
num_tasks: int = 10,
) -> None:
self.generate_tasks_prompt = (
PromptTemplateGenerator().get_generate_tasks_prompt(
TaskType.AI_SOCIETY
)
)
self.num_tasks = num_tasks
# TODO: Return role names for user and assistant with the generator.
def from_role_files(
self,
assistant_role_names_path: str = "data/ai_society/assistant_roles.txt",
user_role_names_path: str = "data/ai_society/user_roles.txt",
) -> Generator[Tuple[str, Tuple[str, str]], None, None]:
r"""Generate tasks from role files.
Args:
assistant_role_names_path (str, optional): The path to the file
containing the assistant role names.
(default: :obj:`"data/ai_society/assistant_roles.txt"`)
user_role_names_path (str, optional): The path to the file
containing the user role names.
(default: :obj:`"data/ai_society/user_roles.txt"`)
Returns:
Generator[Tuple[str, Tuple[str, str]], None, None]: A generator
that yields tuples of task prompts and role names.
"""
roles_generator = RoleNameGenerator(
assistant_role_names_path, user_role_names_path
).from_role_files()
for role_1, role_2 in roles_generator:
generate_tasks_prompt = self.generate_tasks_prompt.format(
assistant_role=role_1,
user_role=role_2,
num_tasks=self.num_tasks,
)
yield (generate_tasks_prompt, (role_1, role_2))
def from_role_generator(
self, role_generator: Generator[Tuple, None, None]
) -> Generator[Tuple[str, Tuple[str, str]], None, None]:
r"""Generate tasks from a role generator.
Args:
role_generator (Generator[Tuple, None, None]): A generator that
yields tuples of role names.
Returns:
Generator[Tuple[str, Tuple[str, str]], None, None]: A generator
that yields tuples of task prompts and role names.
"""
for role_1, role_2 in role_generator:
generate_tasks_prompt = self.generate_tasks_prompt.format(
assistant_role=role_1,
user_role=role_2,
num_tasks=self.num_tasks,
)
yield (generate_tasks_prompt, (role_1, role_2))
class SingleTxtGenerator:
r"""Single text generator for role-playing workers.
Args:
text_file_path (str): The path to the file containing the text data.
"""
def __init__(
self,
text_file_path: str,
) -> None:
with open(text_file_path, "r") as f:
data_list: List[str] = f.read().splitlines()
self.data_list = [
" ".join(name.split(" ")[1:]) for name in data_list
]
def from_role_files(self) -> Generator[str, None, None]:
r"""Generate text from the file.
Returns:
Generator[str, None, None]: A generator that yields the text data.
"""
for data in self.data_list:
yield data
class CodeTaskPromptGenerator:
r"""Code task prompt generator for code tasks.
Args:
num_tasks (int, optional): The number of tasks to generate.
(default: :obj:`50`)
"""
def __init__(
self,
num_tasks: int = 50,
) -> None:
self.generate_tasks_prompt = (
PromptTemplateGenerator().get_generate_tasks_prompt(TaskType.CODE)
)
self.num_tasks = num_tasks
def from_role_files(
self,
languages_path: str = "data/code/languages.txt",
domains_path: str = "data/code/domains.txt",
) -> Generator[Tuple[TextPrompt, str, str], None, None]:
r"""Generate tasks from role files.
Args:
languages_path (str, optional): The path to the file containing
the language names. (default: :obj:`"data/code/languages.txt"`)
domains_path (str, optional): The path to the file containing
the domain names. (default: :obj:`"data/code/domains.txt"`)
Returns:
Generator[Tuple[TextPrompt, str, str], None, None]: A generator
that yields tuples of task prompts, language names, and domain
names.
"""
language_generator = SingleTxtGenerator(
languages_path
).from_role_files()
for language in language_generator:
domains_generator = SingleTxtGenerator(
domains_path
).from_role_files()
for domain in domains_generator:
generated_tasks_prompt = self.generate_tasks_prompt.format(
language=language, domain=domain, num_tasks=self.num_tasks
)
yield generated_tasks_prompt, language, domain
def from_role_generator(
self, role_generator: Generator[Tuple, None, None]
) -> Generator[str, None, None]:
r"""Generate tasks from a role generator.
Args:
role_generator (Generator[Tuple, None, None]): A generator that
yields tuples of role names.
Returns:
Generator[str, None, None]: A generator that yields the task
prompts.
"""
raise NotImplementedError

Some files were not shown because too many files have changed in this diff Show More