refactor: Update with camel version 0.2.23

This commit is contained in:
Wendong
2025-03-10 04:00:32 +08:00
parent 7887585e8e
commit 3738fbb0e5
283 changed files with 7201 additions and 42380 deletions

29
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,29 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: 'v0.7.4'
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix, --show-fixes]
exclude: ^docs/cookbooks/ # Ignore files under docs/cookbooks
- id: ruff-format
exclude: ^docs/cookbooks/ # Ignore files under docs/cookbooks
- repo: local
hooks:
- id: mypy
name: Check mypy
entry: mypy --namespace-packages -p owl
language: python
types: [python]
pass_filenames: false
require_serial: true
exclude: ^docs/cookbooks/ # Ignore files under docs/cookbooks
- repo: local
hooks:
- id: check-license
name: Check License
entry: python licenses/update_license.py . licenses/license_template.txt
language: system
types: [python]
exclude: ^docs/cookbooks/ # Ignore files under docs/cookbooks

View File

@@ -102,36 +102,31 @@ https://private-user-images.githubusercontent.com/55657767/420212194-e813fc05-13
# 🛠️ Installation
## **Clone the Github repository**
```bash
# Clone github repo
git clone https://github.com/camel-ai/owl.git
# Change directory into project directory
cd owl
```
## **Set up Environment**
# Install uv if you don't have it already
pip install uv
Using Conda (recommended):
```bash
conda create -n owl python=3.11
conda activate owl
```
# Create a virtual environment and install dependencies
# We support using Python 3.10, 3.11, 3.12
uv venv .venv --python=3.10
Using venv (alternative):
```bash
python -m venv owl_env
# On Windows
owl_env\Scripts\activate
# On Unix or MacOS
source owl_env/bin/activate
```
# Activate the virtual environment
# For macOS/Linux
source .venv/bin/activate
# For Windows
.venv\Scripts\activate
# Install CAMEL with all dependencies
uv pip install -e .
## **Install Dependencies**
```bash
python -m pip install -r requirements.txt
playwright install
# Exit the virtual environment when done
deactivate
```
## **Setup Environment Variables**
@@ -210,7 +205,7 @@ question = "Task description here."
society = construct_society(question)
answer, chat_history, token_count = run_society(society)
print(f"Answer: {answer}")
print(f"\033[94mAnswer: {answer}\033[0m")
```
For uploading files, simply provide the file path along with your question:
@@ -221,7 +216,7 @@ question = "What is in the given DOCX file? Here is the file path: tmp/example.d
society = construct_society(question)
answer, chat_history, token_count = run_society(society)
print(f"Answer: {answer}")
print(f"\033[94mAnswer: {answer}\033[0m")
```
OWL will then automatically invoke document-related tools to process the file and extract the answer.

View File

@@ -104,31 +104,30 @@ https://private-user-images.githubusercontent.com/55657767/420212194-e813fc05-13
## **克隆 Github 仓库**
```bash
# 克隆 GitHub 仓库
git clone https://github.com/camel-ai/owl.git
# 进入项目目录
cd owl
```
## **设置环境**
# 如果你还没有安装 uv请先安装
pip install uv
使用 Conda推荐
```bash
conda create -n owl python=3.11
conda activate owl
```
# 创建虚拟环境并安装依赖
# 我们支持使用 Python 3.10、3.11、3.12
uv venv .venv --python=3.10
使用 venv备用
```bash
python -m venv owl_env
# Windows 系统
owl_env\Scripts\activate
# Unix 或 MacOS 系统
source owl_env/bin/activate
```
# 激活虚拟环境
# 对于 macOS/Linux
source .venv/bin/activate
# 对于 Windows
.venv\Scripts\activate
## **安装依赖**
# 安装 CAMEL 及其所有依赖
uv pip install -e .
```bash
python -m pip install -r requirements.txt
# 完成后退出虚拟环境
deactivate
```
## **设置环境变量**
@@ -201,7 +200,7 @@ question = "Task description here."
society = construct_society(question)
answer, chat_history, token_count = run_society(society)
print(f"Answer: {answer}")
print(f"\033[94mAnswer: {answer}\033[0m")
```
上传文件时,只需提供文件路径和问题:

View File

@@ -39,43 +39,37 @@ def update_license_in_file(
start_line_start_with: str,
end_line_start_with: str,
) -> bool:
with open(
file_path, 'r', encoding='utf-8'
) as f: # for windows compatibility
with open(file_path, "r", encoding="utf-8") as f: # for windows compatibility
content = f.read()
with open(license_template_path, 'r', encoding='utf-8') as f:
with open(license_template_path, "r", encoding="utf-8") as f:
new_license = f.read().strip()
maybe_existing_licenses = re.findall(
r'^#.*?(?=\n)', content, re.MULTILINE | re.DOTALL
r"^#.*?(?=\n)", content, re.MULTILINE | re.DOTALL
)
start_index = fine_license_start_line(
maybe_existing_licenses, start_line_start_with
)
end_index = find_license_end_line(
maybe_existing_licenses, end_line_start_with
)
end_index = find_license_end_line(maybe_existing_licenses, end_line_start_with)
if start_index is not None and end_index is not None:
maybe_existing_licenses = maybe_existing_licenses[
start_index : end_index + 1
]
maybe_existing_licenses = maybe_existing_licenses[start_index : end_index + 1]
else:
maybe_existing_licenses = None
if maybe_existing_licenses:
maybe_old_licenses = '\n'.join(maybe_existing_licenses)
maybe_old_licenses = "\n".join(maybe_existing_licenses)
if maybe_old_licenses.strip() != new_license.strip():
replaced_content = content.replace(maybe_old_licenses, new_license)
with open(file_path, 'w') as f:
with open(file_path, "w") as f:
f.write(replaced_content)
print(f'Replaced license in {file_path}')
print(f"Replaced license in {file_path}")
return True
else:
return False
else:
with open(file_path, 'w') as f:
f.write(new_license + '\n' + content)
print(f'Added license to {file_path}')
with open(file_path, "w") as f:
f.write(new_license + "\n" + content)
print(f"Added license to {file_path}")
return True
@@ -87,16 +81,16 @@ def update_license_in_directory(
) -> None:
# Check if directory exists
if not os.path.isdir(directory_path):
raise NotADirectoryError(f'{directory_path} is not a directory')
raise NotADirectoryError(f"{directory_path} is not a directory")
# Check if license template exists
if not os.path.isfile(license_template_path):
raise FileNotFoundError(f'{license_template_path} not found')
raise FileNotFoundError(f"{license_template_path} not found")
file_count = 0
for py_files in Path(directory_path).rglob("*.py"):
if py_files.name.startswith('.'):
if py_files.name.startswith("."):
continue
if any(part.startswith('.') for part in py_files.parts):
if any(part.startswith(".") for part in py_files.parts):
continue
if update_license_in_file(
py_files,
@@ -106,10 +100,10 @@ def update_license_in_directory(
):
file_count += 1
print(f'License updated in {file_count} files')
print(f"License updated in {file_count} files")
if __name__ == '__main__':
if __name__ == "__main__":
if len(sys.argv) < 3:
print(
"Usage from command line: "

View File

@@ -1,25 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from camel.logger import disable_logging, enable_logging, set_log_level
__version__ = '0.2.11'
__all__ = [
'__version__',
'camel',
'disable_logging',
'enable_logging',
'set_log_level',
]

View File

@@ -1,44 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from .base import BaseAgent
from .chat_agent import ChatAgent
from .critic_agent import CriticAgent
from .embodied_agent import EmbodiedAgent
from .knowledge_graph_agent import KnowledgeGraphAgent
from .role_assignment_agent import RoleAssignmentAgent
from .search_agent import SearchAgent
from .task_agent import (
TaskCreationAgent,
TaskPlannerAgent,
TaskPrioritizationAgent,
TaskSpecifyAgent,
)
from .tool_agents.base import BaseToolAgent
from .tool_agents.hugging_face_tool_agent import HuggingFaceToolAgent
__all__ = [
'BaseAgent',
'ChatAgent',
'TaskSpecifyAgent',
'TaskPlannerAgent',
'TaskCreationAgent',
'TaskPrioritizationAgent',
'CriticAgent',
'BaseToolAgent',
'HuggingFaceToolAgent',
'EmbodiedAgent',
'RoleAssignmentAgent',
'SearchAgent',
'KnowledgeGraphAgent',
]

View File

@@ -1,29 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from abc import ABC, abstractmethod
from typing import Any
class BaseAgent(ABC):
r"""An abstract base class for all CAMEL agents."""
@abstractmethod
def reset(self, *args: Any, **kwargs: Any) -> Any:
r"""Resets the agent to its initial state."""
pass
@abstractmethod
def step(self, *args: Any, **kwargs: Any) -> Any:
r"""Performs a single step of the agent."""
pass

File diff suppressed because it is too large Load Diff

View File

@@ -1,202 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import random
import warnings
from typing import Any, Dict, Optional, Sequence
from colorama import Fore
from camel.agents.chat_agent import ChatAgent
from camel.memories import AgentMemory
from camel.messages import BaseMessage
from camel.models import BaseModelBackend
from camel.responses import ChatAgentResponse
from camel.utils import get_first_int, print_text_animated
# AgentOps decorator setting
try:
import os
if os.getenv("AGENTOPS_API_KEY") is not None:
from agentops import track_agent
else:
raise ImportError
except (ImportError, AttributeError):
from camel.utils import track_agent
@track_agent(name="CriticAgent")
class CriticAgent(ChatAgent):
r"""A class for the critic agent that assists in selecting an option.
Args:
system_message (BaseMessage): The system message for the critic
agent.
model (BaseModelBackend, optional): The model backend to use for
generating responses. (default: :obj:`OpenAIModel` with
`GPT_4O_MINI`)
message_window_size (int, optional): The maximum number of previous
messages to include in the context window. If `None`, no windowing
is performed. (default: :obj:`6`)
retry_attempts (int, optional): The number of retry attempts if the
critic fails to return a valid option. (default: :obj:`2`)
verbose (bool, optional): Whether to print the critic's messages.
logger_color (Any): The color of the menu options displayed to the
user. (default: :obj:`Fore.MAGENTA`)
"""
def __init__(
self,
system_message: BaseMessage,
model: Optional[BaseModelBackend] = None,
memory: Optional[AgentMemory] = None,
message_window_size: int = 6,
retry_attempts: int = 2,
verbose: bool = False,
logger_color: Any = Fore.MAGENTA,
) -> None:
super().__init__(
system_message,
model=model,
memory=memory,
message_window_size=message_window_size,
)
self.options_dict: Dict[str, str] = dict()
self.retry_attempts = retry_attempts
self.verbose = verbose
self.logger_color = logger_color
def flatten_options(self, messages: Sequence[BaseMessage]) -> str:
r"""Flattens the options to the critic.
Args:
messages (Sequence[BaseMessage]): A list of `BaseMessage` objects.
Returns:
str: A string containing the flattened options to the critic.
"""
options = [message.content for message in messages]
flatten_options = (
f"> Proposals from "
f"{messages[0].role_name} ({messages[0].role_type}). "
"Please choose an option:\n"
)
for index, option in enumerate(options):
flatten_options += f"Option {index + 1}:\n{option}\n\n"
self.options_dict[str(index + 1)] = option
format = (
f"Please first enter your choice ([1-{len(self.options_dict)}]) "
"and then your explanation and comparison: "
)
return flatten_options + format
def get_option(self, input_message: BaseMessage) -> str:
r"""Gets the option selected by the critic.
Args:
input_message (BaseMessage): A `BaseMessage` object representing
the input message.
Returns:
str: The option selected by the critic.
"""
# TODO: Add support for editing options by the critic.
msg_content = input_message.content
i = 0
while i < self.retry_attempts:
critic_response = self.step(input_message)
if critic_response.msgs is None or len(critic_response.msgs) == 0:
raise RuntimeError("Got None critic messages.")
if critic_response.terminated:
raise RuntimeError("Critic step failed.")
critic_msg = critic_response.msg
if self.verbose:
print_text_animated(
self.logger_color + "\n> Critic response: "
f"\x1b[3m{critic_msg.content}\x1b[0m\n"
)
choice = self.parse_critic(critic_msg)
if choice in self.options_dict:
return self.options_dict[choice]
else:
input_message = BaseMessage(
role_name=input_message.role_name,
role_type=input_message.role_type,
meta_dict=input_message.meta_dict,
content="> Invalid choice. Please choose again.\n"
+ msg_content,
)
i += 1
warnings.warn(
"Critic failed to get a valid option. "
f"After {self.retry_attempts} attempts. "
"Returning a random option."
)
return random.choice(list(self.options_dict.values()))
def parse_critic(self, critic_msg: BaseMessage) -> Optional[str]:
r"""Parses the critic's message and extracts the choice.
Args:
critic_msg (BaseMessage): A `BaseMessage` object representing the
critic's response.
Returns:
Optional[str]: The critic's choice as a string, or None if the
message could not be parsed.
"""
choice = str(get_first_int(critic_msg.content))
return choice
def reduce_step(
self,
input_messages: Sequence[BaseMessage],
) -> ChatAgentResponse:
r"""Performs one step of the conversation by flattening options to the
critic, getting the option, and parsing the choice.
Args:
input_messages (Sequence[BaseMessage]): A list of BaseMessage
objects.
Returns:
ChatAgentResponse: A `ChatAgentResponse` object includes the
critic's choice.
"""
meta_chat_message = BaseMessage(
role_name=input_messages[0].role_name,
role_type=input_messages[0].role_type,
meta_dict=input_messages[0].meta_dict,
content="",
)
flatten_options = self.flatten_options(input_messages)
if self.verbose:
print_text_animated(
self.logger_color + f"\x1b[3m{flatten_options}\x1b[0m\n"
)
input_msg = meta_chat_message.create_new_instance(flatten_options)
option = self.get_option(input_msg)
output_msg = meta_chat_message.create_new_instance(option)
# TODO: The return `info` can be improved.
return ChatAgentResponse(
msgs=[output_msg],
terminated=False,
info={},
)

View File

@@ -1,303 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import re
from typing import Dict, List, Optional, Union
from camel.agents.chat_agent import ChatAgent
from camel.logger import get_logger
from camel.messages import BaseMessage
from camel.models import BaseModelBackend
from camel.prompts import TextPrompt
from camel.types import RoleType
logger = get_logger(__name__)
# AgentOps decorator setting
try:
import os
if os.getenv("AGENTOPS_API_KEY") is not None:
from agentops import track_agent
else:
raise ImportError
except (ImportError, AttributeError):
from camel.utils import track_agent
@track_agent(name="DeductiveReasonerAgent")
class DeductiveReasonerAgent(ChatAgent):
r"""An agent responsible for deductive reasoning. Model of deductive
reasoning:
- L: A ⊕ C -> q * B
- A represents the known starting state.
- B represents the known target state.
- C represents the conditions required to transition from A to B.
- Q represents the quality or effectiveness of the transition from
A to B.
- L represents the path or process from A to B.
Args:
model (BaseModelBackend, optional): The model backend to use for
generating responses. (default: :obj:`OpenAIModel` with
`GPT_4O_MINI`)
"""
def __init__(
self,
model: Optional[BaseModelBackend] = None,
) -> None:
system_message = BaseMessage(
role_name="Insight Agent",
role_type=RoleType.ASSISTANT,
meta_dict=None,
content="You assign roles based on tasks.",
)
super().__init__(system_message, model=model)
def deduce_conditions_and_quality(
self,
starting_state: str,
target_state: str,
role_descriptions_dict: Optional[Dict[str, str]] = None,
) -> Dict[str, Union[List[str], Dict[str, str]]]:
r"""Derives the conditions and quality from the starting state and the
target state based on the model of the deductive reasoning and the
knowledge base. It can optionally consider the roles involved in the
scenario, which allows tailoring the output more closely to the AI
agent's environment.
Args:
starting_state (str): The initial or starting state from which
conditions are deduced.
target_state (str): The target state of the task.
role_descriptions_dict (Optional[Dict[str, str]], optional): The
descriptions of the roles. (default: :obj:`None`)
role_descriptions_dict (Optional[Dict[str, str]], optional): A
dictionary describing the roles involved in the scenario. This
is optional and can be used to provide a context for the
CAMEL's role-playing, enabling the generation of more relevant
and tailored conditions and quality assessments. This could be
generated using a `RoleAssignmentAgent()` or defined manually
by the user.
Returns:
Dict[str, Union[List[str], Dict[str, str]]]: A dictionary with the
extracted data from the message. The dictionary contains three
keys:
- 'conditions': A list where each key is a condition ID and
each value is the corresponding condition text.
- 'labels': A list of label strings extracted from the message.
- 'quality': A string of quality assessment strings extracted
from the message.
"""
self.reset()
deduce_prompt = """You are a deductive reasoner. You are tasked to
complete the TASK based on the THOUGHT OF DEDUCTIVE REASONING, the
STARTING STATE A and the TARGET STATE B. You are given the CONTEXT
CONTENT to help you complete the TASK.
Your answer MUST strictly adhere to the structure of ANSWER TEMPLATE, ONLY
fill in the BLANKs, and DO NOT alter or modify any other part of the template
===== MODELING OF DEDUCTIVE REASONING =====
You are tasked with understanding a mathematical model based on the components
${A, B, C, Q, L}$. In this model: ``L: A ⊕ C -> q * B``.
- $A$ represents the known starting state.
- $B$ represents the known target state.
- $C$ represents the conditions required to transition from $A$ to $B$.
- $Q$ represents the quality or effectiveness of the transition from $A$ to
$B$.
- $L$ represents the path or process from $A$ to $B$.
===== THOUGHT OF DEDUCTIVE REASONING =====
1. Define the Parameters of A and B:
- Characterization: Before delving into transitions, thoroughly understand
the nature and boundaries of both $A$ and $B$. This includes the type,
properties, constraints, and possible interactions between the two.
- Contrast and Compare: Highlight the similarities and differences between
$A$ and $B$. This comparative analysis will give an insight into what
needs changing and what remains constant.
2. Historical & Empirical Analysis:
- Previous Transitions according to the Knowledge Base of GPT: (if
applicable) Extract conditions and patterns from the historical instances
where a similar transition from a state comparable to $A$ moved towards
$B$.
- Scientific Principles: (if applicable) Consider the underlying
scientific principles governing or related to the states and their
transition. For example, if $A$ and $B$ are physical states, laws of
physics might apply.
3. Logical Deduction of Conditions ($C$):
- Direct Path Analysis: What are the immediate and direct conditions
required to move from $A$ to $B$?
- Intermediate States: Are there states between $A$ and $B$ that must be
traversed or can be used to make the transition smoother or more
efficient? If yes, what is the content?
- Constraints & Limitations: Identify potential barriers or restrictions
in moving from $A$ to $B$. These can be external (e.g., environmental
factors) or internal (properties of $A$ or $B$).
- Resource and Information Analysis: What resources and information are
required for the transition? This could be time, entity, factor, code
language, software platform, unknowns, etc.
- External Influences: Consider socio-economic, political, or
environmental factors (if applicable) that could influence the transition
conditions.
- Creative/Heuristic Reasoning: Open your mind to multiple possible $C$'s,
no matter how unconventional they might seem. Utilize analogies,
metaphors, or brainstorming techniques to envision possible conditions or
paths from $A$ to $B$.
- The conditions $C$ should be multiple but in one sentence. And each
condition should be concerned with one aspect/entity.
4. Entity/Label Recognition of Conditions ($C$):
- Identify and categorize entities of Conditions ($C$) such as the names,
locations, dates, specific technical terms or contextual parameters that
might be associated with events, innovations post-2022.
- The output of the entities/labels will be used as tags or labels for
semantic similarity searches. The entities/labels may be the words, or
phrases, each of them should contain valuable, high information entropy
information, and should be independent.
- Ensure that the identified entities are formatted in a manner suitable
for database indexing and retrieval. Organize the entities into
categories, and combine the category with its instance into a continuous
phrase, without using colons or other separators.
- Format these entities for database indexing: output the category rather
than its instance/content into a continuous phrase. For example, instead
of "Jan. 02", identify it as "Event time".
5. Quality Assessment ($Q$):
- Efficiency: How efficient is the transition from $A$ to $B$, which
measures the resources used versus the desired outcome?
- Effectiveness: Did the transition achieve the desired outcome or was the
target state achieved as intended?
- Safety & Risks: Assess any risks associated with the transition and the
measures to mitigate them.
- Feedback Mechanisms: Incorporate feedback loops to continuously monitor
and adjust the quality of transition, making it more adaptive.
6. Iterative Evaluation:
- Test & Refine: Based on the initially deduced conditions and assessed
quality, iterate the process to refine and optimize the transition. This
might involve tweaking conditions, employing different paths, or changing
resources.
- Feedback Integration: Use feedback to make improvements and increase the
quality of the transition.
7. Real-world scenarios often present challenges that may not be captured by
models and frameworks. While using the model, maintain an adaptive mindset:
- Scenario Exploration: Continuously imagine various possible scenarios,
both positive and negative, to prepare for unexpected events.
- Flexibility: Be prepared to modify conditions ($C$) or alter the path/
process ($L$) if unforeseen challenges arise.
- Feedback Integration: Rapidly integrate feedback from actual
implementations to adjust the model's application, ensuring relevancy and
effectiveness.
===== TASK =====
Given the starting state $A$ and the target state $B$, assuming that a path
$L$ always exists between $A$ and $B$, how can one deduce or identify the
necessary conditions $C$ and the quality $Q$ of the transition?
===== STARTING STATE $A$ =====
{starting_state}
===== TARGET STATE $B$ =====
{target_state}
{role_with_description_prompt}
===== ANSWER TEMPLATE =====
- Characterization and comparison of $A$ and $B$:\n<BLANK>
- Historical & Empirical Analysis:\n<BLANK>/None
- Logical Deduction of Conditions ($C$) (multiple conditions can be deduced):
condition <NUM>:
<BLANK>.
- Entity/Label Recognition of Conditions:\n[<BLANK>, <BLANK>, ...] (include
square brackets)
- Quality Assessment ($Q$) (do not use symbols):
<BLANK>.
- Iterative Evaluation:\n<BLANK>/None"""
if role_descriptions_dict is not None:
role_names = role_descriptions_dict.keys()
role_with_description_prompt = (
"===== ROLES WITH DESCRIPTIONS =====\n"
+ "\n".join(
f"{role_name}:\n{role_descriptions_dict[role_name]}\n"
for role_name in role_names
)
+ "\n\n"
)
else:
role_with_description_prompt = ""
deduce_prompt = TextPrompt(deduce_prompt)
deduce = deduce_prompt.format(
starting_state=starting_state,
target_state=target_state,
role_with_description_prompt=role_with_description_prompt,
)
conditions_and_quality_generation_msg = BaseMessage.make_user_message(
role_name="Deductive Reasoner", content=deduce
)
response = self.step(
input_message=conditions_and_quality_generation_msg
)
if response.terminated:
raise RuntimeError(
"Deduction failed. Error:\n" + f"{response.info}"
)
msg: BaseMessage = response.msg
logger.info(f"Message content:\n{msg.content}")
# Extract the conditions from the message
conditions_dict = {
f"condition {i}": cdt.replace("<", "")
.replace(">", "")
.strip()
.strip('\n')
for i, cdt in re.findall(
r"condition (\d+):\s*(.+?)(?=condition \d+|- Entity)",
msg.content,
re.DOTALL,
)
}
# Extract the labels from the message
labels = [
label.strip().strip('\n').strip("\"'")
for label in re.findall(
r"Entity/Label Recognition of Conditions:\n\[(.+?)\]",
msg.content,
re.DOTALL,
)[0].split(",")
]
# Extract the quality from the message
quality = next(
q.strip().strip('\n')
for q in re.findall(
r"Quality Assessment \(\$Q\$\) \(do not use symbols\):"
r"\n(.+?)- Iterative",
msg.content,
re.DOTALL,
)
)
# Convert them into JSON format
conditions_and_quality_json: Dict[
str, Union[List[str], Dict[str, str]]
] = {}
conditions_and_quality_json["conditions"] = conditions_dict
conditions_and_quality_json["labels"] = labels
conditions_and_quality_json["evaluate_quality"] = quality
return conditions_and_quality_json

View File

@@ -1,201 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from typing import Any, List, Optional
from colorama import Fore
from camel.agents.chat_agent import ChatAgent
from camel.agents.tool_agents.base import BaseToolAgent
from camel.interpreters import (
BaseInterpreter,
InternalPythonInterpreter,
SubprocessInterpreter,
)
from camel.messages import BaseMessage
from camel.models import BaseModelBackend
from camel.responses import ChatAgentResponse
from camel.utils import print_text_animated
# AgentOps decorator setting
try:
import os
if os.getenv("AGENTOPS_API_KEY") is not None:
from agentops import track_agent
else:
raise ImportError
except (ImportError, AttributeError):
from camel.utils import track_agent
@track_agent(name="EmbodiedAgent")
class EmbodiedAgent(ChatAgent):
r"""Class for managing conversations of CAMEL Embodied Agents.
Args:
system_message (BaseMessage): The system message for the chat agent.
model (BaseModelBackend, optional): The model backend to use for
generating responses. (default: :obj:`OpenAIModel` with
`GPT_4O_MINI`)
message_window_size (int, optional): The maximum number of previous
messages to include in the context window. If `None`, no windowing
is performed. (default: :obj:`None`)
tool_agents (List[BaseToolAgent], optional): The tools agents to use in
the embodied agent. (default: :obj:`None`)
code_interpreter (BaseInterpreter, optional): The code interpreter to
execute codes. If `code_interpreter` and `tool_agent` are both
`None`, default to `SubProcessInterpreter`. If `code_interpreter`
is `None` and `tool_agents` is not `None`, default to
`InternalPythonInterpreter`. (default: :obj:`None`)
verbose (bool, optional): Whether to print the critic's messages.
logger_color (Any): The color of the logger displayed to the user.
(default: :obj:`Fore.MAGENTA`)
"""
def __init__(
self,
system_message: BaseMessage,
model: Optional[BaseModelBackend] = None,
message_window_size: Optional[int] = None,
tool_agents: Optional[List[BaseToolAgent]] = None,
code_interpreter: Optional[BaseInterpreter] = None,
verbose: bool = False,
logger_color: Any = Fore.MAGENTA,
) -> None:
self.tool_agents = tool_agents
self.code_interpreter: BaseInterpreter
if code_interpreter is not None:
self.code_interpreter = code_interpreter
elif self.tool_agents:
self.code_interpreter = InternalPythonInterpreter()
else:
self.code_interpreter = SubprocessInterpreter()
if self.tool_agents:
system_message = self._set_tool_agents(system_message)
self.verbose = verbose
self.logger_color = logger_color
super().__init__(
system_message=system_message,
model=model,
message_window_size=message_window_size,
)
def _set_tool_agents(self, system_message: BaseMessage) -> BaseMessage:
action_space_prompt = self._get_tool_agents_prompt()
result_message = system_message.create_new_instance(
content=system_message.content.format(
action_space=action_space_prompt
)
)
if self.tool_agents is not None:
self.code_interpreter.update_action_space(
{tool.name: tool for tool in self.tool_agents}
)
return result_message
def _get_tool_agents_prompt(self) -> str:
r"""Returns the action space prompt.
Returns:
str: The action space prompt.
"""
if self.tool_agents is not None:
return "\n".join(
[
f"*** {tool.name} ***:\n {tool.description}"
for tool in self.tool_agents
]
)
else:
return ""
def get_tool_agent_names(self) -> List[str]:
r"""Returns the names of tool agents.
Returns:
List[str]: The names of tool agents.
"""
if self.tool_agents is not None:
return [tool.name for tool in self.tool_agents]
else:
return []
# ruff: noqa: E501
def step(self, input_message: BaseMessage) -> ChatAgentResponse: # type: ignore[override]
r"""Performs a step in the conversation.
Args:
input_message (BaseMessage): The input message.
Returns:
ChatAgentResponse: A struct containing the output messages,
a boolean indicating whether the chat session has terminated,
and information about the chat session.
"""
response = super().step(input_message)
if response.msgs is None or len(response.msgs) == 0:
raise RuntimeError("Got None output messages.")
if response.terminated:
raise RuntimeError(f"{self.__class__.__name__} step failed.")
# NOTE: Only single output messages are supported
explanations, codes = response.msg.extract_text_and_code_prompts()
if self.verbose:
for explanation, code in zip(explanations, codes):
print_text_animated(
self.logger_color + f"> Explanation:\n{explanation}"
)
print_text_animated(self.logger_color + f"> Code:\n{code}")
if len(explanations) > len(codes):
print_text_animated(
self.logger_color + f"> Explanation:\n{explanations[-1]}"
)
content = response.msg.content
if codes is not None:
try:
content = "\n> Executed Results:\n"
for block_idx, code in enumerate(codes):
executed_output = self.code_interpreter.run(
code, code.code_type
)
content += (
f"Executing code block {block_idx}: {{\n"
+ executed_output
+ "}\n"
)
except InterruptedError as e:
content = (
f"\n> Running code fail: {e}\n"
"Please regenerate the code."
)
# TODO: Handle errors
content = input_message.content + f"\n> Embodied Actions:\n{content}"
message = BaseMessage(
input_message.role_name,
input_message.role_type,
input_message.meta_dict,
content,
)
return ChatAgentResponse(
msgs=[message],
terminated=response.terminated,
info=response.info,
)

View File

@@ -1,259 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from typing import TYPE_CHECKING, Optional, Union
if TYPE_CHECKING:
from unstructured.documents.elements import Element
from camel.agents import ChatAgent
from camel.messages import BaseMessage
from camel.models import BaseModelBackend
from camel.prompts import TextPrompt
from camel.storages.graph_storages.graph_element import (
GraphElement,
Node,
Relationship,
)
from camel.types import RoleType
# AgentOps decorator setting
try:
import os
if os.getenv("AGENTOPS_API_KEY") is not None:
from agentops import track_agent
else:
raise ImportError
except (ImportError, AttributeError):
from camel.utils import track_agent
text_prompt = """
You are tasked with extracting nodes and relationships from given content and
structures them into Node and Relationship objects. Here's the outline of what
you needs to do:
Content Extraction:
You should be able to process input content and identify entities mentioned
within it.
Entities can be any noun phrases or concepts that represent distinct entities
in the context of the given content.
Node Extraction:
For each identified entity, you should create a Node object.
Each Node object should have a unique identifier (id) and a type (type).
Additional properties associated with the node can also be extracted and
stored.
Relationship Extraction:
You should identify relationships between entities mentioned in the content.
For each relationship, create a Relationship object.
A Relationship object should have a subject (subj) and an object (obj) which
are Node objects representing the entities involved in the relationship.
Each relationship should also have a type (type), and additional properties if
applicable.
Output Formatting:
The extracted nodes and relationships should be formatted as instances of the
provided Node and Relationship classes.
Ensure that the extracted data adheres to the structure defined by the classes.
Output the structured data in a format that can be easily validated against
the provided code.
Instructions for you:
Read the provided content thoroughly.
Identify distinct entities mentioned in the content and categorize them as
nodes.
Determine relationships between these entities and represent them as directed
relationships.
Provide the extracted nodes and relationships in the specified format below.
Example for you:
Example Content:
"John works at XYZ Corporation. He is a software engineer. The company is
located in New York City."
Expected Output:
Nodes:
Node(id='John', type='Person')
Node(id='XYZ Corporation', type='Organization')
Node(id='New York City', type='Location')
Relationships:
Relationship(subj=Node(id='John', type='Person'), obj=Node(id='XYZ
Corporation', type='Organization'), type='WorksAt')
Relationship(subj=Node(id='John', type='Person'), obj=Node(id='New York City',
type='Location'), type='ResidesIn')
===== TASK =====
Please extracts nodes and relationships from given content and structures them
into Node and Relationship objects.
{task}
"""
@track_agent(name="KnowledgeGraphAgent")
class KnowledgeGraphAgent(ChatAgent):
r"""An agent that can extract node and relationship information for
different entities from given `Element` content.
Attributes:
task_prompt (TextPrompt): A prompt for the agent to extract node and
relationship information for different entities.
"""
def __init__(
self,
model: Optional[BaseModelBackend] = None,
) -> None:
r"""Initialize the `KnowledgeGraphAgent`.
Args:
model (BaseModelBackend, optional): The model backend to use for
generating responses. (default: :obj:`OpenAIModel` with
`GPT_4O_MINI`)
"""
system_message = BaseMessage(
role_name="Graphify",
role_type=RoleType.ASSISTANT,
meta_dict=None,
content="Your mission is to transform unstructured content "
"into structured graph data. Extract nodes and relationships with "
"precision, and let the connections unfold. Your graphs will "
"illuminate the hidden connections within the chaos of "
"information.",
)
super().__init__(system_message, model=model)
def run(
self,
element: "Element",
parse_graph_elements: bool = False,
) -> Union[str, GraphElement]:
r"""Run the agent to extract node and relationship information.
Args:
element (Element): The input element.
parse_graph_elements (bool, optional): Whether to parse into
`GraphElement`. Defaults to `False`.
Returns:
Union[str, GraphElement]: The extracted node and relationship
information. If `parse_graph_elements` is `True` then return
`GraphElement`, else return `str`.
"""
self.reset()
self.element = element
knowledge_graph_prompt = TextPrompt(text_prompt)
knowledge_graph_generation = knowledge_graph_prompt.format(
task=str(element)
)
knowledge_graph_generation_msg = BaseMessage.make_user_message(
role_name="Graphify", content=knowledge_graph_generation
)
response = self.step(input_message=knowledge_graph_generation_msg)
content = response.msg.content
if parse_graph_elements:
content = self._parse_graph_elements(content)
return content
def _validate_node(self, node: Node) -> bool:
r"""Validate if the object is a valid Node.
Args:
node (Node): Object to be validated.
Returns:
bool: True if the object is a valid Node, False otherwise.
"""
return (
isinstance(node, Node)
and isinstance(node.id, (str, int))
and isinstance(node.type, str)
)
def _validate_relationship(self, relationship: Relationship) -> bool:
r"""Validate if the object is a valid Relationship.
Args:
relationship (Relationship): Object to be validated.
Returns:
bool: True if the object is a valid Relationship, False otherwise.
"""
return (
isinstance(relationship, Relationship)
and self._validate_node(relationship.subj)
and self._validate_node(relationship.obj)
and isinstance(relationship.type, str)
)
def _parse_graph_elements(self, input_string: str) -> GraphElement:
r"""Parses graph elements from given content.
Args:
input_string (str): The input content.
Returns:
GraphElement: The parsed graph elements.
"""
import re
# Regular expressions to extract nodes and relationships
node_pattern = r"Node\(id='(.*?)', type='(.*?)'\)"
rel_pattern = (
r"Relationship\(subj=Node\(id='(.*?)', type='(.*?)'\), "
r"obj=Node\(id='(.*?)', type='(.*?)'\), type='(.*?)'\)"
)
nodes = {}
relationships = []
# Extract nodes
for match in re.finditer(node_pattern, input_string):
id, type = match.groups()
properties = {'source': 'agent_created'}
if id not in nodes:
node = Node(id=id, type=type, properties=properties)
if self._validate_node(node):
nodes[id] = node
# Extract relationships
for match in re.finditer(rel_pattern, input_string):
subj_id, subj_type, obj_id, obj_type, rel_type = match.groups()
properties = {'source': 'agent_created'}
if subj_id in nodes and obj_id in nodes:
subj = nodes[subj_id]
obj = nodes[obj_id]
relationship = Relationship(
subj=subj, obj=obj, type=rel_type, properties=properties
)
if self._validate_relationship(relationship):
relationships.append(relationship)
return GraphElement(
nodes=list(nodes.values()),
relationships=relationships,
source=self.element,
)

View File

@@ -1,141 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import re
from typing import Dict, Optional, Union
from camel.agents.chat_agent import ChatAgent
from camel.messages import BaseMessage
from camel.models import BaseModelBackend
from camel.prompts import TextPrompt
from camel.types import RoleType
# AgentOps decorator setting
try:
import os
if os.getenv("AGENTOPS_API_KEY") is not None:
from agentops import track_agent
else:
raise ImportError
except (ImportError, AttributeError):
from camel.utils import track_agent
@track_agent(name="RoleAssignmentAgent")
class RoleAssignmentAgent(ChatAgent):
r"""An agent that generates role names based on the task prompt.
Args:
model (BaseModelBackend, optional): The model backend to use for
generating responses. (default: :obj:`OpenAIModel` with
`GPT_4O_MINI`)
Attributes:
role_assignment_prompt (TextPrompt): A prompt for the agent to generate
role names.
"""
def __init__(
self,
model: Optional[BaseModelBackend] = None,
) -> None:
system_message = BaseMessage(
role_name="Role Assigner",
role_type=RoleType.ASSISTANT,
meta_dict=None,
content="You assign roles based on tasks.",
)
super().__init__(system_message, model=model)
def run(
self,
task_prompt: Union[str, TextPrompt],
num_roles: int = 2,
) -> Dict[str, str]:
r"""Generate role names based on the input task prompt.
Args:
task_prompt (Union[str, TextPrompt]): The prompt
for the task based on which the roles are to be generated.
num_roles (int, optional): The number of roles to generate.
(default: :obj:`2`)
Returns:
Dict[str, str]: A dictionary mapping role names to their
descriptions.
"""
self.reset()
expert_prompt = "===== ANSWER PROMPT =====\n" + "\n".join(
f"Domain expert {i + 1}: <BLANK>\n"
f"Associated competencies, characteristics, duties "
f"and workflows: <BLANK>. End."
for i in range(num_roles or 0)
)
role_assignment_generation_prompt = TextPrompt(
"You are a role assignment agent, and you're in charge of "
+ "recruiting {num_roles} experts for the following task."
+ "\n==== TASK =====\n {task}\n\n"
+ "Identify the domain experts you'd recruit and detail their "
+ "associated competencies, characteristics, duties and workflows "
+ "to complete the task.\n "
+ "Your answer MUST adhere to the format of ANSWER PROMPT, and "
+ "ONLY answer the BLANKs.\n"
+ expert_prompt
)
role_assignment_generation = role_assignment_generation_prompt.format(
num_roles=num_roles, task=task_prompt
)
role_assignment_generation_msg = BaseMessage.make_user_message(
role_name="Role Assigner", content=role_assignment_generation
)
response = self.step(input_message=role_assignment_generation_msg)
msg = response.msg # type: BaseMessage
terminated = response.terminated
# Distribute the output completions into role names and descriptions
role_names = [
desc.replace("<|", "").replace("|>", "")
for desc in re.findall(
r"Domain expert \d: (.+?)\nAssociated competencies,",
msg.content,
re.DOTALL,
)
]
role_descriptions = [
desc.replace("<|", "").replace("|>", "")
for desc in re.findall(
r"Associated competencies, characteristics, "
r"duties and workflows: (.+?) End.",
msg.content,
re.DOTALL,
)
]
if len(role_names) != num_roles or len(role_descriptions) != num_roles:
raise RuntimeError(
"Got None or insufficient information of roles."
)
if terminated:
raise RuntimeError("Role assignment failed.")
role_descriptions_dict = {
role_name: description
for role_name, description in zip(role_names, role_descriptions)
}
return role_descriptions_dict

View File

@@ -1,133 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from typing import Optional
from camel.agents.chat_agent import ChatAgent
from camel.messages import BaseMessage
from camel.models import BaseModelBackend
from camel.prompts import TextPrompt
from camel.types import RoleType
from camel.utils import create_chunks
# AgentOps decorator setting
try:
import os
if os.getenv("AGENTOPS_API_KEY") is not None:
from agentops import track_agent
else:
raise ImportError
except (ImportError, AttributeError):
from camel.utils import track_agent
@track_agent(name="SearchAgent")
class SearchAgent(ChatAgent):
r"""An agent that summarizes text based on a query and evaluates the
relevance of an answer.
Args:
model (BaseModelBackend, optional): The model backend to use for
generating responses. (default: :obj:`OpenAIModel` with
`GPT_4O_MINI`)
"""
def __init__(
self,
model: Optional[BaseModelBackend] = None,
) -> None:
system_message = BaseMessage(
role_name="Assistant",
role_type=RoleType.ASSISTANT,
meta_dict=None,
content="You are a helpful assistant.",
)
super().__init__(system_message, model=model)
def summarize_text(self, text: str, query: str) -> str:
r"""Summarize the information from the text, base on the query.
Args:
text (str): Text to summarize.
query (str): What information you want.
Returns:
str: Strings with information.
"""
self.reset()
summary_prompt = TextPrompt(
'''Gather information from this text that relative to the
question, but do not directly answer the question.\nquestion:
{query}\ntext '''
)
summary_prompt = summary_prompt.format(query=query)
# Max length of each chunk
max_len = 3000
results = ""
chunks = create_chunks(text, max_len)
# Summarize
for i, chunk in enumerate(chunks, start=1):
prompt = summary_prompt + str(i) + ": " + chunk
user_msg = BaseMessage.make_user_message(
role_name="User",
content=prompt,
)
result = self.step(user_msg).msg.content
results += result + "\n"
# Final summarization
final_prompt = TextPrompt(
'''Here are some summarized texts which split from one text. Using
the information to answer the question. If can't find the answer,
you must answer "I can not find the answer to the query" and
explain why.\n Query:\n{query}.\n\nText:\n'''
)
final_prompt = final_prompt.format(query=query)
prompt = final_prompt + results
user_msg = BaseMessage.make_user_message(
role_name="User",
content=prompt,
)
response = self.step(user_msg).msg.content
return response
def continue_search(self, query: str, answer: str) -> bool:
r"""Ask whether to continue search or not based on the provided answer.
Args:
query (str): The question.
answer (str): The answer to the question.
Returns:
bool: `True` if the user want to continue search, `False`
otherwise.
"""
prompt = TextPrompt(
"Do you think the ANSWER can answer the QUERY? "
"Use only 'yes' or 'no' to answer.\n"
"===== QUERY =====\n{query}\n\n"
"===== ANSWER =====\n{answer}"
)
prompt = prompt.format(query=query, answer=answer)
user_msg = BaseMessage.make_user_message(
role_name="User",
content=prompt,
)
response = self.step(user_msg).msg.content
if "yes" in str(response).lower():
return False
return True

View File

@@ -1,410 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from typing import Any, Dict, List, Optional, Union
from camel.agents.chat_agent import ChatAgent
from camel.messages import BaseMessage
from camel.models import BaseModelBackend
from camel.prompts import PromptTemplateGenerator, TextPrompt
from camel.types import RoleType, TaskType
from camel.utils import get_task_list
# AgentOps decorator setting
try:
import os
if os.getenv("AGENTOPS_API_KEY") is not None:
from agentops import track_agent
else:
raise ImportError
except (ImportError, AttributeError):
from camel.utils import track_agent
@track_agent(name="TaskSpecifyAgent")
class TaskSpecifyAgent(ChatAgent):
r"""An agent that specifies a given task prompt by prompting the user to
provide more details.
Attributes:
DEFAULT_WORD_LIMIT (int): The default word limit for the task prompt.
task_specify_prompt (TextPrompt): The prompt for specifying the task.
Args:
model (BaseModelBackend, optional): The model backend to use for
generating responses. (default: :obj:`OpenAIModel` with
`GPT_4O_MINI`)
task_type (TaskType, optional): The type of task for which to generate
a prompt. (default: :obj:`TaskType.AI_SOCIETY`)
task_specify_prompt (Union[str, TextPrompt], optional): The prompt for
specifying the task. (default: :obj:`None`)
word_limit (int, optional): The word limit for the task prompt.
(default: :obj:`50`)
output_language (str, optional): The language to be output by the
agent. (default: :obj:`None`)
"""
DEFAULT_WORD_LIMIT = 50
def __init__(
self,
model: Optional[BaseModelBackend] = None,
task_type: TaskType = TaskType.AI_SOCIETY,
task_specify_prompt: Optional[Union[str, TextPrompt]] = None,
word_limit: int = DEFAULT_WORD_LIMIT,
output_language: Optional[str] = None,
) -> None:
self.task_specify_prompt: Union[str, TextPrompt]
if task_specify_prompt is None:
task_specify_prompt_template = (
PromptTemplateGenerator().get_task_specify_prompt(task_type)
)
self.task_specify_prompt = task_specify_prompt_template.format(
word_limit=word_limit
)
else:
self.task_specify_prompt = TextPrompt(task_specify_prompt)
system_message = BaseMessage(
role_name="Task Specifier",
role_type=RoleType.ASSISTANT,
meta_dict=None,
content="You can make a task more specific.",
)
super().__init__(
system_message,
model=model,
output_language=output_language,
)
def run(
self,
task_prompt: Union[str, TextPrompt],
meta_dict: Optional[Dict[str, Any]] = None,
) -> TextPrompt:
r"""Specify the given task prompt by providing more details.
Args:
task_prompt (Union[str, TextPrompt]): The original task
prompt.
meta_dict (Dict[str, Any], optional): A dictionary containing
additional information to include in the prompt.
(default: :obj:`None`)
Returns:
TextPrompt: The specified task prompt.
"""
self.reset()
task_specify_prompt = self.task_specify_prompt.format(task=task_prompt)
if meta_dict is not None:
task_specify_prompt = task_specify_prompt.format(**meta_dict)
task_msg = BaseMessage.make_user_message(
role_name="Task Specifier", content=task_specify_prompt
)
specifier_response = self.step(task_msg)
if specifier_response.terminated:
raise RuntimeError("Task specification failed.")
if len(specifier_response.msgs) == 0:
raise RuntimeError("Got no specification message.")
specified_task_msg = specifier_response.msgs[0]
return TextPrompt(specified_task_msg.content)
@track_agent(name="TaskPlannerAgent")
class TaskPlannerAgent(ChatAgent):
r"""An agent that helps divide a task into subtasks based on the input
task prompt.
Attributes:
task_planner_prompt (TextPrompt): A prompt for the agent to divide
the task into subtasks.
Args:
model (BaseModelBackend, optional): The model backend to use for
generating responses. (default: :obj:`OpenAIModel` with
`GPT_4O_MINI`)
output_language (str, optional): The language to be output by the
agent. (default: :obj:`None`)
"""
def __init__(
self,
model: Optional[BaseModelBackend] = None,
output_language: Optional[str] = None,
) -> None:
self.task_planner_prompt = TextPrompt(
"Divide this task into subtasks: {task}. Be concise."
)
system_message = BaseMessage(
role_name="Task Planner",
role_type=RoleType.ASSISTANT,
meta_dict=None,
content="You are a helpful task planner.",
)
super().__init__(
system_message,
model=model,
output_language=output_language,
)
def run(
self,
task_prompt: Union[str, TextPrompt],
) -> TextPrompt:
r"""Generate subtasks based on the input task prompt.
Args:
task_prompt (Union[str, TextPrompt]): The prompt for the task to
be divided into subtasks.
Returns:
TextPrompt: A prompt for the subtasks generated by the agent.
"""
# TODO: Maybe include roles information.
self.reset()
task_planner_prompt = self.task_planner_prompt.format(task=task_prompt)
task_msg = BaseMessage.make_user_message(
role_name="Task Planner", content=task_planner_prompt
)
task_response = self.step(task_msg)
if task_response.terminated:
raise RuntimeError("Task planning failed.")
if len(task_response.msgs) == 0:
raise RuntimeError("Got no task planning message.")
sub_tasks_msg = task_response.msgs[0]
return TextPrompt(sub_tasks_msg.content)
@track_agent(name="TaskCreationAgent")
class TaskCreationAgent(ChatAgent):
r"""An agent that helps create new tasks based on the objective
and last completed task. Compared to :obj:`TaskPlannerAgent`,
it's still a task planner, but it has more context information
like last task and incomplete task list. Modified from
`BabyAGI <https://github.com/yoheinakajima/babyagi>`_.
Attributes:
task_creation_prompt (TextPrompt): A prompt for the agent to
create new tasks.
Args:
role_name (str): The role name of the Agent to create the task.
objective (Union[str, TextPrompt]): The objective of the Agent to
perform the task.
model (BaseModelBackend, optional): The LLM backend to use for
generating responses. (default: :obj:`OpenAIModel` with
`GPT_4O_MINI`)
output_language (str, optional): The language to be output by the
agent. (default: :obj:`None`)
message_window_size (int, optional): The maximum number of previous
messages to include in the context window. If `None`, no windowing
is performed. (default: :obj:`None`)
max_task_num (int, optional): The maximum number of planned
tasks in one round. (default: :obj:3)
"""
def __init__(
self,
role_name: str,
objective: Union[str, TextPrompt],
model: Optional[BaseModelBackend] = None,
output_language: Optional[str] = None,
message_window_size: Optional[int] = None,
max_task_num: Optional[int] = 3,
) -> None:
task_creation_prompt = TextPrompt(
"""Create new a task with the following objective: {objective}.
Never forget you are a Task Creator of {role_name}.
You must instruct me based on my expertise and your needs to solve the task.
You should consider past solved tasks and in-progress tasks: {task_list}.
The new created tasks must not overlap with these past tasks.
The result must be a numbered list in the format:
#. First Task
#. Second Task
#. Third Task
You can only give me up to {max_task_num} tasks at a time. \
Each task should be concise, concrete and doable for a {role_name}.
You should make task plan and not ask me questions.
If you think no new tasks are needed right now, write "No tasks to add."
Now start to give me new tasks one by one. No more than three tasks.
Be concrete.
"""
)
self.task_creation_prompt = task_creation_prompt.format(
objective=objective, role_name=role_name, max_task_num=max_task_num
)
self.objective = objective
system_message = BaseMessage(
role_name="Task Creator",
role_type=RoleType.ASSISTANT,
meta_dict=None,
content="You are a helpful task creator.",
)
super().__init__(
system_message,
model=model,
output_language=output_language,
message_window_size=message_window_size,
)
def run(
self,
task_list: List[str],
) -> List[str]:
r"""Generate subtasks based on the previous task results and
incomplete task list.
Args:
task_list (List[str]): The completed or in-progress
tasks which should not overlap with new created tasks.
Returns:
List[str]: The new task list generated by the Agent.
"""
if len(task_list) > 0:
task_creation_prompt = self.task_creation_prompt.format(
task_list=task_list
)
else:
task_creation_prompt = self.task_creation_prompt.format(
task_list=""
)
task_msg = BaseMessage.make_user_message(
role_name="Task Creator", content=task_creation_prompt
)
task_response = self.step(task_msg)
if task_response.terminated:
raise RuntimeError("Task creation failed.")
if len(task_response.msgs) == 0:
raise RuntimeError("Got no task creation message.")
sub_tasks_msg = task_response.msgs[0]
return get_task_list(sub_tasks_msg.content)
@track_agent(name="TaskPrioritizationAgent")
class TaskPrioritizationAgent(ChatAgent):
r"""An agent that helps re-prioritize the task list and
returns numbered prioritized list. Modified from
`BabyAGI <https://github.com/yoheinakajima/babyagi>`_.
Attributes:
task_prioritization_prompt (TextPrompt): A prompt for the agent to
prioritize tasks.
Args:
objective (Union[str, TextPrompt]): The objective of the Agent to
perform the task.
model (BaseModelBackend, optional): The LLM backend to use for
generating responses. (default: :obj:`OpenAIModel` with
`GPT_4O_MINI`)
output_language (str, optional): The language to be output by the
agent. (default: :obj:`None`)
message_window_size (int, optional): The maximum number of previous
messages to include in the context window. If `None`, no windowing
is performed. (default: :obj:`None`)
"""
def __init__(
self,
objective: Union[str, TextPrompt],
model: Optional[BaseModelBackend] = None,
output_language: Optional[str] = None,
message_window_size: Optional[int] = None,
) -> None:
task_prioritization_prompt = TextPrompt(
"""Prioritize the following tasks : {task_list}.
Consider the ultimate objective of you: {objective}.
Tasks should be sorted from highest to lowest priority, where higher-priority \
tasks are those that act as pre-requisites or are more essential for meeting \
the objective. Return one task per line in your response.
Do not remove or modify any tasks.
The result must be a numbered list in the format:
#. First task
#. Second task
The entries must be consecutively numbered, starting with 1.
The number of each entry must be followed by a period.
Do not include any headers before your ranked list or follow your list \
with any other output."""
)
self.task_prioritization_prompt = task_prioritization_prompt.format(
objective=objective
)
self.objective = objective
system_message = BaseMessage(
role_name="Task Prioritizer",
role_type=RoleType.ASSISTANT,
meta_dict=None,
content="You are a helpful task prioritizer.",
)
super().__init__(
system_message,
model=model,
output_language=output_language,
message_window_size=message_window_size,
)
def run(
self,
task_list: List[str],
) -> List[str]:
r"""Prioritize the task list given the agent objective.
Args:
task_list (List[str]): The unprioritized tasks of agent.
Returns:
List[str]: The new prioritized task list generated by the Agent.
"""
task_prioritization_prompt = self.task_prioritization_prompt.format(
task_list=task_list
)
task_msg = BaseMessage.make_user_message(
role_name="Task Prioritizer", content=task_prioritization_prompt
)
task_response = self.step(task_msg)
if task_response.terminated:
raise RuntimeError("Task prioritization failed.")
if len(task_response.msgs) == 0:
raise RuntimeError("Got no task prioritization message.")
sub_tasks_msg = task_response.msgs[0]
return get_task_list(sub_tasks_msg.content)

View File

@@ -1,20 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from .base import BaseToolAgent
from .hugging_face_tool_agent import HuggingFaceToolAgent
__all__ = [
'BaseToolAgent',
'HuggingFaceToolAgent',
]

View File

@@ -1,39 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from camel.agents import BaseAgent
class BaseToolAgent(BaseAgent):
r"""Creates a :obj:`BaseToolAgent` object with the specified name and
description.
Args:
name (str): The name of the tool agent.
description (str): The description of the tool agent.
"""
def __init__(self, name: str, description: str) -> None:
self.name = name
self.description = description
def reset(self) -> None:
r"""Resets the agent to its initial state."""
pass
def step(self) -> None:
r"""Performs a single step of the agent."""
pass
def __str__(self) -> str:
return f"{self.name}: {self.description}"

View File

@@ -1,206 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from typing import Any, Optional
from camel.agents.tool_agents.base import BaseToolAgent
# flake8: noqa :E501
class HuggingFaceToolAgent(BaseToolAgent):
r"""Tool agent for calling HuggingFace models. This agent is a wrapper
around agents from the `transformers` library. For more information
about the available models, please see the `transformers` documentation
at https://huggingface.co/docs/transformers/transformers_agents.
Args:
name (str): The name of the agent.
*args (Any): Additional positional arguments to pass to the underlying
Agent class.
remote (bool, optional): Flag indicating whether to run the agent
remotely. (default: :obj:`True`)
**kwargs (Any): Additional keyword arguments to pass to the underlying
Agent class.
"""
def __init__(
self,
name: str,
*args: Any,
remote: bool = True,
**kwargs: Any,
) -> None:
try:
# TODO: Support other tool agents
import transformers
from packaging import version
if version.parse(transformers.__version__) < version.parse(
"4.31.0"
):
raise ValueError(
"The version of \"transformers\" package should >= 4.31.0"
)
from transformers.tools import OpenAiAgent
from transformers.tools.agent_types import AgentImage
except (ImportError, ValueError):
raise ValueError(
"Could not import transformers tool agents. "
"Please setup the environment with "
"pip install huggingface_hub==0.14.1 transformers==4.31.0 diffusers accelerate==0.20.3 datasets torch soundfile sentencepiece opencv-python"
)
self.agent_image_type = AgentImage
self.agent = OpenAiAgent(*args, **kwargs)
description = f"""The `{name}` is a tool agent that can perform a variety of tasks including:
- Document question answering: given a document (such as a PDF) in image format, answer a question on this document
- Text question answering: given a long text and a question, answer the question in the text
- Unconditional image captioning: Caption the image!
- Image question answering: given an image, answer a question on this image
- Image segmentation: given an image and a prompt, output the segmentation mask of that prompt
- Speech to text: given an audio recording of a person talking, transcribe the speech into text
- Text to speech: convert text to speech
- Zero-shot text classification: given a text and a list of labels, identify to which label the text corresponds the most
- Text summarization: summarize a long text in one or a few sentences
- Translation: translate the text into a given language
- Text downloading: to download a text from a web URL
- Text to image: generate an image according to a prompt, leveraging stable diffusion
- Image transformation: modify an image given an initial image and a prompt, leveraging instruct pix2pix stable diffusion
- Text to video: generate a small video according to a prompt
Here are some python code examples of what you can do with this agent:
Single execution (step) mode, the single execution method is when using the step() method of the agent:
```
# Text to image
rivers_and_lakes_image = {name}.step("Draw me a picture of rivers and lakes.")
rivers_and_lakes_image.save("./rivers_and_lakes_image.png")
# Text to image -> Image transformation
sea_add_island_image = {name}.step("Draw me a picture of the sea then transform the picture to add an island")
sea_add_island_image.save("./sea_add_island_image.png")
# If you'd like to keep a state across executions or to pass non-text objects to the agent,
# you can do so by specifying variables that you would like the agent to use. For example,
# you could generate the first image of rivers and lakes, and ask the model to update that picture to add an island by doing the following:
picture = {name}.step("Generate a picture of rivers and lakes.")
picture.save("./picture.png")
updated_picture = {name}.step("Transform the image in `picture` to add an island to it.", picture=picture)
updated_picture.save("./updated_picture.png")
capybara_sea_image = {name}.step("Draw me a picture of the `prompt`", prompt="a capybara swimming in the sea")
capybara_sea_image.save("./capybara_sea_image.png")
# Document question answering
answer = {name}.step(
"In the following `document`, where will the TRRF Scientific Advisory Council Meeting take place?",
document=document,
)
print(answer)
# Text to image
boat_image = {name}.step("Generate an image of a boat in the water")
boat_image.save("./boat_image.png")
# Unconditional image captioning
boat_image_caption = {name}.step("Can you caption the `boat_image`?", boat_image=boat_image)
print(boat_image_caption)
# Text to image -> Unconditional image captioning -> Text to speech
boat_audio = {name}.step("Can you generate an image of a boat? Please read out loud the contents of the image afterwards")
# Text downloading
document = {name}.step("Download the text from http://hf.co")
print(document)
# Text summarization
summary = {name}.step("Summarize the following text: `document`", document=document)
print(summary)
# Text downloading -> Text summarization -> Text to speech
audio = {name}.step("Read out loud the summary of http://hf.co")
```
Chat-based execution (chat), the agent also has a chat-based approach, using the chat() method:
```
# Clean the chat history
{name}.reset()
# Text to image
capybara_image = {name}.chat("Show me an an image of a capybara")
capybara_image.save("./capybara_image.png")
# Image transformation
transformed_capybara_image = {name}.chat("Transform the image so that it snows")
transformed_capybara_image.save("./transformed_capybara_image.png")
# Image segmentation
segmented_transformed_capybara_image = {name}.chat("Show me a mask of the snowy capybaras")
segmented_transformed_capybara_image.save("./segmented_transformed_capybara_image.png")
```
"""
super(HuggingFaceToolAgent, self).__init__(name, description)
self.remote = remote
def reset(self) -> None:
r"""Resets the chat history of the agent."""
self.agent.prepare_for_new_chat()
def step(
self,
*args: Any,
remote: Optional[bool] = None,
**kwargs: Any,
) -> Any:
r"""Runs the agent in single execution mode.
Args:
*args (Any): Positional arguments to pass to the agent.
remote (bool, optional): Flag indicating whether to run the agent
remotely. Overrides the default setting. (default: :obj:`None`)
**kwargs (Any): Keyword arguments to pass to the agent.
Returns:
str: The response from the agent.
"""
if remote is None:
remote = self.remote
agent_output = self.agent.run(*args, remote=remote, **kwargs)
if isinstance(agent_output, self.agent_image_type):
agent_output = agent_output.to_raw()
return agent_output
def chat(
self,
*args: Any,
remote: Optional[bool] = None,
**kwargs: Any,
) -> Any:
r"""Runs the agent in a chat conversation mode.
Args:
*args (Any): Positional arguments to pass to the agent.
remote (bool, optional): Flag indicating whether to run the agent
remotely. Overrides the default setting. (default: :obj:`None`)
**kwargs (Any): Keyword arguments to pass to the agent.
Returns:
str: The response from the agent.
"""
if remote is None:
remote = self.remote
agent_output = self.agent.chat(*args, remote=remote, **kwargs)
if isinstance(agent_output, self.agent_image_type):
agent_output = agent_output.to_raw()
return agent_output

View File

@@ -1,17 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from .base import BaseBenchmark
__all__ = ["BaseBenchmark"]

View File

@@ -1,152 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import logging
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional
from camel.agents import ChatAgent
logger = logging.getLogger(__name__)
class BaseBenchmark(ABC):
r"""Base class for benchmarks.
Attributes:
name (str): Name of the benchmark.
data_dir (str): Path to the data directory.
save_to (str): Path to save the results.
processes (int): Number of processes to use for parallel
processing. :(default: :obj:`1`)
"""
def __init__(
self, name: str, data_dir: str, save_to: str, processes: int = 1
):
r"""Initialize the benchmark.
Args:
name (str): Name of the benchmark.
data_dir (str): Path to the data directory.
save_to (str): Path to save the results.
processes (int): Number of processes to use for parallel
processing. :(default: :obj:`1`)
"""
self.name = name
self.data_dir = Path(data_dir)
self.processes = processes
self.save_to = save_to
if not self.data_dir.exists():
logger.info(
f"Data directory {data_dir} does not exist. Creating it."
)
self.data_dir.mkdir(parents=True, exist_ok=True)
if not self.data_dir.is_dir():
raise NotADirectoryError(
f"Data directory {data_dir} is not a directory"
)
self._data: Dict[str, List[Dict[str, Any]]] = dict()
self._results: List[Dict[str, Any]] = []
@abstractmethod
def download(self) -> "BaseBenchmark":
r"""Download the benchmark data.
Returns:
BaseBenchmark: The benchmark instance.
"""
pass
@abstractmethod
def load(self, force_download: bool = False) -> "BaseBenchmark":
r"""Load the benchmark data.
Args:
force_download (bool): Whether to force download the data.
Returns:
BaseBenchmark: The benchmark instance.
"""
pass
@property
def train(self) -> List[Dict[str, Any]]:
r"""Get the training data.
Returns:
List[Dict[str, Any]]: The training data.
"""
if not self._data:
logger.info("Data not loaded. Loading data.")
self.load()
return self._data["train"]
@property
def valid(self) -> List[Dict[str, Any]]:
r"""Get the validation data.
Returns:
List[Dict[str, Any]]: The validation data.
"""
if not self._data:
logger.info("Data not loaded. Loading data.")
self.load()
return self._data["valid"]
@property
def test(self) -> List[Dict[str, Any]]:
r"""Get the test data.
Returns:
List[Dict[str, Any]]: The test data.
"""
if not self._data:
logger.info("Data not loaded. Loading data.")
self.load()
return self._data["test"]
@abstractmethod
def run(
self,
agent: ChatAgent,
on: Literal["train", "valid", "test"],
randomize: bool = False,
subset: Optional[int] = None,
*args,
**kwargs,
) -> "BaseBenchmark":
r"""Run the benchmark.
Args:
agent (ChatAgent): The chat agent.
on (str): The data split to run the benchmark on.
randomize (bool): Whether to randomize the data.
subset (int): The subset of the data to run the benchmark on.
Returns:
BaseBenchmark: The benchmark instance.
"""
pass
@property
def results(self) -> List[Dict[str, Any]]:
r"""Get the results.
Returns:
List[Dict[str, Any]]: The results.
"""
return self._results

View File

@@ -1,34 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from .discord_app import DiscordApp
from .slack.models import (
SlackAppMentionEventBody,
SlackAppMentionEventProfile,
SlackAuthProfile,
SlackEventBody,
SlackEventProfile,
)
from .slack.slack_app import SlackApp
from .telegram_bot import TelegramBot
__all__ = [
'DiscordApp',
'SlackApp',
'SlackAppMentionEventBody',
'SlackAppMentionEventProfile',
'SlackAuthProfile',
'SlackEventBody',
'SlackEventProfile',
'TelegramBot',
]

View File

@@ -1,138 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import logging
import os
from typing import TYPE_CHECKING, List, Optional
from camel.utils import dependencies_required
if TYPE_CHECKING:
from discord import Message
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class DiscordApp:
r"""A class representing a Discord app that uses the `discord.py` library
to interact with Discord servers.
This bot can respond to messages in specific channels and only reacts to
messages that mention the bot.
Attributes:
channel_ids (Optional[List[int]]): A list of allowed channel IDs. If
provided, the bot will only respond to messages in these channels.
token (Optional[str]): The Discord bot token used for authentication.
"""
@dependencies_required('discord')
def __init__(
self,
channel_ids: Optional[List[int]] = None,
token: Optional[str] = None,
) -> None:
r"""Initialize the DiscordApp instance by setting up the Discord client
and event handlers.
Args:
channel_ids (Optional[List[int]]): A list of allowed channel IDs.
The bot will only respond to messages in these channels if
provided.
token (Optional[str]): The Discord bot token for authentication.
If not provided, the token will be retrieved from the
environment variable `DISCORD_TOKEN`.
Raises:
ValueError: If the `DISCORD_TOKEN` is not found in environment
variables.
"""
self.token = token or os.getenv('DISCORD_TOKEN')
self.channel_ids = channel_ids
if not self.token:
raise ValueError(
"`DISCORD_TOKEN` not found in environment variables. Get it"
" here: `https://discord.com/developers/applications`."
)
import discord
intents = discord.Intents.default()
intents.message_content = True
self._client = discord.Client(intents=intents)
# Register event handlers
self._client.event(self.on_ready)
self._client.event(self.on_message)
async def start(self):
r"""Asynchronously start the Discord bot using its token.
This method starts the bot and logs into Discord asynchronously using
the provided token. It should be awaited when used in an async
environment.
"""
await self._client.start(self.token)
def run(self) -> None:
r"""Start the Discord bot using its token.
This method starts the bot and logs into Discord synchronously using
the provided token. It blocks execution and keeps the bot running.
"""
self._client.run(self.token) # type: ignore[arg-type]
async def on_ready(self) -> None:
r"""Event handler that is called when the bot has successfully
connected to the Discord server.
When the bot is ready and logged into Discord, it prints a message
displaying the bot's username.
"""
logger.info(f'We have logged in as {self._client.user}')
async def on_message(self, message: 'Message') -> None:
r"""Event handler for processing incoming messages.
This method is called whenever a new message is received by the bot. It
will ignore messages sent by the bot itself, only respond to messages
in allowed channels (if specified), and only to messages that mention
the bot.
Args:
message (discord.Message): The message object received from
Discord.
"""
# If the message author is the bot itself,
# do not respond to this message
if message.author == self._client.user:
return
# If allowed channel IDs are provided,
# only respond to messages in those channels
if self.channel_ids and message.channel.id not in self.channel_ids:
return
# Only respond to messages that mention the bot
if not self._client.user or not self._client.user.mentioned_in(
message
):
return
logger.info(f"Received message: {message.content}")
@property
def client(self):
return self._client

View File

@@ -1,30 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from .models import (
SlackAppMentionEventBody,
SlackAppMentionEventProfile,
SlackAuthProfile,
SlackEventBody,
SlackEventProfile,
)
from .slack_app import SlackApp
__all__ = [
'SlackApp',
'SlackAppMentionEventBody',
'SlackAppMentionEventProfile',
'SlackAuthProfile',
'SlackEventBody',
'SlackEventProfile',
]

View File

@@ -1,158 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from typing import Optional
from pydantic import BaseModel
class SlackAuthProfile(BaseModel):
r"""Represents the authorization profile within a Slack event.
Events will contain a single, compact authorizations field that shows one
installation of your app that the event is visible to.
In other words, lists of authorizations will be truncated to one element.
If there's more than one installing party that your app is keeping track
of, it's best not to rely on the single party listed in authorizations to
be any particular one.
To get a full list of who can see events, call the apps.event.
authorizations.list method after obtaining an app-level token. Read more on
the changes here; they have taken effect for existing apps as of
February 24, 2021.
References:
- https://api.slack.com/apis/events-api#authorizations
- https://api.slack.com/changelog/2020-09-15-events-api-truncate-authed-users#no_context
"""
enterprise_id: Optional[str] = None
"""The ID of the enterprise associated with the authorization."""
team_id: str
"""The ID of the team associated with the authorization."""
user_id: str
"""The ID of the user associated with the authorization."""
is_bot: bool
"""Whether the authorized user is a bot."""
is_enterprise_install: bool
"""Whether the authorization is for an enterprise installation."""
class SlackEventProfile(BaseModel):
r"""Represents the detailed profile of a Slack event, including user,
message, and context data.
"""
user: str
"""The ID of the user associated with the event."""
type: str
"""The type of the event (e.g., 'message')."""
ts: str
"""A timestamp representing when the event was triggered."""
thread_ts: Optional[str] = None
"""The timestamp of the parent message in a thread."""
client_msg_id: str
"""A unique ID generated by the client for the message (if available)."""
text: str
"""The message content text."""
team: str
"""The ID of the team that the event is associated with."""
blocks: list
"""The list of message blocks, providing structured information."""
channel: str
"""The ID of the Slack channel where the event happened."""
event_ts: str
"""The event-specific timestamp when it occurred."""
channel_type: Optional[str]
"""The type of Slack channel (e.g., 'channel', 'im')."""
class SlackEventBody(BaseModel):
r"""Represents the entire body of a Slack event, including the event
profile, authorization, and context.
"""
token: str
"""The token to verify the source of the event."""
team_id: str
"""The ID of the team where the event is happening."""
context_team_id: Optional[str]
"""The team ID for the shared channel context, if applicable."""
context_enterprise_id: Optional[str] = None
"""The enterprise ID for the shared channel context, if applicable."""
api_app_id: str
"""The unique identifier for the Slack app that received the event."""
event: SlackEventProfile
"""A detailed profile of the event"""
type: str
"""The overall type of event received (e.g., 'event_callback')."""
event_id: str
"""A unique identifier assigned to this event by Slack."""
event_time: int
"""The timestamp (in seconds) representing when the event was triggered."""
authorizations: Optional[list[SlackAuthProfile]] = None
"""An optional list of authorizations that describe which installation can
see the event."""
is_ext_shared_channel: bool
"""Indicates if the event is part of a shared channel between different
organizations."""
event_context: str
"""A unique string representing the context of the event."""
class SlackAppMentionEventProfile(SlackEventProfile):
r"""Represents the detailed profile of a Slack event where the app was
mentioned in a message.
"""
channel_type: Optional[str] = None
"""The type of Slack channel. it's None for app mentions."""
class SlackAppMentionEventBody(SlackEventBody):
r"""Represents the entire body of a Slack event where the app was mentioned
in a message.
"""
context_team_id: Optional[str] = None
"""A detailed profile of the event. it's None for app mentions."""
event: SlackAppMentionEventProfile
"""A detailed profile of the event"""

View File

@@ -1,255 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import logging
import os
from typing import TYPE_CHECKING, Any, Dict, Optional
from slack_sdk.oauth.installation_store.async_installation_store import (
AsyncInstallationStore,
)
from starlette import requests, responses
from camel.bots.slack.models import (
SlackAppMentionEventBody,
SlackAppMentionEventProfile,
SlackEventBody,
SlackEventProfile,
)
from camel.utils import dependencies_required
if TYPE_CHECKING:
from slack_bolt.context.async_context import AsyncBoltContext
from slack_bolt.context.say.async_say import AsyncSay
from slack_sdk.web.async_client import AsyncWebClient
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class SlackApp:
r"""Represents a Slack app that is powered by a Slack Bolt `AsyncApp`.
This class is responsible for initializing and managing the Slack
application by setting up event handlers, running the app server, and
handling events such as messages and mentions from Slack.
Args:
token (Optional[str]): Slack API token for authentication.
scopes (Optional[str]): Slack app scopes for permissions.
signing_secret (Optional[str]): Signing secret for verifying Slack
requests.
client_id (Optional[str]): Slack app client ID.
client_secret (Optional[str]): Slack app client secret.
redirect_uri_path (str): The URI path for OAuth redirect, defaults to
"/slack/oauth_redirect".
installation_store (Optional[AsyncInstallationStore]): The installation
store for handling OAuth installations.
"""
@dependencies_required('slack_bolt')
def __init__(
self,
token: Optional[str] = None,
scopes: Optional[str] = None,
signing_secret: Optional[str] = None,
client_id: Optional[str] = None,
client_secret: Optional[str] = None,
redirect_uri_path: str = "/slack/oauth_redirect",
installation_store: Optional[AsyncInstallationStore] = None,
) -> None:
r"""Initializes the SlackApp instance by setting up the Slack Bolt app
and configuring event handlers and OAuth settings.
Args:
token (Optional[str]): The Slack API token.
scopes (Optional[str]): The scopes for Slack app permissions.
signing_secret (Optional[str]): The signing secret for verifying
requests.
client_id (Optional[str]): The Slack app client ID.
client_secret (Optional[str]): The Slack app client secret.
redirect_uri_path (str): The URI path for handling OAuth redirects
(default is "/slack/oauth_redirect").
installation_store (Optional[AsyncInstallationStore]): An optional
installation store for OAuth installations.
"""
from slack_bolt.adapter.starlette.async_handler import (
AsyncSlackRequestHandler,
)
from slack_bolt.app.async_app import AsyncApp
from slack_bolt.oauth.async_oauth_settings import AsyncOAuthSettings
self.token: Optional[str] = token or os.getenv("SLACK_TOKEN")
self.scopes: Optional[str] = scopes or os.getenv("SLACK_SCOPES")
self.signing_secret: Optional[str] = signing_secret or os.getenv(
"SLACK_SIGNING_SECRET"
)
self.client_id: Optional[str] = client_id or os.getenv(
"SLACK_CLIENT_ID"
)
self.client_secret: Optional[str] = client_secret or os.getenv(
"SLACK_CLIENT_SECRET"
)
if not all([self.token, self.scopes, self.signing_secret]):
raise ValueError(
"`SLACK_TOKEN`, `SLACK_SCOPES`, and `SLACK_SIGNING_SECRET` "
"environment variables must be set. Get it here: "
"`https://api.slack.com/apps`."
)
# Setup OAuth settings if client ID and secret are provided
if self.client_id and self.client_secret:
self._app = AsyncApp(
oauth_settings=AsyncOAuthSettings(
client_id=self.client_id,
client_secret=self.client_secret,
scopes=self.scopes,
redirect_uri_path=redirect_uri_path,
),
logger=logger,
signing_secret=self.signing_secret,
installation_store=installation_store,
token=self.token,
)
else:
# Initialize Slack Bolt AsyncApp with settings
self._app = AsyncApp(
logger=logger,
signing_secret=self.signing_secret,
installation_store=installation_store,
token=self.token,
)
self._handler = AsyncSlackRequestHandler(self._app)
self.setup_handlers()
def setup_handlers(self) -> None:
r"""Sets up the event handlers for Slack events, such as `app_mention`
and `message`.
This method registers the `app_mention` and `on_message` event handlers
with the Slack Bolt app to respond to Slack events.
"""
self._app.event("app_mention")(self.app_mention)
self._app.event("message")(self.on_message)
def run(
self,
port: int = 3000,
path: str = "/slack/events",
host: Optional[str] = None,
) -> None:
r"""Starts the Slack Bolt app server to listen for incoming Slack
events.
Args:
port (int): The port on which the server should run (default is
3000).
path (str): The endpoint path for receiving Slack events (default
is "/slack/events").
host (Optional[str]): The hostname to bind the server (default is
None).
"""
self._app.start(port=port, path=path, host=host)
async def handle_request(
self, request: requests.Request
) -> responses.Response:
r"""Handles incoming requests from Slack through the request handler.
Args:
request (Request): A Starlette request object representing the
incoming request.
Returns:
The response generated by the Slack Bolt handler.
"""
return await self._handler.handle(request)
async def app_mention(
self,
context: "AsyncBoltContext",
client: "AsyncWebClient",
event: Dict[str, Any],
body: Dict[str, Any],
say: "AsyncSay",
) -> None:
r"""Event handler for `app_mention` events.
This method is triggered when someone mentions the app in Slack.
Args:
context (AsyncBoltContext): The Slack Bolt context for the event.
client (AsyncWebClient): The Slack Web API client.
event (Dict[str, Any]): The event data for the app mention.
body (Dict[str, Any]): The full request body from Slack.
say (AsyncSay): A function to send a response back to the channel.
"""
event_profile = SlackAppMentionEventProfile(**event)
event_body = SlackAppMentionEventBody(**body)
logger.info(f"app_mention, context: {context}")
logger.info(f"app_mention, client: {client}")
logger.info(f"app_mention, event_profile: {event_profile}")
logger.info(f"app_mention, event_body: {event_body}")
logger.info(f"app_mention, say: {say}")
async def on_message(
self,
context: "AsyncBoltContext",
client: "AsyncWebClient",
event: Dict[str, Any],
body: Dict[str, Any],
say: "AsyncSay",
) -> None:
r"""Event handler for `message` events.
This method is triggered when the app receives a message in Slack.
Args:
context (AsyncBoltContext): The Slack Bolt context for the event.
client (AsyncWebClient): The Slack Web API client.
event (Dict[str, Any]): The event data for the message.
body (Dict[str, Any]): The full request body from Slack.
say (AsyncSay): A function to send a response back to the channel.
"""
await context.ack()
event_profile = SlackEventProfile(**event)
event_body = SlackEventBody(**body)
logger.info(f"on_message, context: {context}")
logger.info(f"on_message, client: {client}")
logger.info(f"on_message, event_profile: {event_profile}")
logger.info(f"on_message, event_body: {event_body}")
logger.info(f"on_message, say: {say}")
logger.info(f"Received message: {event_profile.text}")
def mention_me(
self, context: "AsyncBoltContext", body: SlackEventBody
) -> bool:
r"""Check if the bot is mentioned in the message.
Args:
context (AsyncBoltContext): The Slack Bolt context for the event.
body (SlackEventBody): The body of the Slack event.
Returns:
bool: True if the bot is mentioned in the message, False otherwise.
"""
message = body.event.text
bot_user_id = context.bot_user_id
mention = f"<@{bot_user_id}>"
return mention in message

View File

@@ -1,82 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import os
from typing import TYPE_CHECKING, Optional
from camel.agents import ChatAgent
from camel.messages import BaseMessage
from camel.utils import dependencies_required
# Conditionally import telebot types only for type checking
if TYPE_CHECKING:
from telebot.types import ( # type: ignore[import-untyped]
Message,
)
class TelegramBot:
r"""Represents a Telegram bot that is powered by an agent.
Attributes:
chat_agent (ChatAgent): Chat agent that will power the bot.
telegram_token (str, optional): The bot token.
"""
@dependencies_required('telebot')
def __init__(
self,
chat_agent: ChatAgent,
telegram_token: Optional[str] = None,
) -> None:
self.chat_agent = chat_agent
if not telegram_token:
self.token = os.getenv('TELEGRAM_TOKEN')
if not self.token:
raise ValueError(
"`TELEGRAM_TOKEN` not found in environment variables. "
"Get it from t.me/BotFather."
)
else:
self.token = telegram_token
import telebot # type: ignore[import-untyped]
self.bot = telebot.TeleBot(token=self.token)
# Register the message handler within the constructor
self.bot.message_handler(func=lambda message: True)(self.on_message)
def run(self) -> None:
r"""Start the Telegram bot."""
print("Telegram bot is running...")
self.bot.infinity_polling()
def on_message(self, message: 'Message') -> None:
r"""Handles incoming messages from the user.
Args:
message (types.Message): The incoming message object.
"""
self.chat_agent.reset()
if not message.text:
return
user_msg = BaseMessage.make_user_message(
role_name="User", content=message.text
)
assistant_response = self.chat_agent.step(user_msg)
self.bot.reply_to(message, assistant_response.msg.content)

View File

@@ -1,76 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from .anthropic_config import ANTHROPIC_API_PARAMS, AnthropicConfig
from .base_config import BaseConfig
from .cohere_config import COHERE_API_PARAMS, CohereConfig
from .deepseek_config import DEEPSEEK_API_PARAMS, DeepSeekConfig
from .gemini_config import Gemini_API_PARAMS, GeminiConfig
from .groq_config import GROQ_API_PARAMS, GroqConfig
from .litellm_config import LITELLM_API_PARAMS, LiteLLMConfig
from .mistral_config import MISTRAL_API_PARAMS, MistralConfig
from .nvidia_config import NVIDIA_API_PARAMS, NvidiaConfig
from .ollama_config import OLLAMA_API_PARAMS, OllamaConfig
from .openai_config import OPENAI_API_PARAMS, ChatGPTConfig
from .qwen_config import QWEN_API_PARAMS, QwenConfig
from .reka_config import REKA_API_PARAMS, RekaConfig
from .samba_config import (
SAMBA_CLOUD_API_PARAMS,
SAMBA_VERSE_API_PARAMS,
SambaCloudAPIConfig,
SambaVerseAPIConfig,
)
from .togetherai_config import TOGETHERAI_API_PARAMS, TogetherAIConfig
from .vllm_config import VLLM_API_PARAMS, VLLMConfig
from .yi_config import YI_API_PARAMS, YiConfig
from .zhipuai_config import ZHIPUAI_API_PARAMS, ZhipuAIConfig
__all__ = [
'BaseConfig',
'ChatGPTConfig',
'OPENAI_API_PARAMS',
'AnthropicConfig',
'ANTHROPIC_API_PARAMS',
'GROQ_API_PARAMS',
'GroqConfig',
'LiteLLMConfig',
'LITELLM_API_PARAMS',
'NvidiaConfig',
'NVIDIA_API_PARAMS',
'OllamaConfig',
'OLLAMA_API_PARAMS',
'ZhipuAIConfig',
'ZHIPUAI_API_PARAMS',
'GeminiConfig',
'Gemini_API_PARAMS',
'VLLMConfig',
'VLLM_API_PARAMS',
'MistralConfig',
'MISTRAL_API_PARAMS',
'RekaConfig',
'REKA_API_PARAMS',
'SambaVerseAPIConfig',
'SAMBA_VERSE_API_PARAMS',
'SambaCloudAPIConfig',
'SAMBA_CLOUD_API_PARAMS',
'TogetherAIConfig',
'TOGETHERAI_API_PARAMS',
'CohereConfig',
'COHERE_API_PARAMS',
'YiConfig',
'YI_API_PARAMS',
'QwenConfig',
'QWEN_API_PARAMS',
'DeepSeekConfig',
'DEEPSEEK_API_PARAMS',
]

View File

@@ -1,69 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import List, Union
from camel.configs.base_config import BaseConfig
from camel.types import NOT_GIVEN, NotGiven
class AnthropicConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
Anthropic API.
See: https://docs.anthropic.com/claude/reference/complete_post
Args:
max_tokens (int, optional): The maximum number of tokens to
generate before stopping. Note that Anthropic models may stop
before reaching this maximum. This parameter only specifies the
absolute maximum number of tokens to generate.
(default: :obj:`256`)
stop_sequences (List[str], optional): Sequences that will cause the
model to stop generating completion text. Anthropic models stop
on "\n\nHuman:", and may include additional built-in stop sequences
in the future. By providing the stop_sequences parameter, you may
include additional strings that will cause the model to stop
generating.
temperature (float, optional): Amount of randomness injected into the
response. Defaults to 1. Ranges from 0 to 1. Use temp closer to 0
for analytical / multiple choice, and closer to 1 for creative
and generative tasks.
(default: :obj:`1`)
top_p (float, optional): Use nucleus sampling. In nucleus sampling, we
compute the cumulative distribution over all the options for each
subsequent token in decreasing probability order and cut it off
once it reaches a particular probability specified by `top_p`.
You should either alter `temperature` or `top_p`,
but not both.
(default: :obj:`0.7`)
top_k (int, optional): Only sample from the top K options for each
subsequent token. Used to remove "long tail" low probability
responses.
(default: :obj:`5`)
metadata: An object describing metadata about the request.
stream (bool, optional): Whether to incrementally stream the response
using server-sent events. (default: :obj:`False`)
"""
max_tokens: int = 256
stop_sequences: Union[List[str], NotGiven] = NOT_GIVEN
temperature: float = 1
top_p: Union[float, NotGiven] = NOT_GIVEN
top_k: Union[int, NotGiven] = NOT_GIVEN
metadata: NotGiven = NOT_GIVEN
stream: bool = False
ANTHROPIC_API_PARAMS = {param for param in AnthropicConfig.model_fields.keys()}

View File

@@ -1,89 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from abc import ABC
from typing import Any, List, Optional
from pydantic import BaseModel, ConfigDict, field_validator
class BaseConfig(ABC, BaseModel):
r"""Base configuration class for all models.
This class provides a common interface for all models, ensuring that all
models have a consistent set of attributes and methods.
"""
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra="forbid",
frozen=True,
# UserWarning: conflict with protected namespace "model_"
protected_namespaces=(),
)
tools: Optional[List[Any]] = None
"""A list of tools the model may
call. Currently, only functions are supported as a tool. Use this
to provide a list of functions the model may generate JSON inputs
for. A max of 128 functions are supported.
"""
@field_validator("tools", mode="before")
@classmethod
def fields_type_checking(cls, tools):
r"""Validate the type of tools in the configuration.
This method ensures that the tools provided in the configuration are
instances of `FunctionTool`. If any tool is not an instance of
`FunctionTool`, it raises a ValueError.
"""
if tools is not None:
from camel.toolkits import FunctionTool
for tool in tools:
if not isinstance(tool, FunctionTool):
raise ValueError(
f"The tool {tool} should "
"be an instance of `FunctionTool`."
)
return tools
def as_dict(self) -> dict[str, Any]:
r"""Convert the current configuration to a dictionary.
This method converts the current configuration object to a dictionary
representation, which can be used for serialization or other purposes.
Returns:
dict[str, Any]: A dictionary representation of the current
configuration.
"""
config_dict = self.model_dump()
tools_schema = None
if self.tools:
from camel.toolkits import FunctionTool
tools_schema = []
for tool in self.tools:
if not isinstance(tool, FunctionTool):
raise ValueError(
f"The tool {tool} should "
"be an instance of `FunctionTool`."
)
tools_schema.append(tool.get_openai_tool_schema())
config_dict["tools"] = tools_schema
return config_dict

View File

@@ -1,76 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import List, Optional
from camel.configs.base_config import BaseConfig
class CohereConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
Cohere API.
Args:
temperature (float, optional): Sampling temperature to use, between
:obj:`0` and :obj:`2`. Higher values make the output more random,
while lower values make it more focused and deterministic.
(default: :obj:`0.3`)
documents (list, optional): A list of relevant documents that the
model can cite to generate a more accurate reply. Each document is
either a string or document object with content and metadata.
(default: :obj:`None`)
max_tokens (int, optional): The maximum number of tokens the model
will generate as part of the response. (default: :obj:`None`)
stop_sequences (List(str), optional): A list of up to 5 strings that
the model will use to stop generation. If the model generates a
string that matches any of the strings in the list, it will stop
generating tokens and return the generated text up to that point
not including the stop sequence. (default: :obj:`None`)
seed (int, optional): If specified, the backend will make a best
effort to sample tokens deterministically, such that repeated
requests with the same seed and parameters should return the same
result. However, determinism cannot be totally guaranteed.
(default: :obj:`None`)
frequency_penalty (float, optional): Min value of `0.0`, max value of
`1.0`. Used to reduce repetitiveness of generated tokens. The
higher the value, the stronger a penalty is applied to previously
present tokens, proportional to how many times they have already
appeared in the prompt or prior generation. (default: :obj:`0.0`)
presence_penalty (float, optional): Min value of `0.0`, max value of
`1.0`. Used to reduce repetitiveness of generated tokens. Similar
to `frequency_penalty`, except that this penalty is applied
equally to all tokens that have already appeared, regardless of
their exact frequencies. (default: :obj:`0.0`)
k (int, optional): Ensures only the top k most likely tokens are
considered for generation at each step. Min value of `0`, max
value of `500`. (default: :obj:`0`)
p (float, optional): Ensures that only the most likely tokens, with
total probability mass of `p`, are considered for generation at
each step. If both k and p are enabled, `p` acts after `k`. Min
value of `0.01`, max value of `0.99`. (default: :obj:`0.75`)
"""
temperature: Optional[float] = 0.2
documents: Optional[list] = None
max_tokens: Optional[int] = None
stop_sequences: Optional[List[str]] = None
seed: Optional[int] = None
frequency_penalty: Optional[float] = 0.0
presence_penalty: Optional[float] = 0.0
k: Optional[int] = 0
p: Optional[float] = 0.75
COHERE_API_PARAMS = {param for param in CohereConfig().model_fields.keys()}

View File

@@ -1,134 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import Any, Optional, Sequence, Type, Union
from pydantic import BaseModel
from camel.configs.base_config import BaseConfig
from camel.types import NOT_GIVEN, NotGiven
class DeepSeekConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
DeepSeek API.
Args:
temperature (float, optional): Sampling temperature to use, between
:obj:`0` and :obj:`2`. Higher values make the output more random,
while lower values make it more focused and deterministic.
(default: :obj:`0.2`)
top_p (float, optional): Controls the diversity and focus of the
generated results. Higher values make the output more diverse,
while lower values make it more focused. (default: :obj:`1.0`)
response_format (object, optional): Specifies the format of the
returned content. The available values are `{"type": "text"}` or
`{"type": "json_object"}`. Setting it to `{"type": "json_object"}`
will output a standard JSON string.
(default: :obj:`{"type": "text"}`)
stream (bool, optional): If set, partial message deltas will be sent.
Tokens will be sent as data-only server-sent events (SSE) as
they become available, with the stream terminated by a
data: [DONE] message. (default: :obj:`False`)
stop (Union[str, list[str]], optional): Up to 16 sequences where
the API will stop generating further tokens. (default: :obj:`None`)
max_tokens (int, optional): The maximum number of tokens that can
be generated in the chat completion. The total length of input
tokens and generated tokens is limited by the model's context
length. (default: :obj:`None`)
presence_penalty (float, optional): Number between -2.0 and 2.0.
Positive values penalize new tokens based on whether they
appear in the text so far, increasing the model's likelihood
to talk about new topics. (default: :obj:`0.0`)
frequency_penalty (float, optional): Number between -2.0 and 2.0.
Positive values penalize new tokens based on their existing
frequency in the text so far, decreasing the model's likelihood
to repeat the same line verbatim. (default: :obj:`0`)
tools (list[FunctionTool], optional): A list of tools the model may
call. Currently, only functions are supported as a tool. Use
this to provide a list of functions the model may generate JSON
inputs for. A max of 128 functions are supported.
(default: :obj:`None`)
tool_choice (Union[dict[str, str], str], optional): Controls which
(if any) tool is called by the model. "none" means the model
will not call any tool and instead generates a message. "auto"
means the model can pick between generating a message or calling
one or more tools. "required" means the model must call one or
more tools. Specifying a particular tool via
{"type": "function", "function": {"name": "my_function"}} forces
the model to call that tool. "none" is the default when no tools
are present. "auto" is the default if tools are present.
(default: :obj:`"auto"`)
logprobs (bool, optional): Whether to return log probabilities of
the output tokens or not. If true, returns the log probabilities
of each output token returned in the content of message.
(default: :obj:`False`)
top_logprobs (int, optional): An integer between 0 and 20 specifying
the number of most likely tokens to return at each token
position, each with an associated log probability. logprobs
must be set to true if this parameter is used.
(default: :obj:`None`)
include_usage (bool, optional): When streaming, specifies whether to
include usage information in `stream_options`. (default:
:obj:`True`)
"""
temperature: float = 0.2 # deepseek default: 1.0
top_p: float = 1.0
stream: bool = False
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
max_tokens: Union[int, NotGiven] = NOT_GIVEN
presence_penalty: float = 0.0
response_format: Union[Type[BaseModel], dict, NotGiven] = NOT_GIVEN
frequency_penalty: float = 0.0
tool_choice: Optional[Union[dict[str, str], str]] = None
logprobs: bool = False
top_logprobs: Optional[int] = None
def __init__(self, include_usage: bool = True, **kwargs):
super().__init__(**kwargs)
# Only set stream_options when stream is True
# Otherwise, it will raise error when calling the API
if self.stream:
self.stream_options = {"include_usage": include_usage}
def as_dict(self) -> dict[str, Any]:
r"""Convert the current configuration to a dictionary.
This method converts the current configuration object to a dictionary
representation, which can be used for serialization or other purposes.
Returns:
dict[str, Any]: A dictionary representation of the current
configuration.
"""
config_dict = self.model_dump()
if self.tools:
from camel.toolkits import FunctionTool
tools_schema = []
for tool in self.tools:
if not isinstance(tool, FunctionTool):
raise ValueError(
f"The tool {tool} should "
"be an instance of `FunctionTool`."
)
tools_schema.append(tool.get_openai_tool_schema())
config_dict["tools"] = NOT_GIVEN
return config_dict
DEEPSEEK_API_PARAMS = {param for param in DeepSeekConfig.model_fields.keys()}

View File

@@ -1,114 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import Any, Optional, Sequence, Type, Union
from pydantic import BaseModel
from camel.configs.base_config import BaseConfig
from camel.types import NOT_GIVEN, NotGiven
class GeminiConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
Gemini API.
Args:
temperature (float, optional): Sampling temperature to use, between
:obj:`0` and :obj:`2`. Higher values make the output more random,
while lower values make it more focused and deterministic.
(default: :obj:`0.2`)
top_p (float, optional): An alternative to sampling with temperature,
called nucleus sampling, where the model considers the results of
the tokens with top_p probability mass. So :obj:`0.1` means only
the tokens comprising the top 10% probability mass are considered.
(default: :obj:`1.0`)
n (int, optional): How many chat completion choices to generate for
each input message. (default: :obj:`1`)
response_format (object, optional): An object specifying the format
that the model must output. Compatible with GPT-4 Turbo and all
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
{"type": "json_object"} enables JSON mode, which guarantees the
message the model generates is valid JSON. Important: when using
JSON mode, you must also instruct the model to produce JSON
yourself via a system or user message. Without this, the model
may generate an unending stream of whitespace until the generation
reaches the token limit, resulting in a long-running and seemingly
"stuck" request. Also note that the message content may be
partially cut off if finish_reason="length", which indicates the
generation exceeded max_tokens or the conversation exceeded the
max context length.
stream (bool, optional): If True, partial message deltas will be sent
as data-only server-sent events as they become available.
(default: :obj:`False`)
stop (str or list, optional): Up to :obj:`4` sequences where the API
will stop generating further tokens. (default: :obj:`None`)
max_tokens (int, optional): The maximum number of tokens to generate
in the chat completion. The total length of input tokens and
generated tokens is limited by the model's context length.
(default: :obj:`None`)
tools (list[FunctionTool], optional): A list of tools the model may
call. Currently, only functions are supported as a tool. Use this
to provide a list of functions the model may generate JSON inputs
for. A max of 128 functions are supported.
tool_choice (Union[dict[str, str], str], optional): Controls which (if
any) tool is called by the model. :obj:`"none"` means the model
will not call any tool and instead generates a message.
:obj:`"auto"` means the model can pick between generating a
message or calling one or more tools. :obj:`"required"` means the
model must call one or more tools. Specifying a particular tool
via {"type": "function", "function": {"name": "my_function"}}
forces the model to call that tool. :obj:`"none"` is the default
when no tools are present. :obj:`"auto"` is the default if tools
are present.
"""
temperature: float = 0.2 # openai default: 1.0
top_p: float = 1.0
n: int = 1
stream: bool = False
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
max_tokens: Union[int, NotGiven] = NOT_GIVEN
response_format: Union[Type[BaseModel], dict, NotGiven] = NOT_GIVEN
tool_choice: Optional[Union[dict[str, str], str]] = None
def as_dict(self) -> dict[str, Any]:
r"""Convert the current configuration to a dictionary.
This method converts the current configuration object to a dictionary
representation, which can be used for serialization or other purposes.
Returns:
dict[str, Any]: A dictionary representation of the current
configuration.
"""
config_dict = self.model_dump()
if self.tools:
from camel.toolkits import FunctionTool
tools_schema = []
for tool in self.tools:
if not isinstance(tool, FunctionTool):
raise ValueError(
f"The tool {tool} should "
"be an instance of `FunctionTool`."
)
tools_schema.append(tool.get_openai_tool_schema())
config_dict["tools"] = NOT_GIVEN
return config_dict
Gemini_API_PARAMS = {param for param in GeminiConfig.model_fields.keys()}

View File

@@ -1,104 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import Optional, Sequence, Union
from camel.configs.base_config import BaseConfig
from camel.types import NOT_GIVEN, NotGiven
class GroqConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using OpenAI
compatibility.
Reference: https://console.groq.com/docs/openai
Args:
temperature (float, optional): Sampling temperature to use, between
:obj:`0` and :obj:`2`. Higher values make the output more random,
while lower values make it more focused and deterministic.
(default: :obj:`0.2`)
top_p (float, optional): An alternative to sampling with temperature,
called nucleus sampling, where the model considers the results of
the tokens with top_p probability mass. So :obj:`0.1` means only
the tokens comprising the top 10% probability mass are considered.
(default: :obj:`1.0`)
n (int, optional): How many chat completion choices to generate for
each input message. (default: :obj:`1`)
response_format (object, optional): An object specifying the format
that the model must output. Compatible with GPT-4 Turbo and all
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
{"type": "json_object"} enables JSON mode, which guarantees the
message the model generates is valid JSON. Important: when using
JSON mode, you must also instruct the model to produce JSON
yourself via a system or user message. Without this, the model
may generate an unending stream of whitespace until the generation
reaches the token limit, resulting in a long-running and seemingly
"stuck" request. Also note that the message content may be
partially cut off if finish_reason="length", which indicates the
generation exceeded max_tokens or the conversation exceeded the
max context length.
stream (bool, optional): If True, partial message deltas will be sent
as data-only server-sent events as they become available.
(default: :obj:`False`)
stop (str or list, optional): Up to :obj:`4` sequences where the API
will stop generating further tokens. (default: :obj:`None`)
max_tokens (int, optional): The maximum number of tokens to generate
in the chat completion. The total length of input tokens and
generated tokens is limited by the model's context length.
(default: :obj:`None`)
presence_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on whether
they appear in the text so far, increasing the model's likelihood
to talk about new topics. See more information about frequency and
presence penalties. (default: :obj:`0.0`)
frequency_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on their
existing frequency in the text so far, decreasing the model's
likelihood to repeat the same line verbatim. See more information
about frequency and presence penalties. (default: :obj:`0.0`)
user (str, optional): A unique identifier representing your end-user,
which can help OpenAI to monitor and detect abuse.
(default: :obj:`""`)
tools (list[FunctionTool], optional): A list of tools the model may
call. Currently, only functions are supported as a tool. Use this
to provide a list of functions the model may generate JSON inputs
for. A max of 128 functions are supported.
tool_choice (Union[dict[str, str], str], optional): Controls which (if
any) tool is called by the model. :obj:`"none"` means the model
will not call any tool and instead generates a message.
:obj:`"auto"` means the model can pick between generating a
message or calling one or more tools. :obj:`"required"` means the
model must call one or more tools. Specifying a particular tool
via {"type": "function", "function": {"name": "my_function"}}
forces the model to call that tool. :obj:`"none"` is the default
when no tools are present. :obj:`"auto"` is the default if tools
are present.
"""
temperature: float = 0.2 # openai default: 1.0
top_p: float = 1.0
n: int = 1
stream: bool = False
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
max_tokens: Union[int, NotGiven] = NOT_GIVEN
presence_penalty: float = 0.0
response_format: Union[dict, NotGiven] = NOT_GIVEN
frequency_penalty: float = 0.0
user: str = ""
tool_choice: Optional[Union[dict[str, str], str]] = "auto"
GROQ_API_PARAMS = {param for param in GroqConfig.model_fields.keys()}

View File

@@ -1,97 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import List, Optional, Union
from camel.configs.base_config import BaseConfig
class LiteLLMConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
LiteLLM API.
Args:
timeout (Optional[Union[float, str]], optional): Request timeout.
(default: None)
temperature (Optional[float], optional): Temperature parameter for
controlling randomness. (default: None)
top_p (Optional[float], optional): Top-p parameter for nucleus
sampling. (default: None)
n (Optional[int], optional): Number of completions to generate.
(default: None)
stream (Optional[bool], optional): Whether to return a streaming
response. (default: None)
stream_options (Optional[dict], optional): Options for the streaming
response. (default: None)
stop (Optional[Union[str, List[str]]], optional): Sequences where the
API will stop generating further tokens. (default: None)
max_tokens (Optional[int], optional): Maximum number of tokens to
generate. (default: None)
presence_penalty (Optional[float], optional): Penalize new tokens
based on their existence in the text so far. (default: None)
frequency_penalty (Optional[float], optional): Penalize new tokens
based on their frequency in the text so far. (default: None)
logit_bias (Optional[dict], optional): Modify the probability of
specific tokens appearing in the completion. (default: None)
user (Optional[str], optional): A unique identifier representing the
end-user. (default: None)
response_format (Optional[dict], optional): Response format
parameters. (default: None)
seed (Optional[int], optional): Random seed. (default: None)
tools (Optional[List], optional): List of tools. (default: None)
tool_choice (Optional[Union[str, dict]], optional): Tool choice
parameters. (default: None)
logprobs (Optional[bool], optional): Whether to return log
probabilities of the output tokens. (default: None)
top_logprobs (Optional[int], optional): Number of most likely tokens
to return at each token position. (default: None)
deployment_id (Optional[str], optional): Deployment ID. (default: None)
extra_headers (Optional[dict], optional): Additional headers for the
request. (default: None)
api_version (Optional[str], optional): API version. (default: None)
mock_response (Optional[str], optional): Mock completion response for
testing or debugging. (default: None)
custom_llm_provider (Optional[str], optional): Non-OpenAI LLM
provider. (default: None)
max_retries (Optional[int], optional): Maximum number of retries.
(default: None)
"""
timeout: Optional[Union[float, str]] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
n: Optional[int] = None
stream: Optional[bool] = None
stream_options: Optional[dict] = None
stop: Optional[Union[str, List[str]]] = None
max_tokens: Optional[int] = None
presence_penalty: Optional[float] = None
frequency_penalty: Optional[float] = None
logit_bias: Optional[dict] = None
user: Optional[str] = None
response_format: Optional[dict] = None
seed: Optional[int] = None
tool_choice: Optional[Union[str, dict]] = None
logprobs: Optional[bool] = None
top_logprobs: Optional[int] = None
deployment_id: Optional[str] = None
extra_headers: Optional[dict] = None
api_version: Optional[str] = None
mock_response: Optional[str] = None
custom_llm_provider: Optional[str] = None
max_retries: Optional[int] = None
LITELLM_API_PARAMS = {param for param in LiteLLMConfig.model_fields.keys()}

View File

@@ -1,79 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import Any, Dict, Optional, Union
from pydantic import field_validator
from camel.configs.base_config import BaseConfig
class MistralConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
Mistral API.
reference: https://github.com/mistralai/client-python/blob/9d238f88c41689821d7b08570f13b43426f97fd6/src/mistralai/client.py#L195
#TODO: Support stream mode
Args:
temperature (Optional[float], optional): temperature the temperature
to use for sampling, e.g. 0.5.
top_p (Optional[float], optional): the cumulative probability of
tokens to generate, e.g. 0.9. Defaults to None.
max_tokens (Optional[int], optional): the maximum number of tokens to
generate, e.g. 100. Defaults to None.
stop (Optional[Union[str,list[str]]]): Stop generation if this token
is detected. Or if one of these tokens is detected when providing
a string list.
random_seed (Optional[int], optional): the random seed to use for
sampling, e.g. 42. Defaults to None.
safe_prompt (bool, optional): whether to use safe prompt, e.g. true.
Defaults to False.
response_format (Union[Dict[str, str], ResponseFormat): format of the
response.
tool_choice (str, optional): Controls which (if
any) tool is called by the model. :obj:`"none"` means the model
will not call any tool and instead generates a message.
:obj:`"auto"` means the model can pick between generating a
message or calling one or more tools. :obj:`"any"` means the
model must call one or more tools. :obj:`"auto"` is the default
value.
"""
temperature: Optional[float] = None
top_p: Optional[float] = None
max_tokens: Optional[int] = None
stop: Optional[Union[str, list[str]]] = None
random_seed: Optional[int] = None
safe_prompt: bool = False
response_format: Optional[Union[Dict[str, str], Any]] = None
tool_choice: Optional[str] = "auto"
@field_validator("response_format", mode="before")
@classmethod
def fields_type_checking(cls, response_format):
if response_format and not isinstance(response_format, dict):
from mistralai.models import ResponseFormat
if not isinstance(response_format, ResponseFormat):
raise ValueError(
f"The tool {response_format} should be an instance "
"of `mistralai.models.ResponseFormat`."
)
return response_format
MISTRAL_API_PARAMS = {param for param in MistralConfig().model_fields.keys()}

View File

@@ -1,70 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import List, Optional, Union
from pydantic import Field
from camel.configs.base_config import BaseConfig
from camel.types import NOT_GIVEN, NotGiven
class NvidiaConfig(BaseConfig):
r"""Configuration class for NVIDIA API models.
This class defines the configuration parameters for NVIDIA's language
models, including temperature, sampling parameters, and response format
settings.
Args:
stream (bool, optional): Whether to stream the response.
(default: :obj:`False`)
temperature (float, optional): Controls randomness in the response.
Higher values make output more random, lower values make it more
deterministic. Range: [0.0, 2.0]. (default: :obj:`0.7`)
top_p (float, optional): Controls diversity via nucleus sampling.
Range: [0.0, 1.0]. (default: :obj:`0.95`)
presence_penalty (float, optional): Penalizes new tokens based on
whether they appear in the text so far. Range: [-2.0, 2.0].
(default: :obj:`0.0`)
frequency_penalty (float, optional): Penalizes new tokens based on
their frequency in the text so far. Range: [-2.0, 2.0].
(default: :obj:`0.0`)
max_tokens (Union[int, NotGiven], optional): Maximum number of tokens
to generate. If not provided, model will use its default maximum.
(default: :obj:`NOT_GIVEN`)
seed (Optional[int], optional): Random seed for deterministic sampling.
(default: :obj:`None`)
tools (Optional[List[Dict]], optional): List of tools available to the
model. This includes tools such as a text editor, a calculator, or
a search engine. (default: :obj:`None`)
tool_choice (Optional[str], optional): Tool choice configuration.
(default: :obj:`None`)
stop (Optional[List[str]], optional): List of stop sequences.
(default: :obj:`None`)
"""
stream: bool = Field(default=False)
temperature: float = Field(default=0.7)
top_p: float = Field(default=0.95)
presence_penalty: float = Field(default=0.0)
frequency_penalty: float = Field(default=0.0)
max_tokens: Union[int, NotGiven] = Field(default=NOT_GIVEN)
seed: Optional[int] = Field(default=None)
tool_choice: Optional[str] = Field(default=None)
stop: Optional[List[str]] = Field(default=None)
NVIDIA_API_PARAMS = {param for param in NvidiaConfig.model_fields.keys()}

View File

@@ -1,82 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import Sequence, Union
from camel.configs.base_config import BaseConfig
from camel.types import NOT_GIVEN, NotGiven
class OllamaConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using OpenAI
compatibility
Reference: https://github.com/ollama/ollama/blob/main/docs/openai.md
Args:
temperature (float, optional): Sampling temperature to use, between
:obj:`0` and :obj:`2`. Higher values make the output more random,
while lower values make it more focused and deterministic.
(default: :obj:`0.2`)
top_p (float, optional): An alternative to sampling with temperature,
called nucleus sampling, where the model considers the results of
the tokens with top_p probability mass. So :obj:`0.1` means only
the tokens comprising the top 10% probability mass are considered.
(default: :obj:`1.0`)
response_format (object, optional): An object specifying the format
that the model must output. Compatible with GPT-4 Turbo and all
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
{"type": "json_object"} enables JSON mode, which guarantees the
message the model generates is valid JSON. Important: when using
JSON mode, you must also instruct the model to produce JSON
yourself via a system or user message. Without this, the model
may generate an unending stream of whitespace until the generation
reaches the token limit, resulting in a long-running and seemingly
"stuck" request. Also note that the message content may be
partially cut off if finish_reason="length", which indicates the
generation exceeded max_tokens or the conversation exceeded the
max context length.
stream (bool, optional): If True, partial message deltas will be sent
as data-only server-sent events as they become available.
(default: :obj:`False`)
stop (str or list, optional): Up to :obj:`4` sequences where the API
will stop generating further tokens. (default: :obj:`None`)
max_tokens (int, optional): The maximum number of tokens to generate
in the chat completion. The total length of input tokens and
generated tokens is limited by the model's context length.
(default: :obj:`None`)
presence_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on whether
they appear in the text so far, increasing the model's likelihood
to talk about new topics. See more information about frequency and
presence penalties. (default: :obj:`0.0`)
frequency_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on their
existing frequency in the text so far, decreasing the model's
likelihood to repeat the same line verbatim. See more information
about frequency and presence penalties. (default: :obj:`0.0`)
"""
temperature: float = 0.2
top_p: float = 1.0
stream: bool = False
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
max_tokens: Union[int, NotGiven] = NOT_GIVEN
presence_penalty: float = 0.0
response_format: Union[dict, NotGiven] = NOT_GIVEN
frequency_penalty: float = 0.0
OLLAMA_API_PARAMS = {param for param in OllamaConfig.model_fields.keys()}

View File

@@ -1,139 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import Any, Optional, Sequence, Type, Union
from pydantic import BaseModel, Field
from camel.configs.base_config import BaseConfig
from camel.types import NOT_GIVEN, NotGiven
class ChatGPTConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
OpenAI API.
Args:
temperature (float, optional): Sampling temperature to use, between
:obj:`0` and :obj:`2`. Higher values make the output more random,
while lower values make it more focused and deterministic.
(default: :obj:`0.2`)
top_p (float, optional): An alternative to sampling with temperature,
called nucleus sampling, where the model considers the results of
the tokens with top_p probability mass. So :obj:`0.1` means only
the tokens comprising the top 10% probability mass are considered.
(default: :obj:`1.0`)
n (int, optional): How many chat completion choices to generate for
each input message. (default: :obj:`1`)
response_format (object, optional): An object specifying the format
that the model must output. Compatible with GPT-4 Turbo and all
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
{"type": "json_object"} enables JSON mode, which guarantees the
message the model generates is valid JSON. Important: when using
JSON mode, you must also instruct the model to produce JSON
yourself via a system or user message. Without this, the model
may generate an unending stream of whitespace until the generation
reaches the token limit, resulting in a long-running and seemingly
"stuck" request. Also note that the message content may be
partially cut off if finish_reason="length", which indicates the
generation exceeded max_tokens or the conversation exceeded the
max context length.
stream (bool, optional): If True, partial message deltas will be sent
as data-only server-sent events as they become available.
(default: :obj:`False`)
stop (str or list, optional): Up to :obj:`4` sequences where the API
will stop generating further tokens. (default: :obj:`None`)
max_tokens (int, optional): The maximum number of tokens to generate
in the chat completion. The total length of input tokens and
generated tokens is limited by the model's context length.
(default: :obj:`None`)
presence_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on whether
they appear in the text so far, increasing the model's likelihood
to talk about new topics. See more information about frequency and
presence penalties. (default: :obj:`0.0`)
frequency_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on their
existing frequency in the text so far, decreasing the model's
likelihood to repeat the same line verbatim. See more information
about frequency and presence penalties. (default: :obj:`0.0`)
logit_bias (dict, optional): Modify the likelihood of specified tokens
appearing in the completion. Accepts a json object that maps tokens
(specified by their token ID in the tokenizer) to an associated
bias value from :obj:`-100` to :obj:`100`. Mathematically, the bias
is added to the logits generated by the model prior to sampling.
The exact effect will vary per model, but values between:obj:` -1`
and :obj:`1` should decrease or increase likelihood of selection;
values like :obj:`-100` or :obj:`100` should result in a ban or
exclusive selection of the relevant token. (default: :obj:`{}`)
user (str, optional): A unique identifier representing your end-user,
which can help OpenAI to monitor and detect abuse.
(default: :obj:`""`)
tools (list[FunctionTool], optional): A list of tools the model may
call. Currently, only functions are supported as a tool. Use this
to provide a list of functions the model may generate JSON inputs
for. A max of 128 functions are supported.
tool_choice (Union[dict[str, str], str], optional): Controls which (if
any) tool is called by the model. :obj:`"none"` means the model
will not call any tool and instead generates a message.
:obj:`"auto"` means the model can pick between generating a
message or calling one or more tools. :obj:`"required"` means the
model must call one or more tools. Specifying a particular tool
via {"type": "function", "function": {"name": "my_function"}}
forces the model to call that tool. :obj:`"none"` is the default
when no tools are present. :obj:`"auto"` is the default if tools
are present.
"""
temperature: float = 0.2 # openai default: 1.0
top_p: float = 1.0
n: int = 1
stream: bool = False
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
max_tokens: Union[int, NotGiven] = NOT_GIVEN
presence_penalty: float = 0.0
response_format: Union[Type[BaseModel], dict, NotGiven] = NOT_GIVEN
frequency_penalty: float = 0.0
logit_bias: dict = Field(default_factory=dict)
user: str = ""
tool_choice: Optional[Union[dict[str, str], str]] = None
def as_dict(self) -> dict[str, Any]:
r"""Convert the current configuration to a dictionary.
This method converts the current configuration object to a dictionary
representation, which can be used for serialization or other purposes.
Returns:
dict[str, Any]: A dictionary representation of the current
configuration.
"""
config_dict = self.model_dump()
if self.tools:
from camel.toolkits import FunctionTool
tools_schema = []
for tool in self.tools:
if not isinstance(tool, FunctionTool):
raise ValueError(
f"The tool {tool} should "
"be an instance of `FunctionTool`."
)
tools_schema.append(tool.get_openai_tool_schema())
config_dict["tools"] = NOT_GIVEN
return config_dict
OPENAI_API_PARAMS = {param for param in ChatGPTConfig.model_fields.keys()}

View File

@@ -1,91 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import ClassVar, Optional, Union
from camel.configs.base_config import BaseConfig
from camel.types import NOT_GIVEN, NotGiven
class QwenConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
Qwen API. You can refer to the following link for more details:
https://help.aliyun.com/zh/model-studio/developer-reference/use-qwen-by-calling-api
Args:
stream (bool, optional): Whether to stream the response.
(default: :obj:`False`)
temperature (float, optional): Controls the diversity and focus of
the generated results. Lower values make the output more focused,
while higher values make it more diverse. (default: :obj:`0.3`)
top_p (float, optional): Controls the diversity and focus of the
generated results. Higher values make the output more diverse,
while lower values make it more focused. (default: :obj:`0.9`)
presence_penalty (float, optional): Controls the repetition of
content in the generated results. Positive values reduce the
repetition of content, while negative values increase it.
(default: :obj:`0.0`)
response_format (object, optional): Specifies the format of the
returned content. The available values are `{"type": "text"}` or
`{"type": "json_object"}`. Setting it to `{"type": "json_object"}`
will output a standard JSON string.
(default: :obj:`{"type": "text"}`)
max_tokens (Union[int, NotGiven], optional): Allows the model to
generate the maximum number of tokens.
(default: :obj:`NOT_GIVEN`)
seed (int, optional): Sets the seed parameter to make the text
generation process more deterministic, typically used to ensure
that the results are consistent across model runs. By passing the
same seed value (specified by you) in each model call while
keeping other parameters unchanged, the model is likely to return
the same result.
(default: :obj:`None`)
stop (str or list, optional): Using the stop parameter, the model will
automatically stop generating text when it is about to include the
specified string or token_id. You can use the stop parameter to
control the output of the model by passing sensitive words.
(default: :obj:`None`)
tools (list, optional): Specifies an array of tools that the model can
call. It can contain one or more tool objects. During a function
call process, the model will select one tool from the array.
(default: :obj:`None`)
extra_body (dict, optional): Additional parameters to be sent to the
Qwen API. If you want to enable internet search, you can set this
parameter to `{"enable_search": True}`.
(default: :obj:`{"enable_search": False}`)
include_usage (bool, optional): When streaming, specifies whether to
include usage information in `stream_options`. (default:
:obj:`True`)
"""
stream: bool = False
temperature: float = 0.3
top_p: float = 0.9
presence_penalty: float = 0.0
response_format: ClassVar[dict] = {"type": "text"}
max_tokens: Union[int, NotGiven] = NOT_GIVEN
seed: Optional[int] = None
stop: Optional[Union[str, list]] = None
extra_body: ClassVar[dict] = {"enable_search": False}
def __init__(self, include_usage: bool = True, **kwargs):
super().__init__(**kwargs)
# Only set stream_options when stream is True
# Otherwise, it will raise error when calling the API
if self.stream:
self.stream_options = {"include_usage": include_usage}
QWEN_API_PARAMS = {param for param in QwenConfig.model_fields.keys()}

View File

@@ -1,74 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import Any, Optional, Union
from camel.configs.base_config import BaseConfig
class RekaConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
Reka API.
Reference: https://docs.reka.ai/api-reference/chat/create
Args:
temperature (Optional[float], optional): temperature the temperature
to use for sampling, e.g. 0.5.
top_p (Optional[float], optional): the cumulative probability of
tokens to generate, e.g. 0.9. Defaults to None.
top_k (Optional[int], optional): Parameter which forces the model to
only consider the tokens with the `top_k` highest probabilities at
the next step. Defaults to 1024.
max_tokens (Optional[int], optional): the maximum number of tokens to
generate, e.g. 100. Defaults to None.
stop (Optional[Union[str,list[str]]]): Stop generation if this token
is detected. Or if one of these tokens is detected when providing
a string list.
seed (Optional[int], optional): the random seed to use for sampling, e.
g. 42. Defaults to None.
presence_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on whether
they appear in the text so far, increasing the model's likelihood
to talk about new topics. See more information about frequency and
presence penalties. (default: :obj:`0.0`)
frequency_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on their
existing frequency in the text so far, decreasing the model's
likelihood to repeat the same line verbatim. See more information
about frequency and presence penalties. (default: :obj:`0.0`)
use_search_engine (Optional[bool]): Whether to consider using search
engine to complete the request. Note that even if this is set to
`True`, the model might decide to not use search.
"""
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
max_tokens: Optional[int] = None
stop: Optional[Union[str, list[str]]] = None
seed: Optional[int] = None
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
use_search_engine: Optional[bool] = False
def as_dict(self) -> dict[str, Any]:
config_dict = super().as_dict()
if "tools" in config_dict:
del config_dict["tools"] # Reka does not support tool calling
return config_dict
REKA_API_PARAMS = {param for param in RekaConfig().model_fields.keys()}

View File

@@ -1,170 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import Any, Optional, Sequence, Union
from pydantic import Field
from camel.configs.base_config import BaseConfig
from camel.types import NOT_GIVEN, NotGiven
class SambaVerseAPIConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
SambaVerse API.
Args:
temperature (float, optional): Sampling temperature to use, between
:obj:`0` and :obj:`2`. Higher values make the output more random,
while lower values make it more focused and deterministic.
(default: :obj:`0.7`)
top_p (float, optional): An alternative to sampling with temperature,
called nucleus sampling, where the model considers the results of
the tokens with top_p probability mass. So :obj:`0.1` means only
the tokens comprising the top 10% probability mass are considered.
(default: :obj:`0.95`)
top_k (int, optional): Only sample from the top K options for each
subsequent token. Used to remove "long tail" low probability
responses.
(default: :obj:`50`)
max_tokens (Optional[int], optional): The maximum number of tokens to
generate, e.g. 100.
(default: :obj:`2048`)
repetition_penalty (Optional[float], optional): The parameter for
repetition penalty. 1.0 means no penalty.
(default: :obj:`1.0`)
stop (Optional[Union[str,list[str]]]): Stop generation if this token
is detected. Or if one of these tokens is detected when providing
a string list.
(default: :obj:`""`)
stream (Optional[bool]): If True, partial message deltas will be sent
as data-only server-sent events as they become available.
Currently SambaVerse API doesn't support stream mode.
(default: :obj:`False`)
"""
temperature: Optional[float] = 0.7
top_p: Optional[float] = 0.95
top_k: Optional[int] = 50
max_tokens: Optional[int] = 2048
repetition_penalty: Optional[float] = 1.0
stop: Optional[Union[str, list[str]]] = ""
stream: Optional[bool] = False
def as_dict(self) -> dict[str, Any]:
config_dict = super().as_dict()
if "tools" in config_dict:
del config_dict["tools"] # SambaNova does not support tool calling
return config_dict
SAMBA_VERSE_API_PARAMS = {
param for param in SambaVerseAPIConfig().model_fields.keys()
}
class SambaCloudAPIConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
OpenAI API.
Args:
temperature (float, optional): Sampling temperature to use, between
:obj:`0` and :obj:`2`. Higher values make the output more random,
while lower values make it more focused and deterministic.
(default: :obj:`0.2`)
top_p (float, optional): An alternative to sampling with temperature,
called nucleus sampling, where the model considers the results of
the tokens with top_p probability mass. So :obj:`0.1` means only
the tokens comprising the top 10% probability mass are considered.
(default: :obj:`1.0`)
n (int, optional): How many chat completion choices to generate for
each input message. (default: :obj:`1`)
response_format (object, optional): An object specifying the format
that the model must output. Compatible with GPT-4 Turbo and all
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
{"type": "json_object"} enables JSON mode, which guarantees the
message the model generates is valid JSON. Important: when using
JSON mode, you must also instruct the model to produce JSON
yourself via a system or user message. Without this, the model
may generate an unending stream of whitespace until the generation
reaches the token limit, resulting in a long-running and seemingly
"stuck" request. Also note that the message content may be
partially cut off if finish_reason="length", which indicates the
generation exceeded max_tokens or the conversation exceeded the
max context length.
stream (bool, optional): If True, partial message deltas will be sent
as data-only server-sent events as they become available.
(default: :obj:`False`)
stop (str or list, optional): Up to :obj:`4` sequences where the API
will stop generating further tokens. (default: :obj:`None`)
max_tokens (int, optional): The maximum number of tokens to generate
in the chat completion. The total length of input tokens and
generated tokens is limited by the model's context length.
(default: :obj:`None`)
presence_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on whether
they appear in the text so far, increasing the model's likelihood
to talk about new topics. See more information about frequency and
presence penalties. (default: :obj:`0.0`)
frequency_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on their
existing frequency in the text so far, decreasing the model's
likelihood to repeat the same line verbatim. See more information
about frequency and presence penalties. (default: :obj:`0.0`)
logit_bias (dict, optional): Modify the likelihood of specified tokens
appearing in the completion. Accepts a json object that maps tokens
(specified by their token ID in the tokenizer) to an associated
bias value from :obj:`-100` to :obj:`100`. Mathematically, the bias
is added to the logits generated by the model prior to sampling.
The exact effect will vary per model, but values between:obj:` -1`
and :obj:`1` should decrease or increase likelihood of selection;
values like :obj:`-100` or :obj:`100` should result in a ban or
exclusive selection of the relevant token. (default: :obj:`{}`)
user (str, optional): A unique identifier representing your end-user,
which can help OpenAI to monitor and detect abuse.
(default: :obj:`""`)
tools (list[FunctionTool], optional): A list of tools the model may
call. Currently, only functions are supported as a tool. Use this
to provide a list of functions the model may generate JSON inputs
for. A max of 128 functions are supported.
tool_choice (Union[dict[str, str], str], optional): Controls which (if
any) tool is called by the model. :obj:`"none"` means the model
will not call any tool and instead generates a message.
:obj:`"auto"` means the model can pick between generating a
message or calling one or more tools. :obj:`"required"` means the
model must call one or more tools. Specifying a particular tool
via {"type": "function", "function": {"name": "my_function"}}
forces the model to call that tool. :obj:`"none"` is the default
when no tools are present. :obj:`"auto"` is the default if tools
are present.
"""
temperature: float = 0.2 # openai default: 1.0
top_p: float = 1.0
n: int = 1
stream: bool = False
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
max_tokens: Union[int, NotGiven] = NOT_GIVEN
presence_penalty: float = 0.0
response_format: Union[dict, NotGiven] = NOT_GIVEN
frequency_penalty: float = 0.0
logit_bias: dict = Field(default_factory=dict)
user: str = ""
tool_choice: Optional[Union[dict[str, str], str]] = None
SAMBA_CLOUD_API_PARAMS = {
param for param in SambaCloudAPIConfig().model_fields.keys()
}

View File

@@ -1,107 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import Any, Sequence, Union
from pydantic import Field
from camel.configs.base_config import BaseConfig
from camel.types import NOT_GIVEN, NotGiven
class TogetherAIConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
OpenAI API.
Args:
temperature (float, optional): Sampling temperature to use, between
:obj:`0` and :obj:`2`. Higher values make the output more random,
while lower values make it more focused and deterministic.
(default: :obj:`0.2`)
top_p (float, optional): An alternative to sampling with temperature,
called nucleus sampling, where the model considers the results of
the tokens with top_p probability mass. So :obj:`0.1` means only
the tokens comprising the top 10% probability mass are considered.
(default: :obj:`1.0`)
n (int, optional): How many chat completion choices to generate for
each input message. (default: :obj:`1`)
response_format (object, optional): An object specifying the format
that the model must output. Compatible with GPT-4 Turbo and all
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
{"type": "json_object"} enables JSON mode, which guarantees the
message the model generates is valid JSON. Important: when using
JSON mode, you must also instruct the model to produce JSON
yourself via a system or user message. Without this, the model
may generate an unending stream of whitespace until the generation
reaches the token limit, resulting in a long-running and seemingly
"stuck" request. Also note that the message content may be
partially cut off if finish_reason="length", which indicates the
generation exceeded max_tokens or the conversation exceeded the
max context length.
stream (bool, optional): If True, partial message deltas will be sent
as data-only server-sent events as they become available.
(default: :obj:`False`)
stop (str or list, optional): Up to :obj:`4` sequences where the API
will stop generating further tokens. (default: :obj:`None`)
max_tokens (int, optional): The maximum number of tokens to generate
in the chat completion. The total length of input tokens and
generated tokens is limited by the model's context length.
(default: :obj:`None`)
presence_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on whether
they appear in the text so far, increasing the model's likelihood
to talk about new topics. See more information about frequency and
presence penalties. (default: :obj:`0.0`)
frequency_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on their
existing frequency in the text so far, decreasing the model's
likelihood to repeat the same line verbatim. See more information
about frequency and presence penalties. (default: :obj:`0.0`)
logit_bias (dict, optional): Modify the likelihood of specified tokens
appearing in the completion. Accepts a json object that maps tokens
(specified by their token ID in the tokenizer) to an associated
bias value from :obj:`-100` to :obj:`100`. Mathematically, the bias
is added to the logits generated by the model prior to sampling.
The exact effect will vary per model, but values between:obj:` -1`
and :obj:`1` should decrease or increase likelihood of selection;
values like :obj:`-100` or :obj:`100` should result in a ban or
exclusive selection of the relevant token. (default: :obj:`{}`)
user (str, optional): A unique identifier representing your end-user,
which can help OpenAI to monitor and detect abuse.
(default: :obj:`""`)
"""
temperature: float = 0.2 # openai default: 1.0
top_p: float = 1.0
n: int = 1
stream: bool = False
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
max_tokens: Union[int, NotGiven] = NOT_GIVEN
presence_penalty: float = 0.0
response_format: Union[dict, NotGiven] = NOT_GIVEN
frequency_penalty: float = 0.0
logit_bias: dict = Field(default_factory=dict)
user: str = ""
def as_dict(self) -> dict[str, Any]:
config_dict = super().as_dict()
if "tools" in config_dict:
del config_dict["tools"] # Currently does not support tool calling
return config_dict
TOGETHERAI_API_PARAMS = {
param for param in TogetherAIConfig.model_fields.keys()
}

View File

@@ -1,111 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import Optional, Sequence, Union
from pydantic import Field
from camel.configs.base_config import BaseConfig
from camel.types import NOT_GIVEN, NotGiven
# flake8: noqa: E501
class VLLMConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
OpenAI API.
Reference: https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html
Args:
temperature (float, optional): Sampling temperature to use, between
:obj:`0` and :obj:`2`. Higher values make the output more random,
while lower values make it more focused and deterministic.
(default: :obj:`0.2`)
top_p (float, optional): An alternative to sampling with temperature,
called nucleus sampling, where the model considers the results of
the tokens with top_p probability mass. So :obj:`0.1` means only
the tokens comprising the top 10% probability mass are considered.
(default: :obj:`1.0`)
n (int, optional): How many chat completion choices to generate for
each input message. (default: :obj:`1`)
response_format (object, optional): An object specifying the format
that the model must output. Compatible with GPT-4 Turbo and all
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
{"type": "json_object"} enables JSON mode, which guarantees the
message the model generates is valid JSON. Important: when using
JSON mode, you must also instruct the model to produce JSON
yourself via a system or user message. Without this, the model
may generate an unending stream of whitespace until the generation
reaches the token limit, resulting in a long-running and seemingly
"stuck" request. Also note that the message content may be
partially cut off if finish_reason="length", which indicates the
generation exceeded max_tokens or the conversation exceeded the
max context length.
stream (bool, optional): If True, partial message deltas will be sent
as data-only server-sent events as they become available.
(default: :obj:`False`)
stop (str or list, optional): Up to :obj:`4` sequences where the API
will stop generating further tokens. (default: :obj:`None`)
max_tokens (int, optional): The maximum number of tokens to generate
in the chat completion. The total length of input tokens and
generated tokens is limited by the model's context length.
(default: :obj:`None`)
presence_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on whether
they appear in the text so far, increasing the model's likelihood
to talk about new topics. See more information about frequency and
presence penalties. (default: :obj:`0.0`)
frequency_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on their
existing frequency in the text so far, decreasing the model's
likelihood to repeat the same line verbatim. See more information
about frequency and presence penalties. (default: :obj:`0.0`)
logit_bias (dict, optional): Modify the likelihood of specified tokens
appearing in the completion. Accepts a json object that maps tokens
(specified by their token ID in the tokenizer) to an associated
bias value from :obj:`-100` to :obj:`100`. Mathematically, the bias
is added to the logits generated by the model prior to sampling.
The exact effect will vary per model, but values between:obj:` -1`
and :obj:`1` should decrease or increase likelihood of selection;
values like :obj:`-100` or :obj:`100` should result in a ban or
exclusive selection of the relevant token. (default: :obj:`{}`)
user (str, optional): A unique identifier representing your end-user,
which can help OpenAI to monitor and detect abuse.
(default: :obj:`""`)
logprobs: Whether to return log probabilities of the output tokens or
not. If true, returns the log probabilities of each output token
returned in the `logits` of `message`. (default: :obj:`None`)
top_logprobs: An integer between 0 and 20 specifying the number of
most likely tokens to return at each token position, each with an
associated log probability. `logprobs` must be set to `true` if
this parameter is used. (default: :obj:`None`)
"""
temperature: float = 0.2 # openai default: 1.0
top_p: float = 1.0
n: int = 1
stream: bool = False
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
max_tokens: Union[int, NotGiven] = NOT_GIVEN
presence_penalty: float = 0.0
response_format: Union[dict, NotGiven] = NOT_GIVEN
frequency_penalty: float = 0.0
logit_bias: dict = Field(default_factory=dict)
user: str = ""
logprobs: Optional[bool] = None
top_logprobs: Optional[int] = None
VLLM_API_PARAMS = {param for param in VLLMConfig.model_fields.keys()}

View File

@@ -1,58 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import Optional, Union
from camel.configs.base_config import BaseConfig
from camel.types import NOT_GIVEN, NotGiven
class YiConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
Yi API. You can refer to the following link for more details:
https://platform.lingyiwanwu.com/docs/api-reference
Args:
tool_choice (Union[dict[str, str], str], optional): Controls which (if
any) tool is called by the model. :obj:`"none"` means the model
will not call any tool and instead generates a message.
:obj:`"auto"` means the model can pick between generating a
message or calling one or more tools. :obj:`"required"` or
specifying a particular tool via
{"type": "function", "function": {"name": "some_function"}}
can be used to guide the model to use tools more strongly.
(default: :obj:`None`)
max_tokens (int, optional): Specifies the maximum number of tokens
the model can generate. This sets an upper limit, but does not
guarantee that this number will always be reached.
(default: :obj:`5000`)
top_p (float, optional): Controls the randomness of the generated
results. Lower values lead to less randomness, while higher
values increase randomness. (default: :obj:`0.9`)
temperature (float, optional): Controls the diversity and focus of
the generated results. Lower values make the output more focused,
while higher values make it more diverse. (default: :obj:`0.3`)
stream (bool, optional): If True, enables streaming output.
(default: :obj:`False`)
"""
tool_choice: Optional[Union[dict[str, str], str]] = None
max_tokens: Union[int, NotGiven] = NOT_GIVEN
top_p: float = 0.9
temperature: float = 0.3
stream: bool = False
YI_API_PARAMS = {param for param in YiConfig.model_fields.keys()}

View File

@@ -1,71 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import Optional, Sequence, Union
from camel.configs.base_config import BaseConfig
from camel.types import NOT_GIVEN, NotGiven
class ZhipuAIConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using OpenAI
compatibility
Reference: https://open.bigmodel.cn/dev/api#glm-4v
Args:
temperature (float, optional): Sampling temperature to use, between
:obj:`0` and :obj:`2`. Higher values make the output more random,
while lower values make it more focused and deterministic.
(default: :obj:`0.2`)
top_p (float, optional): An alternative to sampling with temperature,
called nucleus sampling, where the model considers the results of
the tokens with top_p probability mass. So :obj:`0.1` means only
the tokens comprising the top 10% probability mass are considered.
(default: :obj:`0.6`)
stream (bool, optional): If True, partial message deltas will be sent
as data-only server-sent events as they become available.
(default: :obj:`False`)
stop (str or list, optional): Up to :obj:`4` sequences where the API
will stop generating further tokens. (default: :obj:`None`)
max_tokens (int, optional): The maximum number of tokens to generate
in the chat completion. The total length of input tokens and
generated tokens is limited by the model's context length.
(default: :obj:`None`)
tools (list[FunctionTool], optional): A list of tools the model may
call. Currently, only functions are supported as a tool. Use this
to provide a list of functions the model may generate JSON inputs
for. A max of 128 functions are supported.
tool_choice (Union[dict[str, str], str], optional): Controls which (if
any) tool is called by the model. :obj:`"none"` means the model
will not call any tool and instead generates a message.
:obj:`"auto"` means the model can pick between generating a
message or calling one or more tools. :obj:`"required"` means the
model must call one or more tools. Specifying a particular tool
via {"type": "function", "function": {"name": "my_function"}}
forces the model to call that tool. :obj:`"none"` is the default
when no tools are present. :obj:`"auto"` is the default if tools
are present.
"""
temperature: float = 0.2
top_p: float = 0.6
stream: bool = False
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
max_tokens: Union[int, NotGiven] = NOT_GIVEN
tool_choice: Optional[Union[dict[str, str], str]] = None
ZHIPUAI_API_PARAMS = {param for param in ZhipuAIConfig.model_fields.keys()}

View File

@@ -1,23 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from .base import BaseDatasetManager
from .huggingface import HuggingFaceDatasetManager
from .models import Record
__all__ = [
"BaseDatasetManager",
"Record",
"HuggingFaceDatasetManager",
]

View File

@@ -1,136 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from abc import ABC, abstractmethod
from typing import Any, List
from camel.datahubs.models import Record
class BaseDatasetManager(ABC):
r"""Abstract base class for dataset managers."""
@abstractmethod
def create_dataset(self, name: str, **kwargs: Any) -> str:
r"""Creates a new dataset.
Args:
name (str): The name of the dataset.
kwargs (Any): Additional keyword arguments.
Returns:
str: The URL of the created dataset.
"""
pass
@abstractmethod
def list_datasets(
self, username: str, limit: int = 100, **kwargs: Any
) -> List[str]:
r"""Lists all datasets for the current user.
Args:
username (str): The username of the user whose datasets to list.
limit (int): The maximum number of datasets to list.
(default::obj:`100`)
kwargs (Any): Additional keyword arguments.
Returns:
List[str]: A list of dataset ids.
"""
pass
@abstractmethod
def delete_dataset(self, dataset_name: str, **kwargs: Any) -> None:
r"""Deletes a dataset.
Args:
dataset_name (str): The name of the dataset to delete.
kwargs (Any): Additional keyword arguments.
"""
pass
@abstractmethod
def add_records(
self,
dataset_name: str,
records: List[Record],
filepath: str = "records/records.json",
**kwargs: Any,
) -> None:
r"""Adds records to a dataset.
Args:
dataset_name (str): The name of the dataset.
records (List[Record]): A list of records to add to the dataset.
filepath (str): The path to the file containing the records.
(default::obj:`"records/records.json"`)
kwargs (Any): Additional keyword arguments.
"""
pass
@abstractmethod
def update_records(
self,
dataset_name: str,
records: List[Record],
filepath: str = "records/records.json",
**kwargs: Any,
) -> None:
r"""Updates records in a dataset.
Args:
dataset_name (str): The name of the dataset.
records (List[Record]): A list of records to update in the dataset.
filepath (str): The path to the file containing the records.
(default::obj:`"records/records.json"`)
kwargs (Any): Additional keyword arguments.
"""
pass
@abstractmethod
def list_records(
self,
dataset_name: str,
filepath: str = "records/records.json",
**kwargs: Any,
) -> List[Record]:
r"""Lists records in a dataset.
Args:
dataset_name (str): The name of the dataset.
filepath (str): The path to the file containing the records.
(default::obj:`"records/records.json"`)
kwargs (Any): Additional keyword arguments.
"""
pass
# New method for record deletion
@abstractmethod
def delete_record(
self,
dataset_name: str,
record_id: str,
filepath: str = "records/records.json",
**kwargs: Any,
) -> None:
r"""Deletes a record from the dataset.
Args:
dataset_name (str): The name of the dataset.
record_id (str): The ID of the record to delete.
filepath (str): The path to the file containing the records.
(default::obj:`"records/records.json"`)
kwargs (Any): Additional keyword arguments.
"""
pass

View File

@@ -1,433 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import json
import os
import tempfile
from typing import Any, List, Optional
from camel.datahubs.base import BaseDatasetManager
from camel.datahubs.models import Record
from camel.logger import get_logger
from camel.types import HuggingFaceRepoType
from camel.utils import api_keys_required, dependencies_required
logger = get_logger(__name__)
class HuggingFaceDatasetManager(BaseDatasetManager):
r"""A dataset manager for Hugging Face datasets. This class provides
methods to create, add, update, delete, and list records in a dataset on
the Hugging Face Hub.
Args:
token (str): The Hugging Face API token. If not provided, the token
will be read from the environment variable `HUGGING_FACE_TOKEN`.
"""
@api_keys_required("HUGGING_FACE_TOKEN")
@dependencies_required('huggingface_hub')
def __init__(self, token: Optional[str] = None):
from huggingface_hub import HfApi
self._api_key = token or os.getenv("HUGGING_FACE_TOKEN")
self.api = HfApi(token=self._api_key)
def create_dataset_card(
self,
dataset_name: str,
description: str,
license: Optional[str] = None,
version: Optional[str] = None,
tags: Optional[List[str]] = None,
authors: Optional[List[str]] = None,
size_category: Optional[List[str]] = None,
language: Optional[List[str]] = None,
task_categories: Optional[List[str]] = None,
content: Optional[str] = None,
) -> None:
r"""Creates and uploads a dataset card to the Hugging Face Hub in YAML
format.
Args:
dataset_name (str): The name of the dataset.
description (str): A description of the dataset.
license (str): The license of the dataset. (default: :obj:`None`)
version (str): The version of the dataset. (default: :obj:`None`)
tags (list): A list of tags for the dataset.(default: :obj:`None`)
authors (list): A list of authors of the dataset. (default:
:obj:`None`)
size_category (list): A size category for the dataset. (default:
:obj:`None`)
language (list): A list of languages the dataset is in. (default:
:obj:`None`)
task_categories (list): A list of task categories. (default:
:obj:`None`)
content (str): Custom markdown content that the user wants to add
to the dataset card. (default: :obj:`None`)
"""
import yaml
metadata = {
"license": license,
"authors": authors,
"task_categories": task_categories,
"language": language,
"tags": tags,
"pretty_name": dataset_name,
"size_categories": size_category,
"version": version,
"description": description,
}
# Remove keys with None values
metadata = {k: v for k, v in metadata.items() if v}
card_content = (
"---\n"
+ yaml.dump(metadata, default_flow_style=False, allow_unicode=True)
+ "\n---"
)
if content:
card_content += f"\n\n# Additional Information\n{content}\n"
self._upload_file(
file_content=card_content,
dataset_name=dataset_name,
filepath="README.md",
file_type="md",
)
def create_dataset(
self, name: str, private: bool = False, **kwargs: Any
) -> str:
r"""Creates a new dataset on the Hugging Face Hub.
Args:
name (str): The name of the dataset.
private (bool): Whether the dataset should be private. defaults to
False.
kwargs (Any): Additional keyword arguments.
Returns:
str: The URL of the created dataset.
"""
from huggingface_hub.errors import RepositoryNotFoundError
try:
self.api.repo_info(
repo_id=name,
repo_type=HuggingFaceRepoType.DATASET.value,
**kwargs,
)
except RepositoryNotFoundError:
self.api.create_repo(
repo_id=name,
repo_type=HuggingFaceRepoType.DATASET.value,
private=private,
)
return f"https://huggingface.co/datasets/{name}"
def list_datasets(
self, username: str, limit: int = 100, **kwargs: Any
) -> List[str]:
r"""Lists all datasets for the current user.
Args:
username (str): The username of the user whose datasets to list.
limit (int): The maximum number of datasets to list.
(default: :obj:`100`)
kwargs (Any): Additional keyword arguments.
Returns:
List[str]: A list of dataset ids.
"""
try:
return [
dataset.id
for dataset in self.api.list_datasets(
author=username, limit=limit, **kwargs
)
]
except Exception as e:
logger.error(f"Error listing datasets: {e}")
return []
def delete_dataset(self, dataset_name: str, **kwargs: Any) -> None:
r"""Deletes a dataset from the Hugging Face Hub.
Args:
dataset_name (str): The name of the dataset to delete.
kwargs (Any): Additional keyword arguments.
"""
try:
self.api.delete_repo(
repo_id=dataset_name,
repo_type=HuggingFaceRepoType.DATASET.value,
**kwargs,
)
logger.info(f"Dataset '{dataset_name}' deleted successfully.")
except Exception as e:
logger.error(f"Error deleting dataset '{dataset_name}': {e}")
raise
def add_records(
self,
dataset_name: str,
records: List[Record],
filepath: str = "records/records.json",
**kwargs: Any,
) -> None:
r"""Adds records to a dataset on the Hugging Face Hub.
Args:
dataset_name (str): The name of the dataset.
records (List[Record]): A list of records to add to the dataset.
filepath (str): The path to the file containing the records.
kwargs (Any): Additional keyword arguments.
Raises:
ValueError: If the dataset already has a records file.
"""
existing_records = self._download_records(
dataset_name=dataset_name, filepath=filepath, **kwargs
)
if existing_records:
raise ValueError(
f"Dataset '{filepath}' already exists. "
f"Use `update_records` to modify."
)
self._upload_records(
records=records,
dataset_name=dataset_name,
filepath=filepath,
**kwargs,
)
def update_records(
self,
dataset_name: str,
records: List[Record],
filepath: str = "records/records.json",
**kwargs: Any,
) -> None:
r"""Updates records in a dataset on the Hugging Face Hub.
Args:
dataset_name (str): The name of the dataset.
records (List[Record]): A list of records to update in the dataset.
filepath (str): The path to the file containing the records.
kwargs (Any): Additional keyword arguments.
Raises:
ValueError: If the dataset does not have an existing file to update
records in.
"""
existing_records = self._download_records(
dataset_name=dataset_name, filepath=filepath, **kwargs
)
if not existing_records:
logger.warning(
f"Dataset '{dataset_name}' does not have existing "
"records. Adding new records."
)
self._upload_records(
records=records,
dataset_name=dataset_name,
filepath=filepath,
**kwargs,
)
return
old_dict = {record.id: record for record in existing_records}
new_dict = {record.id: record for record in records}
merged_dict = old_dict.copy()
merged_dict.update(new_dict)
self._upload_records(
records=list(merged_dict.values()),
dataset_name=dataset_name,
filepath=filepath,
**kwargs,
)
def delete_record(
self,
dataset_name: str,
record_id: str,
filepath: str = "records/records.json",
**kwargs: Any,
) -> None:
r"""Deletes a record from the dataset.
Args:
dataset_name (str): The name of the dataset.
record_id (str): The ID of the record to delete.
filepath (str): The path to the file containing the records.
kwargs (Any): Additional keyword arguments.
Raises:
ValueError: If the dataset does not have an existing file to delete
records from.
"""
existing_records = self._download_records(
dataset_name=dataset_name, filepath=filepath, **kwargs
)
if not existing_records:
raise ValueError(
f"Dataset '{dataset_name}' does not have an existing file to "
f"delete records from."
)
filtered_records = [
record for record in existing_records if record.id != record_id
]
self._upload_records(
records=filtered_records,
dataset_name=dataset_name,
filepath=filepath,
**kwargs,
)
def list_records(
self,
dataset_name: str,
filepath: str = "records/records.json",
**kwargs: Any,
) -> List[Record]:
r"""Lists all records in a dataset.
Args:
dataset_name (str): The name of the dataset.
filepath (str): The path to the file containing the records.
kwargs (Any): Additional keyword arguments.
Returns:
List[Record]: A list of records in the dataset.
"""
return self._download_records(
dataset_name=dataset_name, filepath=filepath, **kwargs
)
def _download_records(
self, dataset_name: str, filepath: str, **kwargs: Any
) -> List[Record]:
from huggingface_hub import hf_hub_download
from huggingface_hub.errors import EntryNotFoundError
try:
downloaded_file_path = hf_hub_download(
repo_id=dataset_name,
filename=filepath,
repo_type=HuggingFaceRepoType.DATASET.value,
token=self._api_key,
**kwargs,
)
with open(downloaded_file_path, "r") as f:
records_data = json.load(f)
return [Record(**record) for record in records_data]
except EntryNotFoundError:
logger.info(f"No records found for dataset '{dataset_name}'.")
return []
except Exception as e:
logger.error(f"Error downloading or processing records: {e}")
raise e
def _upload_records(
self,
records: List[Record],
dataset_name: str,
filepath: str,
**kwargs: Any,
):
with tempfile.NamedTemporaryFile(
delete=False, mode="w", newline="", encoding="utf-8"
) as f:
json.dump([record.model_dump() for record in records], f)
temp_file_path = f.name
try:
self.api.upload_file(
path_or_fileobj=temp_file_path,
path_in_repo=filepath,
repo_id=dataset_name,
repo_type=HuggingFaceRepoType.DATASET.value,
**kwargs,
)
except Exception as e:
logger.error(f"Error uploading records file: {e}")
raise
finally:
if os.path.exists(temp_file_path):
os.remove(temp_file_path)
def _upload_file(
self,
file_content: str,
dataset_name: str,
filepath: str,
file_type: str = "json",
**kwargs: Any,
):
with tempfile.NamedTemporaryFile(
mode="w", delete=False, suffix=f".{file_type}"
) as f:
if file_type == "json":
if isinstance(file_content, str):
try:
json_content = json.loads(file_content)
except json.JSONDecodeError:
raise ValueError(
"Invalid JSON string provided for file_content."
)
else:
try:
json.dumps(file_content)
json_content = file_content
except (TypeError, ValueError):
raise ValueError(
"file_content is not JSON serializable."
)
json.dump(json_content, f)
elif file_type == "md" or file_type == "txt":
f.write(file_content)
else:
raise ValueError(f"Unsupported file type: {file_type}")
temp_file_path = f.name
try:
self.api.upload_file(
path_or_fileobj=temp_file_path,
path_in_repo=filepath,
repo_id=dataset_name,
repo_type=HuggingFaceRepoType.DATASET.value,
**kwargs,
)
logger.info(f"File uploaded successfully: {filepath}")
except Exception as e:
logger.error(f"Error uploading file: {e}")
raise
if os.path.exists(temp_file_path):
os.remove(temp_file_path)

View File

@@ -1,22 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from typing import Any, Dict, Optional
from pydantic import BaseModel
class Record(BaseModel):
id: Optional[str] = None
metadata: Optional[Dict[str, Any]] = None
content: Dict[str, Any]

View File

@@ -1,28 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from .base import BaseEmbedding
from .mistral_embedding import MistralEmbedding
from .openai_compatible_embedding import OpenAICompatibleEmbedding
from .openai_embedding import OpenAIEmbedding
from .sentence_transformers_embeddings import SentenceTransformerEncoder
from .vlm_embedding import VisionLanguageEmbedding
__all__ = [
"BaseEmbedding",
"OpenAIEmbedding",
"SentenceTransformerEncoder",
"VisionLanguageEmbedding",
"MistralEmbedding",
"OpenAICompatibleEmbedding",
]

View File

@@ -1,67 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Generic, TypeVar
T = TypeVar('T')
class BaseEmbedding(ABC, Generic[T]):
r"""Abstract base class for text embedding functionalities."""
@abstractmethod
def embed_list(
self,
objs: list[T],
**kwargs: Any,
) -> list[list[float]]:
r"""Generates embeddings for the given texts.
Args:
objs (list[T]): The objects for which to generate the embeddings.
**kwargs (Any): Extra kwargs passed to the embedding API.
Returns:
list[list[float]]: A list that represents the
generated embedding as a list of floating-point numbers.
"""
pass
def embed(
self,
obj: T,
**kwargs: Any,
) -> list[float]:
r"""Generates an embedding for the given text.
Args:
obj (T): The object for which to generate the embedding.
**kwargs (Any): Extra kwargs passed to the embedding API.
Returns:
list[float]: A list of floating-point numbers representing the
generated embedding.
"""
return self.embed_list([obj], **kwargs)[0]
@abstractmethod
def get_output_dim(self) -> int:
r"""Returns the output dimension of the embeddings.
Returns:
int: The dimensionality of the embedding for the current model.
"""
pass

View File

@@ -1,89 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
import os
from typing import Any
from camel.embeddings.base import BaseEmbedding
from camel.types import EmbeddingModelType
from camel.utils import api_keys_required
class MistralEmbedding(BaseEmbedding[str]):
r"""Provides text embedding functionalities using Mistral's models.
Args:
model_type (EmbeddingModelType, optional): The model type to be
used for text embeddings.
(default: :obj:`MISTRAL_EMBED`)
api_key (str, optional): The API key for authenticating with the
Mistral service. (default: :obj:`None`)
dimensions (int, optional): The text embedding output dimensions.
(default: :obj:`None`)
Raises:
RuntimeError: If an unsupported model type is specified.
"""
def __init__(
self,
model_type: EmbeddingModelType = (EmbeddingModelType.MISTRAL_EMBED),
api_key: str | None = None,
dimensions: int | None = None,
) -> None:
from mistralai import Mistral
if not model_type.is_mistral:
raise ValueError("Invalid Mistral embedding model type.")
self.model_type = model_type
if dimensions is None:
self.output_dim = model_type.output_dim
else:
assert isinstance(dimensions, int)
self.output_dim = dimensions
self._api_key = api_key or os.environ.get("MISTRAL_API_KEY")
self._client = Mistral(api_key=self._api_key)
@api_keys_required("MISTRAL_API_KEY")
def embed_list(
self,
objs: list[str],
**kwargs: Any,
) -> list[list[float]]:
r"""Generates embeddings for the given texts.
Args:
objs (list[str]): The texts for which to generate the embeddings.
**kwargs (Any): Extra kwargs passed to the embedding API.
Returns:
list[list[float]]: A list that represents the generated embedding
as a list of floating-point numbers.
"""
# TODO: count tokens
response = self._client.embeddings.create(
inputs=objs,
model=self.model_type.value,
**kwargs,
)
return [data.embedding for data in response.data] # type: ignore[misc,union-attr]
def get_output_dim(self) -> int:
r"""Returns the output dimension of the embeddings.
Returns:
int: The dimensionality of the embedding for the current model.
"""
return self.output_dim

View File

@@ -1,91 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
import os
from typing import Any, Optional
from openai import OpenAI
from camel.embeddings.base import BaseEmbedding
from camel.utils import api_keys_required
class OpenAICompatibleEmbedding(BaseEmbedding[str]):
r"""Provides text embedding functionalities supporting OpenAI
compatibility.
Args:
model_type (str): The model type to be used for text embeddings.
api_key (str): The API key for authenticating with the model service.
url (str): The url to the model service.
"""
def __init__(
self,
model_type: str,
api_key: Optional[str] = None,
url: Optional[str] = None,
) -> None:
self.model_type = model_type
self.output_dim: Optional[int] = None
self._api_key = api_key or os.environ.get(
"OPENAI_COMPATIBILIY_API_KEY"
)
self._url = url or os.environ.get("OPENAI_COMPATIBILIY_API_BASE_URL")
self._client = OpenAI(
timeout=60,
max_retries=3,
api_key=self._api_key,
base_url=self._url,
)
@api_keys_required("OPENAI_COMPATIBILIY_API_KEY")
def embed_list(
self,
objs: list[str],
**kwargs: Any,
) -> list[list[float]]:
r"""Generates embeddings for the given texts.
Args:
objs (list[str]): The texts for which to generate the embeddings.
**kwargs (Any): Extra kwargs passed to the embedding API.
Returns:
list[list[float]]: A list that represents the generated embedding
as a list of floating-point numbers.
"""
response = self._client.embeddings.create(
input=objs,
model=self.model_type,
**kwargs,
)
self.output_dim = len(response.data[0].embedding)
return [data.embedding for data in response.data]
def get_output_dim(self) -> int:
r"""Returns the output dimension of the embeddings.
Returns:
int: The dimensionality of the embedding for the current model.
"""
if self.output_dim is None:
raise ValueError(
"Output dimension is not yet determined. Call "
"'embed_list' first."
)
return self.output_dim

View File

@@ -1,99 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
import os
from typing import Any
from openai import OpenAI
from camel.embeddings.base import BaseEmbedding
from camel.types import NOT_GIVEN, EmbeddingModelType, NotGiven
from camel.utils import api_keys_required
class OpenAIEmbedding(BaseEmbedding[str]):
r"""Provides text embedding functionalities using OpenAI's models.
Args:
model_type (EmbeddingModelType, optional): The model type to be
used for text embeddings.
(default: :obj:`TEXT_EMBEDDING_3_SMALL`)
api_key (str, optional): The API key for authenticating with the
OpenAI service. (default: :obj:`None`)
dimensions (int, optional): The text embedding output dimensions.
(default: :obj:`NOT_GIVEN`)
Raises:
RuntimeError: If an unsupported model type is specified.
"""
def __init__(
self,
model_type: EmbeddingModelType = (
EmbeddingModelType.TEXT_EMBEDDING_3_SMALL
),
api_key: str | None = None,
dimensions: int | NotGiven = NOT_GIVEN,
) -> None:
if not model_type.is_openai:
raise ValueError("Invalid OpenAI embedding model type.")
self.model_type = model_type
if dimensions == NOT_GIVEN:
self.output_dim = model_type.output_dim
else:
assert isinstance(dimensions, int)
self.output_dim = dimensions
self._api_key = api_key or os.environ.get("OPENAI_API_KEY")
self.client = OpenAI(timeout=60, max_retries=3, api_key=self._api_key)
@api_keys_required("OPENAI_API_KEY")
def embed_list(
self,
objs: list[str],
**kwargs: Any,
) -> list[list[float]]:
r"""Generates embeddings for the given texts.
Args:
objs (list[str]): The texts for which to generate the embeddings.
**kwargs (Any): Extra kwargs passed to the embedding API.
Returns:
list[list[float]]: A list that represents the generated embedding
as a list of floating-point numbers.
"""
# TODO: count tokens
if self.model_type == EmbeddingModelType.TEXT_EMBEDDING_ADA_2:
response = self.client.embeddings.create(
input=objs,
model=self.model_type.value,
**kwargs,
)
else:
response = self.client.embeddings.create(
input=objs,
model=self.model_type.value,
dimensions=self.output_dim,
**kwargs,
)
return [data.embedding for data in response.data]
def get_output_dim(self) -> int:
r"""Returns the output dimension of the embeddings.
Returns:
int: The dimensionality of the embedding for the current model.
"""
return self.output_dim

View File

@@ -1,80 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import Any
from numpy import ndarray
from camel.embeddings.base import BaseEmbedding
class SentenceTransformerEncoder(BaseEmbedding[str]):
r"""This class provides functionalities to generate text
embeddings using `Sentence Transformers`.
References:
https://www.sbert.net/
"""
def __init__(
self,
model_name: str = "intfloat/e5-large-v2",
**kwargs,
):
r"""Initializes the: obj: `SentenceTransformerEmbedding` class
with the specified transformer model.
Args:
model_name (str, optional): The name of the model to use.
(default: :obj:`intfloat/e5-large-v2`)
**kwargs (optional): Additional arguments of
:class:`SentenceTransformer`, such as :obj:`prompts` etc.
"""
from sentence_transformers import SentenceTransformer
self.model = SentenceTransformer(model_name, **kwargs)
def embed_list(
self,
objs: list[str],
**kwargs: Any,
) -> list[list[float]]:
r"""Generates embeddings for the given texts using the model.
Args:
objs (list[str]): The texts for which to generate the
embeddings.
Returns:
list[list[float]]: A list that represents the generated embedding
as a list of floating-point numbers.
"""
if not objs:
raise ValueError("Input text list is empty")
embeddings = self.model.encode(
objs, normalize_embeddings=True, **kwargs
)
assert isinstance(embeddings, ndarray)
return embeddings.tolist()
def get_output_dim(self) -> int:
r"""Returns the output dimension of the embeddings.
Returns:
int: The dimensionality of the embeddings.
"""
output_dim = self.model.get_sentence_embedding_dimension()
assert isinstance(output_dim, int)
return output_dim

View File

@@ -1,149 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from typing import Any, List, Optional, Union
from PIL import Image
from camel.embeddings import BaseEmbedding
from camel.logger import get_logger
logger = get_logger(__name__)
class VisionLanguageEmbedding(BaseEmbedding[Union[str, Image.Image]]):
r"""Provides image embedding functionalities using multimodal model.
Args:
model_name : The model type to be used for generating embeddings.
And the default value is: obj:`openai/clip-vit-base-patch32`.
Raises:
RuntimeError: If an unsupported model type is specified.
"""
def __init__(
self, model_name: str = "openai/clip-vit-base-patch32"
) -> None:
r"""Initializes the: obj: `VisionLanguageEmbedding` class with a
specified model and return the dimension of embeddings.
Args:
model_name (str, optional): The version name of the model to use.
(default: :obj:`openai/clip-vit-base-patch32`)
"""
from transformers import AutoModel, AutoProcessor
try:
self.model = AutoModel.from_pretrained(model_name)
self.processor = AutoProcessor.from_pretrained(model_name)
except Exception as e:
raise RuntimeError(f"Failed to load model '{model_name}': {e}")
self.valid_processor_kwargs = []
self.valid_model_kwargs = []
try:
self.valid_processor_kwargs = (
self.processor.image_processor._valid_processor_keys
)
self.valid_model_kwargs = [
"pixel_values",
"return_dict",
"interpolate_pos_encoding",
]
except Exception:
logger.warning("not typically processor and model structure")
pass
self.dim: Optional[int] = None
def embed_list(
self, objs: List[Union[Image.Image, str]], **kwargs: Any
) -> List[List[float]]:
"""Generates embeddings for the given images or texts.
Args:
objs (List[Image.Image|str]): The list of images or texts for
which to generate the embeddings.
image_processor_kwargs: Extra kwargs passed to the image processor.
tokenizer_kwargs: Extra kwargs passed to the text tokenizer
(processor).
model_kwargs: Extra kwargs passed to the main model.
Returns:
List[List[float]]: A list that represents the generated embedding
as a list of floating-point numbers.
Raises:
ValueError: If the input type is not `Image.Image` or `str`.
"""
if not objs:
raise ValueError("Input objs list is empty.")
image_processor_kwargs: Optional[dict] = kwargs.get(
'image_processor_kwargs', {}
)
tokenizer_kwargs: Optional[dict] = kwargs.get('tokenizer_kwargs', {})
model_kwargs: Optional[dict] = kwargs.get('model_kwargs', {})
result_list = []
for obj in objs:
if isinstance(obj, Image.Image):
image_input = self.processor(
images=obj,
return_tensors="pt",
padding=True,
**image_processor_kwargs,
)
image_feature = (
self.model.get_image_features(
**image_input, **model_kwargs
)
.squeeze(dim=0)
.tolist()
)
result_list.append(image_feature)
elif isinstance(obj, str):
text_input = self.processor(
text=obj,
return_tensors="pt",
padding=True,
**tokenizer_kwargs,
)
text_feature = (
self.model.get_text_features(**text_input, **model_kwargs)
.squeeze(dim=0)
.tolist()
)
result_list.append(text_feature)
else:
raise ValueError("Input type is not image nor text.")
self.dim = len(result_list[0])
if any(len(result) != self.dim for result in result_list):
raise ValueError("Dimensionality is not consistent.")
return result_list
def get_output_dim(self) -> int:
r"""Returns the output dimension of the embeddings.
Returns:
int: The dimensionality of the embedding for the current model.
"""
if self.dim is None:
text = 'dimension'
inputs = self.processor(text=[text], return_tensors="pt")
self.dim = self.model.get_text_features(**inputs).shape[1]
return self.dim

View File

@@ -1,375 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from typing import Dict, Generator, List, Optional, Set, Tuple
from camel.messages import BaseMessage
from camel.prompts import PromptTemplateGenerator, TextPrompt
from camel.types import RoleType, TaskType
class SystemMessageGenerator:
r"""System message generator for agents.
Args:
task_type (TaskType, optional): The task type.
(default: :obj:`TaskType.AI_SOCIETY`)
sys_prompts (Optional[Dict[RoleType, str]], optional): The prompts of
the system messages for each role type. (default: :obj:`None`)
sys_msg_meta_dict_keys (Optional[Set[str]], optional): The set of keys
of the meta dictionary used to fill the prompts.
(default: :obj:`None`)
"""
def __init__(
self,
task_type: TaskType = TaskType.AI_SOCIETY,
sys_prompts: Optional[Dict[RoleType, str]] = None,
sys_msg_meta_dict_keys: Optional[Set[str]] = None,
) -> None:
self.sys_prompts: Dict[RoleType, str]
if sys_prompts is not None:
self.sys_prompts = sys_prompts
self.sys_msg_meta_dict_keys = sys_msg_meta_dict_keys or set()
else:
assistant_prompt_template = (
PromptTemplateGenerator().get_system_prompt(
task_type,
RoleType.ASSISTANT,
)
)
user_prompt_template = PromptTemplateGenerator().get_system_prompt(
task_type,
RoleType.USER,
)
critic_prompt_template = (
PromptTemplateGenerator().get_system_prompt(
task_type,
RoleType.CRITIC,
)
)
embodiment_prompt_template = (
PromptTemplateGenerator().get_system_prompt(
task_type,
RoleType.EMBODIMENT,
)
)
self.sys_prompts = dict()
self.sys_prompts[RoleType.ASSISTANT] = assistant_prompt_template
self.sys_prompts[RoleType.USER] = user_prompt_template
self.sys_prompts[RoleType.CRITIC] = critic_prompt_template
self.sys_prompts[RoleType.EMBODIMENT] = embodiment_prompt_template
self.sys_msg_meta_dict_keys = (
assistant_prompt_template.key_words
| user_prompt_template.key_words
| critic_prompt_template.key_words
| embodiment_prompt_template.key_words
)
if RoleType.DEFAULT not in self.sys_prompts:
self.sys_prompts[RoleType.DEFAULT] = "You are a helpful assistant."
def validate_meta_dict_keys(self, meta_dict: Dict[str, str]) -> None:
r"""Validates the keys of the meta_dict.
Args:
meta_dict (Dict[str, str]): The dictionary to validate.
"""
if not set(meta_dict.keys()).issubset(self.sys_msg_meta_dict_keys):
raise ValueError(
"The keys of the meta_dict should be in "
f"{self.sys_msg_meta_dict_keys}. "
f"Got {set(meta_dict.keys())} instead."
)
def from_dict(
self,
meta_dict: Dict[str, str],
role_tuple: Tuple[str, RoleType] = ("", RoleType.DEFAULT),
) -> BaseMessage:
r"""Generates a system message from a dictionary.
Args:
meta_dict (Dict[str, str]): The dictionary containing the
information to generate the system message.
role_tuple (Tuple[str, RoleType], optional): The tuple containing
the role name and role type. (default: ("", RoleType.DEFAULT))
Returns:
BaseMessage: The generated system message.
"""
self.validate_meta_dict_keys(meta_dict)
role_name, role_type = role_tuple
sys_prompt = self.sys_prompts[role_type]
sys_prompt = sys_prompt.format(**meta_dict)
return BaseMessage(
role_name=role_name,
role_type=role_type,
meta_dict=meta_dict,
content=sys_prompt,
)
def from_dicts(
self,
meta_dicts: List[Dict[str, str]],
role_tuples: List[Tuple[str, RoleType]],
) -> List[BaseMessage]:
r"""Generates a list of system messages from a list of dictionaries.
Args:
meta_dicts (List[Dict[str, str]]): A list of dictionaries
containing the information to generate the system messages.
role_tuples (List[Tuple[str, RoleType]]): A list of tuples
containing the role name and role type for each system message.
Returns:
List[BaseMessage]: A list of generated system messages.
Raises:
ValueError: If the number of meta_dicts and role_tuples are
different.
"""
if len(meta_dicts) != len(role_tuples):
raise ValueError(
"The number of meta_dicts and role_types should be the same."
)
return [
self.from_dict(meta_dict, role_tuple)
for meta_dict, role_tuple in zip(meta_dicts, role_tuples)
]
class RoleNameGenerator:
r"""Role name generator for role-playing workers.
Args:
assistant_role_names_path (str, optional): The path to the file
containing the assistant role names.
(default: :obj:`"data/ai_society/assistant_roles.txt"`)
user_role_names_path (str, optional): The path to the file
containing the user role names.
(default: :obj:`"data/ai_society/user_roles.txt"`)
assistant_role_names (Optional[List[str]], optional): The list of
assistant role names. (default: :obj:`None`)
user_role_names (Optional[List[str]], optional): The list of user role
names. (default: :obj:`None`)
"""
def __init__(
self,
assistant_role_names_path: str = "data/ai_society/assistant_roles.txt",
user_role_names_path: str = "data/ai_society/user_roles.txt",
assistant_role_names: Optional[List[str]] = None,
user_role_names: Optional[List[str]] = None,
) -> None:
if assistant_role_names is None:
with open(assistant_role_names_path, "r") as f:
assistant_role_names_: List[str] = f.read().splitlines()
self.assistant_role_names = [
" ".join(name.split(" ")[1:])
for name in assistant_role_names_
]
else:
self.assistant_role_names = assistant_role_names
if user_role_names is None:
with open(user_role_names_path, "r") as f:
user_role_names_: List[str] = f.read().splitlines()
self.user_role_names = [
" ".join(name.split(" ")[1:]) for name in user_role_names_
]
else:
self.user_role_names = user_role_names
def from_role_files(self) -> Generator[Tuple, None, None]:
r"""Generate role names from the file.
Returns:
Generator[Tuple, None, None]: A generator that yields tuples of
assistant role names and user role names.
"""
for assistant_role_name in self.assistant_role_names:
for user_role_name in self.user_role_names:
yield (assistant_role_name, user_role_name)
class AISocietyTaskPromptGenerator:
r"""Task prompt generator for AI society tasks.
Args:
num_tasks (int, optional): The number of tasks to generate.
(default: :obj:`10`)
"""
def __init__(
self,
num_tasks: int = 10,
) -> None:
self.generate_tasks_prompt = (
PromptTemplateGenerator().get_generate_tasks_prompt(
TaskType.AI_SOCIETY
)
)
self.num_tasks = num_tasks
# TODO: Return role names for user and assistant with the generator.
def from_role_files(
self,
assistant_role_names_path: str = "data/ai_society/assistant_roles.txt",
user_role_names_path: str = "data/ai_society/user_roles.txt",
) -> Generator[Tuple[str, Tuple[str, str]], None, None]:
r"""Generate tasks from role files.
Args:
assistant_role_names_path (str, optional): The path to the file
containing the assistant role names.
(default: :obj:`"data/ai_society/assistant_roles.txt"`)
user_role_names_path (str, optional): The path to the file
containing the user role names.
(default: :obj:`"data/ai_society/user_roles.txt"`)
Returns:
Generator[Tuple[str, Tuple[str, str]], None, None]: A generator
that yields tuples of task prompts and role names.
"""
roles_generator = RoleNameGenerator(
assistant_role_names_path, user_role_names_path
).from_role_files()
for role_1, role_2 in roles_generator:
generate_tasks_prompt = self.generate_tasks_prompt.format(
assistant_role=role_1,
user_role=role_2,
num_tasks=self.num_tasks,
)
yield (generate_tasks_prompt, (role_1, role_2))
def from_role_generator(
self, role_generator: Generator[Tuple, None, None]
) -> Generator[Tuple[str, Tuple[str, str]], None, None]:
r"""Generate tasks from a role generator.
Args:
role_generator (Generator[Tuple, None, None]): A generator that
yields tuples of role names.
Returns:
Generator[Tuple[str, Tuple[str, str]], None, None]: A generator
that yields tuples of task prompts and role names.
"""
for role_1, role_2 in role_generator:
generate_tasks_prompt = self.generate_tasks_prompt.format(
assistant_role=role_1,
user_role=role_2,
num_tasks=self.num_tasks,
)
yield (generate_tasks_prompt, (role_1, role_2))
class SingleTxtGenerator:
r"""Single text generator for role-playing workers.
Args:
text_file_path (str): The path to the file containing the text data.
"""
def __init__(
self,
text_file_path: str,
) -> None:
with open(text_file_path, "r") as f:
data_list: List[str] = f.read().splitlines()
self.data_list = [
" ".join(name.split(" ")[1:]) for name in data_list
]
def from_role_files(self) -> Generator[str, None, None]:
r"""Generate text from the file.
Returns:
Generator[str, None, None]: A generator that yields the text data.
"""
for data in self.data_list:
yield data
class CodeTaskPromptGenerator:
r"""Code task prompt generator for code tasks.
Args:
num_tasks (int, optional): The number of tasks to generate.
(default: :obj:`50`)
"""
def __init__(
self,
num_tasks: int = 50,
) -> None:
self.generate_tasks_prompt = (
PromptTemplateGenerator().get_generate_tasks_prompt(TaskType.CODE)
)
self.num_tasks = num_tasks
def from_role_files(
self,
languages_path: str = "data/code/languages.txt",
domains_path: str = "data/code/domains.txt",
) -> Generator[Tuple[TextPrompt, str, str], None, None]:
r"""Generate tasks from role files.
Args:
languages_path (str, optional): The path to the file containing
the language names. (default: :obj:`"data/code/languages.txt"`)
domains_path (str, optional): The path to the file containing
the domain names. (default: :obj:`"data/code/domains.txt"`)
Returns:
Generator[Tuple[TextPrompt, str, str], None, None]: A generator
that yields tuples of task prompts, language names, and domain
names.
"""
language_generator = SingleTxtGenerator(
languages_path
).from_role_files()
for language in language_generator:
domains_generator = SingleTxtGenerator(
domains_path
).from_role_files()
for domain in domains_generator:
generated_tasks_prompt = self.generate_tasks_prompt.format(
language=language, domain=domain, num_tasks=self.num_tasks
)
yield generated_tasks_prompt, language, domain
def from_role_generator(
self, role_generator: Generator[Tuple, None, None]
) -> Generator[str, None, None]:
r"""Generate tasks from a role generator.
Args:
role_generator (Generator[Tuple, None, None]): A generator that
yields tuples of role names.
Returns:
Generator[str, None, None]: A generator that yields the task
prompts.
"""
raise NotImplementedError

View File

@@ -1,138 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from typing import Any, Dict, Sequence
from colorama import Fore
from camel.messages import BaseMessage
from camel.responses import ChatAgentResponse
from camel.utils import print_text_animated
class Human:
r"""A class representing a human user.
Args:
name (str): The name of the human user.
(default: :obj:`"Kill Switch Engineer"`).
logger_color (Any): The color of the menu options displayed to the
user. (default: :obj:`Fore.MAGENTA`)
Attributes:
name (str): The name of the human user.
logger_color (Any): The color of the menu options displayed to the
user.
input_button (str): The text displayed for the input button.
kill_button (str): The text displayed for the kill button.
options_dict (Dict[str, str]): A dictionary containing the options
displayed to the user.
"""
def __init__(
self,
name: str = "Kill Switch Engineer",
logger_color: Any = Fore.MAGENTA,
) -> None:
self.name = name
self.logger_color = logger_color
self.input_button = f"Input by {self.name}."
self.kill_button = "Stop!!!"
self.options_dict: Dict[str, str] = dict()
def display_options(self, messages: Sequence[BaseMessage]) -> None:
r"""Displays the options to the user.
Args:
messages (Sequence[BaseMessage]): A list of `BaseMessage` objects.
Returns:
None
"""
options = [message.content for message in messages]
options.append(self.input_button)
options.append(self.kill_button)
print_text_animated(
self.logger_color + "\n> Proposals from "
f"{messages[0].role_name} ({messages[0].role_type}). "
"Please choose an option:\n"
)
for index, option in enumerate(options):
print_text_animated(
self.logger_color
+ f"\x1b[3mOption {index + 1}:\n{option}\x1b[0m\n"
)
self.options_dict[str(index + 1)] = option
def get_input(self) -> str:
r"""Gets the input from the user.
Returns:
str: The user's input.
"""
while True:
human_input = input(
self.logger_color
+ f"Please enter your choice ([1-{len(self.options_dict)}]): "
)
print("\n")
if human_input in self.options_dict:
break
print_text_animated(
self.logger_color + "\n> Invalid choice. Please try again.\n"
)
return human_input
def parse_input(self, human_input: str) -> str:
r"""Parses the user's input and returns a `BaseMessage` object.
Args:
human_input (str): The user's input.
Returns:
content: A `str` object representing the user's input.
"""
if self.options_dict[human_input] == self.input_button:
content = input(self.logger_color + "Please enter your message: ")
elif self.options_dict[human_input] == self.kill_button:
exit(self.logger_color + f"Killed by {self.name}.")
else:
content = self.options_dict[human_input]
return content
def reduce_step(
self, messages: Sequence[BaseMessage]
) -> ChatAgentResponse:
r"""Performs one step of the conversation by displaying options to the
user, getting their input, and parsing their choice.
Args:
messages (Sequence[BaseMessage]): A list of BaseMessage objects.
Returns:
ChatAgentResponse: A `ChatAgentResponse` object representing the
user's choice.
"""
meta_chat_message = BaseMessage(
role_name=messages[0].role_name,
role_type=messages[0].role_type,
meta_dict=messages[0].meta_dict,
content="",
)
self.display_options(messages)
human_input = self.get_input()
content = self.parse_input(human_input)
message = meta_chat_message.create_new_instance(content)
return ChatAgentResponse(msgs=[message], terminated=False, info={})

View File

@@ -1,29 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from .base import BaseInterpreter
from .docker_interpreter import DockerInterpreter
from .internal_python_interpreter import InternalPythonInterpreter
from .interpreter_error import InterpreterError
from .ipython_interpreter import JupyterKernelInterpreter
from .subprocess_interpreter import SubprocessInterpreter
__all__ = [
'BaseInterpreter',
'InterpreterError',
'InternalPythonInterpreter',
'SubprocessInterpreter',
'DockerInterpreter',
'JupyterKernelInterpreter',
]

View File

@@ -1,49 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from abc import ABC, abstractmethod
from typing import Any, Dict, List
class BaseInterpreter(ABC):
r"""An abstract base class for code interpreters."""
@abstractmethod
def run(self, code: str, code_type: str) -> str:
r"""Executes the given code based on its type.
Args:
code (str): The code to be executed.
code_type (str): The type of the code, which must be one of the
types returned by `supported_code_types()`.
Returns:
str: The result of the code execution. If the execution fails, this
should include sufficient information to diagnose and correct
the issue.
Raises:
InterpreterError: If the code execution encounters errors that
could be resolved by modifying or regenerating the code.
"""
pass
@abstractmethod
def supported_code_types(self) -> List[str]:
r"""Provides supported code types by the interpreter."""
pass
@abstractmethod
def update_action_space(self, action_space: Dict[str, Any]) -> None:
r"""Updates action space for *python* interpreter"""
pass

View File

@@ -1,245 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import io
import shlex
import tarfile
import uuid
from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional
from colorama import Fore
from camel.interpreters.base import BaseInterpreter
from camel.interpreters.interpreter_error import InterpreterError
from camel.logger import get_logger
from camel.utils import is_docker_running
if TYPE_CHECKING:
from docker.models.containers import Container
logger = get_logger(__name__)
class DockerInterpreter(BaseInterpreter):
r"""A class for executing code files or code strings in a docker container.
This class handles the execution of code in different scripting languages
(currently Python and Bash) within a docker container, capturing their
stdout and stderr streams, and allowing user checking before executing code
strings.
Args:
require_confirm (bool, optional): If `True`, prompt user before
running code strings for security. Defaults to `True`.
print_stdout (bool, optional): If `True`, print the standard
output of the executed code. Defaults to `False`.
print_stderr (bool, optional): If `True`, print the standard error
of the executed code. Defaults to `True`.
"""
_CODE_EXECUTE_CMD_MAPPING: ClassVar[Dict[str, str]] = {
"python": "python {file_name}",
"bash": "bash {file_name}",
}
_CODE_EXTENSION_MAPPING: ClassVar[Dict[str, str]] = {
"python": "py",
"bash": "sh",
}
_CODE_TYPE_MAPPING: ClassVar[Dict[str, str]] = {
"python": "python",
"py3": "python",
"python3": "python",
"py": "python",
"shell": "bash",
"bash": "bash",
"sh": "bash",
}
def __init__(
self,
require_confirm: bool = True,
print_stdout: bool = False,
print_stderr: bool = True,
) -> None:
self.require_confirm = require_confirm
self.print_stdout = print_stdout
self.print_stderr = print_stderr
# lazy initialization of container
self._container: Optional[Container] = None
def __del__(self) -> None:
r"""Destructor for the DockerInterpreter class.
This method ensures that the Docker container is removed when the
interpreter is deleted.
"""
if self._container is not None:
self._container.remove(force=True)
def _initialize_if_needed(self) -> None:
if self._container is not None:
return
if not is_docker_running():
raise InterpreterError(
"Docker daemon is not running. Please install/start docker "
"and try again."
)
import docker
client = docker.from_env()
self._container = client.containers.run(
"python:3.10",
detach=True,
name=f"camel-interpreter-{uuid.uuid4()}",
command="tail -f /dev/null",
)
def _create_file_in_container(self, content: str) -> Path:
# get a random name for the file
filename = str(uuid.uuid4())
# create a tar in memory
tar_stream = io.BytesIO()
with tarfile.open(fileobj=tar_stream, mode='w') as tar:
tarinfo = tarfile.TarInfo(name=filename)
tarinfo.size = len(content)
tar.addfile(tarinfo, io.BytesIO(content.encode('utf-8')))
tar_stream.seek(0)
# copy the tar into the container
if self._container is None:
raise InterpreterError(
"Container is not initialized. Try running the code again."
)
self._container.put_archive("/tmp", tar_stream)
return Path(f"/tmp/{filename}")
def _run_file_in_container(
self,
file: Path,
code_type: str,
) -> str:
code_type = self._check_code_type(code_type)
commands = shlex.split(
self._CODE_EXECUTE_CMD_MAPPING[code_type].format(
file_name=file.as_posix()
)
)
if self._container is None:
raise InterpreterError(
"Container is not initialized. Try running the code again."
)
stdout, stderr = self._container.exec_run(
commands,
demux=True,
).output
if self.print_stdout and stdout:
print("======stdout======")
print(Fore.GREEN + stdout.decode() + Fore.RESET)
print("==================")
if self.print_stderr and stderr:
print("======stderr======")
print(Fore.RED + stderr.decode() + Fore.RESET)
print("==================")
exec_result = f"{stdout.decode()}" if stdout else ""
exec_result += f"(stderr: {stderr.decode()})" if stderr else ""
return exec_result
def run(
self,
code: str,
code_type: str,
) -> str:
r"""Executes the given code in the conatiner attached to the
interpreter, and captures the stdout and stderr streams.
Args:
code (str): The code string to execute.
code_type (str): The type of code to execute (e.g., 'python',
'bash').
Returns:
str: A string containing the captured stdout and stderr of the
executed code.
Raises:
InterpreterError: If the user declines to run the code, or the
code type is unsupported, or there is an error in the docker
API/container
"""
import docker.errors
code_type = self._check_code_type(code_type)
# Print code for security checking
if self.require_confirm:
logger.info(
f"The following {code_type} code will run on your "
"computer: {code}"
)
while True:
choice = input("Running code? [Y/n]:").lower()
if choice in ["y", "yes", "ye", ""]:
break
elif choice not in ["no", "n"]:
continue
raise InterpreterError(
"Execution halted: User opted not to run the code. "
"This choice stops the current operation and any "
"further code execution."
)
self._initialize_if_needed()
try:
temp_file_path = self._create_file_in_container(code)
result = self._run_file_in_container(temp_file_path, code_type)
except docker.errors.APIError as e:
raise InterpreterError(
f"Execution halted due to docker API error: {e.explanation}. "
"This choice stops the current operation and any "
"further code execution."
) from e
except docker.errors.DockerException as e:
raise InterpreterError(
f"Execution halted due to docker exceptoin: {e}. "
"This choice stops the current operation and any "
"further code execution."
) from e
return result
def _check_code_type(self, code_type: str) -> str:
if code_type not in self._CODE_TYPE_MAPPING:
raise InterpreterError(
f"Unsupported code type {code_type}. Currently "
f"`{self.__class__.__name__}` only supports "
f"{', '.join(self._CODE_EXTENSION_MAPPING.keys())}."
)
return self._CODE_TYPE_MAPPING[code_type]
def supported_code_types(self) -> List[str]:
r"""Provides supported code types by the interpreter."""
return list(self._CODE_EXTENSION_MAPPING.keys())
def update_action_space(self, action_space: Dict[str, Any]) -> None:
r"""Updates action space for *python* interpreter"""
raise RuntimeError(
"SubprocessInterpreter doesn't support " "`action_space`."
)

View File

@@ -1,516 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import ast
import difflib
import importlib
import typing
from typing import Any, ClassVar, Dict, List, Optional
from camel.interpreters.base import BaseInterpreter
from camel.interpreters.interpreter_error import InterpreterError
class InternalPythonInterpreter(BaseInterpreter):
r"""A customized python interpreter to control the execution of
LLM-generated codes. The interpreter makes sure the code can only execute
functions given in action space and import white list. It also supports
fuzzy variable matching to retrieve uncertain input variable name.
.. highlight:: none
This class is adapted from the hugging face implementation
`python_interpreter.py <https://github.com/huggingface/transformers/blob/8f
093fb799246f7dd9104ff44728da0c53a9f67a/src/transformers/tools/python_interp
reter.py>`_. The original license applies::
Copyright 2023 The HuggingFace Inc. team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied. See the License for the specific language governing
permissions and limitations under the License.
We have modified the original code to suit our requirements. We have
encapsulated the original functions within a class and saved the
interpreter state after execution. We have added support for "import"
statements, "for" statements, and several binary and unary operators. We
have added import white list to keep `import` statement safe. Additionally,
we have modified the variable matching logic and introduced the
:obj:`fuzz_state` for fuzzy matching.
Modifications copyright (C) 2023 CAMEL-AI.org
Args:
action_space (Dict[str, Any], optional): A dictionary that maps action
names to their corresponding functions or objects. The interpreter
can only execute functions that are either directly listed in this
dictionary or are member functions of objects listed in this
dictionary. The concept of :obj:`action_space` is derived from
EmbodiedAgent, representing the actions that an agent is capable of
performing. If `None`, set to empty dict. (default: :obj:`None`)
import_white_list (List[str], optional): A list that stores
the Python modules or functions that can be imported in the code.
All submodules and functions of the modules listed in this list are
importable. Any other import statements will be rejected. The
module and its submodule or function name are separated by a period
(:obj:`.`). (default: :obj:`None`)
unsafe_mode (bool, optional): If `True`, the interpreter runs the code
by `eval()` without any security check. (default: :obj:`False`)
raise_error (bool, optional): Raise error if the interpreter fails.
(default: :obj:`False`)
"""
_CODE_TYPES: ClassVar[List[str]] = ["python", "py", "python3", "python2"]
def __init__(
self,
action_space: Optional[Dict[str, Any]] = None,
import_white_list: Optional[List[str]] = None,
unsafe_mode: bool = False,
raise_error: bool = False,
) -> None:
self.action_space = action_space or dict()
self.state = self.action_space.copy()
self.fuzz_state: Dict[str, Any] = dict()
self.import_white_list = import_white_list or list()
self.raise_error = raise_error
self.unsafe_mode = unsafe_mode
def run(self, code: str, code_type: str) -> str:
r"""Executes the given code with specified code type in the
interpreter.
This method takes a string of code and its type, checks if the code
type is supported, and then executes the code. If `unsafe_mode` is
set to `False`, the code is executed in a controlled environment using
the `execute` method. If `unsafe_mode` is `True`, the code is executed
using `eval()` with the action space as the global context. An
`InterpreterError` is raised if the code type is unsupported or if any
runtime error occurs during execution.
Args:
code (str): The python code to be executed.
code_type (str): The type of the code, which should be one of the
supported code types (`python`, `py`, `python3`, `python2`).
Returns:
str: The string representation of the output of the executed code.
Raises:
InterpreterError: If the `code_type` is not supported or if any
runtime error occurs during the execution of the code.
"""
if code_type not in self._CODE_TYPES:
raise InterpreterError(
f"Unsupported code type {code_type}. "
f"`{self.__class__.__name__}` only supports "
f"{', '.join(self._CODE_TYPES)}."
)
if not self.unsafe_mode:
return str(self.execute(code))
else:
return str(eval(code, self.action_space))
def update_action_space(self, action_space: Dict[str, Any]) -> None:
r"""Updates action space for *python* interpreter."""
self.action_space.update(action_space)
def supported_code_types(self) -> List[str]:
r"""Provides supported code types by the interpreter."""
return self._CODE_TYPES
def execute(
self,
code: str,
state: Optional[Dict[str, Any]] = None,
fuzz_state: Optional[Dict[str, Any]] = None,
keep_state: bool = True,
) -> Any:
r"""Execute the input python codes in a security environment.
Args:
code (str): Generated python code to be executed.
state (Optional[Dict[str, Any]], optional): External variables that
may be used in the generated code. (default: :obj:`None`)
fuzz_state (Optional[Dict[str, Any]], optional): External variables
that do not have certain variable names. The interpreter will
use fuzzy matching to access these variables. For example, if
:obj:`fuzz_state` has a variable :obj:`image`, the generated
code can use :obj:`input_image` to access it. (default:
:obj:`None`)
keep_state (bool, optional): If :obj:`True`, :obj:`state` and
:obj:`fuzz_state` will be kept for later execution. Otherwise,
they will be cleared. (default: :obj:`True`)
Returns:
Any: The value of the last statement (excluding "import") in the
code. For this interpreter, the value of an expression is its
value, the value of an "assign" statement is the assigned
value, and the value of an "if" and "for" block statement is
the value of the last statement in the block.
"""
if state is not None:
self.state.update(state)
if fuzz_state is not None:
self.fuzz_state.update(fuzz_state)
try:
expression = ast.parse(code)
except SyntaxError as e:
if self.raise_error:
raise InterpreterError(f"Syntax error in code: {e}")
else:
import traceback
return traceback.format_exc()
result = None
for idx, node in enumerate(expression.body):
try:
line_result = self._execute_ast(node)
except InterpreterError as e:
if not keep_state:
self.clear_state()
msg = (
f"Evaluation of the code stopped at node {idx}. "
f"See:\n{e}"
)
# More information can be provided by `ast.unparse()`,
# which is new in python 3.9.
if self.raise_error:
raise InterpreterError(msg)
else:
import traceback
return traceback.format_exc()
if line_result is not None:
result = line_result
if not keep_state:
self.clear_state()
return result
def clear_state(self) -> None:
r"""Initialize :obj:`state` and :obj:`fuzz_state`."""
self.state = self.action_space.copy()
self.fuzz_state = {}
# ast.Index is deprecated after python 3.9, which cannot pass type check,
# but is still necessary for older versions.
@typing.no_type_check
def _execute_ast(self, expression: ast.AST) -> Any:
if isinstance(expression, ast.Assign):
# Assignment -> evaluate the assignment which should
# update the state. We return the variable assigned as it may
# be used to determine the final result.
return self._execute_assign(expression)
elif isinstance(expression, ast.Attribute):
value = self._execute_ast(expression.value)
return getattr(value, expression.attr)
elif isinstance(expression, ast.BinOp):
# Binary Operator -> return the result value
return self._execute_binop(expression)
elif isinstance(expression, ast.Call):
# Function call -> return the value of the function call
return self._execute_call(expression)
elif isinstance(expression, ast.Compare):
# Compare -> return True or False
return self._execute_condition(expression)
elif isinstance(expression, ast.Constant):
# Constant -> just return the value
return expression.value
elif isinstance(expression, ast.Dict):
# Dict -> evaluate all keys and values
result: Dict = {}
for k, v in zip(expression.keys, expression.values):
if k is not None:
result[self._execute_ast(k)] = self._execute_ast(v)
else:
result.update(self._execute_ast(v))
return result
elif isinstance(expression, ast.Expr):
# Expression -> evaluate the content
return self._execute_ast(expression.value)
elif isinstance(expression, ast.For):
return self._execute_for(expression)
elif isinstance(expression, ast.FormattedValue):
# Formatted value (part of f-string) -> evaluate the content
# and return
return self._execute_ast(expression.value)
elif isinstance(expression, ast.If):
# If -> execute the right branch
return self._execute_if(expression)
elif isinstance(expression, ast.Import):
# Import -> add imported names in self.state and return None.
self._execute_import(expression)
return None
elif isinstance(expression, ast.ImportFrom):
self._execute_import_from(expression)
return None
elif hasattr(ast, "Index") and isinstance(expression, ast.Index):
# cannot pass type check
return self._execute_ast(expression.value)
elif isinstance(expression, ast.JoinedStr):
return "".join(
[str(self._execute_ast(v)) for v in expression.values]
)
elif isinstance(expression, ast.List):
# List -> evaluate all elements
return [self._execute_ast(elt) for elt in expression.elts]
elif isinstance(expression, ast.Name):
# Name -> pick up the value in the state
return self._execute_name(expression)
elif isinstance(expression, ast.Subscript):
# Subscript -> return the value of the indexing
return self._execute_subscript(expression)
elif isinstance(expression, ast.Tuple):
return tuple([self._execute_ast(elt) for elt in expression.elts])
elif isinstance(expression, ast.UnaryOp):
# Binary Operator -> return the result value
return self._execute_unaryop(expression)
else:
# For now we refuse anything else. Let's add things as we need
# them.
raise InterpreterError(
f"{expression.__class__.__name__} is not supported."
)
def _execute_assign(self, assign: ast.Assign) -> Any:
targets = assign.targets
result = self._execute_ast(assign.value)
for target in targets:
self._assign(target, result)
return result
def _assign(self, target: ast.expr, value: Any):
if isinstance(target, ast.Name):
self.state[target.id] = value
elif isinstance(target, ast.Tuple):
if not isinstance(value, tuple):
raise InterpreterError(
f"Expected type tuple, but got"
f"{value.__class__.__name__} instead."
)
if len(target.elts) != len(value):
raise InterpreterError(
f"Expected {len(target.elts)} values but got"
f" {len(value)}."
)
for t, v in zip(target.elts, value):
self.state[self._execute_ast(t)] = v
else:
raise InterpreterError(
f"Unsupported variable type. Expected "
f"ast.Name or ast.Tuple, got "
f"{target.__class__.__name__} instead."
)
def _execute_call(self, call: ast.Call) -> Any:
callable_func = self._execute_ast(call.func)
# Todo deal with args
args = [self._execute_ast(arg) for arg in call.args]
kwargs = {
keyword.arg: self._execute_ast(keyword.value)
for keyword in call.keywords
}
return callable_func(*args, **kwargs)
def _execute_subscript(self, subscript: ast.Subscript):
index = self._execute_ast(subscript.slice)
value = self._execute_ast(subscript.value)
if not isinstance(subscript.ctx, ast.Load):
raise InterpreterError(
f"{subscript.ctx.__class__.__name__} is not supported for "
"subscript."
)
if isinstance(value, (list, tuple)):
return value[int(index)]
if index in value:
return value[index]
if isinstance(index, str) and isinstance(value, dict):
close_matches = difflib.get_close_matches(
index,
[key for key in list(value.keys()) if isinstance(key, str)],
)
if len(close_matches) > 0:
return value[close_matches[0]]
raise InterpreterError(f"Could not index {value} with '{index}'.")
def _execute_name(self, name: ast.Name):
if isinstance(name.ctx, ast.Store):
return name.id
elif isinstance(name.ctx, ast.Load):
return self._get_value_from_state(name.id)
else:
raise InterpreterError(f"{name.ctx} is not supported.")
def _execute_condition(self, condition: ast.Compare):
if len(condition.ops) > 1:
raise InterpreterError(
"Cannot evaluate conditions with multiple operators"
)
left = self._execute_ast(condition.left)
comparator = condition.ops[0]
right = self._execute_ast(condition.comparators[0])
if isinstance(comparator, ast.Eq):
return left == right
elif isinstance(comparator, ast.NotEq):
return left != right
elif isinstance(comparator, ast.Lt):
return left < right
elif isinstance(comparator, ast.LtE):
return left <= right
elif isinstance(comparator, ast.Gt):
return left > right
elif isinstance(comparator, ast.GtE):
return left >= right
elif isinstance(comparator, ast.Is):
return left is right
elif isinstance(comparator, ast.IsNot):
return left is not right
elif isinstance(comparator, ast.In):
return left in right
elif isinstance(comparator, ast.NotIn):
return left not in right
else:
raise InterpreterError(f"Unsupported operator: {comparator}")
def _execute_if(self, if_statement: ast.If):
result = None
if not isinstance(if_statement.test, ast.Compare):
raise InterpreterError(
"Only Campare expr supported in if statement, get"
f" {if_statement.test.__class__.__name__}"
)
if self._execute_condition(if_statement.test):
for line in if_statement.body:
line_result = self._execute_ast(line)
if line_result is not None:
result = line_result
else:
for line in if_statement.orelse:
line_result = self._execute_ast(line)
if line_result is not None:
result = line_result
return result
def _execute_for(self, for_statement: ast.For):
result = None
for value in self._execute_ast(for_statement.iter):
self._assign(for_statement.target, value)
for line in for_statement.body:
line_result = self._execute_ast(line)
if line_result is not None:
result = line_result
return result
def _execute_import(self, import_module: ast.Import) -> None:
for module in import_module.names:
self._validate_import(module.name)
alias = module.asname or module.name
self.state[alias] = importlib.import_module(module.name)
def _execute_import_from(self, import_from: ast.ImportFrom):
if import_from.module is None:
raise InterpreterError("\"from . import\" is not supported.")
for import_name in import_from.names:
full_name = import_from.module + f".{import_name.name}"
self._validate_import(full_name)
imported_module = importlib.import_module(import_from.module)
alias = import_name.asname or import_name.name
self.state[alias] = getattr(imported_module, import_name.name)
def _validate_import(self, full_name: str):
tmp_name = ""
found_name = False
for name in full_name.split("."):
tmp_name += name if tmp_name == "" else f".{name}"
if tmp_name in self.import_white_list:
found_name = True
return
if not found_name:
raise InterpreterError(
f"It is not permitted to import modules "
f"than module white list (try to import "
f"{full_name})."
)
def _execute_binop(self, binop: ast.BinOp):
left = self._execute_ast(binop.left)
operator = binop.op
right = self._execute_ast(binop.right)
if isinstance(operator, ast.Add):
return left + right
elif isinstance(operator, ast.Sub):
return left - right
elif isinstance(operator, ast.Mult):
return left * right
elif isinstance(operator, ast.Div):
return left / right
elif isinstance(operator, ast.FloorDiv):
return left // right
elif isinstance(operator, ast.Mod):
return left % right
elif isinstance(operator, ast.Pow):
return left**right
elif isinstance(operator, ast.LShift):
return left << right
elif isinstance(operator, ast.RShift):
return left >> right
elif isinstance(operator, ast.MatMult):
return left @ right
else:
raise InterpreterError(f"Operator not supported: {operator}")
def _execute_unaryop(self, unaryop: ast.UnaryOp):
operand = self._execute_ast(unaryop.operand)
operator = unaryop.op
if isinstance(operator, ast.UAdd):
return +operand
elif isinstance(operator, ast.USub):
return -operand
elif isinstance(operator, ast.Not):
return not operand
else:
raise InterpreterError(f"Operator not supported: {operator}")
def _get_value_from_state(self, key: str) -> Any:
if key in self.state:
return self.state[key]
else:
close_matches = difflib.get_close_matches(
key, list(self.fuzz_state.keys()), n=1
)
if close_matches:
return self.fuzz_state[close_matches[0]]
else:
raise InterpreterError(f"The variable `{key}` is not defined.")

View File

@@ -1,19 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# TODO: Do we need a file to store this error class?
class InterpreterError(Exception):
r"""Exception raised for errors that can be solved by regenerating code"""
pass

View File

@@ -1,168 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import queue
import re
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from camel.interpreters.base import BaseInterpreter
from camel.interpreters.interpreter_error import InterpreterError
if TYPE_CHECKING:
from jupyter_client import BlockingKernelClient, KernelManager
TIMEOUT = 30
class JupyterKernelInterpreter(BaseInterpreter):
r"""A class for executing code strings in a Jupyter Kernel.
Args:
require_confirm (bool, optional): If `True`, prompt user before
running code strings for security. Defaults to `True`.
print_stdout (bool, optional): If `True`, print the standard
output of the executed code. Defaults to `False`.
print_stderr (bool, optional): If `True`, print the standard error
of the executed code. Defaults to `True`.
"""
def __init__(
self,
require_confirm: bool = True,
print_stdout: bool = False,
print_stderr: bool = True,
) -> None:
self.require_confirm = require_confirm
self.print_stdout = print_stdout
self.print_stderr = print_stderr
self.kernel_manager: Optional[KernelManager] = None
self.client: Optional[BlockingKernelClient] = None
def __del__(self) -> None:
r"""Clean up the kernel and client."""
if self.kernel_manager:
self.kernel_manager.shutdown_kernel()
if self.client:
self.client.stop_channels()
def _initialize_if_needed(self) -> None:
r"""Initialize the kernel manager and client if they are not already
initialized.
"""
if self.kernel_manager is not None:
return
from jupyter_client.manager import start_new_kernel
self.kernel_manager, self.client = start_new_kernel()
@staticmethod
def _clean_ipython_output(output: str) -> str:
r"""Remove ANSI escape sequences from the output."""
ansi_escape = re.compile(r'\x1B[@-_][0-?]*[ -/]*[@-~]')
return ansi_escape.sub('', output)
def _execute(self, code: str, timeout: float) -> str:
r"""Execute the code in the Jupyter kernel and return the result."""
if not self.kernel_manager or not self.client:
raise InterpreterError("Jupyter client is not initialized.")
self.client.execute(code)
outputs = []
while True:
try:
msg = self.client.get_iopub_msg(timeout=timeout)
msg_content = msg["content"]
msg_type = msg.get("msg_type", None)
if msg_content.get("execution_state", None) == "idle":
break
if msg_type == "error":
print(msg_content.keys())
print(msg_content)
traceback = "\n".join(msg_content["traceback"])
outputs.append(traceback)
elif msg_type == "stream":
outputs.append(msg_content["text"])
elif msg_type in ["execute_result", "display_data"]:
outputs.append(msg_content["data"]["text/plain"])
if "image/png" in msg_content["data"]:
outputs.append(
f"\n![image](data:image/png;base64,"
f"{msg_content['data']['image/png']})\n"
)
except queue.Empty:
outputs.append("Time out")
break
except Exception as e:
outputs.append(f"Exception occurred: {e!s}")
break
exec_result = "\n".join(outputs)
return self._clean_ipython_output(exec_result)
def run(self, code: str, code_type: str) -> str:
r"""Executes the given code in the Jupyter kernel.
Args:
code (str): The code string to execute.
code_type (str): The type of code to execute (e.g., 'python',
'bash').
Returns:
str: A string containing the captured result of the
executed code.
Raises:
InterpreterError: If there is an error when doing code execution.
"""
self._initialize_if_needed()
if code_type == "bash":
code = f"%%bash\n({code})"
try:
result = self._execute(code, timeout=TIMEOUT)
except Exception as e:
raise InterpreterError(f"Execution failed: {e!s}")
return result
def supported_code_types(self) -> List[str]:
r"""Provides supported code types by the interpreter.
Returns:
List[str]: Supported code types.
"""
return ["python", "bash"]
def update_action_space(self, action_space: Dict[str, Any]) -> None:
r"""Updates the action space for the interpreter.
Args:
action_space (Dict[str, Any]): A dictionary representing the
new or updated action space.
Raises:
RuntimeError: Always raised because `JupyterKernelInterpreter`
does not support updating the action space.
"""
raise RuntimeError(
"SubprocessInterpreter doesn't support " "`action_space`."
)

View File

@@ -1,212 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import shlex
import subprocess
import tempfile
from pathlib import Path
from typing import Any, ClassVar, Dict, List
from colorama import Fore
from camel.interpreters.base import BaseInterpreter
from camel.interpreters.interpreter_error import InterpreterError
from camel.logger import get_logger
import os
logger = get_logger(__name__)
class SubprocessInterpreter(BaseInterpreter):
r"""SubprocessInterpreter is a class for executing code files or code
strings in a subprocess.
This class handles the execution of code in different scripting languages
(currently Python, Bash, and Node.js) within a subprocess, capturing their
stdout and stderr streams, and allowing user checking before executing code
strings.
Args:
require_confirm (bool, optional): If True, prompt user before running
code strings for security. (default: :obj:`True`)
print_stdout (bool, optional): If True, print the standard output of
the executed code. (default: :obj:`False`)
print_stderr (bool, optional): If True, print the standard error of the
executed code. (default: :obj:`True`)
"""
_CODE_EXECUTE_CMD_MAPPING: ClassVar[Dict[str, str]] = {
"python": "python {file_name}",
"bash": "bash {file_name}",
"node": "node {file_name}",
}
_CODE_EXTENSION_MAPPING: ClassVar[Dict[str, str]] = {
"python": "py",
"bash": "sh",
"node": "js",
}
_CODE_TYPE_MAPPING: ClassVar[Dict[str, str]] = {
"python": "python",
"py3": "python",
"python3": "python",
"py": "python",
"shell": "bash",
"bash": "bash",
"sh": "bash",
"node": "node",
"javascript": "node",
"js": "node",
}
def __init__(
self,
require_confirm: bool = True,
print_stdout: bool = False,
print_stderr: bool = True,
) -> None:
self.require_confirm = require_confirm
self.print_stdout = print_stdout
self.print_stderr = print_stderr
def run_file(
self,
file: Path,
code_type: str,
) -> str:
r"""Executes a code file in a subprocess and captures its output.
Args:
file (Path): The path object of the file to run.
code_type (str): The type of code to execute (e.g., 'python',
'bash', 'node').
Returns:
str: A string containing the captured stdout and stderr of the
executed code.
Raises:
RuntimeError: If the provided file path does not point to a file.
InterpreterError: If the code type provided is not supported.
"""
if not file.is_file():
raise RuntimeError(f"{file} is not a file.")
code_type = self._check_code_type(code_type)
cmd = shlex.split(
self._CODE_EXECUTE_CMD_MAPPING[code_type].format(file_name=str(file))
)
proc = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
)
stdout, stderr = proc.communicate()
if self.print_stdout and stdout:
print("======stdout======")
print(Fore.GREEN + stdout + Fore.RESET)
print("==================")
if self.print_stderr and stderr:
print("======stderr======")
print(Fore.RED + stderr + Fore.RESET)
print("==================")
exec_result = f"{stdout}"
exec_result += f"(stderr: {stderr})" if stderr else ""
return exec_result
def run(
self,
code: str,
code_type: str,
) -> str:
r"""Generates a temporary file with the given code, executes it, and
deletes the file afterward.
Args:
code (str): The code string to execute.
code_type (str): The type of code to execute (e.g., 'python',
'bash', 'node').
Returns:
str: A string containing the captured stdout and stderr of the
executed code.
Raises:
InterpreterError: If the user declines to run the code or if the
code type is unsupported.
"""
code_type = self._check_code_type(code_type)
if self.require_confirm:
logger.info(
f"The following {code_type} code will run on your " "computer: {code}"
)
while True:
choice = input("Running code? [Y/n]:").lower()
if choice in ["y", "yes", "ye", ""]:
break
elif choice in ["no", "n"]:
raise InterpreterError(
"Execution halted: User opted not to run the code. "
"This choice stops the current operation and any "
"further code execution."
)
temp_file_path = self._create_temp_file(
code=code, extension=self._CODE_EXTENSION_MAPPING[code_type]
)
result = self.run_file(temp_file_path, code_type)
temp_file_path.unlink()
return result
def _create_temp_file(self, code: str, extension: str) -> Path:
with tempfile.NamedTemporaryFile(
mode="w", delete=False, suffix=f".{extension}"
) as f:
f.write(code)
name = f.name
return Path(name)
# def _create_temp_file(self, code: str, extension: str) -> Path:
# # generate a random file name
# import datetime
# current_time = datetime.datetime.now().strftime("%d%H%M%S")
# temp_file_path = os.path.join("tmp", f"{current_time}.{extension}")
# with open(temp_file_path, "w", encoding='utf-8') as f:
# f.write(code)
# f.close()
# f.flush()
# breakpoint()
# return Path(temp_file_path)
def _check_code_type(self, code_type: str) -> str:
if code_type not in self._CODE_TYPE_MAPPING:
raise InterpreterError(
f"Unsupported code type {code_type}. Currently "
f"`{self.__class__.__name__}` only supports "
f"{', '.join(self._CODE_EXTENSION_MAPPING.keys())}."
)
return self._CODE_TYPE_MAPPING[code_type]
def supported_code_types(self) -> List[str]:
r"""Provides supported code types by the interpreter."""
return list(self._CODE_EXTENSION_MAPPING.keys())
def update_action_space(self, action_space: Dict[str, Any]) -> None:
r"""Updates action space for *python* interpreter"""
raise RuntimeError("SubprocessInterpreter doesn't support " "`action_space`.")

View File

@@ -1,29 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from .apify_reader import Apify
from .base_io import File
from .chunkr_reader import ChunkrReader
from .firecrawl_reader import Firecrawl
from .jina_url_reader import JinaURLReader
from .unstructured_io import UnstructuredIO
__all__ = [
'File',
'UnstructuredIO',
'JinaURLReader',
'Firecrawl',
'Apify',
'ChunkrReader',
]

View File

@@ -1,223 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import os
from typing import TYPE_CHECKING, List, Optional
if TYPE_CHECKING:
from apify_client.clients import DatasetClient
from camel.utils import api_keys_required
class Apify:
r"""Apify is a platform that allows you to automate any web workflow.
Args:
api_key (Optional[str]): API key for authenticating with the Apify API.
"""
@api_keys_required("APIFY_API_KEY")
def __init__(
self,
api_key: Optional[str] = None,
) -> None:
from apify_client import ApifyClient
self._api_key = api_key or os.environ.get("APIFY_API_KEY")
self.client = ApifyClient(token=self._api_key)
def run_actor(
self,
actor_id: str,
run_input: Optional[dict] = None,
content_type: Optional[str] = None,
build: Optional[str] = None,
max_items: Optional[int] = None,
memory_mbytes: Optional[int] = None,
timeout_secs: Optional[int] = None,
webhooks: Optional[list] = None,
wait_secs: Optional[int] = None,
) -> Optional[dict]:
r"""Run an actor on the Apify platform.
Args:
actor_id (str): The ID of the actor to run.
run_input (Optional[dict]): The input data for the actor. Defaults
to `None`.
content_type (str, optional): The content type of the input.
build (str, optional): Specifies the Actor build to run. It can be
either a build tag or build number. By default, the run uses
the build specified in the default run configuration for the
Actor (typically latest).
max_items (int, optional): Maximum number of results that will be
returned by this run. If the Actor is charged per result, you
will not be charged for more results than the given limit.
memory_mbytes (int, optional): Memory limit for the run, in
megabytes. By default, the run uses a memory limit specified in
the default run configuration for the Actor.
timeout_secs (int, optional): Optional timeout for the run, in
seconds. By default, the run uses timeout specified in the
default run configuration for the Actor.
webhooks (list, optional): Optional webhooks
(https://docs.apify.com/webhooks) associated with the Actor
run, which can be used to receive a notification, e.g. when the
Actor finished or failed. If you already have a webhook set up
for the Actor, you do not have to add it again here.
wait_secs (int, optional): The maximum number of seconds the server
waits for finish. If not provided, waits indefinitely.
Returns:
Optional[dict]: The output data from the actor if successful.
# please use the 'defaultDatasetId' to get the dataset
Raises:
RuntimeError: If the actor fails to run.
"""
try:
return self.client.actor(actor_id).call(
run_input=run_input,
content_type=content_type,
build=build,
max_items=max_items,
memory_mbytes=memory_mbytes,
timeout_secs=timeout_secs,
webhooks=webhooks,
wait_secs=wait_secs,
)
except Exception as e:
raise RuntimeError(f"Failed to run actor {actor_id}: {e}") from e
def get_dataset_client(
self,
dataset_id: str,
) -> "DatasetClient":
r"""Get a dataset client from the Apify platform.
Args:
dataset_id (str): The ID of the dataset to get the client for.
Returns:
DatasetClient: The dataset client.
Raises:
RuntimeError: If the dataset client fails to be retrieved.
"""
try:
return self.client.dataset(dataset_id)
except Exception as e:
raise RuntimeError(
f"Failed to get dataset {dataset_id}: {e}"
) from e
def get_dataset(
self,
dataset_id: str,
) -> Optional[dict]:
r"""Get a dataset from the Apify platform.
Args:
dataset_id (str): The ID of the dataset to get.
Returns:
dict: The dataset.
Raises:
RuntimeError: If the dataset fails to be retrieved.
"""
try:
return self.get_dataset_client(dataset_id).get()
except Exception as e:
raise RuntimeError(
f"Failed to get dataset {dataset_id}: {e}"
) from e
def update_dataset(
self,
dataset_id: str,
name: str,
) -> dict:
r"""Update a dataset on the Apify platform.
Args:
dataset_id (str): The ID of the dataset to update.
name (str): The new name for the dataset.
Returns:
dict: The updated dataset.
Raises:
RuntimeError: If the dataset fails to be updated.
"""
try:
return self.get_dataset_client(dataset_id).update(name=name)
except Exception as e:
raise RuntimeError(
f"Failed to update dataset {dataset_id}: {e}"
) from e
def get_dataset_items(
self,
dataset_id: str,
) -> List:
r"""Get items from a dataset on the Apify platform.
Args:
dataset_id (str): The ID of the dataset to get items from.
Returns:
list: The items in the dataset.
Raises:
RuntimeError: If the items fail to be retrieved.
"""
try:
items = self.get_dataset_client(dataset_id).list_items().items
return items
except Exception as e:
raise RuntimeError(
f"Failed to get dataset items {dataset_id}: {e}"
) from e
def get_datasets(
self,
unnamed: Optional[bool] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
desc: Optional[bool] = None,
) -> List[dict]:
r"""Get all named datasets from the Apify platform.
Args:
unnamed (bool, optional): Whether to include unnamed key-value
stores in the list
limit (int, optional): How many key-value stores to retrieve
offset (int, optional): What key-value store to include as first
when retrieving the list
desc (bool, optional): Whether to sort the key-value stores in
descending order based on their modification date
Returns:
List[dict]: The datasets.
Raises:
RuntimeError: If the datasets fail to be retrieved.
"""
try:
return (
self.client.datasets()
.list(unnamed=unnamed, limit=limit, offset=offset, desc=desc)
.items
)
except Exception as e:
raise RuntimeError(f"Failed to get datasets: {e}") from e

View File

@@ -1,328 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import json
import re
from abc import ABC, abstractmethod
from copy import deepcopy
from hashlib import md5
from io import BytesIO
from typing import Any, Dict, List, Optional
from camel.utils import dependencies_required
class File(ABC):
r"""Represents an uploaded file comprised of Documents.
Args:
name (str): The name of the file.
file_id (str): The unique identifier of the file.
metadata (Dict[str, Any], optional): Additional metadata
associated with the file. Defaults to None.
docs (List[Dict[str, Any]], optional): A list of documents
contained within the file. Defaults to None.
raw_bytes (bytes, optional): The raw bytes content of the file.
Defaults to b"".
"""
def __init__(
self,
name: str,
file_id: str,
metadata: Optional[Dict[str, Any]] = None,
docs: Optional[List[Dict[str, Any]]] = None,
raw_bytes: bytes = b"",
):
self.name = name
self.file_id = file_id
self.metadata = metadata or {}
self.docs = docs or []
self.raw_bytes = raw_bytes
@classmethod
@abstractmethod
def from_bytes(cls, file: BytesIO, filename: str) -> "File":
r"""Creates a File object from a BytesIO object.
Args:
file (BytesIO): A BytesIO object representing the contents of the
file.
filename (str): The name of the file.
Returns:
File: A File object.
"""
pass
@classmethod
def from_raw_bytes(cls, raw_bytes: bytes, filename: str) -> "File":
r"""Creates a File object from raw bytes.
Args:
raw_bytes (bytes): The raw bytes content of the file.
filename (str): The name of the file.
Returns:
File: A File object.
"""
file = BytesIO(raw_bytes)
return cls.from_bytes(file, filename)
@staticmethod
def create_file(file: BytesIO, filename: str) -> "File":
r"""Reads an uploaded file and returns a File object.
Args:
file (BytesIO): A BytesIO object representing the contents of the
file.
filename (str): The name of the file.
Returns:
File: A File object.
"""
ext_to_cls = {
"docx": DocxFile,
"pdf": PdfFile,
"txt": TxtFile,
"json": JsonFile,
"html": HtmlFile,
}
ext = filename.split(".")[-1].lower()
if ext not in ext_to_cls:
raise NotImplementedError(f"File type {ext} not supported")
out_file = ext_to_cls[ext].from_bytes(file, filename)
return out_file
@staticmethod
def create_file_from_raw_bytes(raw_bytes: bytes, filename: str) -> "File":
r"""Reads raw bytes and returns a File object.
Args:
raw_bytes (bytes): The raw bytes content of the file.
filename (str): The name of the file.
Returns:
File: A File object.
"""
file = BytesIO(raw_bytes)
return File.create_file(file, filename)
def __repr__(self) -> str:
return (
f"File(name={self.name}, id={self.file_id}, "
f"metadata={self.metadata}, docs={self.docs})"
)
def __str__(self) -> str:
return (
f"File(name={self.name}, id={self.file_id}, metadata="
f"{self.metadata})"
)
def copy(self) -> "File":
r"""Create a deep copy of this File"""
return self.__class__(
name=self.name,
file_id=self.file_id,
metadata=deepcopy(self.metadata),
docs=deepcopy(self.docs),
raw_bytes=self.raw_bytes,
)
def strip_consecutive_newlines(text: str) -> str:
r"""Strips consecutive newlines from a string.
Args:
text (str): The string to strip.
Returns:
str: The string with consecutive newlines stripped.
"""
return re.sub(r"\s*\n\s*", "\n", text)
class DocxFile(File):
@classmethod
@dependencies_required('docx2txt')
def from_bytes(cls, file: BytesIO, filename: str) -> "DocxFile":
r"""Creates a DocxFile object from a BytesIO object.
Args:
file (BytesIO): A BytesIO object representing the contents of the
docx file.
filename (str): The name of the file.
Returns:
DocxFile: A DocxFile object.
"""
import docx2txt
text = docx2txt.process(file)
text = strip_consecutive_newlines(text)
# Create a dictionary with the extracted text
doc = {"page_content": text.strip()}
# Calculate a unique identifier for the file
file_id = md5(file.getvalue()).hexdigest()
# Reset the file pointer to the beginning
file.seek(0)
return cls(
name=filename,
file_id=file_id,
docs=[doc],
raw_bytes=file.getvalue(),
)
class PdfFile(File):
@classmethod
def from_bytes(cls, file: BytesIO, filename: str) -> "PdfFile":
r"""Creates a PdfFile object from a BytesIO object.
Args:
file (BytesIO): A BytesIO object representing the contents of the
pdf file.
filename (str): The name of the file.
Returns:
PdfFile: A PdfFile object.
"""
# Use fitz to extract text from pdf files
try:
import fitz
except ImportError:
raise ImportError(
"Please install `PyMuPDF` first. "
"You can install it by running "
"`pip install PyMuPDF`."
)
pdf = fitz.open(stream=file.read(), filetype="pdf")
docs = []
for i, page in enumerate(pdf):
text = page.get_text(sort=True)
text = strip_consecutive_newlines(text)
# Create a dictionary with the extracted text
doc = {"page_content": text.strip(), "page": i + 1}
docs.append(doc)
# Calculate a unique identifier for the file
file_id = md5(file.getvalue()).hexdigest()
# Reset the file pointer to the beginning
file.seek(0)
return cls(
name=filename,
file_id=file_id,
docs=docs,
raw_bytes=file.getvalue(),
)
class TxtFile(File):
@classmethod
def from_bytes(cls, file: BytesIO, filename: str) -> "TxtFile":
r"""Creates a TxtFile object from a BytesIO object.
Args:
file (BytesIO): A BytesIO object representing the contents of the
txt file.
filename (str): The name of the file.
Returns:
TxtFile: A TxtFile object.
"""
# Read the text from the file
text = file.read().decode("utf-8")
text = strip_consecutive_newlines(text)
# Create a dictionary with the extracted text
doc = {"page_content": text.strip()}
# Calculate a unique identifier for the file
file_id = md5(file.getvalue()).hexdigest()
# Reset the file pointer to the beginning
file.seek(0)
return cls(
name=filename,
file_id=file_id,
docs=[doc],
raw_bytes=file.getvalue(),
)
class JsonFile(File):
@classmethod
def from_bytes(cls, file: BytesIO, filename: str) -> "JsonFile":
r"""Creates a JsonFile object from a BytesIO object.
Args:
file (BytesIO): A BytesIO object representing the contents of the
json file.
filename (str): The name of the file.
Returns:
JsonFile: A JsonFile object.
"""
# Parse the JSON data from the file
data = json.load(file)
# Create a dictionary with the parsed data
doc = {"page_content": json.dumps(data)}
# Calculate a unique identifier for the file
file_id = md5(file.getvalue()).hexdigest()
# Reset the file pointer to the beginning
file.seek(0)
return cls(
name=filename,
file_id=file_id,
docs=[doc],
raw_bytes=file.getvalue(),
)
class HtmlFile(File):
@classmethod
def from_bytes(cls, file: BytesIO, filename: str) -> "HtmlFile":
r"""Creates a HtmlFile object from a BytesIO object.
Args:
file (BytesIO): A BytesIO object representing the contents of the
html file.
filename (str): The name of the file.
Returns:
HtmlFile: A HtmlFile object.
"""
# Parse the HTML data from the file
try:
from bs4 import BeautifulSoup
except ImportError:
raise ImportError(
"Please install `beautifulsoup4` first. "
"You can install it by running "
"`pip install beautifulsoup4`."
)
soup = BeautifulSoup(file, "html.parser")
text = soup.get_text()
text = strip_consecutive_newlines(text)
# Create a dictionary with the parsed data
doc = {"page_content": text.strip()}
# Calculate a unique identifier for the file
file_id = md5(file.getvalue()).hexdigest()
# Reset the file pointer to the beginning
file.seek(0)
return cls(
name=filename,
file_id=file_id,
docs=[doc],
raw_bytes=file.getvalue(),
)

View File

@@ -1,162 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import json
import logging
import os
import time
from typing import IO, Any, Optional, Union
import requests
from camel.utils import api_keys_required
logger = logging.getLogger(__name__)
class ChunkrReader:
r"""Chunkr Reader for processing documents and returning content
in various formats.
Args:
api_key (Optional[str], optional): The API key for Chunkr API. If not
provided, it will be retrieved from the environment variable
`CHUNKR_API_KEY`. (default: :obj:`None`)
url (Optional[str], optional): The url to the Chunkr service.
(default: :obj:`https://api.chunkr.ai/api/v1/task`)
timeout (int, optional): The maximum time in seconds to wait for the
API responses. (default: :obj:`30`)
**kwargs (Any): Additional keyword arguments for request headers.
"""
def __init__(
self,
api_key: Optional[str] = None,
url: Optional[str] = "https://api.chunkr.ai/api/v1/task",
timeout: int = 30,
**kwargs: Any,
) -> None:
self._api_key = api_key or os.getenv('CHUNKR_API_KEY')
self._url = os.getenv('CHUNKR_API_URL') or url
self._headers = {
"Authorization": f"{self._api_key}",
**kwargs,
}
self.timeout = timeout
def submit_task(
self,
file_path: str,
model: str = "Fast",
ocr_strategy: str = "Auto",
target_chunk_length: str = "512",
) -> str:
r"""Submits a file to the Chunkr API and returns the task ID.
Args:
file_path (str): The path to the file to be uploaded.
model (str, optional): The model to be used for the task.
(default: :obj:`Fast`)
ocr_strategy (str, optional): The OCR strategy. Defaults to 'Auto'.
target_chunk_length (str, optional): The target chunk length.
(default: :obj:`512`)
Returns:
str: The task ID.
"""
with open(file_path, 'rb') as file:
files: dict[
str, Union[tuple[None, IO[bytes]], tuple[None, str]]
] = {
'file': (
None,
file,
), # Properly pass the file as a binary stream
'model': (None, model),
'ocr_strategy': (None, ocr_strategy),
'target_chunk_length': (None, target_chunk_length),
}
try:
response = requests.post(
self._url, # type: ignore[arg-type]
headers=self._headers,
files=files,
timeout=self.timeout,
)
response.raise_for_status()
task_id = response.json().get('task_id')
if not task_id:
raise ValueError("Task ID not returned in the response.")
logger.info(f"Task submitted successfully. Task ID: {task_id}")
return task_id
except Exception as e:
logger.error(f"Failed to submit task: {e}")
raise ValueError(f"Failed to submit task: {e}") from e
def get_task_output(self, task_id: str, max_retries: int = 5) -> str:
r"""Polls the Chunkr API to check the task status and returns the task
result.
Args:
task_id (str): The task ID to check the status for.
max_retries (int, optional): Maximum number of retry attempts.
(default: :obj:`5`)
Returns:
str: The formatted task result in JSON format.
Raises:
ValueError: If the task status cannot be retrieved.
RuntimeError: If the maximum number of retries is reached without
a successful task completion.
"""
url_get = f"{self._url}/{task_id}"
attempts = 0
while attempts < max_retries:
try:
response = requests.get(
url_get, headers=self._headers, timeout=self.timeout
)
response.raise_for_status()
task_status = response.json().get('status')
if task_status == "Succeeded":
logger.info(f"Task {task_id} completed successfully.")
return self._pretty_print_response(response.json())
else:
logger.info(
f"Task {task_id} is still {task_status}. Retrying "
"in 5 seconds..."
)
except Exception as e:
logger.error(f"Failed to retrieve task status: {e}")
raise ValueError(f"Failed to retrieve task status: {e}") from e
attempts += 1
time.sleep(5)
logger.error(f"Max retries reached for task {task_id}.")
raise RuntimeError(f"Max retries reached for task {task_id}.")
def _pretty_print_response(self, response_json: dict) -> str:
r"""Pretty prints the JSON response.
Args:
response_json (dict): The response JSON to pretty print.
Returns:
str: Formatted JSON as a string.
"""
return json.dumps(response_json, indent=4)

View File

@@ -1,202 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import os
from typing import Any, Dict, Optional
from pydantic import BaseModel
class Firecrawl:
r"""Firecrawl allows you to turn entire websites into LLM-ready markdown.
Args:
api_key (Optional[str]): API key for authenticating with the Firecrawl
API.
api_url (Optional[str]): Base URL for the Firecrawl API.
References:
https://docs.firecrawl.dev/introduction
"""
def __init__(
self,
api_key: Optional[str] = None,
api_url: Optional[str] = None,
) -> None:
from firecrawl import FirecrawlApp
self._api_key = api_key or os.environ.get("FIRECRAWL_API_KEY")
self._api_url = api_url or os.environ.get("FIRECRAWL_API_URL")
self.app = FirecrawlApp(api_key=self._api_key, api_url=self._api_url)
def crawl(
self,
url: str,
params: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
r"""Crawl a URL and all accessible subpages. Customize the crawl by
setting different parameters, and receive the full response or a job
ID based on the specified options.
Args:
url (str): The URL to crawl.
params (Optional[Dict[str, Any]]): Additional parameters for the
crawl request. Defaults to `None`.
**kwargs (Any): Additional keyword arguments, such as
`poll_interval`, `idempotency_key`.
Returns:
Any: The crawl job ID or the crawl results if waiting until
completion.
Raises:
RuntimeError: If the crawling process fails.
"""
try:
crawl_response = self.app.crawl_url(
url=url,
params=params,
**kwargs,
)
return crawl_response
except Exception as e:
raise RuntimeError(f"Failed to crawl the URL: {e}")
def markdown_crawl(self, url: str) -> str:
r"""Crawl a URL and all accessible subpages and return the content in
Markdown format.
Args:
url (str): The URL to crawl.
Returns:
str: The content of the URL in Markdown format.
Raises:
RuntimeError: If the crawling process fails.
"""
try:
crawl_result = self.app.crawl_url(
url,
{'formats': ['markdown']},
)
if not isinstance(crawl_result, list):
raise ValueError("Unexpected response format")
markdown_contents = [
result.get('markdown', '') for result in crawl_result
]
return '\n'.join(markdown_contents)
except Exception as e:
raise RuntimeError(
f"Failed to crawl the URL and retrieve markdown: {e}"
)
def check_crawl_job(self, job_id: str) -> Dict:
r"""Check the status of a crawl job.
Args:
job_id (str): The ID of the crawl job.
Returns:
Dict: The response including status of the crawl job.
Raises:
RuntimeError: If the check process fails.
"""
try:
return self.app.check_crawl_status(job_id)
except Exception as e:
raise RuntimeError(f"Failed to check the crawl job status: {e}")
def scrape(
self,
url: str,
params: Optional[Dict[str, Any]] = None,
) -> Dict:
r"""To scrape a single URL. This function supports advanced scraping
by setting different parameters and returns the full scraped data as a
dictionary.
Reference: https://docs.firecrawl.dev/advanced-scraping-guide
Args:
url (str): The URL to read.
params (Optional[Dict[str, Any]]): Additional parameters for the
scrape request.
Returns:
Dict: The scraped data.
Raises:
RuntimeError: If the scrape process fails.
"""
try:
return self.app.scrape_url(url=url, params=params)
except Exception as e:
raise RuntimeError(f"Failed to scrape the URL: {e}")
def structured_scrape(self, url: str, response_format: BaseModel) -> Dict:
r"""Use LLM to extract structured data from given URL.
Args:
url (str): The URL to read.
response_format (BaseModel): A pydantic model
that includes value types and field descriptions used to
generate a structured response by LLM. This schema helps
in defining the expected output format.
Returns:
Dict: The content of the URL.
Raises:
RuntimeError: If the scrape process fails.
"""
try:
data = self.app.scrape_url(
url,
{
'formats': ['extract'],
'extract': {'schema': response_format.model_json_schema()},
},
)
return data.get("extract", {})
except Exception as e:
raise RuntimeError(f"Failed to perform structured scrape: {e}")
def map_site(
self, url: str, params: Optional[Dict[str, Any]] = None
) -> list:
r"""Map a website to retrieve all accessible URLs.
Args:
url (str): The URL of the site to map.
params (Optional[Dict[str, Any]]): Additional parameters for the
map request. Defaults to `None`.
Returns:
list: A list containing the URLs found on the site.
Raises:
RuntimeError: If the mapping process fails.
"""
try:
return self.app.map_url(url=url, params=params)
except Exception as e:
raise RuntimeError(f"Failed to map the site: {e}")

View File

@@ -1,99 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import os
from typing import Any, Optional
from warnings import warn
from camel.types.enums import JinaReturnFormat
JINA_ENDPOINT = "https://r.jina.ai/"
class JinaURLReader:
r"""URL Reader provided by Jina AI. The output is cleaner and more
LLM-friendly than the URL Reader of UnstructuredIO. Can be configured to
replace the UnstructuredIO URL Reader in the pipeline.
Args:
api_key (Optional[str], optional): The API key for Jina AI. If not
provided, the reader will have a lower rate limit. Defaults to
None.
return_format (ReturnFormat, optional): The level of detail
of the returned content, which is optimized for LLMs. For
now screenshots are not supported. Defaults to
ReturnFormat.DEFAULT.
json_response (bool, optional): Whether to return the response
in JSON format. Defaults to False.
timeout (int, optional): The maximum time in seconds to wait for
the page to be rendered. Defaults to 30.
**kwargs (Any): Additional keyword arguments, including proxies,
cookies, etc. It should align with the HTTP Header field and
value pairs listed in the reference.
References:
https://jina.ai/reader
"""
def __init__(
self,
api_key: Optional[str] = None,
return_format: JinaReturnFormat = JinaReturnFormat.DEFAULT,
json_response: bool = False,
timeout: int = 30,
**kwargs: Any,
) -> None:
api_key = api_key or os.getenv('JINA_API_KEY')
if not api_key:
warn(
"JINA_API_KEY not set. This will result in a low rate limit "
"of Jina URL Reader. Get API key here: https://jina.ai/reader."
)
# if the following field not provided, it will be None
api_field = f"Bearer {api_key}" if api_key else None
json_field = "application/json" if json_response else None
raw_headers = {
"Authorization": api_field,
"X-Return-Format": return_format.value,
"Accept": json_field,
"X-Timeout": str(timeout),
**kwargs,
}
# eliminate None values
self._headers = {k: v for k, v in raw_headers.items() if v}
def read_content(self, url: str) -> str:
r"""Reads the content of a URL and returns it as a string with
given form.
Args:
url (str): The URL to read.
Returns:
str: The content of the URL.
"""
import requests
full_url = f"{JINA_ENDPOINT}{url}"
try:
resp = requests.get(full_url, headers=self._headers)
resp.raise_for_status()
except Exception as e:
raise ValueError(f"Failed to read content from {url}: {e}") from e
return resp.text

View File

@@ -1,471 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import uuid
import warnings
from typing import (
IO,
TYPE_CHECKING,
Any,
Dict,
List,
Literal,
Optional,
Tuple,
Union,
)
if TYPE_CHECKING:
from unstructured.documents.elements import Element
import pdb
class UnstructuredIO:
r"""A class to handle various functionalities provided by the
Unstructured library, including version checking, parsing, cleaning,
extracting, staging, chunking data, and integrating with cloud
services like S3 and Azure for data connection.
References:
https://docs.unstructured.io/
"""
@staticmethod
def create_element_from_text(
text: str,
element_id: Optional[str] = None,
embeddings: Optional[List[float]] = None,
filename: Optional[str] = None,
file_directory: Optional[str] = None,
last_modified: Optional[str] = None,
filetype: Optional[str] = None,
parent_id: Optional[str] = None,
) -> "Element":
r"""Creates a Text element from a given text input, with optional
metadata and embeddings.
Args:
text (str): The text content for the element.
element_id (Optional[str], optional): Unique identifier for the
element. (default: :obj:`None`)
embeddings (List[float], optional): A list of float
numbers representing the text embeddings.
(default: :obj:`None`)
filename (Optional[str], optional): The name of the file the
element is associated with. (default: :obj:`None`)
file_directory (Optional[str], optional): The directory path where
the file is located. (default: :obj:`None`)
last_modified (Optional[str], optional): The last modified date of
the file. (default: :obj:`None`)
filetype (Optional[str], optional): The type of the file.
(default: :obj:`None`)
parent_id (Optional[str], optional): The identifier of the parent
element. (default: :obj:`None`)
Returns:
Element: An instance of Text with the provided content and
metadata.
"""
from unstructured.documents.elements import ElementMetadata, Text
metadata = ElementMetadata(
filename=filename,
file_directory=file_directory,
last_modified=last_modified,
filetype=filetype,
parent_id=parent_id,
)
return Text(
text=text,
element_id=element_id or str(uuid.uuid4()),
metadata=metadata,
embeddings=embeddings,
)
@staticmethod
def parse_file_or_url(
input_path: str,
**kwargs: Any,
) -> Union[List["Element"], None]:
r"""Loads a file or a URL and parses its contents into elements.
Args:
input_path (str): Path to the file or URL to be parsed.
**kwargs: Extra kwargs passed to the partition function.
Returns:
Union[List[Element],None]: List of elements after parsing the file
or URL if success.
Raises:
FileNotFoundError: If the file does not exist at the path
specified.
Notes:
Supported file types:
"csv", "doc", "docx", "epub", "image", "md", "msg", "odt",
"org", "pdf", "ppt", "pptx", "rtf", "rst", "tsv", "xlsx".
References:
https://unstructured-io.github.io/unstructured/
"""
import os
from urllib.parse import urlparse
from unstructured.partition.auto import partition
# Check if the input is a URL
parsed_url = urlparse(input_path)
# pdb.set_trace()
is_url = all([parsed_url.scheme, parsed_url.netloc])
# Handling URL
if is_url:
try:
elements = partition(url=input_path, **kwargs)
return elements
except Exception:
warnings.warn(f"Failed to parse the URL: {input_path}")
return None
# Handling file
else:
# Check if the file exists
if not os.path.exists(input_path):
raise FileNotFoundError(
f"The file {input_path} was not found."
)
# Read the file
try:
with open(input_path, "rb") as f:
elements = partition(file=f, **kwargs)
return elements
except Exception:
warnings.warn(f"Failed to partition the file: {input_path}")
return None
@staticmethod
def parse_bytes(
file: IO[bytes], **kwargs: Any
) -> Union[List["Element"], None]:
r"""Parses a bytes stream and converts its contents into elements.
Args:
file (IO[bytes]): The file in bytes format to be parsed.
**kwargs: Extra kwargs passed to the partition function.
Returns:
Union[List[Element], None]: List of elements after parsing the file
if successful, otherwise `None`.
Notes:
Supported file types:
"csv", "doc", "docx", "epub", "image", "md", "msg", "odt",
"org", "pdf", "ppt", "pptx", "rtf", "rst", "tsv", "xlsx".
References:
https://docs.unstructured.io/open-source/core-functionality/partitioning
"""
from unstructured.partition.auto import partition
try:
# Use partition to process the bytes stream
elements = partition(file=file, **kwargs)
return elements
except Exception as e:
warnings.warn(f"Failed to partition the file stream: {e}")
return None
@staticmethod
def clean_text_data(
text: str,
clean_options: Optional[List[Tuple[str, Dict[str, Any]]]] = None,
) -> str:
r"""Cleans text data using a variety of cleaning functions provided by
the `unstructured` library.
This function applies multiple text cleaning utilities by calling the
`unstructured` library's cleaning bricks for operations like
replacing Unicode quotes, removing extra whitespace, dashes, non-ascii
characters, and more.
If no cleaning options are provided, a default set of cleaning
operations is applied. These defaults including operations
"replace_unicode_quotes", "clean_non_ascii_chars",
"group_broken_paragraphs", and "clean_extra_whitespace".
Args:
text (str): The text to be cleaned.
clean_options (dict): A dictionary specifying which cleaning
options to apply. The keys should match the names of the
cleaning functions, and the values should be dictionaries
containing the parameters for each function. Supported types:
'clean_extra_whitespace', 'clean_bullets',
'clean_ordered_bullets', 'clean_postfix', 'clean_prefix',
'clean_dashes', 'clean_trailing_punctuation',
'clean_non_ascii_chars', 'group_broken_paragraphs',
'remove_punctuation', 'replace_unicode_quotes',
'bytes_string_to_string', 'translate_text'.
Returns:
str: The cleaned text.
Raises:
AttributeError: If a cleaning option does not correspond to a
valid cleaning function in `unstructured`.
Notes:
The 'options' dictionary keys must correspond to valid cleaning
brick names from the `unstructured` library.
Each brick's parameters must be provided in a nested dictionary
as the value for the key.
References:
https://unstructured-io.github.io/unstructured/
"""
from unstructured.cleaners.core import (
bytes_string_to_string,
clean_bullets,
clean_dashes,
clean_extra_whitespace,
clean_non_ascii_chars,
clean_ordered_bullets,
clean_postfix,
clean_prefix,
clean_trailing_punctuation,
group_broken_paragraphs,
remove_punctuation,
replace_unicode_quotes,
)
from unstructured.cleaners.translate import translate_text
cleaning_functions: Any = {
"clean_extra_whitespace": clean_extra_whitespace,
"clean_bullets": clean_bullets,
"clean_ordered_bullets": clean_ordered_bullets,
"clean_postfix": clean_postfix,
"clean_prefix": clean_prefix,
"clean_dashes": clean_dashes,
"clean_trailing_punctuation": clean_trailing_punctuation,
"clean_non_ascii_chars": clean_non_ascii_chars,
"group_broken_paragraphs": group_broken_paragraphs,
"remove_punctuation": remove_punctuation,
"replace_unicode_quotes": replace_unicode_quotes,
"bytes_string_to_string": bytes_string_to_string,
"translate_text": translate_text,
}
# Define default clean options if none are provided
if clean_options is None:
clean_options = [
("replace_unicode_quotes", {}),
("clean_non_ascii_chars", {}),
("group_broken_paragraphs", {}),
("clean_extra_whitespace", {}),
]
cleaned_text = text
for func_name, params in clean_options:
if func_name in cleaning_functions:
cleaned_text = cleaning_functions[func_name](
cleaned_text, **params
)
else:
raise ValueError(
f"'{func_name}' is not a valid function in "
"`Unstructured IO`."
)
return cleaned_text
@staticmethod
def extract_data_from_text(
text: str,
extract_type: Literal[
'extract_datetimetz',
'extract_email_address',
'extract_ip_address',
'extract_ip_address_name',
'extract_mapi_id',
'extract_ordered_bullets',
'extract_text_after',
'extract_text_before',
'extract_us_phone_number',
],
**kwargs,
) -> Any:
r"""Extracts various types of data from text using functions from
unstructured.cleaners.extract.
Args:
text (str): Text to extract data from.
extract_type (Literal['extract_datetimetz',
'extract_email_address', 'extract_ip_address',
'extract_ip_address_name', 'extract_mapi_id',
'extract_ordered_bullets', 'extract_text_after',
'extract_text_before', 'extract_us_phone_number']): Type of
data to extract.
**kwargs: Additional keyword arguments for specific
extraction functions.
Returns:
Any: The extracted data, type depends on extract_type.
References:
https://unstructured-io.github.io/unstructured/
"""
from unstructured.cleaners.extract import (
extract_datetimetz,
extract_email_address,
extract_ip_address,
extract_ip_address_name,
extract_mapi_id,
extract_ordered_bullets,
extract_text_after,
extract_text_before,
extract_us_phone_number,
)
extraction_functions: Any = {
"extract_datetimetz": extract_datetimetz,
"extract_email_address": extract_email_address,
"extract_ip_address": extract_ip_address,
"extract_ip_address_name": extract_ip_address_name,
"extract_mapi_id": extract_mapi_id,
"extract_ordered_bullets": extract_ordered_bullets,
"extract_text_after": extract_text_after,
"extract_text_before": extract_text_before,
"extract_us_phone_number": extract_us_phone_number,
}
if extract_type not in extraction_functions:
raise ValueError(f"Unsupported extract_type: {extract_type}")
return extraction_functions[extract_type](text, **kwargs)
@staticmethod
def stage_elements(
elements: List[Any],
stage_type: Literal[
'convert_to_csv',
'convert_to_dataframe',
'convert_to_dict',
'dict_to_elements',
'stage_csv_for_prodigy',
'stage_for_prodigy',
'stage_for_baseplate',
'stage_for_datasaur',
'stage_for_label_box',
'stage_for_label_studio',
'stage_for_weaviate',
],
**kwargs,
) -> Union[str, List[Dict], Any]:
r"""Stages elements for various platforms based on the
specified staging type.
This function applies multiple staging utilities to format data
for different NLP annotation and machine learning tools. It uses
the 'unstructured.staging' module's functions for operations like
converting to CSV, DataFrame, dictionary, or formatting for
specific platforms like Prodigy, etc.
Args:
elements (List[Any]): List of Element objects to be staged.
stage_type (Literal['convert_to_csv', 'convert_to_dataframe',
'convert_to_dict', 'dict_to_elements',
'stage_csv_for_prodigy', 'stage_for_prodigy',
'stage_for_baseplate', 'stage_for_datasaur',
'stage_for_label_box', 'stage_for_label_studio',
'stage_for_weaviate']): Type of staging to perform.
**kwargs: Additional keyword arguments specific to
the staging type.
Returns:
Union[str, List[Dict], Any]: Staged data in the
format appropriate for the specified staging type.
Raises:
ValueError: If the staging type is not supported or a required
argument is missing.
References:
https://unstructured-io.github.io/unstructured/
"""
from unstructured.staging import (
base,
baseplate,
datasaur,
label_box,
label_studio,
prodigy,
weaviate,
)
staging_functions: Any = {
"convert_to_csv": base.convert_to_csv,
"convert_to_dataframe": base.convert_to_dataframe,
"convert_to_dict": base.convert_to_dict,
"dict_to_elements": base.dict_to_elements,
"stage_csv_for_prodigy": lambda els,
**kw: prodigy.stage_csv_for_prodigy(els, kw.get('metadata', [])),
"stage_for_prodigy": lambda els, **kw: prodigy.stage_for_prodigy(
els, kw.get('metadata', [])
),
"stage_for_baseplate": baseplate.stage_for_baseplate,
"stage_for_datasaur": lambda els,
**kw: datasaur.stage_for_datasaur(els, kw.get('entities', [])),
"stage_for_label_box": lambda els,
**kw: label_box.stage_for_label_box(els, **kw),
"stage_for_label_studio": lambda els,
**kw: label_studio.stage_for_label_studio(els, **kw),
"stage_for_weaviate": weaviate.stage_for_weaviate,
}
if stage_type not in staging_functions:
raise ValueError(f"Unsupported stage type: {stage_type}")
return staging_functions[stage_type](elements, **kwargs)
@staticmethod
def chunk_elements(
elements: List["Element"], chunk_type: str, **kwargs
) -> List["Element"]:
r"""Chunks elements by titles.
Args:
elements (List[Element]): List of Element objects to be chunked.
chunk_type (str): Type chunk going to apply. Supported types:
'chunk_by_title'.
**kwargs: Additional keyword arguments for chunking.
Returns:
List[Dict]: List of chunked sections.
References:
https://unstructured-io.github.io/unstructured/
"""
from unstructured.chunking.title import chunk_by_title
chunking_functions = {
"chunk_by_title": chunk_by_title,
}
if chunk_type not in chunking_functions:
raise ValueError(f"Unsupported chunk type: {chunk_type}")
# Format chunks into a list of dictionaries (or your preferred format)
return chunking_functions[chunk_type](elements, **kwargs)

View File

@@ -1,112 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import logging
import os
import sys
# Create a private logger
_logger = logging.getLogger('camel')
def _configure_library_logging():
if os.environ.get('CAMEL_LOGGING_DISABLED', 'False').lower() == 'true':
return
if not logging.root.handlers and not _logger.handlers:
logging.basicConfig(
level=os.environ.get('LOGLEVEL', 'INFO').upper(),
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
stream=sys.stdout,
)
logging.setLoggerClass(logging.Logger)
_logger.info("Camel library logging has been configured.")
else:
_logger.debug("Existing logger configuration found, using that.")
def disable_logging():
r"""Disable all logging for the Camel library.
This function sets the log level to a value higher than CRITICAL,
effectively disabling all log messages, and adds a NullHandler to
suppress any potential warnings about no handlers being found.
"""
os.environ['CAMEL_LOGGING_DISABLED'] = 'true'
_logger.setLevel(logging.CRITICAL + 1)
# Avoid adding multiple NullHandlers
if not any(
isinstance(handler, logging.NullHandler)
for handler in _logger.handlers
):
_logger.addHandler(logging.NullHandler())
_logger.debug("Logging has been disabled.")
def enable_logging():
r"""Enable logging for the Camel library.
This function re-enables logging if it was previously disabled,
and configures the library logging using the default settings.
If the logging is already configured,
this function does not change its configuration.
"""
os.environ['CAMEL_LOGGING_DISABLED'] = 'false'
_configure_library_logging()
def set_log_level(level):
r"""Set the logging level for the Camel library.
Args:
level (Union[str, int]): The logging level to set. This can be a string
(e.g., 'INFO') or a logging level constant (e.g., logging.INFO,
logging.DEBUG).
See https://docs.python.org/3/library/logging.html#levels
Raises:
ValueError: If the provided level is not a valid logging level.
"""
valid_levels = ['NOTSET', 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']
if isinstance(level, str):
if level.upper() not in valid_levels:
raise ValueError(
f"Invalid logging level."
f" Choose from: {', '.join(valid_levels)}"
)
level = level.upper()
elif not isinstance(level, int):
raise ValueError(
"Logging level must be an option from the logging module."
)
_logger.setLevel(level)
_logger.debug(f"Logging level set to: {logging.getLevelName(level)}")
def get_logger(name):
r"""Get a logger with the specified name, prefixed with 'camel.'.
Args:
name (str): The name to be appended to 'camel.' to create the logger.
Returns:
logging.Logger: A logger instance with the name 'camel.{name}'.
"""
return logging.getLogger(f'camel.{name}')
# Lazy configuration: Only configure logging if explicitly enabled.
if os.environ.get('CAMEL_LOGGING_DISABLED', 'False').strip().lower() != 'true':
_configure_library_logging()

View File

@@ -1,38 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from .agent_memories import (
ChatHistoryMemory,
LongtermAgentMemory,
VectorDBMemory,
)
from .base import AgentMemory, BaseContextCreator, MemoryBlock
from .blocks.chat_history_block import ChatHistoryBlock
from .blocks.vectordb_block import VectorDBBlock
from .context_creators.score_based import ScoreBasedContextCreator
from .records import ContextRecord, MemoryRecord
__all__ = [
'MemoryRecord',
'ContextRecord',
'MemoryBlock',
"AgentMemory",
'BaseContextCreator',
'ScoreBasedContextCreator',
'ChatHistoryMemory',
'VectorDBMemory',
'ChatHistoryBlock',
'VectorDBBlock',
'LongtermAgentMemory',
]

View File

@@ -1,176 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from typing import List, Optional
from camel.memories.base import AgentMemory, BaseContextCreator
from camel.memories.blocks import ChatHistoryBlock, VectorDBBlock
from camel.memories.records import ContextRecord, MemoryRecord
from camel.storages import BaseKeyValueStorage, BaseVectorStorage
from camel.types import OpenAIBackendRole
class ChatHistoryMemory(AgentMemory):
r"""An agent memory wrapper of :obj:`ChatHistoryBlock`.
Args:
context_creator (BaseContextCreator): A model context creator.
storage (BaseKeyValueStorage, optional): A storage backend for storing
chat history. If `None`, an :obj:`InMemoryKeyValueStorage`
will be used. (default: :obj:`None`)
window_size (int, optional): The number of recent chat messages to
retrieve. If not provided, the entire chat history will be
retrieved. (default: :obj:`None`)
"""
def __init__(
self,
context_creator: BaseContextCreator,
storage: Optional[BaseKeyValueStorage] = None,
window_size: Optional[int] = None,
) -> None:
if window_size is not None and not isinstance(window_size, int):
raise TypeError("`window_size` must be an integer or None.")
if window_size is not None and window_size < 0:
raise ValueError("`window_size` must be non-negative.")
self._context_creator = context_creator
self._window_size = window_size
self._chat_history_block = ChatHistoryBlock(storage=storage)
def retrieve(self) -> List[ContextRecord]:
return self._chat_history_block.retrieve(self._window_size)
def write_records(self, records: List[MemoryRecord]) -> None:
self._chat_history_block.write_records(records)
def get_context_creator(self) -> BaseContextCreator:
return self._context_creator
def clear(self) -> None:
self._chat_history_block.clear()
class VectorDBMemory(AgentMemory):
r"""An agent memory wrapper of :obj:`VectorDBBlock`. This memory queries
messages stored in the vector database. Notice that the most recent
messages will not be added to the context.
Args:
context_creator (BaseContextCreator): A model context creator.
storage (BaseVectorStorage, optional): A vector storage storage. If
`None`, an :obj:`QdrantStorage` will be used.
(default: :obj:`None`)
retrieve_limit (int, optional): The maximum number of messages
to be added into the context. (default: :obj:`3`)
"""
def __init__(
self,
context_creator: BaseContextCreator,
storage: Optional[BaseVectorStorage] = None,
retrieve_limit: int = 3,
) -> None:
self._context_creator = context_creator
self._retrieve_limit = retrieve_limit
self._vectordb_block = VectorDBBlock(storage=storage)
self._current_topic: str = ""
def retrieve(self) -> List[ContextRecord]:
return self._vectordb_block.retrieve(
self._current_topic,
limit=self._retrieve_limit,
)
def write_records(self, records: List[MemoryRecord]) -> None:
# Assume the last user input is the current topic.
for record in records:
if record.role_at_backend == OpenAIBackendRole.USER:
self._current_topic = record.message.content
self._vectordb_block.write_records(records)
def get_context_creator(self) -> BaseContextCreator:
return self._context_creator
class LongtermAgentMemory(AgentMemory):
r"""An implementation of the :obj:`AgentMemory` abstract base class for
augmenting ChatHistoryMemory with VectorDBMemory.
Args:
context_creator (BaseContextCreator): A model context creator.
chat_history_block (Optional[ChatHistoryBlock], optional): A chat
history block. If `None`, a :obj:`ChatHistoryBlock` will be used.
(default: :obj:`None`)
vector_db_block (Optional[VectorDBBlock], optional): A vector database
block. If `None`, a :obj:`VectorDBBlock` will be used.
(default: :obj:`None`)
retrieve_limit (int, optional): The maximum number of messages
to be added into the context. (default: :obj:`3`)
"""
def __init__(
self,
context_creator: BaseContextCreator,
chat_history_block: Optional[ChatHistoryBlock] = None,
vector_db_block: Optional[VectorDBBlock] = None,
retrieve_limit: int = 3,
) -> None:
self.chat_history_block = chat_history_block or ChatHistoryBlock()
self.vector_db_block = vector_db_block or VectorDBBlock()
self.retrieve_limit = retrieve_limit
self._context_creator = context_creator
self._current_topic: str = ""
def get_context_creator(self) -> BaseContextCreator:
r"""Returns the context creator used by the memory.
Returns:
BaseContextCreator: The context creator used by the memory.
"""
return self._context_creator
def retrieve(self) -> List[ContextRecord]:
r"""Retrieves context records from both the chat history and the vector
database.
Returns:
List[ContextRecord]: A list of context records retrieved from both
the chat history and the vector database.
"""
chat_history = self.chat_history_block.retrieve()
vector_db_retrieve = self.vector_db_block.retrieve(
self._current_topic, self.retrieve_limit
)
return chat_history[:1] + vector_db_retrieve + chat_history[1:]
def write_records(self, records: List[MemoryRecord]) -> None:
r"""Converts the provided chat messages into vector representations and
writes them to the vector database.
Args:
records (List[MemoryRecord]): Messages to be added to the vector
database.
"""
self.vector_db_block.write_records(records)
self.chat_history_block.write_records(records)
for record in records:
if record.role_at_backend == OpenAIBackendRole.USER:
self._current_topic = record.message.content
def clear(self) -> None:
r"""Removes all records from the memory."""
self.chat_history_block.clear()
self.vector_db_block.clear()

View File

@@ -1,140 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from abc import ABC, abstractmethod
from typing import List, Tuple
from camel.memories.records import ContextRecord, MemoryRecord
from camel.messages import OpenAIMessage
from camel.utils import BaseTokenCounter
class MemoryBlock(ABC):
r"""An abstract class serves as the fundamental component within the agent
memory system. This class is equipped with "write" and "clear" functions.
However, it intentionally does not define a retrieval interface, as the
structure of the data to be retrieved may vary in different types of
memory blocks.
"""
@abstractmethod
def write_records(self, records: List[MemoryRecord]) -> None:
r"""Writes records to the memory, appending them to existing ones.
Args:
records (List[MemoryRecord]): Records to be added to the memory.
"""
pass
def write_record(self, record: MemoryRecord) -> None:
r"""Writes a record to the memory, appending it to existing ones.
Args:
record (MemoryRecord): Record to be added to the memory.
"""
self.write_records([record])
@abstractmethod
def clear(self) -> None:
r"""Clears all messages from the memory."""
pass
class BaseContextCreator(ABC):
r"""An abstract base class defining the interface for context creation
strategies.
This class provides a foundational structure for different strategies to
generate conversational context from a list of context records. The
primary goal is to create a context that is aligned with a specified token
count limit, allowing subclasses to define their specific approach.
Subclasses should implement the :obj:`token_counter`,:obj: `token_limit`,
and :obj:`create_context` methods to provide specific context creation
logic.
Attributes:
token_counter (BaseTokenCounter): A token counter instance responsible
for counting tokens in a message.
token_limit (int): The maximum number of tokens allowed in the
generated context.
"""
@property
@abstractmethod
def token_counter(self) -> BaseTokenCounter:
pass
@property
@abstractmethod
def token_limit(self) -> int:
pass
@abstractmethod
def create_context(
self,
records: List[ContextRecord],
) -> Tuple[List[OpenAIMessage], int]:
r"""An abstract method to create conversational context from the chat
history.
Constructs the context from provided records. The specifics of how this
is done and how the token count is managed should be provided by
subclasses implementing this method. The output messages order
should keep same as the input order.
Args:
records (List[ContextRecord]): A list of context records from
which to generate the context.
Returns:
Tuple[List[OpenAIMessage], int]: A tuple containing the constructed
context in OpenAIMessage format and the total token count.
"""
pass
class AgentMemory(MemoryBlock, ABC):
r"""Represents a specialized form of `MemoryBlock`, uniquely designed for
direct integration with an agent. Two key abstract functions, "retrieve"
and "get_context_creator", are used for generating model context based on
the memory records stored within the AgentMemory.
"""
@abstractmethod
def retrieve(self) -> List[ContextRecord]:
r"""Get a record list from the memory for creating model context.
Returns:
List[ContextRecord]: A record list for creating model context.
"""
pass
@abstractmethod
def get_context_creator(self) -> BaseContextCreator:
r"""Gets context creator.
Returns:
BaseContextCreator: A model context creator.
"""
pass
def get_context(self) -> Tuple[List[OpenAIMessage], int]:
r"""Gets chat context with a proper size for the agent from the memory.
Returns:
(List[OpenAIMessage], int): A tuple containing the constructed
context in OpenAIMessage format and the total token count.
"""
return self.get_context_creator().create_context(self.retrieve())

View File

@@ -1,21 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from .chat_history_block import ChatHistoryBlock
from .vectordb_block import VectorDBBlock
__all__ = [
'ChatHistoryBlock',
'VectorDBBlock',
]

View File

@@ -1,115 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import warnings
from typing import List, Optional
from camel.memories.base import MemoryBlock
from camel.memories.records import ContextRecord, MemoryRecord
from camel.storages import BaseKeyValueStorage, InMemoryKeyValueStorage
from camel.types import OpenAIBackendRole
class ChatHistoryBlock(MemoryBlock):
r"""An implementation of the :obj:`MemoryBlock` abstract base class for
maintaining a record of chat histories.
This memory block helps manage conversation histories with a key-value
storage backend, either provided by the user or using a default
in-memory storage. It offers a windowed approach to retrieving chat
histories, allowing users to specify how many recent messages they'd
like to fetch.
Args:
storage (BaseKeyValueStorage, optional): A storage mechanism for
storing chat history. If `None`, an :obj:`InMemoryKeyValueStorage`
will be used. (default: :obj:`None`)
keep_rate (float, optional): In historical messages, the score of the
last message is 1.0, and with each step taken backward, the score
of the message is multiplied by the `keep_rate`. Higher `keep_rate`
leads to high possiblity to keep history messages during context
creation.
"""
def __init__(
self,
storage: Optional[BaseKeyValueStorage] = None,
keep_rate: float = 0.9,
) -> None:
if keep_rate > 1 or keep_rate < 0:
raise ValueError("`keep_rate` should be in [0,1]")
self.storage = storage or InMemoryKeyValueStorage()
self.keep_rate = keep_rate
def retrieve(
self,
window_size: Optional[int] = None,
) -> List[ContextRecord]:
r"""Retrieves records with a proper size for the agent from the memory
based on the window size or fetches the entire chat history if no
window size is specified.
Args:
window_size (int, optional): Specifies the number of recent chat
messages to retrieve. If not provided, the entire chat history
will be retrieved. (default: :obj:`None`)
Returns:
List[ContextRecord]: A list of retrieved records.
"""
record_dicts = self.storage.load()
if len(record_dicts) == 0:
warnings.warn("The `ChatHistoryMemory` is empty.")
return list()
chat_records: List[MemoryRecord] = []
truncate_idx = -window_size if window_size is not None else 0
for record_dict in record_dicts[truncate_idx:]:
chat_records.append(MemoryRecord.from_dict(record_dict))
# We assume that, in the chat history memory, the closer the record is
# to the current message, the more score it will be.
output_records = []
score = 1.0
for record in reversed(chat_records):
if record.role_at_backend == OpenAIBackendRole.SYSTEM:
# System messages are always kept.
output_records.append(
ContextRecord(memory_record=record, score=1.0)
)
else:
# Other messages' score drops down gradually
score *= self.keep_rate
output_records.append(
ContextRecord(memory_record=record, score=score)
)
output_records.reverse()
return output_records
def write_records(self, records: List[MemoryRecord]) -> None:
r"""Writes memory records to the memory. Additionally, performs
validation checks on the messages.
Args:
records (List[MemoryRecord]): Memory records to be added to the
memory.
"""
stored_records = []
for record in records:
stored_records.append(record.to_dict())
self.storage.save(stored_records)
def clear(self) -> None:
r"""Clears all chat messages from the memory."""
self.storage.clear()

View File

@@ -1,103 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from typing import List, Optional
from camel.embeddings import BaseEmbedding, OpenAIEmbedding
from camel.memories.base import MemoryBlock
from camel.memories.records import ContextRecord, MemoryRecord
from camel.storages.vectordb_storages import (
BaseVectorStorage,
QdrantStorage,
VectorDBQuery,
VectorRecord,
)
class VectorDBBlock(MemoryBlock):
r"""An implementation of the :obj:`MemoryBlock` abstract base class for
maintaining and retrieving information using vector embeddings within a
vector database.
Args:
storage (Optional[BaseVectorStorage], optional): The storage mechanism
for the vector database. Defaults to in-memory :obj:`Qdrant` if not
provided. (default: :obj:`None`)
embedding (Optional[BaseEmbedding], optional): Embedding mechanism to
convert chat messages into vector representations. Defaults to
:obj:`OpenAiEmbedding` if not provided. (default: :obj:`None`)
"""
def __init__(
self,
storage: Optional[BaseVectorStorage] = None,
embedding: Optional[BaseEmbedding] = None,
) -> None:
self.embedding = embedding or OpenAIEmbedding()
self.vector_dim = self.embedding.get_output_dim()
self.storage = storage or QdrantStorage(vector_dim=self.vector_dim)
def retrieve(
self,
keyword: str,
limit: int = 3,
) -> List[ContextRecord]:
r"""Retrieves similar records from the vector database based on the
content of the keyword.
Args:
keyword (str): This string will be converted into a vector
representation to query the database.
limit (int, optional): The maximum number of similar messages to
retrieve. (default: :obj:`3`).
Returns:
List[ContextRecord]: A list of memory records retrieved from the
vector database based on similarity to :obj:`current_state`.
"""
query_vector = self.embedding.embed(keyword)
results = self.storage.query(
VectorDBQuery(query_vector=query_vector, top_k=limit)
)
return [
ContextRecord(
memory_record=MemoryRecord.from_dict(result.record.payload),
score=result.similarity,
)
for result in results
if result.record.payload is not None
]
def write_records(self, records: List[MemoryRecord]) -> None:
"""
Converts the provided chat messages into vector representations and
writes them to the vector database.
Args:
records (List[MemoryRecord]): Memory records to be added to the
memory.
"""
v_records = [
VectorRecord(
vector=self.embedding.embed(record.message.content),
payload=record.to_dict(),
id=str(record.uuid),
)
for record in records
]
self.storage.add(v_records)
def clear(self) -> None:
r"""Removes all records from the vector database memory."""
self.storage.clear()

View File

@@ -1,19 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from .score_based import ScoreBasedContextCreator
__all__ = [
'ScoreBasedContextCreator',
]

View File

@@ -1,142 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from typing import List, Tuple
from pydantic import BaseModel
from camel.memories.base import BaseContextCreator
from camel.memories.records import ContextRecord
from camel.messages import OpenAIMessage
from camel.utils import BaseTokenCounter
class _ContextUnit(BaseModel):
idx: int
record: ContextRecord
num_tokens: int
class ScoreBasedContextCreator(BaseContextCreator):
r"""A default implementation of context creation strategy, which inherits
from :obj:`BaseContextCreator`.
This class provides a strategy to generate a conversational context from
a list of chat history records while ensuring the total token count of
the context does not exceed a specified limit. It prunes messages based
on their score if the total token count exceeds the limit.
Args:
token_counter (BaseTokenCounter): An instance responsible for counting
tokens in a message.
token_limit (int): The maximum number of tokens allowed in the
generated context.
"""
def __init__(
self, token_counter: BaseTokenCounter, token_limit: int
) -> None:
self._token_counter = token_counter
self._token_limit = token_limit
@property
def token_counter(self) -> BaseTokenCounter:
return self._token_counter
@property
def token_limit(self) -> int:
return self._token_limit
def create_context(
self,
records: List[ContextRecord],
) -> Tuple[List[OpenAIMessage], int]:
r"""Creates conversational context from chat history while respecting
token limits.
Constructs the context from provided records and ensures that the total
token count does not exceed the specified limit by pruning the least
score messages if necessary.
Args:
records (List[ContextRecord]): A list of message records from which
to generate the context.
Returns:
Tuple[List[OpenAIMessage], int]: A tuple containing the constructed
context in OpenAIMessage format and the total token count.
Raises:
RuntimeError: If it's impossible to create a valid context without
exceeding the token limit.
"""
# Create unique context units list
uuid_set = set()
context_units = []
for idx, record in enumerate(records):
if record.memory_record.uuid not in uuid_set:
uuid_set.add(record.memory_record.uuid)
context_units.append(
_ContextUnit(
idx=idx,
record=record,
num_tokens=self.token_counter.count_tokens_from_messages(
[record.memory_record.to_openai_message()]
),
)
)
# TODO: optimize the process, may give information back to memory
# If not exceed token limit, simply return
total_tokens = sum([unit.num_tokens for unit in context_units])
if total_tokens <= self.token_limit:
return self._create_output(context_units)
# Sort by score
context_units = sorted(
context_units, key=lambda unit: unit.record.score
)
# Remove the least score messages until total token number is smaller
# than token limit
truncate_idx = None
for i, unit in enumerate(context_units):
if unit.record.score == 1:
raise RuntimeError(
"Cannot create context: exceed token limit.", total_tokens
)
total_tokens -= unit.num_tokens
if total_tokens <= self.token_limit:
truncate_idx = i
break
if truncate_idx is None:
raise RuntimeError(
"Cannot create context: exceed token limit.", total_tokens
)
return self._create_output(context_units[truncate_idx + 1 :])
def _create_output(
self, context_units: List[_ContextUnit]
) -> Tuple[List[OpenAIMessage], int]:
r"""Helper method to generate output from context units.
This method converts the provided context units into a format suitable
for output, specifically a list of OpenAIMessages and an integer
representing the total token count.
"""
context_units = sorted(context_units, key=lambda unit: unit.idx)
return [
unit.record.memory_record.to_openai_message()
for unit in context_units
], sum([unit.num_tokens for unit in context_units])

View File

@@ -1,95 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from dataclasses import asdict
from typing import Any, ClassVar, Dict
from uuid import UUID, uuid4
from pydantic import BaseModel, ConfigDict, Field
from camel.messages import BaseMessage, FunctionCallingMessage, OpenAIMessage
from camel.types import OpenAIBackendRole
class MemoryRecord(BaseModel):
r"""The basic message storing unit in the CAMEL memory system.
Attributes:
message (BaseMessage): The main content of the record.
role_at_backend (OpenAIBackendRole): An enumeration value representing
the role this message played at the OpenAI backend. Note that this
value is different from the :obj:`RoleType` used in the CAMEL role
playing system.
uuid (UUID, optional): A universally unique identifier for this record.
This is used to uniquely identify this record in the memory system.
If not given, it will be assigned with a random UUID.
extra_info (Dict[str, str], optional): A dictionary of additional
key-value pairs that provide more information. If not given, it
will be an empty `Dict`.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
message: BaseMessage
role_at_backend: OpenAIBackendRole
uuid: UUID = Field(default_factory=uuid4)
extra_info: Dict[str, str] = Field(default_factory=dict)
_MESSAGE_TYPES: ClassVar[dict] = {
"BaseMessage": BaseMessage,
"FunctionCallingMessage": FunctionCallingMessage,
}
@classmethod
def from_dict(cls, record_dict: Dict[str, Any]) -> "MemoryRecord":
r"""Reconstruct a :obj:`MemoryRecord` from the input dict.
Args:
record_dict(Dict[str, Any]): A dict generated by :meth:`to_dict`.
"""
message_cls = cls._MESSAGE_TYPES[record_dict["message"]["__class__"]]
kwargs: Dict = record_dict["message"].copy()
kwargs.pop("__class__")
reconstructed_message = message_cls(**kwargs)
return cls(
uuid=UUID(record_dict["uuid"]),
message=reconstructed_message,
role_at_backend=record_dict["role_at_backend"],
extra_info=record_dict["extra_info"],
)
def to_dict(self) -> Dict[str, Any]:
r"""Convert the :obj:`MemoryRecord` to a dict for serialization
purposes.
"""
return {
"uuid": str(self.uuid),
"message": {
"__class__": self.message.__class__.__name__,
**asdict(self.message),
},
"role_at_backend": self.role_at_backend,
"extra_info": self.extra_info,
}
def to_openai_message(self) -> OpenAIMessage:
r"""Converts the record to an :obj:`OpenAIMessage` object."""
return self.message.to_openai_message(self.role_at_backend)
class ContextRecord(BaseModel):
r"""The result of memory retrieving."""
memory_record: MemoryRecord
score: float

View File

@@ -1,63 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from typing import Union
from camel.types import (
ChatCompletionAssistantMessageParam,
ChatCompletionMessageParam,
ChatCompletionSystemMessageParam,
ChatCompletionToolMessageParam,
ChatCompletionUserMessageParam,
)
from .conversion import (
AlpacaItem,
HermesFunctionFormatter,
ShareGPTMessage,
)
from .conversion.conversation_models import (
ShareGPTConversation,
)
from .conversion.sharegpt.function_call_formatter import (
FunctionCallFormatter,
)
OpenAISystemMessage = ChatCompletionSystemMessageParam
OpenAIAssistantMessage = Union[
ChatCompletionAssistantMessageParam,
ChatCompletionToolMessageParam,
]
OpenAIUserMessage = ChatCompletionUserMessageParam
OpenAIToolMessageParam = ChatCompletionToolMessageParam
OpenAIMessage = ChatCompletionMessageParam
from .base import BaseMessage # noqa: E402
from .func_message import FunctionCallingMessage # noqa: E402
__all__ = [
'OpenAISystemMessage',
'OpenAIAssistantMessage',
'OpenAIUserMessage',
'OpenAIToolMessageParam',
'OpenAIMessage',
'FunctionCallFormatter',
'HermesFunctionFormatter',
'ShareGPTConversation',
'ShareGPTMessage',
'BaseMessage',
'FunctionCallingMessage',
'AlpacaItem',
]

View File

@@ -1,541 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import base64
import io
import re
from dataclasses import dataclass
from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union
import numpy as np
from PIL import Image
from pydantic import BaseModel
from camel.messages import (
FunctionCallFormatter,
HermesFunctionFormatter,
OpenAIAssistantMessage,
OpenAIMessage,
OpenAISystemMessage,
OpenAIUserMessage,
)
from camel.messages.conversion import ShareGPTMessage
from camel.prompts import CodePrompt, TextPrompt
from camel.types import (
OpenAIBackendRole,
OpenAIImageType,
OpenAIVisionDetailType,
RoleType,
)
from camel.utils import Constants
@dataclass
class BaseMessage:
r"""Base class for message objects used in CAMEL chat system.
Args:
role_name (str): The name of the user or assistant role.
role_type (RoleType): The type of role, either :obj:`RoleType.
ASSISTANT` or :obj:`RoleType.USER`.
meta_dict (Optional[Dict[str, str]]): Additional metadata dictionary
for the message.
content (str): The content of the message.
video_bytes (Optional[bytes]): Optional bytes of a video associated
with the message. (default: :obj:`None`)
image_list (Optional[List[Image.Image]]): Optional list of PIL Image
objects associated with the message. (default: :obj:`None`)
image_detail (Literal["auto", "low", "high"]): Detail level of the
images associated with the message. (default: :obj:`auto`)
video_detail (Literal["auto", "low", "high"]): Detail level of the
videos associated with the message. (default: :obj:`low`)
parsed: Optional[Union[Type[BaseModel], dict]]: Optional object which
is parsed from the content. (default: :obj:`None`)
"""
role_name: str
role_type: RoleType
meta_dict: Optional[Dict[str, Any]]
content: str
video_bytes: Optional[bytes] = None
image_list: Optional[List[Image.Image]] = None
image_detail: Literal["auto", "low", "high"] = "auto"
video_detail: Literal["auto", "low", "high"] = "low"
parsed: Optional[Union[Type[BaseModel], dict]] = None
@classmethod
def make_user_message(
cls,
role_name: str,
content: str,
meta_dict: Optional[Dict[str, str]] = None,
video_bytes: Optional[bytes] = None,
image_list: Optional[List[Image.Image]] = None,
image_detail: Union[
OpenAIVisionDetailType, str
] = OpenAIVisionDetailType.AUTO,
video_detail: Union[
OpenAIVisionDetailType, str
] = OpenAIVisionDetailType.LOW,
) -> "BaseMessage":
r"""Create a new user message.
Args:
role_name (str): The name of the user role.
content (str): The content of the message.
meta_dict (Optional[Dict[str, str]]): Additional metadata
dictionary for the message.
video_bytes (Optional[bytes]): Optional bytes of a video
associated with the message.
image_list (Optional[List[Image.Image]]): Optional list of PIL
Image objects associated with the message.
image_detail (Union[OpenAIVisionDetailType, str]): Detail level of
the images associated with the message.
video_detail (Union[OpenAIVisionDetailType, str]): Detail level of
the videos associated with the message.
Returns:
BaseMessage: The new user message.
"""
return cls(
role_name,
RoleType.USER,
meta_dict,
content,
video_bytes,
image_list,
OpenAIVisionDetailType(image_detail).value,
OpenAIVisionDetailType(video_detail).value,
)
@classmethod
def make_assistant_message(
cls,
role_name: str,
content: str,
meta_dict: Optional[Dict[str, str]] = None,
video_bytes: Optional[bytes] = None,
image_list: Optional[List[Image.Image]] = None,
image_detail: Union[
OpenAIVisionDetailType, str
] = OpenAIVisionDetailType.AUTO,
video_detail: Union[
OpenAIVisionDetailType, str
] = OpenAIVisionDetailType.LOW,
) -> "BaseMessage":
r"""Create a new assistant message.
Args:
role_name (str): The name of the assistant role.
content (str): The content of the message.
meta_dict (Optional[Dict[str, str]]): Additional metadata
dictionary for the message.
video_bytes (Optional[bytes]): Optional bytes of a video
associated with the message.
image_list (Optional[List[Image.Image]]): Optional list of PIL
Image objects associated with the message.
image_detail (Union[OpenAIVisionDetailType, str]): Detail level of
the images associated with the message.
video_detail (Union[OpenAIVisionDetailType, str]): Detail level of
the videos associated with the message.
Returns:
BaseMessage: The new assistant message.
"""
return cls(
role_name,
RoleType.ASSISTANT,
meta_dict,
content,
video_bytes,
image_list,
OpenAIVisionDetailType(image_detail).value,
OpenAIVisionDetailType(video_detail).value,
)
def create_new_instance(self, content: str) -> "BaseMessage":
r"""Create a new instance of the :obj:`BaseMessage` with updated
content.
Args:
content (str): The new content value.
Returns:
BaseMessage: The new instance of :obj:`BaseMessage`.
"""
return self.__class__(
role_name=self.role_name,
role_type=self.role_type,
meta_dict=self.meta_dict,
content=content,
)
def __add__(self, other: Any) -> Union["BaseMessage", Any]:
r"""Addition operator override for :obj:`BaseMessage`.
Args:
other (Any): The value to be added with.
Returns:
Union[BaseMessage, Any]: The result of the addition.
"""
if isinstance(other, BaseMessage):
combined_content = self.content.__add__(other.content)
elif isinstance(other, str):
combined_content = self.content.__add__(other)
else:
raise TypeError(
f"Unsupported operand type(s) for +: '{type(self)}' and "
f"'{type(other)}'"
)
return self.create_new_instance(combined_content)
def __mul__(self, other: Any) -> Union["BaseMessage", Any]:
r"""Multiplication operator override for :obj:`BaseMessage`.
Args:
other (Any): The value to be multiplied with.
Returns:
Union[BaseMessage, Any]: The result of the multiplication.
"""
if isinstance(other, int):
multiplied_content = self.content.__mul__(other)
return self.create_new_instance(multiplied_content)
else:
raise TypeError(
f"Unsupported operand type(s) for *: '{type(self)}' and "
f"'{type(other)}'"
)
def __len__(self) -> int:
r"""Length operator override for :obj:`BaseMessage`.
Returns:
int: The length of the content.
"""
return len(self.content)
def __contains__(self, item: str) -> bool:
r"""Contains operator override for :obj:`BaseMessage`.
Args:
item (str): The item to check for containment.
Returns:
bool: :obj:`True` if the item is contained in the content,
:obj:`False` otherwise.
"""
return item in self.content
def extract_text_and_code_prompts(
self,
) -> Tuple[List[TextPrompt], List[CodePrompt]]:
r"""Extract text and code prompts from the message content.
Returns:
Tuple[List[TextPrompt], List[CodePrompt]]: A tuple containing a
list of text prompts and a list of code prompts extracted
from the content.
"""
text_prompts: List[TextPrompt] = []
code_prompts: List[CodePrompt] = []
lines = self.content.split("\n")
idx = 0
start_idx = 0
while idx < len(lines):
while idx < len(lines) and (
not lines[idx].lstrip().startswith("```")
):
idx += 1
text = "\n".join(lines[start_idx:idx]).strip()
text_prompts.append(TextPrompt(text))
if idx >= len(lines):
break
code_type = lines[idx].strip()[3:].strip()
idx += 1
start_idx = idx
while not lines[idx].lstrip().startswith("```"):
idx += 1
code = "\n".join(lines[start_idx:idx]).strip()
code_prompts.append(CodePrompt(code, code_type=code_type))
idx += 1
start_idx = idx
return text_prompts, code_prompts
@classmethod
def from_sharegpt(
cls,
message: ShareGPTMessage,
function_format: Optional[FunctionCallFormatter[Any, Any]] = None,
role_mapping=None,
) -> "BaseMessage":
r"""Convert ShareGPT message to BaseMessage or FunctionCallingMessage.
Note tool calls and responses have an 'assistant' role in CAMEL
Args:
message (ShareGPTMessage): ShareGPT message to convert.
function_format (FunctionCallFormatter, optional): Function call
formatter to use. (default: :obj:`HermesFunctionFormatter()`.
role_mapping (Dict[str, List[str, RoleType]], optional): Role
mapping to use. Defaults to a CAMEL specific mapping.
Returns:
BaseMessage: Converted message.
"""
from camel.messages import FunctionCallingMessage
if role_mapping is None:
role_mapping = {
"system": ["system", RoleType.USER],
"human": ["user", RoleType.USER],
"gpt": ["assistant", RoleType.ASSISTANT],
"tool": ["assistant", RoleType.ASSISTANT],
}
role_name, role_type = role_mapping[message.from_]
if function_format is None:
function_format = HermesFunctionFormatter()
# Check if this is a function-related message
if message.from_ == "gpt":
func_info = function_format.extract_tool_calls(message.value)
if (
func_info and len(func_info) == 1
): # TODO: Handle multiple tool calls
# Including cleaned content is useful to
# remind consumers of non-considered content
clean_content = re.sub(
r"<tool_call>.*?</tool_call>",
"",
message.value,
flags=re.DOTALL,
).strip()
return FunctionCallingMessage(
role_name=role_name,
role_type=role_type,
meta_dict=None,
content=clean_content,
func_name=func_info[0].__dict__["name"],
args=func_info[0].__dict__["arguments"],
)
elif message.from_ == "tool":
func_r_info = function_format.extract_tool_response(message.value)
if func_r_info:
return FunctionCallingMessage(
role_name=role_name,
role_type=role_type,
meta_dict=None,
content="",
func_name=func_r_info.__dict__["name"],
result=func_r_info.__dict__["content"],
)
# Regular message
return cls(
role_name=role_name,
role_type=role_type,
meta_dict=None,
content=message.value,
)
def to_sharegpt(
self,
function_format: Optional[FunctionCallFormatter] = None,
) -> ShareGPTMessage:
r"""Convert BaseMessage to ShareGPT message
Args:
function_format (FunctionCallFormatter): Function call formatter
to use. Defaults to Hermes.
"""
if function_format is None:
function_format = HermesFunctionFormatter()
# Convert role type to ShareGPT 'from' field
if self.role_type == RoleType.USER:
from_ = "system" if self.role_name == "system" else "human"
else: # RoleType.ASSISTANT
from_ = "gpt"
# Function conversion code in FunctionCallingMessage
return ShareGPTMessage(from_=from_, value=self.content) # type: ignore[call-arg]
def to_openai_message(
self,
role_at_backend: OpenAIBackendRole,
) -> OpenAIMessage:
r"""Converts the message to an :obj:`OpenAIMessage` object.
Args:
role_at_backend (OpenAIBackendRole): The role of the message in
OpenAI chat system.
Returns:
OpenAIMessage: The converted :obj:`OpenAIMessage` object.
"""
if role_at_backend == OpenAIBackendRole.SYSTEM:
return self.to_openai_system_message()
elif role_at_backend == OpenAIBackendRole.USER:
return self.to_openai_user_message()
elif role_at_backend == OpenAIBackendRole.ASSISTANT:
return self.to_openai_assistant_message()
else:
raise ValueError(f"Unsupported role: {role_at_backend}.")
def to_openai_system_message(self) -> OpenAISystemMessage:
r"""Converts the message to an :obj:`OpenAISystemMessage` object.
Returns:
OpenAISystemMessage: The converted :obj:`OpenAISystemMessage`
object.
"""
return {"role": "system", "content": self.content}
def to_openai_user_message(self) -> OpenAIUserMessage:
r"""Converts the message to an :obj:`OpenAIUserMessage` object.
Returns:
OpenAIUserMessage: The converted :obj:`OpenAIUserMessage` object.
"""
hybird_content: List[Any] = []
hybird_content.append(
{
"type": "text",
"text": self.content,
}
)
if self.image_list and len(self.image_list) > 0:
for image in self.image_list:
if image.format is None:
raise ValueError(
f"Image's `format` is `None`, please "
f"transform the `PIL.Image.Image` to one of "
f"following supported formats, such as "
f"{list(OpenAIImageType)}"
)
image_type: str = image.format.lower()
if image_type not in OpenAIImageType:
raise ValueError(
f"Image type {image.format} "
f"is not supported by OpenAI vision model"
)
with io.BytesIO() as buffer:
image.save(fp=buffer, format=image.format)
encoded_image = base64.b64encode(buffer.getvalue()).decode(
"utf-8"
)
image_prefix = f"data:image/{image_type};base64,"
hybird_content.append(
{
"type": "image_url",
"image_url": {
"url": f"{image_prefix}{encoded_image}",
"detail": self.image_detail,
},
}
)
if self.video_bytes:
import imageio.v3 as iio
base64Frames: List[str] = []
frame_count = 0
# read video bytes
video = iio.imiter(
self.video_bytes, plugin=Constants.VIDEO_DEFAULT_PLUG_PYAV
)
for frame in video:
frame_count += 1
if (
frame_count % Constants.VIDEO_IMAGE_EXTRACTION_INTERVAL
== 0
):
# convert frame to numpy array
frame_array = np.asarray(frame)
frame_image = Image.fromarray(frame_array)
# Get the dimensions of the frame
width, height = frame_image.size
# resize the frame to the default image size
new_width = Constants.VIDEO_DEFAULT_IMAGE_SIZE
aspect_ratio = width / height
new_height = int(new_width / aspect_ratio)
resized_img = frame_image.resize((new_width, new_height))
# encode the image to base64
with io.BytesIO() as buffer:
image_format = OpenAIImageType.JPEG.value
image_format = image_format.upper()
resized_img.save(fp=buffer, format=image_format)
encoded_image = base64.b64encode(
buffer.getvalue()
).decode("utf-8")
base64Frames.append(encoded_image)
for encoded_image in base64Frames:
item = {
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{encoded_image}",
"detail": self.video_detail,
},
}
hybird_content.append(item)
if len(hybird_content) > 1:
return {
"role": "user",
"content": hybird_content,
}
# This return just for str message
else:
return {
"role": "user",
"content": self.content,
}
def to_openai_assistant_message(self) -> OpenAIAssistantMessage:
r"""Converts the message to an :obj:`OpenAIAssistantMessage` object.
Returns:
OpenAIAssistantMessage: The converted :obj:`OpenAIAssistantMessage`
object.
"""
return {"role": "assistant", "content": self.content}
def to_dict(self) -> Dict:
r"""Converts the message to a dictionary.
Returns:
dict: The converted dictionary.
"""
return {
"role_name": self.role_name,
"role_type": self.role_type.name,
**(self.meta_dict or {}),
"content": self.content,
}

View File

@@ -1,31 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from .alpaca import AlpacaItem
from .conversation_models import (
ShareGPTConversation,
ShareGPTMessage,
ToolCall,
ToolResponse,
)
from .sharegpt import HermesFunctionFormatter
__all__ = [
'ShareGPTMessage',
'ShareGPTConversation',
'HermesFunctionFormatter',
'AlpacaItem',
'ToolCall',
'ToolResponse',
]

View File

@@ -1,122 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import re
from pydantic import BaseModel, Field, field_validator
class AlpacaItem(BaseModel):
r"""Represents an instruction-response item in the Alpaca format.
Appropripate for both cases where input field is empty, or populated.
Provides parsing from string format using the class method from_string().
Args:
instruction (str): The instruction/question/prompt
input (str): Input context or examples (put empty string if none)
output (str): The response/answer to the instruction
"""
instruction: str = Field(description="The instruction/question/prompt")
input: str = Field(
description="Optional context or input for the task."
" For example, when the instruction is \"Summarize the "
"following article\", the input is the article."
)
output: str = Field(description="The response/answer to the instruction")
@field_validator('instruction', 'output')
def no_section_markers(cls, value: str) -> str:
r"""Ensures fields don't contain section markers like '###
Response:'
"""
if (
'### Response' in value
or '### Instruction' in value
or '### Input' in value
):
raise ValueError("Field cannot contain section markers")
return value.strip()
@classmethod
def from_string(cls, text: str) -> "AlpacaItem":
r"""Creates an AlpacaItem from a formatted string.
Args:
text: String in either of these formats:
With input:
### Instruction:
{instruction}
### Input:
{input}
### Response:
{response}
Without input:
### Instruction:
{instruction}
### Response:
{response}
Returns:
AlpacaItem: Parsed instance
Raises:
ValueError: text doesn't match expected format or sections missing
"""
# Strip and standardize newlines
text = text.strip().replace('\r\n', '\n')
# Try to extract sections using regex
instruction_match = re.search(
r'###\s*Instruction:\s*\n(.+?)(?=\n###|\Z)', text, re.DOTALL
)
input_match = re.search(
r'###\s*Input:\s*\n(.+?)(?=\n###|\Z)', text, re.DOTALL
)
response_match = re.search(
r'###\s*Response:\s*\n(.+?)(?=\n###|\Z)', text, re.DOTALL
)
if not instruction_match or not response_match:
raise ValueError(
"Text must contain '### Instruction:'"
" and '### Response:' sections"
)
return cls(
instruction=instruction_match.group(1).strip(),
input=input_match.group(1).strip() if input_match else "",
output=response_match.group(1).strip(),
)
def to_string(self) -> str:
r"""Converts the AlpacaItem to its string representation.
Returns:
str: Formatted string representation with sections markers
"""
return "\n".join(
[
"### Instruction:",
self.instruction,
"",
"### Input:",
self.input,
"",
"### Response:",
self.output,
]
)

View File

@@ -1,178 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import json
from typing import Any, Dict, List, Literal
from pydantic import (
BaseModel,
Field,
RootModel,
field_validator,
model_validator,
)
class ShareGPTMessage(BaseModel):
r"""A single message in ShareGPT format with enhanced validation"""
from_: Literal["human", "gpt", "system", "tool"] = Field(
alias="from", description="The role of the message sender"
)
value: str = Field(
min_length=0,
max_length=100000,
description="The content of the message",
)
model_config = {
"populate_by_name": True,
"extra": "forbid",
"json_schema_extra": {
"examples": [
{"from": "human", "value": "What's the weather like today?"}
]
},
}
class ShareGPTConversation(RootModel):
r"""A full conversation in ShareGPT format with validation"""
root: List[ShareGPTMessage]
@model_validator(mode='after')
def validate_conversation_flow(self) -> 'ShareGPTConversation':
r"""Validate the conversation follows logical message order"""
messages = self.root
if not messages:
raise ValueError("Conversation cannot be empty")
if messages[0].from_ not in ("system", "human"):
raise ValueError(
"Conversation must start with either system or human message"
)
# Validate message sequence
for i in range(1, len(messages)):
curr, prev = messages[i], messages[i - 1]
if curr.from_ == "tool":
if prev.from_ != "gpt" or "<tool_call>" not in prev.value:
raise ValueError(
f"Tool response at position {i} "
f"must follow an gpt message with a tool call"
)
if curr.from_ == "gpt" and prev.from_ not in (
"human",
"tool",
):
raise ValueError(
f"Assistant message at position {i} "
f"must follow a human or tool message"
)
return self
def model_dump(self, **kwargs):
return self.root
def __iter__(self):
return iter(self.root)
class ToolCall(BaseModel):
r"""Represents a single tool/function call with validation"""
name: str = Field(
min_length=1,
max_length=256,
description="The name of the tool to call",
)
arguments: Dict[str, Any] = Field(
description="The arguments to pass to the tool"
)
@field_validator('arguments')
@classmethod
def validate_arguments(cls, v: Dict[str, Any]) -> Dict[str, Any]:
r"""Validate argument structure and content"""
# Try to serialize arguments to ensure they're JSON-compatible
try:
json.dumps(v)
except (TypeError, ValueError):
raise ValueError("Arguments must be JSON-serializable")
return v
model_config = {
"extra": "forbid",
"json_schema_extra": {
"examples": [
{
"name": "get_weather",
"arguments": {"city": "London", "units": "celsius"},
}
]
},
}
class ToolResponse(BaseModel):
r"""Represents a tool/function response with validation. This is a
base class and default implementation for tool responses, for the purpose
of converting between different formats.
"""
name: str = Field(
min_length=1,
max_length=256,
description="The name of the tool that was called",
)
content: Any = Field(
description="The response content from the tool."
" Must be JSON serializable literal or object"
)
@field_validator('content')
@classmethod
def validate_content(cls, v: Dict[str, Any]) -> Dict[str, Any]:
r"""Validate response content structure"""
# Ensure content is JSON-serializable
try:
json.dumps(v)
except (TypeError, ValueError):
raise ValueError("Response content must be JSON-serializable")
return v
model_config = {
"extra": "forbid",
"json_schema_extra": {
"examples": [
{
"name": "get_weather",
"content": {
"temperature": 20,
"conditions": "sunny",
"humidity": 65,
},
}
]
},
}

View File

@@ -1,20 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from .hermes import HermesFunctionFormatter
__all__ = [
'HermesFunctionFormatter',
]

View File

@@ -1,49 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from abc import ABC, abstractmethod
from typing import Any, Dict, Generic, List, Optional, TypeVar
from camel.messages.conversion import (
ToolCall,
ToolResponse,
)
CallT = TypeVar('CallT', bound=ToolCall, covariant=True)
ResponseT = TypeVar('ResponseT', bound=ToolResponse, covariant=True)
class FunctionCallFormatter(ABC, Generic[CallT, ResponseT]):
r"""Abstract base class for function calling formats"""
@abstractmethod
def extract_tool_calls(self, message: str) -> List[CallT]:
r"""Extract function call info from a message string"""
pass
@abstractmethod
def extract_tool_response(self, message: str) -> Optional[ResponseT]:
r"""Extract function response info from a message string"""
pass
@abstractmethod
def format_tool_call(
self, content: str, func_name: str, args: Dict[str, Any]
) -> str:
r"""Format a function call into a message string"""
pass
@abstractmethod
def format_tool_response(self, func_name: str, result: Any) -> str:
r"""Format a function response into a message string"""
pass

View File

@@ -1,19 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from .hermes_function_formatter import HermesFunctionFormatter
__all__ = [
'HermesFunctionFormatter',
]

View File

@@ -1,128 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import json
import re
from typing import Any, Dict, List, Optional
from camel.messages.conversion import (
ToolCall,
ToolResponse,
)
from camel.messages.conversion.sharegpt.function_call_formatter import (
FunctionCallFormatter,
)
class HermesToolResponse(ToolResponse):
r"""Represents a single tool/function call with validation"""
pass
class HermesToolCall(ToolCall):
r"""Represents a single tool/function call with validation"""
pass
class HermesFunctionFormatter(
FunctionCallFormatter[HermesToolCall, HermesToolResponse]
):
r"""Hermes-style function calling format implementation with validation"""
def extract_tool_calls(self, message: str) -> List[HermesToolCall]:
r"""Extracts all tool calls from the provided message string.
Args:
message (str): The input message string containing potential tool
calls.
Returns:
List[HermesToolCall]: A list of parsed HermesToolCall objects.
"""
tool_calls = []
pattern = r"<tool_call>\s*({.*?})\s*</tool_call>"
matches = re.finditer(pattern, message, re.DOTALL)
for match in matches:
try:
call_dict = json.loads(match.group(1).replace("'", '"'))
tool_calls.append(HermesToolCall.model_validate(call_dict))
except Exception as e:
print(f"Warning: Failed to parse tool call: {e}")
continue
return tool_calls
def extract_tool_response(
self, message: str
) -> Optional[HermesToolResponse]:
r"""Extracts a single tool response from the provided message string.
Args:
message (str): The input message string containing a potential
tool response.
Returns:
Optional[HermesToolResponse]: A parsed HermesToolResponse object,
or None if no valid response is found.
"""
pattern = r"<tool_response>\s*({.*?})\s*</tool_response>"
match = re.search(pattern, message, re.DOTALL)
if match:
try:
response_json = match.group(1)
response_dict = json.loads(response_json.replace("'", '"'))
return HermesToolResponse.model_validate(response_dict)
except Exception as e:
print(f"Warning: Failed to parse tool response: {e}")
return None
return None
def format_tool_call(
self, content: str, func_name: str, args: Dict[str, Any]
) -> str:
r"""Formats a tool call message with the given content, function name,
and arguments.
Args:
content (str): The content or message to be included in the tool
call.
func_name (str): The name of the function being called.
args (Dict[str, Any]): A dictionary of arguments to be passed to
the function.
Returns:
str: A formatted string representing the tool call in Hermes
format.
"""
tool_call_dict = {"name": func_name, "arguments": args}
return f"{content}\n<tool_call>\n{tool_call_dict}\n</tool_call>"
def format_tool_response(self, func_name: str, result: Any) -> str:
r"""Formats a tool response message with the given function name and
result.
Args:
func_name (str): The name of the function whose result is being
returned.
result (Any): The result to be included in the tool response.
Returns:
str: A formatted string representing the tool response in Hermes
format.
"""
response_dict = {"name": func_name, "content": result}
return f"<tool_response>\n{response_dict}\n</tool_response>"

View File

@@ -1,163 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import json
from dataclasses import dataclass
from typing import Any, Dict, Optional
from camel.messages import (
BaseMessage,
HermesFunctionFormatter,
OpenAIAssistantMessage,
OpenAIMessage,
OpenAIToolMessageParam,
)
from camel.messages.conversion import (
ShareGPTMessage,
ToolCall,
ToolResponse,
)
from camel.messages.conversion.sharegpt.function_call_formatter import (
FunctionCallFormatter,
)
from camel.types import OpenAIBackendRole
@dataclass
class FunctionCallingMessage(BaseMessage):
r"""Class for message objects used specifically for
function-related messages.
Args:
func_name (Optional[str]): The name of the function used.
(default: :obj:`None`)
args (Optional[Dict]): The dictionary of arguments passed to the
function. (default: :obj:`None`)
result (Optional[Any]): The result of function execution.
(default: :obj:`None`)
tool_call_id (Optional[str]): The ID of the tool call, if available.
(default: :obj:`None`)
"""
func_name: Optional[str] = None
args: Optional[Dict] = None
result: Optional[Any] = None
tool_call_id: Optional[str] = None
def to_openai_message(
self,
role_at_backend: OpenAIBackendRole,
) -> OpenAIMessage:
r"""Converts the message to an :obj:`OpenAIMessage` object.
Args:
role_at_backend (OpenAIBackendRole): The role of the message in
OpenAI chat system.
Returns:
OpenAIMessage: The converted :obj:`OpenAIMessage` object.
"""
if role_at_backend == OpenAIBackendRole.ASSISTANT:
return self.to_openai_assistant_message()
elif role_at_backend == OpenAIBackendRole.FUNCTION:
return self.to_openai_tool_message()
else:
raise ValueError(f"Unsupported role: {role_at_backend}.")
def to_sharegpt(
self,
function_format: Optional[
FunctionCallFormatter[ToolCall, ToolResponse]
] = None,
) -> ShareGPTMessage:
r"""Convert FunctionCallingMessage to ShareGPT message.
Args:
function_format (FunctionCallFormatter[ToolCall, ToolResponse],
optional): The function formatter to use. Defaults to None.
"""
if function_format is None:
function_format = HermesFunctionFormatter()
# The role of the message is an unreliable indicator of whether
# it is a function call or response, so use result
if self.result is None:
# This is a function call
# TODO: split the incoming types to be more specific
# and remove the type ignores
content = function_format.format_tool_call(
self.content or "", # type: ignore[arg-type]
self.func_name, # type: ignore[arg-type]
self.args, # type: ignore[arg-type]
)
return ShareGPTMessage(from_="gpt", value=content) # type: ignore[call-arg]
else:
# This is a function response
# TODO: Allow for more flexible setting of tool role,
# optionally to be the same as assistant messages
content = function_format.format_tool_response(
self.func_name, # type: ignore[arg-type]
self.result, # type: ignore[arg-type]
)
return ShareGPTMessage(from_="tool", value=content) # type: ignore[call-arg]
def to_openai_assistant_message(self) -> OpenAIAssistantMessage:
r"""Converts the message to an :obj:`OpenAIAssistantMessage` object.
Returns:
OpenAIAssistantMessage: The converted :obj:`OpenAIAssistantMessage`
object.
"""
if (not self.func_name) or (self.args is None):
raise ValueError(
"Invalid request for converting into assistant message"
" due to missing function name or arguments."
)
return {
"role": "assistant",
"content": self.content or "",
"tool_calls": [
{
"id": self.tool_call_id or "null",
"type": "function",
"function": {
"name": self.func_name,
"arguments": json.dumps(self.args),
},
}
],
}
def to_openai_tool_message(self) -> OpenAIToolMessageParam:
r"""Converts the message to an :obj:`OpenAIToolMessageParam` object
with the role being "tool".
Returns:
OpenAIToolMessageParam: The converted
:obj:`OpenAIToolMessageParam` object with its role being
"tool".
"""
if not self.func_name:
raise ValueError(
"Invalid request for converting into function message"
" due to missing function name."
)
result_content = str(self.result)
return {
"role": "tool",
"content": result_content,
"tool_call_id": self.tool_call_id or "null",
}

View File

@@ -1,68 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from .anthropic_model import AnthropicModel
from .azure_openai_model import AzureOpenAIModel
from .base_model import BaseModelBackend
from .cohere_model import CohereModel
from .deepseek_model import DeepSeekModel
from .gemini_model import GeminiModel
from .groq_model import GroqModel
from .litellm_model import LiteLLMModel
from .mistral_model import MistralModel
from .model_factory import ModelFactory
from .model_manager import ModelManager, ModelProcessingError
from .nemotron_model import NemotronModel
from .nvidia_model import NvidiaModel
from .ollama_model import OllamaModel
from .openai_audio_models import OpenAIAudioModels
from .openai_compatible_model import OpenAICompatibleModel
from .openai_model import OpenAIModel
from .qwen_model import QwenModel
from .reka_model import RekaModel
from .samba_model import SambaModel
from .stub_model import StubModel
from .togetherai_model import TogetherAIModel
from .vllm_model import VLLMModel
from .yi_model import YiModel
from .zhipuai_model import ZhipuAIModel
from .fish_audio_model import FishAudioModel
__all__ = [
'BaseModelBackend',
'OpenAIModel',
'AzureOpenAIModel',
'AnthropicModel',
'MistralModel',
'GroqModel',
'StubModel',
'ZhipuAIModel',
'CohereModel',
'ModelFactory',
'ModelManager',
'LiteLLMModel',
'OpenAIAudioModels',
'NemotronModel',
'NvidiaModel',
'OllamaModel',
'VLLMModel',
'GeminiModel',
'OpenAICompatibleModel',
'RekaModel',
'SambaModel',
'TogetherAIModel',
'YiModel',
'QwenModel',
'ModelProcessingError',
'DeepSeekModel',
]

View File

@@ -1,167 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import os
from typing import Any, Dict, List, Optional, Union
from camel.configs import ANTHROPIC_API_PARAMS, AnthropicConfig
from camel.messages import OpenAIMessage
from camel.models.base_model import BaseModelBackend
from camel.types import ChatCompletion, ModelType
from camel.utils import (
AnthropicTokenCounter,
BaseTokenCounter,
api_keys_required,
dependencies_required,
)
class AnthropicModel(BaseModelBackend):
r"""Anthropic API in a unified BaseModelBackend interface.
Args:
model_type (Union[ModelType, str]): Model for which a backend is
created, one of CLAUDE_* series.
model_config_dict (Optional[Dict[str, Any]], optional): A dictionary
that will be fed into Anthropic.messages.create(). If
:obj:`None`, :obj:`AnthropicConfig().as_dict()` will be used.
(default::obj:`None`)
api_key (Optional[str], optional): The API key for authenticating with
the Anthropic service. (default: :obj:`None`)
url (Optional[str], optional): The url to the Anthropic service.
(default: :obj:`None`)
token_counter (Optional[BaseTokenCounter], optional): Token counter to
use for the model. If not provided, :obj:`AnthropicTokenCounter`
will be used. (default: :obj:`None`)
"""
@dependencies_required('anthropic')
def __init__(
self,
model_type: Union[ModelType, str],
model_config_dict: Optional[Dict[str, Any]] = None,
api_key: Optional[str] = None,
url: Optional[str] = None,
token_counter: Optional[BaseTokenCounter] = None,
) -> None:
from anthropic import Anthropic
if model_config_dict is None:
model_config_dict = AnthropicConfig().as_dict()
api_key = api_key or os.environ.get("ANTHROPIC_API_KEY")
url = url or os.environ.get("ANTHROPIC_API_BASE_URL")
super().__init__(
model_type, model_config_dict, api_key, url, token_counter
)
self.client = Anthropic(api_key=self._api_key, base_url=self._url)
def _convert_response_from_anthropic_to_openai(self, response):
# openai ^1.0.0 format, reference openai/types/chat/chat_completion.py
obj = ChatCompletion.construct(
id=None,
choices=[
dict(
index=0,
message={
"role": "assistant",
"content": response.content[0].text,
},
finish_reason=response.stop_reason,
)
],
created=None,
model=response.model,
object="chat.completion",
)
return obj
@property
def token_counter(self) -> BaseTokenCounter:
r"""Initialize the token counter for the model backend.
Returns:
BaseTokenCounter: The token counter following the model's
tokenization style.
"""
if not self._token_counter:
self._token_counter = AnthropicTokenCounter()
return self._token_counter
def count_tokens_from_prompt(self, prompt: str) -> int:
r"""Count the number of tokens from a prompt.
Args:
prompt (str): The prompt string.
Returns:
int: The number of tokens in the prompt.
"""
return self.client.count_tokens(prompt)
@api_keys_required("ANTHROPIC_API_KEY")
def run(
self,
messages: List[OpenAIMessage],
):
r"""Run inference of Anthropic chat completion.
Args:
messages (List[OpenAIMessage]): Message list with the chat history
in OpenAI API format.
Returns:
ChatCompletion: Response in the OpenAI API format.
"""
from anthropic import NOT_GIVEN
if messages[0]["role"] == "system":
sys_msg = str(messages.pop(0)["content"])
else:
sys_msg = NOT_GIVEN # type: ignore[assignment]
response = self.client.messages.create(
model=self.model_type,
system=sys_msg,
messages=messages, # type: ignore[arg-type]
**self.model_config_dict,
)
# format response to openai format
response = self._convert_response_from_anthropic_to_openai(response)
return response
def check_model_config(self):
r"""Check whether the model configuration is valid for anthropic
model backends.
Raises:
ValueError: If the model configuration dictionary contains any
unexpected arguments to OpenAI API, or it does not contain
:obj:`model_path` or :obj:`server_url`.
"""
for param in self.model_config_dict:
if param not in ANTHROPIC_API_PARAMS:
raise ValueError(
f"Unexpected argument `{param}` is "
"input into Anthropic model backend."
)
@property
def stream(self) -> bool:
r"""Returns whether the model is in stream mode, which sends partial
results each time.
Returns:
bool: Whether the model is in stream mode.
"""
return self.model_config_dict.get("stream", False)

View File

@@ -1,155 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import os
from typing import Any, Dict, List, Optional, Union
from openai import AzureOpenAI, Stream
from camel.configs import OPENAI_API_PARAMS, ChatGPTConfig
from camel.messages import OpenAIMessage
from camel.models.base_model import BaseModelBackend
from camel.types import (
ChatCompletion,
ChatCompletionChunk,
ModelType,
)
from camel.utils import BaseTokenCounter, OpenAITokenCounter, api_keys_required
class AzureOpenAIModel(BaseModelBackend):
r"""Azure OpenAI API in a unified BaseModelBackend interface.
Args:
model_type (Union[ModelType, str]): Model for which a backend is
created, one of GPT_* series.
model_config_dict (Optional[Dict[str, Any]], optional): A dictionary
that will be fed into:obj:`openai.ChatCompletion.create()`. If
:obj:`None`, :obj:`ChatGPTConfig().as_dict()` will be used.
(default: :obj:`None`)
api_key (Optional[str], optional): The API key for authenticating with
the OpenAI service. (default: :obj:`None`)
url (Optional[str], optional): The url to the OpenAI service.
(default: :obj:`None`)
api_version (Optional[str], optional): The api version for the model.
(default: :obj:`None`)
azure_deployment_name (Optional[str], optional): The deployment name
you chose when you deployed an azure model. (default: :obj:`None`)
token_counter (Optional[BaseTokenCounter], optional): Token counter to
use for the model. If not provided, :obj:`OpenAITokenCounter`
will be used. (default: :obj:`None`)
References:
https://learn.microsoft.com/en-us/azure/ai-services/openai/
"""
def __init__(
self,
model_type: Union[ModelType, str],
model_config_dict: Optional[Dict[str, Any]] = None,
api_key: Optional[str] = None,
url: Optional[str] = None,
token_counter: Optional[BaseTokenCounter] = None,
api_version: Optional[str] = None,
azure_deployment_name: Optional[str] = None,
) -> None:
if model_config_dict is None:
model_config_dict = ChatGPTConfig().as_dict()
api_key = api_key or os.environ.get("AZURE_OPENAI_API_KEY")
url = url or os.environ.get("AZURE_OPENAI_BASE_URL")
super().__init__(
model_type, model_config_dict, api_key, url, token_counter
)
self.api_version = api_version or os.environ.get("AZURE_API_VERSION")
self.azure_deployment_name = azure_deployment_name or os.environ.get(
"AZURE_DEPLOYMENT_NAME"
)
if self.api_version is None:
raise ValueError(
"Must provide either the `api_version` argument "
"or `AZURE_API_VERSION` environment variable."
)
if self.azure_deployment_name is None:
raise ValueError(
"Must provide either the `azure_deployment_name` argument "
"or `AZURE_DEPLOYMENT_NAME` environment variable."
)
self._client = AzureOpenAI(
azure_endpoint=str(self._url),
azure_deployment=self.azure_deployment_name,
api_version=self.api_version,
api_key=self._api_key,
timeout=60,
max_retries=3,
)
@property
def token_counter(self) -> BaseTokenCounter:
r"""Initialize the token counter for the model backend.
Returns:
BaseTokenCounter: The token counter following the model's
tokenization style.
"""
if not self._token_counter:
self._token_counter = OpenAITokenCounter(self.model_type)
return self._token_counter
@api_keys_required("AZURE_OPENAI_API_KEY", "AZURE_API_VERSION")
def run(
self,
messages: List[OpenAIMessage],
) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]:
r"""Runs inference of Azure OpenAI chat completion.
Args:
messages (List[OpenAIMessage]): Message list with the chat history
in OpenAI API format.
Returns:
Union[ChatCompletion, Stream[ChatCompletionChunk]]:
`ChatCompletion` in the non-stream mode, or
`Stream[ChatCompletionChunk]` in the stream mode.
"""
response = self._client.chat.completions.create(
messages=messages,
model=self.azure_deployment_name, # type:ignore[arg-type]
**self.model_config_dict,
)
return response
def check_model_config(self):
r"""Check whether the model configuration contains any
unexpected arguments to Azure OpenAI API.
Raises:
ValueError: If the model configuration dictionary contains any
unexpected arguments to Azure OpenAI API.
"""
for param in self.model_config_dict:
if param not in OPENAI_API_PARAMS:
raise ValueError(
f"Unexpected argument `{param}` is "
"input into Azure OpenAI model backend."
)
@property
def stream(self) -> bool:
r"""Returns whether the model is in stream mode,
which sends partial results each time.
Returns:
bool: Whether the model is in stream mode.
"""
return self.model_config_dict.get("stream", False)

View File

@@ -1,140 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union
from openai import Stream
from camel.messages import OpenAIMessage
from camel.types import (
ChatCompletion,
ChatCompletionChunk,
ModelType,
UnifiedModelType,
)
from camel.utils import BaseTokenCounter
class BaseModelBackend(ABC):
r"""Base class for different model backends.
It may be OpenAI API, a local LLM, a stub for unit tests, etc.
Args:
model_type (Union[ModelType, str]): Model for which a backend is
created.
model_config_dict (Optional[Dict[str, Any]], optional): A config
dictionary. (default: :obj:`{}`)
api_key (Optional[str], optional): The API key for authenticating
with the model service. (default: :obj:`None`)
url (Optional[str], optional): The url to the model service.
(default: :obj:`None`)
token_counter (Optional[BaseTokenCounter], optional): Token
counter to use for the model. If not provided,
:obj:`OpenAITokenCounter` will be used. (default: :obj:`None`)
"""
def __init__(
self,
model_type: Union[ModelType, str],
model_config_dict: Optional[Dict[str, Any]] = None,
api_key: Optional[str] = None,
url: Optional[str] = None,
token_counter: Optional[BaseTokenCounter] = None,
) -> None:
self.model_type: UnifiedModelType = UnifiedModelType(model_type)
if model_config_dict is None:
model_config_dict = {}
self.model_config_dict = model_config_dict
self._api_key = api_key
self._url = url
self._token_counter = token_counter
self.check_model_config()
@property
@abstractmethod
def token_counter(self) -> BaseTokenCounter:
r"""Initialize the token counter for the model backend.
Returns:
BaseTokenCounter: The token counter following the model's
tokenization style.
"""
pass
@abstractmethod
def run(
self,
messages: List[OpenAIMessage],
) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]:
r"""Runs the query to the backend model.
Args:
messages (List[OpenAIMessage]): Message list with the chat history
in OpenAI API format.
Returns:
Union[ChatCompletion, Stream[ChatCompletionChunk]]:
`ChatCompletion` in the non-stream mode, or
`Stream[ChatCompletionChunk]` in the stream mode.
"""
pass
@abstractmethod
def check_model_config(self):
r"""Check whether the input model configuration contains unexpected
arguments
Raises:
ValueError: If the model configuration dictionary contains any
unexpected argument for this model class.
"""
pass
def count_tokens_from_messages(self, messages: List[OpenAIMessage]) -> int:
r"""Count the number of tokens in the messages using the specific
tokenizer.
Args:
messages (List[Dict]): message list with the chat history
in OpenAI API format.
Returns:
int: Number of tokens in the messages.
"""
return self.token_counter.count_tokens_from_messages(messages)
@property
def token_limit(self) -> int:
r"""Returns the maximum token limit for a given model.
This method retrieves the maximum token limit either from the
`model_config_dict` or from the model's default token limit.
Returns:
int: The maximum token limit for the given model.
"""
return (
self.model_config_dict.get("max_tokens")
or self.model_type.token_limit
)
@property
def stream(self) -> bool:
r"""Returns whether the model is in stream mode, which sends partial
results each time.
Returns:
bool: Whether the model is in stream mode.
"""
return False

View File

@@ -1,282 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import ast
import json
import logging
import os
import uuid
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
if TYPE_CHECKING:
from cohere.types import ChatMessageV2, ChatResponse
from camel.configs import COHERE_API_PARAMS, CohereConfig
from camel.messages import OpenAIMessage
from camel.models import BaseModelBackend
from camel.types import ChatCompletion, ModelType
from camel.utils import (
BaseTokenCounter,
OpenAITokenCounter,
api_keys_required,
)
try:
if os.getenv("AGENTOPS_API_KEY") is not None:
from agentops import LLMEvent, record
else:
raise ImportError
except (ImportError, AttributeError):
LLMEvent = None
class CohereModel(BaseModelBackend):
r"""Cohere API in a unified BaseModelBackend interface."""
def __init__(
self,
model_type: Union[ModelType, str],
model_config_dict: Optional[Dict[str, Any]] = None,
api_key: Optional[str] = None,
url: Optional[str] = None,
token_counter: Optional[BaseTokenCounter] = None,
):
import cohere
if model_config_dict is None:
model_config_dict = CohereConfig().as_dict()
api_key = api_key or os.environ.get("COHERE_API_KEY")
url = url or os.environ.get("COHERE_API_BASE_URL")
super().__init__(
model_type, model_config_dict, api_key, url, token_counter
)
self._client = cohere.ClientV2(api_key=self._api_key)
def _to_openai_response(self, response: 'ChatResponse') -> ChatCompletion:
if response.usage and response.usage.tokens:
input_tokens = response.usage.tokens.input_tokens or 0
output_tokens = response.usage.tokens.output_tokens or 0
usage = {
"prompt_tokens": input_tokens,
"completion_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens,
}
else:
usage = {}
tool_calls = response.message.tool_calls
choices = []
if tool_calls:
for tool_call in tool_calls:
openai_tool_calls = [
dict(
id=tool_call.id,
function={
"name": tool_call.function.name,
"arguments": tool_call.function.arguments,
}
if tool_call.function
else {},
type=tool_call.type,
)
]
choice = dict(
index=None,
message={
"role": "assistant",
"content": response.message.tool_plan,
"tool_calls": openai_tool_calls,
},
finish_reason=response.finish_reason
if response.finish_reason
else None,
)
choices.append(choice)
else:
openai_tool_calls = None
choice = dict(
index=None,
message={
"role": "assistant",
"content": response.message.content[0].text, # type: ignore[union-attr,index]
"tool_calls": openai_tool_calls,
},
finish_reason=response.finish_reason
if response.finish_reason
else None,
)
choices.append(choice)
obj = ChatCompletion.construct(
id=response.id,
choices=choices,
created=None,
model=self.model_type,
object="chat.completion",
usage=usage,
)
return obj
def _to_cohere_chatmessage(
self, messages: List[OpenAIMessage]
) -> List["ChatMessageV2"]:
from cohere.types import ToolCallV2Function
from cohere.types.chat_message_v2 import (
AssistantChatMessageV2,
SystemChatMessageV2,
ToolCallV2,
ToolChatMessageV2,
UserChatMessageV2,
)
tool_call_id = None
new_messages = []
for msg in messages:
role = msg.get("role")
content = msg.get("content")
function_call = msg.get("function_call")
if role == "user":
new_message = UserChatMessageV2(role="user", content=content) # type: ignore[arg-type]
elif role in {"tool", "function"}:
new_message = ToolChatMessageV2(
role="tool",
tool_call_id=tool_call_id, # type: ignore[arg-type]
content=content, # type: ignore[assignment,arg-type]
)
elif role == "assistant":
if not function_call:
new_message = AssistantChatMessageV2( # type: ignore[assignment]
role="assistant",
content=content, # type: ignore[arg-type]
)
else:
arguments = function_call.get("arguments") # type: ignore[attr-defined]
arguments_dict = ast.literal_eval(arguments)
arguments_json = json.dumps(arguments_dict)
assis_tool_call_id = str(uuid.uuid4())
tool_call_id = assis_tool_call_id
new_message = AssistantChatMessageV2( # type: ignore[assignment]
role="assistant",
tool_calls=[
ToolCallV2(
id=assis_tool_call_id,
type="function",
function=ToolCallV2Function(
name=function_call.get("name"), # type: ignore[attr-defined]
arguments=arguments_json, # type: ignore[attr-defined]
),
)
],
content=content, # type: ignore[arg-type]
)
elif role == "system":
new_message = SystemChatMessageV2( # type: ignore[assignment]
role="system",
content=content, # type: ignore[arg-type]
)
else:
raise ValueError(f"Unsupported message role: {role}")
new_messages.append(new_message)
return new_messages # type: ignore[return-value]
@property
def token_counter(self) -> BaseTokenCounter:
r"""Initialize the token counter for the model backend.
Returns:
BaseTokenCounter: The token counter following the model's
tokenization style.
"""
if not self._token_counter:
self._token_counter = OpenAITokenCounter(
model=ModelType.GPT_4O_MINI
)
return self._token_counter
@api_keys_required("COHERE_API_KEY")
def run(self, messages: List[OpenAIMessage]) -> ChatCompletion:
r"""Runs inference of Cohere chat completion.
Args:
messages (List[OpenAIMessage]): Message list with the chat history
in OpenAI API format.
Returns:
ChatCompletion.
"""
from cohere.core.api_error import ApiError
cohere_messages = self._to_cohere_chatmessage(messages)
try:
response = self._client.chat(
messages=cohere_messages,
model=self.model_type,
**self.model_config_dict,
)
except ApiError as e:
logging.error(f"Cohere API Error: {e.status_code}")
logging.error(f"Error body: {e.body}")
raise
except Exception as e:
logging.error(f"Unexpected error when calling Cohere API: {e!s}")
raise
openai_response = self._to_openai_response(response)
# Add AgentOps LLM Event tracking
if LLMEvent:
llm_event = LLMEvent(
thread_id=openai_response.id,
prompt=" ".join(
[message.get("content") for message in messages] # type: ignore[misc]
),
prompt_tokens=openai_response.usage.prompt_tokens, # type: ignore[union-attr]
completion=openai_response.choices[0].message.content,
completion_tokens=openai_response.usage.completion_tokens, # type: ignore[union-attr]
model=self.model_type,
)
record(llm_event)
return openai_response
def check_model_config(self):
r"""Check whether the model configuration contains any unexpected
arguments to Cohere API.
Raises:
ValueError: If the model configuration dictionary contains any
unexpected arguments to Cohere API.
"""
for param in self.model_config_dict:
if param not in COHERE_API_PARAMS:
raise ValueError(
f"Unexpected argument `{param}` is "
"input into Cohere model backend."
)
@property
def stream(self) -> bool:
r"""Returns whether the model is in stream mode, which sends partial
results each time. Current it's not supported.
Returns:
bool: Whether the model is in stream mode.
"""
return False

View File

@@ -1,225 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import os
from typing import Any, Dict, List, Optional, Union
from openai import OpenAI, Stream
from camel.configs import DEEPSEEK_API_PARAMS, DeepSeekConfig
from camel.logger import get_logger
from camel.messages import OpenAIMessage
from camel.models.base_model import BaseModelBackend
from camel.types import (
ChatCompletion,
ChatCompletionChunk,
ModelType,
)
from camel.utils import BaseTokenCounter, OpenAITokenCounter, api_keys_required
from retry import retry
import json
logger = get_logger(__name__)
class DeepSeekModel(BaseModelBackend):
r"""DeepSeek API in a unified BaseModelBackend interface.
Args:
model_type (Union[ModelType, str]): Model for which a backend is
created.
model_config_dict (Optional[Dict[str, Any]], optional): A dictionary
that will be fed into:obj:`openai.ChatCompletion.create()`. If
:obj:`None`, :obj:`DeepSeekConfig().as_dict()` will be used.
(default: :obj:`None`)
api_key (Optional[str], optional): The API key for authenticating with
the DeepSeek service. (default: :obj:`None`)
url (Optional[str], optional): The url to the DeepSeek service.
(default: :obj:`https://api.deepseek.com`)
token_counter (Optional[BaseTokenCounter], optional): Token counter to
use for the model. If not provided, :obj:`OpenAITokenCounter`
will be used. (default: :obj:`None`)
References:
https://api-docs.deepseek.com/
"""
def __init__(
self,
model_type: Union[ModelType, str],
model_config_dict: Optional[Dict[str, Any]] = None,
api_key: Optional[str] = None,
url: Optional[str] = None,
token_counter: Optional[BaseTokenCounter] = None,
) -> None:
if model_config_dict is None:
model_config_dict = DeepSeekConfig().as_dict()
api_key = api_key or os.environ.get("DEEPSEEK_API_KEY")
url = url or os.environ.get(
"DEEPSEEK_API_BASE_URL",
"https://api.deepseek.com",
)
super().__init__(
model_type, model_config_dict, api_key, url, token_counter
)
self._client = OpenAI(
timeout=180,
max_retries=3,
api_key=self._api_key,
base_url=self._url,
)
@property
def token_counter(self) -> BaseTokenCounter:
r"""Initialize the token counter for the model backend.
Returns:
BaseTokenCounter: The token counter following the model's
tokenization style.
"""
if not self._token_counter:
self._token_counter = OpenAITokenCounter(
model=ModelType.GPT_4O_MINI
)
return self._token_counter
@retry((ValueError, TypeError, json.decoder.JSONDecodeError), delay=10, logger=logger)
def run(
self,
messages: List[OpenAIMessage],
) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]:
r"""Runs inference of DeepSeek chat completion.
Args:
messages (List[OpenAIMessage]): Message list with the chat history
in OpenAI API format.
Returns:
Union[ChatCompletion, Stream[ChatCompletionChunk]]:
`ChatCompletion` in the non-stream mode, or
`Stream[ChatCompletionChunk]` in the stream mode.
"""
# deepseek reasoner has limitations
# reference: https://api-docs.deepseek.com/guides/reasoning_model#api-parameters
if self.model_type in [
ModelType.DEEPSEEK_REASONER,
]:
import re
logger.warning(
"You are using a DeepSeek Reasoner model, "
"which has certain limitations, reference: "
"`https://api-docs.deepseek.com/guides/reasoning_model#api-parameters`"
)
# Check and remove unsupported parameters and reset the fixed
# parameters
unsupported_keys = [
"temperature",
"top_p",
"presence_penalty",
"frequency_penalty",
"logprobs",
"top_logprobs",
"tools",
]
for key in unsupported_keys:
if key in self.model_config_dict:
del self.model_config_dict[key]
# Remove thinking content from messages before sending to API
# This ensures only the final response is sent, excluding
# intermediate thought processes
messages = [
{ # type: ignore[misc]
**msg,
'content': re.sub(
r'<think>.*?</think>',
'',
msg['content'], # type: ignore[arg-type]
flags=re.DOTALL,
).strip(),
}
for msg in messages
]
response = self._client.chat.completions.create(
messages=messages,
model=self.model_type,
**self.model_config_dict,
)
# Handle reasoning content with <think> tags at the beginning
if (
self.model_type
in [
ModelType.DEEPSEEK_REASONER,
]
and os.environ.get("GET_REASONING_CONTENT", "false").lower()
== "true"
):
reasoning_content = response.choices[0].message.reasoning_content
combined_content = (
f"<think>\n{reasoning_content}\n</think>\n"
if reasoning_content
else ""
) + response.choices[0].message.content
response = ChatCompletion.construct(
id=response.id,
choices=[
dict(
index=response.choices[0].index,
message={
"role": response.choices[0].message.role,
"content": combined_content,
"tool_calls": None,
},
finish_reason=response.choices[0].finish_reason
if response.choices[0].finish_reason
else None,
)
],
created=response.created,
model=response.model,
object="chat.completion",
usage=response.usage,
)
return response
def check_model_config(self):
r"""Check whether the model configuration contains any
unexpected arguments to DeepSeek API.
Raises:
ValueError: If the model configuration dictionary contains any
unexpected arguments to DeepSeek API.
"""
for param in self.model_config_dict:
if param not in DEEPSEEK_API_PARAMS:
raise ValueError(
f"Unexpected argument `{param}` is "
"input into DeepSeek model backend."
)
@property
def stream(self) -> bool:
r"""Returns whether the model is in stream mode, which sends partial
results each time.
Returns:
bool: Whether the model is in stream mode.
"""
return self.model_config_dict.get("stream", False)

View File

@@ -1,147 +0,0 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import os
from typing import Any, Optional
class FishAudioModel:
r"""Provides access to FishAudio's Text-to-Speech (TTS) and Speech_to_Text
(STT) models.
"""
def __init__(
self,
api_key: Optional[str] = None,
url: Optional[str] = None,
) -> None:
r"""Initialize an instance of FishAudioModel.
Args:
api_key (Optional[str]): API key for FishAudio service. If not
provided, the environment variable `FISHAUDIO_API_KEY` will be
used.
url (Optional[str]): Base URL for FishAudio API. If not provided,
the environment variable `FISHAUDIO_API_BASE_URL` will be used.
"""
from fish_audio_sdk import Session
self._api_key = api_key or os.environ.get("FISHAUDIO_API_KEY")
self._url = url or os.environ.get(
"FISHAUDIO_API_BASE_URL", "https://api.fish.audio"
)
self.session = Session(apikey=self._api_key, base_url=self._url)
def text_to_speech(
self,
input: str,
storage_path: str,
reference_id: Optional[str] = None,
reference_audio: Optional[str] = None,
reference_audio_text: Optional[str] = None,
**kwargs: Any,
) -> Any:
r"""Convert text to speech and save the output to a file.
Args:
input_text (str): The text to convert to speech.
storage_path (str): The file path where the resulting speech will
be saved.
reference_id (Optional[str]): An optional reference ID to
associate with the request. (default: :obj:`None`)
reference_audio (Optional[str]): Path to an audio file for
reference speech. (default: :obj:`None`)
reference_audio_text (Optional[str]): Text for the reference audio.
(default: :obj:`None`)
**kwargs (Any): Additional parameters to pass to the TTS request.
Raises:
FileNotFoundError: If the reference audio file cannot be found.
"""
from fish_audio_sdk import ReferenceAudio, TTSRequest
directory = os.path.dirname(storage_path)
if directory and not os.path.exists(directory):
os.makedirs(directory)
if not reference_audio:
with open(f"{storage_path}", "wb") as f:
for chunk in self.session.tts(
TTSRequest(reference_id=reference_id, text=input, **kwargs)
):
f.write(chunk)
else:
if not os.path.exists(reference_audio):
raise FileNotFoundError(
f"Reference audio file not found: {reference_audio}"
)
if not reference_audio_text:
raise ValueError("reference_audio_text should be provided")
with open(f"{reference_audio}", "rb") as audio_file:
with open(f"{storage_path}", "wb") as f:
for chunk in self.session.tts(
TTSRequest(
text=input,
references=[
ReferenceAudio(
audio=audio_file.read(),
text=reference_audio_text,
)
],
**kwargs,
)
):
f.write(chunk)
def speech_to_text(
self,
audio_file_path: str,
language: Optional[str] = None,
ignore_timestamps: Optional[bool] = None,
**kwargs: Any,
) -> str:
r"""Convert speech to text from an audio file.
Args:
audio_file_path (str): The path to the audio file to transcribe.
language (Optional[str]): The language of the audio. (default:
:obj:`None`)
ignore_timestamps (Optional[bool]): Whether to ignore timestamps.
(default: :obj:`None`)
**kwargs (Any): Additional parameters to pass to the STT request.
Returns:
str: The transcribed text from the audio.
Raises:
FileNotFoundError: If the audio file cannot be found.
"""
from fish_audio_sdk import ASRRequest
if not os.path.exists(audio_file_path):
raise FileNotFoundError(f"Audio file not found: {audio_file_path}")
with open(f"{audio_file_path}", "rb") as audio_file:
audio_data = audio_file.read()
response = self.session.asr(
ASRRequest(
audio=audio_data,
language=language,
ignore_timestamps=ignore_timestamps,
**kwargs,
)
)
return response.text

Some files were not shown because too many files have changed in this diff Show More