mirror of
https://github.com/camel-ai/owl.git
synced 2026-03-22 14:07:17 +08:00
refactor: Update with camel version 0.2.23 (#122)
This commit is contained in:
29
.pre-commit-config.yaml
Normal file
29
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,29 @@
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: 'v0.7.4'
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix, --exit-non-zero-on-fix, --show-fixes]
|
||||
exclude: ^docs/cookbooks/ # Ignore files under docs/cookbooks
|
||||
- id: ruff-format
|
||||
exclude: ^docs/cookbooks/ # Ignore files under docs/cookbooks
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: mypy
|
||||
name: Check mypy
|
||||
entry: mypy --namespace-packages -p owl
|
||||
language: python
|
||||
types: [python]
|
||||
pass_filenames: false
|
||||
require_serial: true
|
||||
exclude: ^docs/cookbooks/ # Ignore files under docs/cookbooks
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: check-license
|
||||
name: Check License
|
||||
entry: python licenses/update_license.py . licenses/license_template.txt
|
||||
language: system
|
||||
types: [python]
|
||||
exclude: ^docs/cookbooks/ # Ignore files under docs/cookbooks
|
||||
43
README.md
43
README.md
@@ -102,36 +102,31 @@ https://private-user-images.githubusercontent.com/55657767/420212194-e813fc05-13
|
||||
|
||||
# 🛠️ Installation
|
||||
|
||||
## **Clone the Github repository**
|
||||
|
||||
```bash
|
||||
# Clone github repo
|
||||
git clone https://github.com/camel-ai/owl.git
|
||||
|
||||
# Change directory into project directory
|
||||
cd owl
|
||||
```
|
||||
|
||||
## **Set up Environment**
|
||||
# Install uv if you don't have it already
|
||||
pip install uv
|
||||
|
||||
Using Conda (recommended):
|
||||
```bash
|
||||
conda create -n owl python=3.11
|
||||
conda activate owl
|
||||
```
|
||||
# Create a virtual environment and install dependencies
|
||||
# We support using Python 3.10, 3.11, 3.12
|
||||
uv venv .venv --python=3.10
|
||||
|
||||
Using venv (alternative):
|
||||
```bash
|
||||
python -m venv owl_env
|
||||
# On Windows
|
||||
owl_env\Scripts\activate
|
||||
# On Unix or MacOS
|
||||
source owl_env/bin/activate
|
||||
```
|
||||
# Activate the virtual environment
|
||||
# For macOS/Linux
|
||||
source .venv/bin/activate
|
||||
# For Windows
|
||||
.venv\Scripts\activate
|
||||
|
||||
# Install CAMEL with all dependencies
|
||||
uv pip install -e .
|
||||
|
||||
## **Install Dependencies**
|
||||
|
||||
```bash
|
||||
python -m pip install -r requirements.txt
|
||||
playwright install
|
||||
# Exit the virtual environment when done
|
||||
deactivate
|
||||
```
|
||||
|
||||
## **Setup Environment Variables**
|
||||
@@ -210,7 +205,7 @@ question = "Task description here."
|
||||
society = construct_society(question)
|
||||
answer, chat_history, token_count = run_society(society)
|
||||
|
||||
print(f"Answer: {answer}")
|
||||
print(f"\033[94mAnswer: {answer}\033[0m")
|
||||
```
|
||||
|
||||
For uploading files, simply provide the file path along with your question:
|
||||
@@ -221,7 +216,7 @@ question = "What is in the given DOCX file? Here is the file path: tmp/example.d
|
||||
|
||||
society = construct_society(question)
|
||||
answer, chat_history, token_count = run_society(society)
|
||||
print(f"Answer: {answer}")
|
||||
print(f"\033[94mAnswer: {answer}\033[0m")
|
||||
```
|
||||
|
||||
OWL will then automatically invoke document-related tools to process the file and extract the answer.
|
||||
|
||||
37
README_zh.md
37
README_zh.md
@@ -104,31 +104,30 @@ https://private-user-images.githubusercontent.com/55657767/420212194-e813fc05-13
|
||||
## **克隆 Github 仓库**
|
||||
|
||||
```bash
|
||||
# 克隆 GitHub 仓库
|
||||
git clone https://github.com/camel-ai/owl.git
|
||||
|
||||
# 进入项目目录
|
||||
cd owl
|
||||
```
|
||||
|
||||
## **设置环境**
|
||||
# 如果你还没有安装 uv,请先安装
|
||||
pip install uv
|
||||
|
||||
使用 Conda(推荐):
|
||||
```bash
|
||||
conda create -n owl python=3.11
|
||||
conda activate owl
|
||||
```
|
||||
# 创建虚拟环境并安装依赖
|
||||
# 我们支持使用 Python 3.10、3.11、3.12
|
||||
uv venv .venv --python=3.10
|
||||
|
||||
使用 venv(备用):
|
||||
```bash
|
||||
python -m venv owl_env
|
||||
# Windows 系统
|
||||
owl_env\Scripts\activate
|
||||
# Unix 或 MacOS 系统
|
||||
source owl_env/bin/activate
|
||||
```
|
||||
# 激活虚拟环境
|
||||
# 对于 macOS/Linux
|
||||
source .venv/bin/activate
|
||||
# 对于 Windows
|
||||
.venv\Scripts\activate
|
||||
|
||||
## **安装依赖**
|
||||
# 安装 CAMEL 及其所有依赖
|
||||
uv pip install -e .
|
||||
|
||||
```bash
|
||||
python -m pip install -r requirements.txt
|
||||
# 完成后退出虚拟环境
|
||||
deactivate
|
||||
```
|
||||
|
||||
## **设置环境变量**
|
||||
@@ -201,7 +200,7 @@ question = "Task description here."
|
||||
society = construct_society(question)
|
||||
answer, chat_history, token_count = run_society(society)
|
||||
|
||||
print(f"Answer: {answer}")
|
||||
print(f"\033[94mAnswer: {answer}\033[0m")
|
||||
```
|
||||
|
||||
上传文件时,只需提供文件路径和问题:
|
||||
|
||||
@@ -39,43 +39,37 @@ def update_license_in_file(
|
||||
start_line_start_with: str,
|
||||
end_line_start_with: str,
|
||||
) -> bool:
|
||||
with open(
|
||||
file_path, 'r', encoding='utf-8'
|
||||
) as f: # for windows compatibility
|
||||
with open(file_path, "r", encoding="utf-8") as f: # for windows compatibility
|
||||
content = f.read()
|
||||
|
||||
with open(license_template_path, 'r', encoding='utf-8') as f:
|
||||
with open(license_template_path, "r", encoding="utf-8") as f:
|
||||
new_license = f.read().strip()
|
||||
|
||||
maybe_existing_licenses = re.findall(
|
||||
r'^#.*?(?=\n)', content, re.MULTILINE | re.DOTALL
|
||||
r"^#.*?(?=\n)", content, re.MULTILINE | re.DOTALL
|
||||
)
|
||||
start_index = fine_license_start_line(
|
||||
maybe_existing_licenses, start_line_start_with
|
||||
)
|
||||
end_index = find_license_end_line(
|
||||
maybe_existing_licenses, end_line_start_with
|
||||
)
|
||||
end_index = find_license_end_line(maybe_existing_licenses, end_line_start_with)
|
||||
if start_index is not None and end_index is not None:
|
||||
maybe_existing_licenses = maybe_existing_licenses[
|
||||
start_index : end_index + 1
|
||||
]
|
||||
maybe_existing_licenses = maybe_existing_licenses[start_index : end_index + 1]
|
||||
else:
|
||||
maybe_existing_licenses = None
|
||||
if maybe_existing_licenses:
|
||||
maybe_old_licenses = '\n'.join(maybe_existing_licenses)
|
||||
maybe_old_licenses = "\n".join(maybe_existing_licenses)
|
||||
if maybe_old_licenses.strip() != new_license.strip():
|
||||
replaced_content = content.replace(maybe_old_licenses, new_license)
|
||||
with open(file_path, 'w') as f:
|
||||
with open(file_path, "w") as f:
|
||||
f.write(replaced_content)
|
||||
print(f'Replaced license in {file_path}')
|
||||
print(f"Replaced license in {file_path}")
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
with open(file_path, 'w') as f:
|
||||
f.write(new_license + '\n' + content)
|
||||
print(f'Added license to {file_path}')
|
||||
with open(file_path, "w") as f:
|
||||
f.write(new_license + "\n" + content)
|
||||
print(f"Added license to {file_path}")
|
||||
return True
|
||||
|
||||
|
||||
@@ -87,16 +81,16 @@ def update_license_in_directory(
|
||||
) -> None:
|
||||
# Check if directory exists
|
||||
if not os.path.isdir(directory_path):
|
||||
raise NotADirectoryError(f'{directory_path} is not a directory')
|
||||
raise NotADirectoryError(f"{directory_path} is not a directory")
|
||||
# Check if license template exists
|
||||
if not os.path.isfile(license_template_path):
|
||||
raise FileNotFoundError(f'{license_template_path} not found')
|
||||
raise FileNotFoundError(f"{license_template_path} not found")
|
||||
|
||||
file_count = 0
|
||||
for py_files in Path(directory_path).rglob("*.py"):
|
||||
if py_files.name.startswith('.'):
|
||||
if py_files.name.startswith("."):
|
||||
continue
|
||||
if any(part.startswith('.') for part in py_files.parts):
|
||||
if any(part.startswith(".") for part in py_files.parts):
|
||||
continue
|
||||
if update_license_in_file(
|
||||
py_files,
|
||||
@@ -106,10 +100,10 @@ def update_license_in_directory(
|
||||
):
|
||||
file_count += 1
|
||||
|
||||
print(f'License updated in {file_count} files')
|
||||
print(f"License updated in {file_count} files")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) < 3:
|
||||
print(
|
||||
"Usage from command line: "
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
from camel.logger import disable_logging, enable_logging, set_log_level
|
||||
|
||||
__version__ = '0.2.11'
|
||||
|
||||
__all__ = [
|
||||
'__version__',
|
||||
'camel',
|
||||
'disable_logging',
|
||||
'enable_logging',
|
||||
'set_log_level',
|
||||
]
|
||||
@@ -1,44 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from .base import BaseAgent
|
||||
from .chat_agent import ChatAgent
|
||||
from .critic_agent import CriticAgent
|
||||
from .embodied_agent import EmbodiedAgent
|
||||
from .knowledge_graph_agent import KnowledgeGraphAgent
|
||||
from .role_assignment_agent import RoleAssignmentAgent
|
||||
from .search_agent import SearchAgent
|
||||
from .task_agent import (
|
||||
TaskCreationAgent,
|
||||
TaskPlannerAgent,
|
||||
TaskPrioritizationAgent,
|
||||
TaskSpecifyAgent,
|
||||
)
|
||||
from .tool_agents.base import BaseToolAgent
|
||||
from .tool_agents.hugging_face_tool_agent import HuggingFaceToolAgent
|
||||
|
||||
__all__ = [
|
||||
'BaseAgent',
|
||||
'ChatAgent',
|
||||
'TaskSpecifyAgent',
|
||||
'TaskPlannerAgent',
|
||||
'TaskCreationAgent',
|
||||
'TaskPrioritizationAgent',
|
||||
'CriticAgent',
|
||||
'BaseToolAgent',
|
||||
'HuggingFaceToolAgent',
|
||||
'EmbodiedAgent',
|
||||
'RoleAssignmentAgent',
|
||||
'SearchAgent',
|
||||
'KnowledgeGraphAgent',
|
||||
]
|
||||
@@ -1,29 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
|
||||
class BaseAgent(ABC):
|
||||
r"""An abstract base class for all CAMEL agents."""
|
||||
|
||||
@abstractmethod
|
||||
def reset(self, *args: Any, **kwargs: Any) -> Any:
|
||||
r"""Resets the agent to its initial state."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def step(self, *args: Any, **kwargs: Any) -> Any:
|
||||
r"""Performs a single step of the agent."""
|
||||
pass
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,202 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
import random
|
||||
import warnings
|
||||
from typing import Any, Dict, Optional, Sequence
|
||||
|
||||
from colorama import Fore
|
||||
|
||||
from camel.agents.chat_agent import ChatAgent
|
||||
from camel.memories import AgentMemory
|
||||
from camel.messages import BaseMessage
|
||||
from camel.models import BaseModelBackend
|
||||
from camel.responses import ChatAgentResponse
|
||||
from camel.utils import get_first_int, print_text_animated
|
||||
|
||||
# AgentOps decorator setting
|
||||
try:
|
||||
import os
|
||||
|
||||
if os.getenv("AGENTOPS_API_KEY") is not None:
|
||||
from agentops import track_agent
|
||||
else:
|
||||
raise ImportError
|
||||
except (ImportError, AttributeError):
|
||||
from camel.utils import track_agent
|
||||
|
||||
|
||||
@track_agent(name="CriticAgent")
|
||||
class CriticAgent(ChatAgent):
|
||||
r"""A class for the critic agent that assists in selecting an option.
|
||||
|
||||
Args:
|
||||
system_message (BaseMessage): The system message for the critic
|
||||
agent.
|
||||
model (BaseModelBackend, optional): The model backend to use for
|
||||
generating responses. (default: :obj:`OpenAIModel` with
|
||||
`GPT_4O_MINI`)
|
||||
message_window_size (int, optional): The maximum number of previous
|
||||
messages to include in the context window. If `None`, no windowing
|
||||
is performed. (default: :obj:`6`)
|
||||
retry_attempts (int, optional): The number of retry attempts if the
|
||||
critic fails to return a valid option. (default: :obj:`2`)
|
||||
verbose (bool, optional): Whether to print the critic's messages.
|
||||
logger_color (Any): The color of the menu options displayed to the
|
||||
user. (default: :obj:`Fore.MAGENTA`)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
system_message: BaseMessage,
|
||||
model: Optional[BaseModelBackend] = None,
|
||||
memory: Optional[AgentMemory] = None,
|
||||
message_window_size: int = 6,
|
||||
retry_attempts: int = 2,
|
||||
verbose: bool = False,
|
||||
logger_color: Any = Fore.MAGENTA,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
system_message,
|
||||
model=model,
|
||||
memory=memory,
|
||||
message_window_size=message_window_size,
|
||||
)
|
||||
self.options_dict: Dict[str, str] = dict()
|
||||
self.retry_attempts = retry_attempts
|
||||
self.verbose = verbose
|
||||
self.logger_color = logger_color
|
||||
|
||||
def flatten_options(self, messages: Sequence[BaseMessage]) -> str:
|
||||
r"""Flattens the options to the critic.
|
||||
|
||||
Args:
|
||||
messages (Sequence[BaseMessage]): A list of `BaseMessage` objects.
|
||||
|
||||
Returns:
|
||||
str: A string containing the flattened options to the critic.
|
||||
"""
|
||||
options = [message.content for message in messages]
|
||||
flatten_options = (
|
||||
f"> Proposals from "
|
||||
f"{messages[0].role_name} ({messages[0].role_type}). "
|
||||
"Please choose an option:\n"
|
||||
)
|
||||
for index, option in enumerate(options):
|
||||
flatten_options += f"Option {index + 1}:\n{option}\n\n"
|
||||
self.options_dict[str(index + 1)] = option
|
||||
format = (
|
||||
f"Please first enter your choice ([1-{len(self.options_dict)}]) "
|
||||
"and then your explanation and comparison: "
|
||||
)
|
||||
return flatten_options + format
|
||||
|
||||
def get_option(self, input_message: BaseMessage) -> str:
|
||||
r"""Gets the option selected by the critic.
|
||||
|
||||
Args:
|
||||
input_message (BaseMessage): A `BaseMessage` object representing
|
||||
the input message.
|
||||
|
||||
Returns:
|
||||
str: The option selected by the critic.
|
||||
"""
|
||||
# TODO: Add support for editing options by the critic.
|
||||
msg_content = input_message.content
|
||||
i = 0
|
||||
while i < self.retry_attempts:
|
||||
critic_response = self.step(input_message)
|
||||
|
||||
if critic_response.msgs is None or len(critic_response.msgs) == 0:
|
||||
raise RuntimeError("Got None critic messages.")
|
||||
if critic_response.terminated:
|
||||
raise RuntimeError("Critic step failed.")
|
||||
|
||||
critic_msg = critic_response.msg
|
||||
if self.verbose:
|
||||
print_text_animated(
|
||||
self.logger_color + "\n> Critic response: "
|
||||
f"\x1b[3m{critic_msg.content}\x1b[0m\n"
|
||||
)
|
||||
choice = self.parse_critic(critic_msg)
|
||||
|
||||
if choice in self.options_dict:
|
||||
return self.options_dict[choice]
|
||||
else:
|
||||
input_message = BaseMessage(
|
||||
role_name=input_message.role_name,
|
||||
role_type=input_message.role_type,
|
||||
meta_dict=input_message.meta_dict,
|
||||
content="> Invalid choice. Please choose again.\n"
|
||||
+ msg_content,
|
||||
)
|
||||
i += 1
|
||||
warnings.warn(
|
||||
"Critic failed to get a valid option. "
|
||||
f"After {self.retry_attempts} attempts. "
|
||||
"Returning a random option."
|
||||
)
|
||||
return random.choice(list(self.options_dict.values()))
|
||||
|
||||
def parse_critic(self, critic_msg: BaseMessage) -> Optional[str]:
|
||||
r"""Parses the critic's message and extracts the choice.
|
||||
|
||||
Args:
|
||||
critic_msg (BaseMessage): A `BaseMessage` object representing the
|
||||
critic's response.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The critic's choice as a string, or None if the
|
||||
message could not be parsed.
|
||||
"""
|
||||
choice = str(get_first_int(critic_msg.content))
|
||||
return choice
|
||||
|
||||
def reduce_step(
|
||||
self,
|
||||
input_messages: Sequence[BaseMessage],
|
||||
) -> ChatAgentResponse:
|
||||
r"""Performs one step of the conversation by flattening options to the
|
||||
critic, getting the option, and parsing the choice.
|
||||
|
||||
Args:
|
||||
input_messages (Sequence[BaseMessage]): A list of BaseMessage
|
||||
objects.
|
||||
|
||||
Returns:
|
||||
ChatAgentResponse: A `ChatAgentResponse` object includes the
|
||||
critic's choice.
|
||||
"""
|
||||
meta_chat_message = BaseMessage(
|
||||
role_name=input_messages[0].role_name,
|
||||
role_type=input_messages[0].role_type,
|
||||
meta_dict=input_messages[0].meta_dict,
|
||||
content="",
|
||||
)
|
||||
|
||||
flatten_options = self.flatten_options(input_messages)
|
||||
if self.verbose:
|
||||
print_text_animated(
|
||||
self.logger_color + f"\x1b[3m{flatten_options}\x1b[0m\n"
|
||||
)
|
||||
input_msg = meta_chat_message.create_new_instance(flatten_options)
|
||||
|
||||
option = self.get_option(input_msg)
|
||||
output_msg = meta_chat_message.create_new_instance(option)
|
||||
|
||||
# TODO: The return `info` can be improved.
|
||||
return ChatAgentResponse(
|
||||
msgs=[output_msg],
|
||||
terminated=False,
|
||||
info={},
|
||||
)
|
||||
@@ -1,303 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
import re
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from camel.agents.chat_agent import ChatAgent
|
||||
from camel.logger import get_logger
|
||||
from camel.messages import BaseMessage
|
||||
from camel.models import BaseModelBackend
|
||||
from camel.prompts import TextPrompt
|
||||
from camel.types import RoleType
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# AgentOps decorator setting
|
||||
try:
|
||||
import os
|
||||
|
||||
if os.getenv("AGENTOPS_API_KEY") is not None:
|
||||
from agentops import track_agent
|
||||
else:
|
||||
raise ImportError
|
||||
except (ImportError, AttributeError):
|
||||
from camel.utils import track_agent
|
||||
|
||||
|
||||
@track_agent(name="DeductiveReasonerAgent")
|
||||
class DeductiveReasonerAgent(ChatAgent):
|
||||
r"""An agent responsible for deductive reasoning. Model of deductive
|
||||
reasoning:
|
||||
- L: A ⊕ C -> q * B
|
||||
- A represents the known starting state.
|
||||
- B represents the known target state.
|
||||
- C represents the conditions required to transition from A to B.
|
||||
- Q represents the quality or effectiveness of the transition from
|
||||
A to B.
|
||||
- L represents the path or process from A to B.
|
||||
|
||||
Args:
|
||||
model (BaseModelBackend, optional): The model backend to use for
|
||||
generating responses. (default: :obj:`OpenAIModel` with
|
||||
`GPT_4O_MINI`)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[BaseModelBackend] = None,
|
||||
) -> None:
|
||||
system_message = BaseMessage(
|
||||
role_name="Insight Agent",
|
||||
role_type=RoleType.ASSISTANT,
|
||||
meta_dict=None,
|
||||
content="You assign roles based on tasks.",
|
||||
)
|
||||
super().__init__(system_message, model=model)
|
||||
|
||||
def deduce_conditions_and_quality(
|
||||
self,
|
||||
starting_state: str,
|
||||
target_state: str,
|
||||
role_descriptions_dict: Optional[Dict[str, str]] = None,
|
||||
) -> Dict[str, Union[List[str], Dict[str, str]]]:
|
||||
r"""Derives the conditions and quality from the starting state and the
|
||||
target state based on the model of the deductive reasoning and the
|
||||
knowledge base. It can optionally consider the roles involved in the
|
||||
scenario, which allows tailoring the output more closely to the AI
|
||||
agent's environment.
|
||||
|
||||
Args:
|
||||
starting_state (str): The initial or starting state from which
|
||||
conditions are deduced.
|
||||
target_state (str): The target state of the task.
|
||||
role_descriptions_dict (Optional[Dict[str, str]], optional): The
|
||||
descriptions of the roles. (default: :obj:`None`)
|
||||
role_descriptions_dict (Optional[Dict[str, str]], optional): A
|
||||
dictionary describing the roles involved in the scenario. This
|
||||
is optional and can be used to provide a context for the
|
||||
CAMEL's role-playing, enabling the generation of more relevant
|
||||
and tailored conditions and quality assessments. This could be
|
||||
generated using a `RoleAssignmentAgent()` or defined manually
|
||||
by the user.
|
||||
|
||||
Returns:
|
||||
Dict[str, Union[List[str], Dict[str, str]]]: A dictionary with the
|
||||
extracted data from the message. The dictionary contains three
|
||||
keys:
|
||||
- 'conditions': A list where each key is a condition ID and
|
||||
each value is the corresponding condition text.
|
||||
- 'labels': A list of label strings extracted from the message.
|
||||
- 'quality': A string of quality assessment strings extracted
|
||||
from the message.
|
||||
"""
|
||||
self.reset()
|
||||
|
||||
deduce_prompt = """You are a deductive reasoner. You are tasked to
|
||||
complete the TASK based on the THOUGHT OF DEDUCTIVE REASONING, the
|
||||
STARTING STATE A and the TARGET STATE B. You are given the CONTEXT
|
||||
CONTENT to help you complete the TASK.
|
||||
Your answer MUST strictly adhere to the structure of ANSWER TEMPLATE, ONLY
|
||||
fill in the BLANKs, and DO NOT alter or modify any other part of the template
|
||||
|
||||
===== MODELING OF DEDUCTIVE REASONING =====
|
||||
You are tasked with understanding a mathematical model based on the components
|
||||
${A, B, C, Q, L}$. In this model: ``L: A ⊕ C -> q * B``.
|
||||
- $A$ represents the known starting state.
|
||||
- $B$ represents the known target state.
|
||||
- $C$ represents the conditions required to transition from $A$ to $B$.
|
||||
- $Q$ represents the quality or effectiveness of the transition from $A$ to
|
||||
$B$.
|
||||
- $L$ represents the path or process from $A$ to $B$.
|
||||
|
||||
===== THOUGHT OF DEDUCTIVE REASONING =====
|
||||
1. Define the Parameters of A and B:
|
||||
- Characterization: Before delving into transitions, thoroughly understand
|
||||
the nature and boundaries of both $A$ and $B$. This includes the type,
|
||||
properties, constraints, and possible interactions between the two.
|
||||
- Contrast and Compare: Highlight the similarities and differences between
|
||||
$A$ and $B$. This comparative analysis will give an insight into what
|
||||
needs changing and what remains constant.
|
||||
2. Historical & Empirical Analysis:
|
||||
- Previous Transitions according to the Knowledge Base of GPT: (if
|
||||
applicable) Extract conditions and patterns from the historical instances
|
||||
where a similar transition from a state comparable to $A$ moved towards
|
||||
$B$.
|
||||
- Scientific Principles: (if applicable) Consider the underlying
|
||||
scientific principles governing or related to the states and their
|
||||
transition. For example, if $A$ and $B$ are physical states, laws of
|
||||
physics might apply.
|
||||
3. Logical Deduction of Conditions ($C$):
|
||||
- Direct Path Analysis: What are the immediate and direct conditions
|
||||
required to move from $A$ to $B$?
|
||||
- Intermediate States: Are there states between $A$ and $B$ that must be
|
||||
traversed or can be used to make the transition smoother or more
|
||||
efficient? If yes, what is the content?
|
||||
- Constraints & Limitations: Identify potential barriers or restrictions
|
||||
in moving from $A$ to $B$. These can be external (e.g., environmental
|
||||
factors) or internal (properties of $A$ or $B$).
|
||||
- Resource and Information Analysis: What resources and information are
|
||||
required for the transition? This could be time, entity, factor, code
|
||||
language, software platform, unknowns, etc.
|
||||
- External Influences: Consider socio-economic, political, or
|
||||
environmental factors (if applicable) that could influence the transition
|
||||
conditions.
|
||||
- Creative/Heuristic Reasoning: Open your mind to multiple possible $C$'s,
|
||||
no matter how unconventional they might seem. Utilize analogies,
|
||||
metaphors, or brainstorming techniques to envision possible conditions or
|
||||
paths from $A$ to $B$.
|
||||
- The conditions $C$ should be multiple but in one sentence. And each
|
||||
condition should be concerned with one aspect/entity.
|
||||
4. Entity/Label Recognition of Conditions ($C$):
|
||||
- Identify and categorize entities of Conditions ($C$) such as the names,
|
||||
locations, dates, specific technical terms or contextual parameters that
|
||||
might be associated with events, innovations post-2022.
|
||||
- The output of the entities/labels will be used as tags or labels for
|
||||
semantic similarity searches. The entities/labels may be the words, or
|
||||
phrases, each of them should contain valuable, high information entropy
|
||||
information, and should be independent.
|
||||
- Ensure that the identified entities are formatted in a manner suitable
|
||||
for database indexing and retrieval. Organize the entities into
|
||||
categories, and combine the category with its instance into a continuous
|
||||
phrase, without using colons or other separators.
|
||||
- Format these entities for database indexing: output the category rather
|
||||
than its instance/content into a continuous phrase. For example, instead
|
||||
of "Jan. 02", identify it as "Event time".
|
||||
5. Quality Assessment ($Q$):
|
||||
- Efficiency: How efficient is the transition from $A$ to $B$, which
|
||||
measures the resources used versus the desired outcome?
|
||||
- Effectiveness: Did the transition achieve the desired outcome or was the
|
||||
target state achieved as intended?
|
||||
- Safety & Risks: Assess any risks associated with the transition and the
|
||||
measures to mitigate them.
|
||||
- Feedback Mechanisms: Incorporate feedback loops to continuously monitor
|
||||
and adjust the quality of transition, making it more adaptive.
|
||||
6. Iterative Evaluation:
|
||||
- Test & Refine: Based on the initially deduced conditions and assessed
|
||||
quality, iterate the process to refine and optimize the transition. This
|
||||
might involve tweaking conditions, employing different paths, or changing
|
||||
resources.
|
||||
- Feedback Integration: Use feedback to make improvements and increase the
|
||||
quality of the transition.
|
||||
7. Real-world scenarios often present challenges that may not be captured by
|
||||
models and frameworks. While using the model, maintain an adaptive mindset:
|
||||
- Scenario Exploration: Continuously imagine various possible scenarios,
|
||||
both positive and negative, to prepare for unexpected events.
|
||||
- Flexibility: Be prepared to modify conditions ($C$) or alter the path/
|
||||
process ($L$) if unforeseen challenges arise.
|
||||
- Feedback Integration: Rapidly integrate feedback from actual
|
||||
implementations to adjust the model's application, ensuring relevancy and
|
||||
effectiveness.
|
||||
|
||||
===== TASK =====
|
||||
Given the starting state $A$ and the target state $B$, assuming that a path
|
||||
$L$ always exists between $A$ and $B$, how can one deduce or identify the
|
||||
necessary conditions $C$ and the quality $Q$ of the transition?
|
||||
|
||||
===== STARTING STATE $A$ =====
|
||||
{starting_state}
|
||||
|
||||
===== TARGET STATE $B$ =====
|
||||
{target_state}
|
||||
|
||||
{role_with_description_prompt}
|
||||
===== ANSWER TEMPLATE =====
|
||||
- Characterization and comparison of $A$ and $B$:\n<BLANK>
|
||||
- Historical & Empirical Analysis:\n<BLANK>/None
|
||||
- Logical Deduction of Conditions ($C$) (multiple conditions can be deduced):
|
||||
condition <NUM>:
|
||||
<BLANK>.
|
||||
- Entity/Label Recognition of Conditions:\n[<BLANK>, <BLANK>, ...] (include
|
||||
square brackets)
|
||||
- Quality Assessment ($Q$) (do not use symbols):
|
||||
<BLANK>.
|
||||
- Iterative Evaluation:\n<BLANK>/None"""
|
||||
|
||||
if role_descriptions_dict is not None:
|
||||
role_names = role_descriptions_dict.keys()
|
||||
role_with_description_prompt = (
|
||||
"===== ROLES WITH DESCRIPTIONS =====\n"
|
||||
+ "\n".join(
|
||||
f"{role_name}:\n{role_descriptions_dict[role_name]}\n"
|
||||
for role_name in role_names
|
||||
)
|
||||
+ "\n\n"
|
||||
)
|
||||
else:
|
||||
role_with_description_prompt = ""
|
||||
deduce_prompt = TextPrompt(deduce_prompt)
|
||||
|
||||
deduce = deduce_prompt.format(
|
||||
starting_state=starting_state,
|
||||
target_state=target_state,
|
||||
role_with_description_prompt=role_with_description_prompt,
|
||||
)
|
||||
|
||||
conditions_and_quality_generation_msg = BaseMessage.make_user_message(
|
||||
role_name="Deductive Reasoner", content=deduce
|
||||
)
|
||||
|
||||
response = self.step(
|
||||
input_message=conditions_and_quality_generation_msg
|
||||
)
|
||||
|
||||
if response.terminated:
|
||||
raise RuntimeError(
|
||||
"Deduction failed. Error:\n" + f"{response.info}"
|
||||
)
|
||||
msg: BaseMessage = response.msg
|
||||
logger.info(f"Message content:\n{msg.content}")
|
||||
|
||||
# Extract the conditions from the message
|
||||
conditions_dict = {
|
||||
f"condition {i}": cdt.replace("<", "")
|
||||
.replace(">", "")
|
||||
.strip()
|
||||
.strip('\n')
|
||||
for i, cdt in re.findall(
|
||||
r"condition (\d+):\s*(.+?)(?=condition \d+|- Entity)",
|
||||
msg.content,
|
||||
re.DOTALL,
|
||||
)
|
||||
}
|
||||
|
||||
# Extract the labels from the message
|
||||
labels = [
|
||||
label.strip().strip('\n').strip("\"'")
|
||||
for label in re.findall(
|
||||
r"Entity/Label Recognition of Conditions:\n\[(.+?)\]",
|
||||
msg.content,
|
||||
re.DOTALL,
|
||||
)[0].split(",")
|
||||
]
|
||||
|
||||
# Extract the quality from the message
|
||||
quality = next(
|
||||
q.strip().strip('\n')
|
||||
for q in re.findall(
|
||||
r"Quality Assessment \(\$Q\$\) \(do not use symbols\):"
|
||||
r"\n(.+?)- Iterative",
|
||||
msg.content,
|
||||
re.DOTALL,
|
||||
)
|
||||
)
|
||||
|
||||
# Convert them into JSON format
|
||||
conditions_and_quality_json: Dict[
|
||||
str, Union[List[str], Dict[str, str]]
|
||||
] = {}
|
||||
conditions_and_quality_json["conditions"] = conditions_dict
|
||||
conditions_and_quality_json["labels"] = labels
|
||||
conditions_and_quality_json["evaluate_quality"] = quality
|
||||
|
||||
return conditions_and_quality_json
|
||||
@@ -1,201 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from colorama import Fore
|
||||
|
||||
from camel.agents.chat_agent import ChatAgent
|
||||
from camel.agents.tool_agents.base import BaseToolAgent
|
||||
from camel.interpreters import (
|
||||
BaseInterpreter,
|
||||
InternalPythonInterpreter,
|
||||
SubprocessInterpreter,
|
||||
)
|
||||
from camel.messages import BaseMessage
|
||||
from camel.models import BaseModelBackend
|
||||
from camel.responses import ChatAgentResponse
|
||||
from camel.utils import print_text_animated
|
||||
|
||||
# AgentOps decorator setting
|
||||
try:
|
||||
import os
|
||||
|
||||
if os.getenv("AGENTOPS_API_KEY") is not None:
|
||||
from agentops import track_agent
|
||||
else:
|
||||
raise ImportError
|
||||
except (ImportError, AttributeError):
|
||||
from camel.utils import track_agent
|
||||
|
||||
|
||||
@track_agent(name="EmbodiedAgent")
|
||||
class EmbodiedAgent(ChatAgent):
|
||||
r"""Class for managing conversations of CAMEL Embodied Agents.
|
||||
|
||||
Args:
|
||||
system_message (BaseMessage): The system message for the chat agent.
|
||||
model (BaseModelBackend, optional): The model backend to use for
|
||||
generating responses. (default: :obj:`OpenAIModel` with
|
||||
`GPT_4O_MINI`)
|
||||
message_window_size (int, optional): The maximum number of previous
|
||||
messages to include in the context window. If `None`, no windowing
|
||||
is performed. (default: :obj:`None`)
|
||||
tool_agents (List[BaseToolAgent], optional): The tools agents to use in
|
||||
the embodied agent. (default: :obj:`None`)
|
||||
code_interpreter (BaseInterpreter, optional): The code interpreter to
|
||||
execute codes. If `code_interpreter` and `tool_agent` are both
|
||||
`None`, default to `SubProcessInterpreter`. If `code_interpreter`
|
||||
is `None` and `tool_agents` is not `None`, default to
|
||||
`InternalPythonInterpreter`. (default: :obj:`None`)
|
||||
verbose (bool, optional): Whether to print the critic's messages.
|
||||
logger_color (Any): The color of the logger displayed to the user.
|
||||
(default: :obj:`Fore.MAGENTA`)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
system_message: BaseMessage,
|
||||
model: Optional[BaseModelBackend] = None,
|
||||
message_window_size: Optional[int] = None,
|
||||
tool_agents: Optional[List[BaseToolAgent]] = None,
|
||||
code_interpreter: Optional[BaseInterpreter] = None,
|
||||
verbose: bool = False,
|
||||
logger_color: Any = Fore.MAGENTA,
|
||||
) -> None:
|
||||
self.tool_agents = tool_agents
|
||||
self.code_interpreter: BaseInterpreter
|
||||
if code_interpreter is not None:
|
||||
self.code_interpreter = code_interpreter
|
||||
elif self.tool_agents:
|
||||
self.code_interpreter = InternalPythonInterpreter()
|
||||
else:
|
||||
self.code_interpreter = SubprocessInterpreter()
|
||||
|
||||
if self.tool_agents:
|
||||
system_message = self._set_tool_agents(system_message)
|
||||
self.verbose = verbose
|
||||
self.logger_color = logger_color
|
||||
super().__init__(
|
||||
system_message=system_message,
|
||||
model=model,
|
||||
message_window_size=message_window_size,
|
||||
)
|
||||
|
||||
def _set_tool_agents(self, system_message: BaseMessage) -> BaseMessage:
|
||||
action_space_prompt = self._get_tool_agents_prompt()
|
||||
result_message = system_message.create_new_instance(
|
||||
content=system_message.content.format(
|
||||
action_space=action_space_prompt
|
||||
)
|
||||
)
|
||||
if self.tool_agents is not None:
|
||||
self.code_interpreter.update_action_space(
|
||||
{tool.name: tool for tool in self.tool_agents}
|
||||
)
|
||||
return result_message
|
||||
|
||||
def _get_tool_agents_prompt(self) -> str:
|
||||
r"""Returns the action space prompt.
|
||||
|
||||
Returns:
|
||||
str: The action space prompt.
|
||||
"""
|
||||
if self.tool_agents is not None:
|
||||
return "\n".join(
|
||||
[
|
||||
f"*** {tool.name} ***:\n {tool.description}"
|
||||
for tool in self.tool_agents
|
||||
]
|
||||
)
|
||||
else:
|
||||
return ""
|
||||
|
||||
def get_tool_agent_names(self) -> List[str]:
|
||||
r"""Returns the names of tool agents.
|
||||
|
||||
Returns:
|
||||
List[str]: The names of tool agents.
|
||||
"""
|
||||
if self.tool_agents is not None:
|
||||
return [tool.name for tool in self.tool_agents]
|
||||
else:
|
||||
return []
|
||||
|
||||
# ruff: noqa: E501
|
||||
def step(self, input_message: BaseMessage) -> ChatAgentResponse: # type: ignore[override]
|
||||
r"""Performs a step in the conversation.
|
||||
|
||||
Args:
|
||||
input_message (BaseMessage): The input message.
|
||||
|
||||
Returns:
|
||||
ChatAgentResponse: A struct containing the output messages,
|
||||
a boolean indicating whether the chat session has terminated,
|
||||
and information about the chat session.
|
||||
"""
|
||||
response = super().step(input_message)
|
||||
|
||||
if response.msgs is None or len(response.msgs) == 0:
|
||||
raise RuntimeError("Got None output messages.")
|
||||
if response.terminated:
|
||||
raise RuntimeError(f"{self.__class__.__name__} step failed.")
|
||||
|
||||
# NOTE: Only single output messages are supported
|
||||
explanations, codes = response.msg.extract_text_and_code_prompts()
|
||||
|
||||
if self.verbose:
|
||||
for explanation, code in zip(explanations, codes):
|
||||
print_text_animated(
|
||||
self.logger_color + f"> Explanation:\n{explanation}"
|
||||
)
|
||||
print_text_animated(self.logger_color + f"> Code:\n{code}")
|
||||
|
||||
if len(explanations) > len(codes):
|
||||
print_text_animated(
|
||||
self.logger_color + f"> Explanation:\n{explanations[-1]}"
|
||||
)
|
||||
|
||||
content = response.msg.content
|
||||
|
||||
if codes is not None:
|
||||
try:
|
||||
content = "\n> Executed Results:\n"
|
||||
for block_idx, code in enumerate(codes):
|
||||
executed_output = self.code_interpreter.run(
|
||||
code, code.code_type
|
||||
)
|
||||
content += (
|
||||
f"Executing code block {block_idx}: {{\n"
|
||||
+ executed_output
|
||||
+ "}\n"
|
||||
)
|
||||
except InterruptedError as e:
|
||||
content = (
|
||||
f"\n> Running code fail: {e}\n"
|
||||
"Please regenerate the code."
|
||||
)
|
||||
|
||||
# TODO: Handle errors
|
||||
content = input_message.content + f"\n> Embodied Actions:\n{content}"
|
||||
message = BaseMessage(
|
||||
input_message.role_name,
|
||||
input_message.role_type,
|
||||
input_message.meta_dict,
|
||||
content,
|
||||
)
|
||||
return ChatAgentResponse(
|
||||
msgs=[message],
|
||||
terminated=response.terminated,
|
||||
info=response.info,
|
||||
)
|
||||
@@ -1,259 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from unstructured.documents.elements import Element
|
||||
|
||||
from camel.agents import ChatAgent
|
||||
from camel.messages import BaseMessage
|
||||
from camel.models import BaseModelBackend
|
||||
from camel.prompts import TextPrompt
|
||||
from camel.storages.graph_storages.graph_element import (
|
||||
GraphElement,
|
||||
Node,
|
||||
Relationship,
|
||||
)
|
||||
from camel.types import RoleType
|
||||
|
||||
# AgentOps decorator setting
|
||||
try:
|
||||
import os
|
||||
|
||||
if os.getenv("AGENTOPS_API_KEY") is not None:
|
||||
from agentops import track_agent
|
||||
else:
|
||||
raise ImportError
|
||||
except (ImportError, AttributeError):
|
||||
from camel.utils import track_agent
|
||||
|
||||
|
||||
text_prompt = """
|
||||
You are tasked with extracting nodes and relationships from given content and
|
||||
structures them into Node and Relationship objects. Here's the outline of what
|
||||
you needs to do:
|
||||
|
||||
Content Extraction:
|
||||
You should be able to process input content and identify entities mentioned
|
||||
within it.
|
||||
Entities can be any noun phrases or concepts that represent distinct entities
|
||||
in the context of the given content.
|
||||
|
||||
Node Extraction:
|
||||
For each identified entity, you should create a Node object.
|
||||
Each Node object should have a unique identifier (id) and a type (type).
|
||||
Additional properties associated with the node can also be extracted and
|
||||
stored.
|
||||
|
||||
Relationship Extraction:
|
||||
You should identify relationships between entities mentioned in the content.
|
||||
For each relationship, create a Relationship object.
|
||||
A Relationship object should have a subject (subj) and an object (obj) which
|
||||
are Node objects representing the entities involved in the relationship.
|
||||
Each relationship should also have a type (type), and additional properties if
|
||||
applicable.
|
||||
|
||||
Output Formatting:
|
||||
The extracted nodes and relationships should be formatted as instances of the
|
||||
provided Node and Relationship classes.
|
||||
Ensure that the extracted data adheres to the structure defined by the classes.
|
||||
Output the structured data in a format that can be easily validated against
|
||||
the provided code.
|
||||
|
||||
Instructions for you:
|
||||
Read the provided content thoroughly.
|
||||
Identify distinct entities mentioned in the content and categorize them as
|
||||
nodes.
|
||||
Determine relationships between these entities and represent them as directed
|
||||
relationships.
|
||||
Provide the extracted nodes and relationships in the specified format below.
|
||||
Example for you:
|
||||
|
||||
Example Content:
|
||||
"John works at XYZ Corporation. He is a software engineer. The company is
|
||||
located in New York City."
|
||||
|
||||
Expected Output:
|
||||
|
||||
Nodes:
|
||||
|
||||
Node(id='John', type='Person')
|
||||
Node(id='XYZ Corporation', type='Organization')
|
||||
Node(id='New York City', type='Location')
|
||||
|
||||
Relationships:
|
||||
|
||||
Relationship(subj=Node(id='John', type='Person'), obj=Node(id='XYZ
|
||||
Corporation', type='Organization'), type='WorksAt')
|
||||
Relationship(subj=Node(id='John', type='Person'), obj=Node(id='New York City',
|
||||
type='Location'), type='ResidesIn')
|
||||
|
||||
===== TASK =====
|
||||
Please extracts nodes and relationships from given content and structures them
|
||||
into Node and Relationship objects.
|
||||
|
||||
{task}
|
||||
"""
|
||||
|
||||
|
||||
@track_agent(name="KnowledgeGraphAgent")
|
||||
class KnowledgeGraphAgent(ChatAgent):
|
||||
r"""An agent that can extract node and relationship information for
|
||||
different entities from given `Element` content.
|
||||
|
||||
Attributes:
|
||||
task_prompt (TextPrompt): A prompt for the agent to extract node and
|
||||
relationship information for different entities.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[BaseModelBackend] = None,
|
||||
) -> None:
|
||||
r"""Initialize the `KnowledgeGraphAgent`.
|
||||
|
||||
Args:
|
||||
model (BaseModelBackend, optional): The model backend to use for
|
||||
generating responses. (default: :obj:`OpenAIModel` with
|
||||
`GPT_4O_MINI`)
|
||||
"""
|
||||
system_message = BaseMessage(
|
||||
role_name="Graphify",
|
||||
role_type=RoleType.ASSISTANT,
|
||||
meta_dict=None,
|
||||
content="Your mission is to transform unstructured content "
|
||||
"into structured graph data. Extract nodes and relationships with "
|
||||
"precision, and let the connections unfold. Your graphs will "
|
||||
"illuminate the hidden connections within the chaos of "
|
||||
"information.",
|
||||
)
|
||||
super().__init__(system_message, model=model)
|
||||
|
||||
def run(
|
||||
self,
|
||||
element: "Element",
|
||||
parse_graph_elements: bool = False,
|
||||
) -> Union[str, GraphElement]:
|
||||
r"""Run the agent to extract node and relationship information.
|
||||
|
||||
Args:
|
||||
element (Element): The input element.
|
||||
parse_graph_elements (bool, optional): Whether to parse into
|
||||
`GraphElement`. Defaults to `False`.
|
||||
|
||||
Returns:
|
||||
Union[str, GraphElement]: The extracted node and relationship
|
||||
information. If `parse_graph_elements` is `True` then return
|
||||
`GraphElement`, else return `str`.
|
||||
"""
|
||||
self.reset()
|
||||
self.element = element
|
||||
|
||||
knowledge_graph_prompt = TextPrompt(text_prompt)
|
||||
knowledge_graph_generation = knowledge_graph_prompt.format(
|
||||
task=str(element)
|
||||
)
|
||||
|
||||
knowledge_graph_generation_msg = BaseMessage.make_user_message(
|
||||
role_name="Graphify", content=knowledge_graph_generation
|
||||
)
|
||||
|
||||
response = self.step(input_message=knowledge_graph_generation_msg)
|
||||
|
||||
content = response.msg.content
|
||||
|
||||
if parse_graph_elements:
|
||||
content = self._parse_graph_elements(content)
|
||||
|
||||
return content
|
||||
|
||||
def _validate_node(self, node: Node) -> bool:
|
||||
r"""Validate if the object is a valid Node.
|
||||
|
||||
Args:
|
||||
node (Node): Object to be validated.
|
||||
|
||||
Returns:
|
||||
bool: True if the object is a valid Node, False otherwise.
|
||||
"""
|
||||
return (
|
||||
isinstance(node, Node)
|
||||
and isinstance(node.id, (str, int))
|
||||
and isinstance(node.type, str)
|
||||
)
|
||||
|
||||
def _validate_relationship(self, relationship: Relationship) -> bool:
|
||||
r"""Validate if the object is a valid Relationship.
|
||||
|
||||
Args:
|
||||
relationship (Relationship): Object to be validated.
|
||||
|
||||
Returns:
|
||||
bool: True if the object is a valid Relationship, False otherwise.
|
||||
"""
|
||||
return (
|
||||
isinstance(relationship, Relationship)
|
||||
and self._validate_node(relationship.subj)
|
||||
and self._validate_node(relationship.obj)
|
||||
and isinstance(relationship.type, str)
|
||||
)
|
||||
|
||||
def _parse_graph_elements(self, input_string: str) -> GraphElement:
|
||||
r"""Parses graph elements from given content.
|
||||
|
||||
Args:
|
||||
input_string (str): The input content.
|
||||
|
||||
Returns:
|
||||
GraphElement: The parsed graph elements.
|
||||
"""
|
||||
import re
|
||||
|
||||
# Regular expressions to extract nodes and relationships
|
||||
node_pattern = r"Node\(id='(.*?)', type='(.*?)'\)"
|
||||
rel_pattern = (
|
||||
r"Relationship\(subj=Node\(id='(.*?)', type='(.*?)'\), "
|
||||
r"obj=Node\(id='(.*?)', type='(.*?)'\), type='(.*?)'\)"
|
||||
)
|
||||
|
||||
nodes = {}
|
||||
relationships = []
|
||||
|
||||
# Extract nodes
|
||||
for match in re.finditer(node_pattern, input_string):
|
||||
id, type = match.groups()
|
||||
properties = {'source': 'agent_created'}
|
||||
if id not in nodes:
|
||||
node = Node(id=id, type=type, properties=properties)
|
||||
if self._validate_node(node):
|
||||
nodes[id] = node
|
||||
|
||||
# Extract relationships
|
||||
for match in re.finditer(rel_pattern, input_string):
|
||||
subj_id, subj_type, obj_id, obj_type, rel_type = match.groups()
|
||||
properties = {'source': 'agent_created'}
|
||||
if subj_id in nodes and obj_id in nodes:
|
||||
subj = nodes[subj_id]
|
||||
obj = nodes[obj_id]
|
||||
relationship = Relationship(
|
||||
subj=subj, obj=obj, type=rel_type, properties=properties
|
||||
)
|
||||
if self._validate_relationship(relationship):
|
||||
relationships.append(relationship)
|
||||
|
||||
return GraphElement(
|
||||
nodes=list(nodes.values()),
|
||||
relationships=relationships,
|
||||
source=self.element,
|
||||
)
|
||||
@@ -1,141 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
import re
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from camel.agents.chat_agent import ChatAgent
|
||||
from camel.messages import BaseMessage
|
||||
from camel.models import BaseModelBackend
|
||||
from camel.prompts import TextPrompt
|
||||
from camel.types import RoleType
|
||||
|
||||
# AgentOps decorator setting
|
||||
try:
|
||||
import os
|
||||
|
||||
if os.getenv("AGENTOPS_API_KEY") is not None:
|
||||
from agentops import track_agent
|
||||
else:
|
||||
raise ImportError
|
||||
except (ImportError, AttributeError):
|
||||
from camel.utils import track_agent
|
||||
|
||||
|
||||
@track_agent(name="RoleAssignmentAgent")
|
||||
class RoleAssignmentAgent(ChatAgent):
|
||||
r"""An agent that generates role names based on the task prompt.
|
||||
|
||||
Args:
|
||||
model (BaseModelBackend, optional): The model backend to use for
|
||||
generating responses. (default: :obj:`OpenAIModel` with
|
||||
`GPT_4O_MINI`)
|
||||
|
||||
Attributes:
|
||||
role_assignment_prompt (TextPrompt): A prompt for the agent to generate
|
||||
role names.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[BaseModelBackend] = None,
|
||||
) -> None:
|
||||
system_message = BaseMessage(
|
||||
role_name="Role Assigner",
|
||||
role_type=RoleType.ASSISTANT,
|
||||
meta_dict=None,
|
||||
content="You assign roles based on tasks.",
|
||||
)
|
||||
super().__init__(system_message, model=model)
|
||||
|
||||
def run(
|
||||
self,
|
||||
task_prompt: Union[str, TextPrompt],
|
||||
num_roles: int = 2,
|
||||
) -> Dict[str, str]:
|
||||
r"""Generate role names based on the input task prompt.
|
||||
|
||||
Args:
|
||||
task_prompt (Union[str, TextPrompt]): The prompt
|
||||
for the task based on which the roles are to be generated.
|
||||
num_roles (int, optional): The number of roles to generate.
|
||||
(default: :obj:`2`)
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: A dictionary mapping role names to their
|
||||
descriptions.
|
||||
"""
|
||||
self.reset()
|
||||
|
||||
expert_prompt = "===== ANSWER PROMPT =====\n" + "\n".join(
|
||||
f"Domain expert {i + 1}: <BLANK>\n"
|
||||
f"Associated competencies, characteristics, duties "
|
||||
f"and workflows: <BLANK>. End."
|
||||
for i in range(num_roles or 0)
|
||||
)
|
||||
role_assignment_generation_prompt = TextPrompt(
|
||||
"You are a role assignment agent, and you're in charge of "
|
||||
+ "recruiting {num_roles} experts for the following task."
|
||||
+ "\n==== TASK =====\n {task}\n\n"
|
||||
+ "Identify the domain experts you'd recruit and detail their "
|
||||
+ "associated competencies, characteristics, duties and workflows "
|
||||
+ "to complete the task.\n "
|
||||
+ "Your answer MUST adhere to the format of ANSWER PROMPT, and "
|
||||
+ "ONLY answer the BLANKs.\n"
|
||||
+ expert_prompt
|
||||
)
|
||||
role_assignment_generation = role_assignment_generation_prompt.format(
|
||||
num_roles=num_roles, task=task_prompt
|
||||
)
|
||||
|
||||
role_assignment_generation_msg = BaseMessage.make_user_message(
|
||||
role_name="Role Assigner", content=role_assignment_generation
|
||||
)
|
||||
|
||||
response = self.step(input_message=role_assignment_generation_msg)
|
||||
|
||||
msg = response.msg # type: BaseMessage
|
||||
terminated = response.terminated
|
||||
|
||||
# Distribute the output completions into role names and descriptions
|
||||
role_names = [
|
||||
desc.replace("<|", "").replace("|>", "")
|
||||
for desc in re.findall(
|
||||
r"Domain expert \d: (.+?)\nAssociated competencies,",
|
||||
msg.content,
|
||||
re.DOTALL,
|
||||
)
|
||||
]
|
||||
role_descriptions = [
|
||||
desc.replace("<|", "").replace("|>", "")
|
||||
for desc in re.findall(
|
||||
r"Associated competencies, characteristics, "
|
||||
r"duties and workflows: (.+?) End.",
|
||||
msg.content,
|
||||
re.DOTALL,
|
||||
)
|
||||
]
|
||||
|
||||
if len(role_names) != num_roles or len(role_descriptions) != num_roles:
|
||||
raise RuntimeError(
|
||||
"Got None or insufficient information of roles."
|
||||
)
|
||||
if terminated:
|
||||
raise RuntimeError("Role assignment failed.")
|
||||
|
||||
role_descriptions_dict = {
|
||||
role_name: description
|
||||
for role_name, description in zip(role_names, role_descriptions)
|
||||
}
|
||||
|
||||
return role_descriptions_dict
|
||||
@@ -1,133 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from typing import Optional
|
||||
|
||||
from camel.agents.chat_agent import ChatAgent
|
||||
from camel.messages import BaseMessage
|
||||
from camel.models import BaseModelBackend
|
||||
from camel.prompts import TextPrompt
|
||||
from camel.types import RoleType
|
||||
from camel.utils import create_chunks
|
||||
|
||||
# AgentOps decorator setting
|
||||
try:
|
||||
import os
|
||||
|
||||
if os.getenv("AGENTOPS_API_KEY") is not None:
|
||||
from agentops import track_agent
|
||||
else:
|
||||
raise ImportError
|
||||
except (ImportError, AttributeError):
|
||||
from camel.utils import track_agent
|
||||
|
||||
|
||||
@track_agent(name="SearchAgent")
|
||||
class SearchAgent(ChatAgent):
|
||||
r"""An agent that summarizes text based on a query and evaluates the
|
||||
relevance of an answer.
|
||||
|
||||
Args:
|
||||
model (BaseModelBackend, optional): The model backend to use for
|
||||
generating responses. (default: :obj:`OpenAIModel` with
|
||||
`GPT_4O_MINI`)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[BaseModelBackend] = None,
|
||||
) -> None:
|
||||
system_message = BaseMessage(
|
||||
role_name="Assistant",
|
||||
role_type=RoleType.ASSISTANT,
|
||||
meta_dict=None,
|
||||
content="You are a helpful assistant.",
|
||||
)
|
||||
super().__init__(system_message, model=model)
|
||||
|
||||
def summarize_text(self, text: str, query: str) -> str:
|
||||
r"""Summarize the information from the text, base on the query.
|
||||
|
||||
Args:
|
||||
text (str): Text to summarize.
|
||||
query (str): What information you want.
|
||||
|
||||
Returns:
|
||||
str: Strings with information.
|
||||
"""
|
||||
self.reset()
|
||||
|
||||
summary_prompt = TextPrompt(
|
||||
'''Gather information from this text that relative to the
|
||||
question, but do not directly answer the question.\nquestion:
|
||||
{query}\ntext '''
|
||||
)
|
||||
summary_prompt = summary_prompt.format(query=query)
|
||||
# Max length of each chunk
|
||||
max_len = 3000
|
||||
results = ""
|
||||
chunks = create_chunks(text, max_len)
|
||||
# Summarize
|
||||
for i, chunk in enumerate(chunks, start=1):
|
||||
prompt = summary_prompt + str(i) + ": " + chunk
|
||||
user_msg = BaseMessage.make_user_message(
|
||||
role_name="User",
|
||||
content=prompt,
|
||||
)
|
||||
result = self.step(user_msg).msg.content
|
||||
results += result + "\n"
|
||||
|
||||
# Final summarization
|
||||
final_prompt = TextPrompt(
|
||||
'''Here are some summarized texts which split from one text. Using
|
||||
the information to answer the question. If can't find the answer,
|
||||
you must answer "I can not find the answer to the query" and
|
||||
explain why.\n Query:\n{query}.\n\nText:\n'''
|
||||
)
|
||||
final_prompt = final_prompt.format(query=query)
|
||||
prompt = final_prompt + results
|
||||
|
||||
user_msg = BaseMessage.make_user_message(
|
||||
role_name="User",
|
||||
content=prompt,
|
||||
)
|
||||
response = self.step(user_msg).msg.content
|
||||
|
||||
return response
|
||||
|
||||
def continue_search(self, query: str, answer: str) -> bool:
|
||||
r"""Ask whether to continue search or not based on the provided answer.
|
||||
|
||||
Args:
|
||||
query (str): The question.
|
||||
answer (str): The answer to the question.
|
||||
|
||||
Returns:
|
||||
bool: `True` if the user want to continue search, `False`
|
||||
otherwise.
|
||||
"""
|
||||
prompt = TextPrompt(
|
||||
"Do you think the ANSWER can answer the QUERY? "
|
||||
"Use only 'yes' or 'no' to answer.\n"
|
||||
"===== QUERY =====\n{query}\n\n"
|
||||
"===== ANSWER =====\n{answer}"
|
||||
)
|
||||
prompt = prompt.format(query=query, answer=answer)
|
||||
user_msg = BaseMessage.make_user_message(
|
||||
role_name="User",
|
||||
content=prompt,
|
||||
)
|
||||
response = self.step(user_msg).msg.content
|
||||
if "yes" in str(response).lower():
|
||||
return False
|
||||
return True
|
||||
@@ -1,410 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from camel.agents.chat_agent import ChatAgent
|
||||
from camel.messages import BaseMessage
|
||||
from camel.models import BaseModelBackend
|
||||
from camel.prompts import PromptTemplateGenerator, TextPrompt
|
||||
from camel.types import RoleType, TaskType
|
||||
from camel.utils import get_task_list
|
||||
|
||||
# AgentOps decorator setting
|
||||
try:
|
||||
import os
|
||||
|
||||
if os.getenv("AGENTOPS_API_KEY") is not None:
|
||||
from agentops import track_agent
|
||||
else:
|
||||
raise ImportError
|
||||
except (ImportError, AttributeError):
|
||||
from camel.utils import track_agent
|
||||
|
||||
|
||||
@track_agent(name="TaskSpecifyAgent")
|
||||
class TaskSpecifyAgent(ChatAgent):
|
||||
r"""An agent that specifies a given task prompt by prompting the user to
|
||||
provide more details.
|
||||
|
||||
Attributes:
|
||||
DEFAULT_WORD_LIMIT (int): The default word limit for the task prompt.
|
||||
task_specify_prompt (TextPrompt): The prompt for specifying the task.
|
||||
|
||||
Args:
|
||||
model (BaseModelBackend, optional): The model backend to use for
|
||||
generating responses. (default: :obj:`OpenAIModel` with
|
||||
`GPT_4O_MINI`)
|
||||
task_type (TaskType, optional): The type of task for which to generate
|
||||
a prompt. (default: :obj:`TaskType.AI_SOCIETY`)
|
||||
task_specify_prompt (Union[str, TextPrompt], optional): The prompt for
|
||||
specifying the task. (default: :obj:`None`)
|
||||
word_limit (int, optional): The word limit for the task prompt.
|
||||
(default: :obj:`50`)
|
||||
output_language (str, optional): The language to be output by the
|
||||
agent. (default: :obj:`None`)
|
||||
"""
|
||||
|
||||
DEFAULT_WORD_LIMIT = 50
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[BaseModelBackend] = None,
|
||||
task_type: TaskType = TaskType.AI_SOCIETY,
|
||||
task_specify_prompt: Optional[Union[str, TextPrompt]] = None,
|
||||
word_limit: int = DEFAULT_WORD_LIMIT,
|
||||
output_language: Optional[str] = None,
|
||||
) -> None:
|
||||
self.task_specify_prompt: Union[str, TextPrompt]
|
||||
if task_specify_prompt is None:
|
||||
task_specify_prompt_template = (
|
||||
PromptTemplateGenerator().get_task_specify_prompt(task_type)
|
||||
)
|
||||
|
||||
self.task_specify_prompt = task_specify_prompt_template.format(
|
||||
word_limit=word_limit
|
||||
)
|
||||
else:
|
||||
self.task_specify_prompt = TextPrompt(task_specify_prompt)
|
||||
|
||||
system_message = BaseMessage(
|
||||
role_name="Task Specifier",
|
||||
role_type=RoleType.ASSISTANT,
|
||||
meta_dict=None,
|
||||
content="You can make a task more specific.",
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
system_message,
|
||||
model=model,
|
||||
output_language=output_language,
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
task_prompt: Union[str, TextPrompt],
|
||||
meta_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> TextPrompt:
|
||||
r"""Specify the given task prompt by providing more details.
|
||||
|
||||
Args:
|
||||
task_prompt (Union[str, TextPrompt]): The original task
|
||||
prompt.
|
||||
meta_dict (Dict[str, Any], optional): A dictionary containing
|
||||
additional information to include in the prompt.
|
||||
(default: :obj:`None`)
|
||||
|
||||
Returns:
|
||||
TextPrompt: The specified task prompt.
|
||||
"""
|
||||
self.reset()
|
||||
task_specify_prompt = self.task_specify_prompt.format(task=task_prompt)
|
||||
|
||||
if meta_dict is not None:
|
||||
task_specify_prompt = task_specify_prompt.format(**meta_dict)
|
||||
task_msg = BaseMessage.make_user_message(
|
||||
role_name="Task Specifier", content=task_specify_prompt
|
||||
)
|
||||
specifier_response = self.step(task_msg)
|
||||
|
||||
if specifier_response.terminated:
|
||||
raise RuntimeError("Task specification failed.")
|
||||
if len(specifier_response.msgs) == 0:
|
||||
raise RuntimeError("Got no specification message.")
|
||||
|
||||
specified_task_msg = specifier_response.msgs[0]
|
||||
|
||||
return TextPrompt(specified_task_msg.content)
|
||||
|
||||
|
||||
@track_agent(name="TaskPlannerAgent")
|
||||
class TaskPlannerAgent(ChatAgent):
|
||||
r"""An agent that helps divide a task into subtasks based on the input
|
||||
task prompt.
|
||||
|
||||
Attributes:
|
||||
task_planner_prompt (TextPrompt): A prompt for the agent to divide
|
||||
the task into subtasks.
|
||||
|
||||
Args:
|
||||
model (BaseModelBackend, optional): The model backend to use for
|
||||
generating responses. (default: :obj:`OpenAIModel` with
|
||||
`GPT_4O_MINI`)
|
||||
output_language (str, optional): The language to be output by the
|
||||
agent. (default: :obj:`None`)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[BaseModelBackend] = None,
|
||||
output_language: Optional[str] = None,
|
||||
) -> None:
|
||||
self.task_planner_prompt = TextPrompt(
|
||||
"Divide this task into subtasks: {task}. Be concise."
|
||||
)
|
||||
system_message = BaseMessage(
|
||||
role_name="Task Planner",
|
||||
role_type=RoleType.ASSISTANT,
|
||||
meta_dict=None,
|
||||
content="You are a helpful task planner.",
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
system_message,
|
||||
model=model,
|
||||
output_language=output_language,
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
task_prompt: Union[str, TextPrompt],
|
||||
) -> TextPrompt:
|
||||
r"""Generate subtasks based on the input task prompt.
|
||||
|
||||
Args:
|
||||
task_prompt (Union[str, TextPrompt]): The prompt for the task to
|
||||
be divided into subtasks.
|
||||
|
||||
Returns:
|
||||
TextPrompt: A prompt for the subtasks generated by the agent.
|
||||
"""
|
||||
# TODO: Maybe include roles information.
|
||||
self.reset()
|
||||
task_planner_prompt = self.task_planner_prompt.format(task=task_prompt)
|
||||
|
||||
task_msg = BaseMessage.make_user_message(
|
||||
role_name="Task Planner", content=task_planner_prompt
|
||||
)
|
||||
|
||||
task_response = self.step(task_msg)
|
||||
|
||||
if task_response.terminated:
|
||||
raise RuntimeError("Task planning failed.")
|
||||
if len(task_response.msgs) == 0:
|
||||
raise RuntimeError("Got no task planning message.")
|
||||
|
||||
sub_tasks_msg = task_response.msgs[0]
|
||||
return TextPrompt(sub_tasks_msg.content)
|
||||
|
||||
|
||||
@track_agent(name="TaskCreationAgent")
|
||||
class TaskCreationAgent(ChatAgent):
|
||||
r"""An agent that helps create new tasks based on the objective
|
||||
and last completed task. Compared to :obj:`TaskPlannerAgent`,
|
||||
it's still a task planner, but it has more context information
|
||||
like last task and incomplete task list. Modified from
|
||||
`BabyAGI <https://github.com/yoheinakajima/babyagi>`_.
|
||||
|
||||
Attributes:
|
||||
task_creation_prompt (TextPrompt): A prompt for the agent to
|
||||
create new tasks.
|
||||
|
||||
Args:
|
||||
role_name (str): The role name of the Agent to create the task.
|
||||
objective (Union[str, TextPrompt]): The objective of the Agent to
|
||||
perform the task.
|
||||
model (BaseModelBackend, optional): The LLM backend to use for
|
||||
generating responses. (default: :obj:`OpenAIModel` with
|
||||
`GPT_4O_MINI`)
|
||||
output_language (str, optional): The language to be output by the
|
||||
agent. (default: :obj:`None`)
|
||||
message_window_size (int, optional): The maximum number of previous
|
||||
messages to include in the context window. If `None`, no windowing
|
||||
is performed. (default: :obj:`None`)
|
||||
max_task_num (int, optional): The maximum number of planned
|
||||
tasks in one round. (default: :obj:3)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
role_name: str,
|
||||
objective: Union[str, TextPrompt],
|
||||
model: Optional[BaseModelBackend] = None,
|
||||
output_language: Optional[str] = None,
|
||||
message_window_size: Optional[int] = None,
|
||||
max_task_num: Optional[int] = 3,
|
||||
) -> None:
|
||||
task_creation_prompt = TextPrompt(
|
||||
"""Create new a task with the following objective: {objective}.
|
||||
Never forget you are a Task Creator of {role_name}.
|
||||
You must instruct me based on my expertise and your needs to solve the task.
|
||||
You should consider past solved tasks and in-progress tasks: {task_list}.
|
||||
The new created tasks must not overlap with these past tasks.
|
||||
The result must be a numbered list in the format:
|
||||
|
||||
#. First Task
|
||||
#. Second Task
|
||||
#. Third Task
|
||||
|
||||
You can only give me up to {max_task_num} tasks at a time. \
|
||||
Each task should be concise, concrete and doable for a {role_name}.
|
||||
You should make task plan and not ask me questions.
|
||||
If you think no new tasks are needed right now, write "No tasks to add."
|
||||
Now start to give me new tasks one by one. No more than three tasks.
|
||||
Be concrete.
|
||||
"""
|
||||
)
|
||||
|
||||
self.task_creation_prompt = task_creation_prompt.format(
|
||||
objective=objective, role_name=role_name, max_task_num=max_task_num
|
||||
)
|
||||
self.objective = objective
|
||||
|
||||
system_message = BaseMessage(
|
||||
role_name="Task Creator",
|
||||
role_type=RoleType.ASSISTANT,
|
||||
meta_dict=None,
|
||||
content="You are a helpful task creator.",
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
system_message,
|
||||
model=model,
|
||||
output_language=output_language,
|
||||
message_window_size=message_window_size,
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
task_list: List[str],
|
||||
) -> List[str]:
|
||||
r"""Generate subtasks based on the previous task results and
|
||||
incomplete task list.
|
||||
|
||||
Args:
|
||||
task_list (List[str]): The completed or in-progress
|
||||
tasks which should not overlap with new created tasks.
|
||||
|
||||
Returns:
|
||||
List[str]: The new task list generated by the Agent.
|
||||
"""
|
||||
|
||||
if len(task_list) > 0:
|
||||
task_creation_prompt = self.task_creation_prompt.format(
|
||||
task_list=task_list
|
||||
)
|
||||
else:
|
||||
task_creation_prompt = self.task_creation_prompt.format(
|
||||
task_list=""
|
||||
)
|
||||
|
||||
task_msg = BaseMessage.make_user_message(
|
||||
role_name="Task Creator", content=task_creation_prompt
|
||||
)
|
||||
task_response = self.step(task_msg)
|
||||
|
||||
if task_response.terminated:
|
||||
raise RuntimeError("Task creation failed.")
|
||||
if len(task_response.msgs) == 0:
|
||||
raise RuntimeError("Got no task creation message.")
|
||||
|
||||
sub_tasks_msg = task_response.msgs[0]
|
||||
return get_task_list(sub_tasks_msg.content)
|
||||
|
||||
|
||||
@track_agent(name="TaskPrioritizationAgent")
|
||||
class TaskPrioritizationAgent(ChatAgent):
|
||||
r"""An agent that helps re-prioritize the task list and
|
||||
returns numbered prioritized list. Modified from
|
||||
`BabyAGI <https://github.com/yoheinakajima/babyagi>`_.
|
||||
|
||||
Attributes:
|
||||
task_prioritization_prompt (TextPrompt): A prompt for the agent to
|
||||
prioritize tasks.
|
||||
|
||||
Args:
|
||||
objective (Union[str, TextPrompt]): The objective of the Agent to
|
||||
perform the task.
|
||||
model (BaseModelBackend, optional): The LLM backend to use for
|
||||
generating responses. (default: :obj:`OpenAIModel` with
|
||||
`GPT_4O_MINI`)
|
||||
output_language (str, optional): The language to be output by the
|
||||
agent. (default: :obj:`None`)
|
||||
message_window_size (int, optional): The maximum number of previous
|
||||
messages to include in the context window. If `None`, no windowing
|
||||
is performed. (default: :obj:`None`)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
objective: Union[str, TextPrompt],
|
||||
model: Optional[BaseModelBackend] = None,
|
||||
output_language: Optional[str] = None,
|
||||
message_window_size: Optional[int] = None,
|
||||
) -> None:
|
||||
task_prioritization_prompt = TextPrompt(
|
||||
"""Prioritize the following tasks : {task_list}.
|
||||
Consider the ultimate objective of you: {objective}.
|
||||
Tasks should be sorted from highest to lowest priority, where higher-priority \
|
||||
tasks are those that act as pre-requisites or are more essential for meeting \
|
||||
the objective. Return one task per line in your response.
|
||||
Do not remove or modify any tasks.
|
||||
The result must be a numbered list in the format:
|
||||
|
||||
#. First task
|
||||
#. Second task
|
||||
|
||||
The entries must be consecutively numbered, starting with 1.
|
||||
The number of each entry must be followed by a period.
|
||||
Do not include any headers before your ranked list or follow your list \
|
||||
with any other output."""
|
||||
)
|
||||
|
||||
self.task_prioritization_prompt = task_prioritization_prompt.format(
|
||||
objective=objective
|
||||
)
|
||||
self.objective = objective
|
||||
|
||||
system_message = BaseMessage(
|
||||
role_name="Task Prioritizer",
|
||||
role_type=RoleType.ASSISTANT,
|
||||
meta_dict=None,
|
||||
content="You are a helpful task prioritizer.",
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
system_message,
|
||||
model=model,
|
||||
output_language=output_language,
|
||||
message_window_size=message_window_size,
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
task_list: List[str],
|
||||
) -> List[str]:
|
||||
r"""Prioritize the task list given the agent objective.
|
||||
|
||||
Args:
|
||||
task_list (List[str]): The unprioritized tasks of agent.
|
||||
|
||||
Returns:
|
||||
List[str]: The new prioritized task list generated by the Agent.
|
||||
"""
|
||||
task_prioritization_prompt = self.task_prioritization_prompt.format(
|
||||
task_list=task_list
|
||||
)
|
||||
|
||||
task_msg = BaseMessage.make_user_message(
|
||||
role_name="Task Prioritizer", content=task_prioritization_prompt
|
||||
)
|
||||
|
||||
task_response = self.step(task_msg)
|
||||
|
||||
if task_response.terminated:
|
||||
raise RuntimeError("Task prioritization failed.")
|
||||
if len(task_response.msgs) == 0:
|
||||
raise RuntimeError("Got no task prioritization message.")
|
||||
|
||||
sub_tasks_msg = task_response.msgs[0]
|
||||
return get_task_list(sub_tasks_msg.content)
|
||||
@@ -1,20 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from .base import BaseToolAgent
|
||||
from .hugging_face_tool_agent import HuggingFaceToolAgent
|
||||
|
||||
__all__ = [
|
||||
'BaseToolAgent',
|
||||
'HuggingFaceToolAgent',
|
||||
]
|
||||
@@ -1,39 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from camel.agents import BaseAgent
|
||||
|
||||
|
||||
class BaseToolAgent(BaseAgent):
|
||||
r"""Creates a :obj:`BaseToolAgent` object with the specified name and
|
||||
description.
|
||||
|
||||
Args:
|
||||
name (str): The name of the tool agent.
|
||||
description (str): The description of the tool agent.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, description: str) -> None:
|
||||
self.name = name
|
||||
self.description = description
|
||||
|
||||
def reset(self) -> None:
|
||||
r"""Resets the agent to its initial state."""
|
||||
pass
|
||||
|
||||
def step(self) -> None:
|
||||
r"""Performs a single step of the agent."""
|
||||
pass
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.name}: {self.description}"
|
||||
@@ -1,206 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from typing import Any, Optional
|
||||
|
||||
from camel.agents.tool_agents.base import BaseToolAgent
|
||||
|
||||
|
||||
# flake8: noqa :E501
|
||||
class HuggingFaceToolAgent(BaseToolAgent):
|
||||
r"""Tool agent for calling HuggingFace models. This agent is a wrapper
|
||||
around agents from the `transformers` library. For more information
|
||||
about the available models, please see the `transformers` documentation
|
||||
at https://huggingface.co/docs/transformers/transformers_agents.
|
||||
|
||||
Args:
|
||||
name (str): The name of the agent.
|
||||
*args (Any): Additional positional arguments to pass to the underlying
|
||||
Agent class.
|
||||
remote (bool, optional): Flag indicating whether to run the agent
|
||||
remotely. (default: :obj:`True`)
|
||||
**kwargs (Any): Additional keyword arguments to pass to the underlying
|
||||
Agent class.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
*args: Any,
|
||||
remote: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
try:
|
||||
# TODO: Support other tool agents
|
||||
import transformers
|
||||
from packaging import version
|
||||
|
||||
if version.parse(transformers.__version__) < version.parse(
|
||||
"4.31.0"
|
||||
):
|
||||
raise ValueError(
|
||||
"The version of \"transformers\" package should >= 4.31.0"
|
||||
)
|
||||
|
||||
from transformers.tools import OpenAiAgent
|
||||
from transformers.tools.agent_types import AgentImage
|
||||
except (ImportError, ValueError):
|
||||
raise ValueError(
|
||||
"Could not import transformers tool agents. "
|
||||
"Please setup the environment with "
|
||||
"pip install huggingface_hub==0.14.1 transformers==4.31.0 diffusers accelerate==0.20.3 datasets torch soundfile sentencepiece opencv-python"
|
||||
)
|
||||
self.agent_image_type = AgentImage
|
||||
self.agent = OpenAiAgent(*args, **kwargs)
|
||||
description = f"""The `{name}` is a tool agent that can perform a variety of tasks including:
|
||||
- Document question answering: given a document (such as a PDF) in image format, answer a question on this document
|
||||
- Text question answering: given a long text and a question, answer the question in the text
|
||||
- Unconditional image captioning: Caption the image!
|
||||
- Image question answering: given an image, answer a question on this image
|
||||
- Image segmentation: given an image and a prompt, output the segmentation mask of that prompt
|
||||
- Speech to text: given an audio recording of a person talking, transcribe the speech into text
|
||||
- Text to speech: convert text to speech
|
||||
- Zero-shot text classification: given a text and a list of labels, identify to which label the text corresponds the most
|
||||
- Text summarization: summarize a long text in one or a few sentences
|
||||
- Translation: translate the text into a given language
|
||||
- Text downloading: to download a text from a web URL
|
||||
- Text to image: generate an image according to a prompt, leveraging stable diffusion
|
||||
- Image transformation: modify an image given an initial image and a prompt, leveraging instruct pix2pix stable diffusion
|
||||
- Text to video: generate a small video according to a prompt
|
||||
|
||||
Here are some python code examples of what you can do with this agent:
|
||||
|
||||
Single execution (step) mode, the single execution method is when using the step() method of the agent:
|
||||
```
|
||||
# Text to image
|
||||
rivers_and_lakes_image = {name}.step("Draw me a picture of rivers and lakes.")
|
||||
rivers_and_lakes_image.save("./rivers_and_lakes_image.png")
|
||||
|
||||
# Text to image -> Image transformation
|
||||
sea_add_island_image = {name}.step("Draw me a picture of the sea then transform the picture to add an island")
|
||||
sea_add_island_image.save("./sea_add_island_image.png")
|
||||
|
||||
# If you'd like to keep a state across executions or to pass non-text objects to the agent,
|
||||
# you can do so by specifying variables that you would like the agent to use. For example,
|
||||
# you could generate the first image of rivers and lakes, and ask the model to update that picture to add an island by doing the following:
|
||||
picture = {name}.step("Generate a picture of rivers and lakes.")
|
||||
picture.save("./picture.png")
|
||||
updated_picture = {name}.step("Transform the image in `picture` to add an island to it.", picture=picture)
|
||||
updated_picture.save("./updated_picture.png")
|
||||
|
||||
capybara_sea_image = {name}.step("Draw me a picture of the `prompt`", prompt="a capybara swimming in the sea")
|
||||
capybara_sea_image.save("./capybara_sea_image.png")
|
||||
|
||||
# Document question answering
|
||||
answer = {name}.step(
|
||||
"In the following `document`, where will the TRRF Scientific Advisory Council Meeting take place?",
|
||||
document=document,
|
||||
)
|
||||
print(answer)
|
||||
|
||||
|
||||
# Text to image
|
||||
boat_image = {name}.step("Generate an image of a boat in the water")
|
||||
boat_image.save("./boat_image.png")
|
||||
|
||||
# Unconditional image captioning
|
||||
boat_image_caption = {name}.step("Can you caption the `boat_image`?", boat_image=boat_image)
|
||||
print(boat_image_caption)
|
||||
|
||||
# Text to image -> Unconditional image captioning -> Text to speech
|
||||
boat_audio = {name}.step("Can you generate an image of a boat? Please read out loud the contents of the image afterwards")
|
||||
|
||||
# Text downloading
|
||||
document = {name}.step("Download the text from http://hf.co")
|
||||
print(document)
|
||||
|
||||
# Text summarization
|
||||
summary = {name}.step("Summarize the following text: `document`", document=document)
|
||||
print(summary)
|
||||
|
||||
# Text downloading -> Text summarization -> Text to speech
|
||||
audio = {name}.step("Read out loud the summary of http://hf.co")
|
||||
```
|
||||
|
||||
Chat-based execution (chat), the agent also has a chat-based approach, using the chat() method:
|
||||
```
|
||||
# Clean the chat history
|
||||
{name}.reset()
|
||||
|
||||
# Text to image
|
||||
capybara_image = {name}.chat("Show me an an image of a capybara")
|
||||
capybara_image.save("./capybara_image.png")
|
||||
|
||||
# Image transformation
|
||||
transformed_capybara_image = {name}.chat("Transform the image so that it snows")
|
||||
transformed_capybara_image.save("./transformed_capybara_image.png")
|
||||
|
||||
# Image segmentation
|
||||
segmented_transformed_capybara_image = {name}.chat("Show me a mask of the snowy capybaras")
|
||||
segmented_transformed_capybara_image.save("./segmented_transformed_capybara_image.png")
|
||||
```
|
||||
"""
|
||||
super(HuggingFaceToolAgent, self).__init__(name, description)
|
||||
self.remote = remote
|
||||
|
||||
def reset(self) -> None:
|
||||
r"""Resets the chat history of the agent."""
|
||||
self.agent.prepare_for_new_chat()
|
||||
|
||||
def step(
|
||||
self,
|
||||
*args: Any,
|
||||
remote: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
r"""Runs the agent in single execution mode.
|
||||
|
||||
Args:
|
||||
*args (Any): Positional arguments to pass to the agent.
|
||||
remote (bool, optional): Flag indicating whether to run the agent
|
||||
remotely. Overrides the default setting. (default: :obj:`None`)
|
||||
**kwargs (Any): Keyword arguments to pass to the agent.
|
||||
|
||||
Returns:
|
||||
str: The response from the agent.
|
||||
"""
|
||||
if remote is None:
|
||||
remote = self.remote
|
||||
agent_output = self.agent.run(*args, remote=remote, **kwargs)
|
||||
if isinstance(agent_output, self.agent_image_type):
|
||||
agent_output = agent_output.to_raw()
|
||||
return agent_output
|
||||
|
||||
def chat(
|
||||
self,
|
||||
*args: Any,
|
||||
remote: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
r"""Runs the agent in a chat conversation mode.
|
||||
|
||||
Args:
|
||||
*args (Any): Positional arguments to pass to the agent.
|
||||
remote (bool, optional): Flag indicating whether to run the agent
|
||||
remotely. Overrides the default setting. (default: :obj:`None`)
|
||||
**kwargs (Any): Keyword arguments to pass to the agent.
|
||||
|
||||
Returns:
|
||||
str: The response from the agent.
|
||||
"""
|
||||
if remote is None:
|
||||
remote = self.remote
|
||||
agent_output = self.agent.chat(*args, remote=remote, **kwargs)
|
||||
if isinstance(agent_output, self.agent_image_type):
|
||||
agent_output = agent_output.to_raw()
|
||||
return agent_output
|
||||
@@ -1,17 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
from .base import BaseBenchmark
|
||||
|
||||
__all__ = ["BaseBenchmark"]
|
||||
@@ -1,152 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from camel.agents import ChatAgent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseBenchmark(ABC):
|
||||
r"""Base class for benchmarks.
|
||||
|
||||
Attributes:
|
||||
name (str): Name of the benchmark.
|
||||
data_dir (str): Path to the data directory.
|
||||
save_to (str): Path to save the results.
|
||||
processes (int): Number of processes to use for parallel
|
||||
processing. :(default: :obj:`1`)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, name: str, data_dir: str, save_to: str, processes: int = 1
|
||||
):
|
||||
r"""Initialize the benchmark.
|
||||
|
||||
Args:
|
||||
name (str): Name of the benchmark.
|
||||
data_dir (str): Path to the data directory.
|
||||
save_to (str): Path to save the results.
|
||||
processes (int): Number of processes to use for parallel
|
||||
processing. :(default: :obj:`1`)
|
||||
|
||||
"""
|
||||
self.name = name
|
||||
self.data_dir = Path(data_dir)
|
||||
self.processes = processes
|
||||
self.save_to = save_to
|
||||
if not self.data_dir.exists():
|
||||
logger.info(
|
||||
f"Data directory {data_dir} does not exist. Creating it."
|
||||
)
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
if not self.data_dir.is_dir():
|
||||
raise NotADirectoryError(
|
||||
f"Data directory {data_dir} is not a directory"
|
||||
)
|
||||
self._data: Dict[str, List[Dict[str, Any]]] = dict()
|
||||
self._results: List[Dict[str, Any]] = []
|
||||
|
||||
@abstractmethod
|
||||
def download(self) -> "BaseBenchmark":
|
||||
r"""Download the benchmark data.
|
||||
|
||||
Returns:
|
||||
BaseBenchmark: The benchmark instance.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load(self, force_download: bool = False) -> "BaseBenchmark":
|
||||
r"""Load the benchmark data.
|
||||
|
||||
Args:
|
||||
force_download (bool): Whether to force download the data.
|
||||
|
||||
Returns:
|
||||
BaseBenchmark: The benchmark instance.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def train(self) -> List[Dict[str, Any]]:
|
||||
r"""Get the training data.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: The training data.
|
||||
"""
|
||||
if not self._data:
|
||||
logger.info("Data not loaded. Loading data.")
|
||||
self.load()
|
||||
return self._data["train"]
|
||||
|
||||
@property
|
||||
def valid(self) -> List[Dict[str, Any]]:
|
||||
r"""Get the validation data.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: The validation data.
|
||||
"""
|
||||
if not self._data:
|
||||
logger.info("Data not loaded. Loading data.")
|
||||
self.load()
|
||||
return self._data["valid"]
|
||||
|
||||
@property
|
||||
def test(self) -> List[Dict[str, Any]]:
|
||||
r"""Get the test data.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: The test data.
|
||||
"""
|
||||
if not self._data:
|
||||
logger.info("Data not loaded. Loading data.")
|
||||
self.load()
|
||||
return self._data["test"]
|
||||
|
||||
@abstractmethod
|
||||
def run(
|
||||
self,
|
||||
agent: ChatAgent,
|
||||
on: Literal["train", "valid", "test"],
|
||||
randomize: bool = False,
|
||||
subset: Optional[int] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> "BaseBenchmark":
|
||||
r"""Run the benchmark.
|
||||
|
||||
Args:
|
||||
agent (ChatAgent): The chat agent.
|
||||
on (str): The data split to run the benchmark on.
|
||||
randomize (bool): Whether to randomize the data.
|
||||
subset (int): The subset of the data to run the benchmark on.
|
||||
|
||||
Returns:
|
||||
BaseBenchmark: The benchmark instance.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def results(self) -> List[Dict[str, Any]]:
|
||||
r"""Get the results.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: The results.
|
||||
"""
|
||||
return self._results
|
||||
@@ -1,34 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from .discord_app import DiscordApp
|
||||
from .slack.models import (
|
||||
SlackAppMentionEventBody,
|
||||
SlackAppMentionEventProfile,
|
||||
SlackAuthProfile,
|
||||
SlackEventBody,
|
||||
SlackEventProfile,
|
||||
)
|
||||
from .slack.slack_app import SlackApp
|
||||
from .telegram_bot import TelegramBot
|
||||
|
||||
__all__ = [
|
||||
'DiscordApp',
|
||||
'SlackApp',
|
||||
'SlackAppMentionEventBody',
|
||||
'SlackAppMentionEventProfile',
|
||||
'SlackAuthProfile',
|
||||
'SlackEventBody',
|
||||
'SlackEventProfile',
|
||||
'TelegramBot',
|
||||
]
|
||||
@@ -1,138 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from camel.utils import dependencies_required
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from discord import Message
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DiscordApp:
|
||||
r"""A class representing a Discord app that uses the `discord.py` library
|
||||
to interact with Discord servers.
|
||||
|
||||
This bot can respond to messages in specific channels and only reacts to
|
||||
messages that mention the bot.
|
||||
|
||||
Attributes:
|
||||
channel_ids (Optional[List[int]]): A list of allowed channel IDs. If
|
||||
provided, the bot will only respond to messages in these channels.
|
||||
token (Optional[str]): The Discord bot token used for authentication.
|
||||
"""
|
||||
|
||||
@dependencies_required('discord')
|
||||
def __init__(
|
||||
self,
|
||||
channel_ids: Optional[List[int]] = None,
|
||||
token: Optional[str] = None,
|
||||
) -> None:
|
||||
r"""Initialize the DiscordApp instance by setting up the Discord client
|
||||
and event handlers.
|
||||
|
||||
Args:
|
||||
channel_ids (Optional[List[int]]): A list of allowed channel IDs.
|
||||
The bot will only respond to messages in these channels if
|
||||
provided.
|
||||
token (Optional[str]): The Discord bot token for authentication.
|
||||
If not provided, the token will be retrieved from the
|
||||
environment variable `DISCORD_TOKEN`.
|
||||
|
||||
Raises:
|
||||
ValueError: If the `DISCORD_TOKEN` is not found in environment
|
||||
variables.
|
||||
"""
|
||||
self.token = token or os.getenv('DISCORD_TOKEN')
|
||||
self.channel_ids = channel_ids
|
||||
|
||||
if not self.token:
|
||||
raise ValueError(
|
||||
"`DISCORD_TOKEN` not found in environment variables. Get it"
|
||||
" here: `https://discord.com/developers/applications`."
|
||||
)
|
||||
|
||||
import discord
|
||||
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
self._client = discord.Client(intents=intents)
|
||||
|
||||
# Register event handlers
|
||||
self._client.event(self.on_ready)
|
||||
self._client.event(self.on_message)
|
||||
|
||||
async def start(self):
|
||||
r"""Asynchronously start the Discord bot using its token.
|
||||
|
||||
This method starts the bot and logs into Discord asynchronously using
|
||||
the provided token. It should be awaited when used in an async
|
||||
environment.
|
||||
"""
|
||||
await self._client.start(self.token)
|
||||
|
||||
def run(self) -> None:
|
||||
r"""Start the Discord bot using its token.
|
||||
|
||||
This method starts the bot and logs into Discord synchronously using
|
||||
the provided token. It blocks execution and keeps the bot running.
|
||||
"""
|
||||
self._client.run(self.token) # type: ignore[arg-type]
|
||||
|
||||
async def on_ready(self) -> None:
|
||||
r"""Event handler that is called when the bot has successfully
|
||||
connected to the Discord server.
|
||||
|
||||
When the bot is ready and logged into Discord, it prints a message
|
||||
displaying the bot's username.
|
||||
"""
|
||||
logger.info(f'We have logged in as {self._client.user}')
|
||||
|
||||
async def on_message(self, message: 'Message') -> None:
|
||||
r"""Event handler for processing incoming messages.
|
||||
|
||||
This method is called whenever a new message is received by the bot. It
|
||||
will ignore messages sent by the bot itself, only respond to messages
|
||||
in allowed channels (if specified), and only to messages that mention
|
||||
the bot.
|
||||
|
||||
Args:
|
||||
message (discord.Message): The message object received from
|
||||
Discord.
|
||||
"""
|
||||
# If the message author is the bot itself,
|
||||
# do not respond to this message
|
||||
if message.author == self._client.user:
|
||||
return
|
||||
|
||||
# If allowed channel IDs are provided,
|
||||
# only respond to messages in those channels
|
||||
if self.channel_ids and message.channel.id not in self.channel_ids:
|
||||
return
|
||||
|
||||
# Only respond to messages that mention the bot
|
||||
if not self._client.user or not self._client.user.mentioned_in(
|
||||
message
|
||||
):
|
||||
return
|
||||
|
||||
logger.info(f"Received message: {message.content}")
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
return self._client
|
||||
@@ -1,30 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from .models import (
|
||||
SlackAppMentionEventBody,
|
||||
SlackAppMentionEventProfile,
|
||||
SlackAuthProfile,
|
||||
SlackEventBody,
|
||||
SlackEventProfile,
|
||||
)
|
||||
from .slack_app import SlackApp
|
||||
|
||||
__all__ = [
|
||||
'SlackApp',
|
||||
'SlackAppMentionEventBody',
|
||||
'SlackAppMentionEventProfile',
|
||||
'SlackAuthProfile',
|
||||
'SlackEventBody',
|
||||
'SlackEventProfile',
|
||||
]
|
||||
@@ -1,158 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SlackAuthProfile(BaseModel):
|
||||
r"""Represents the authorization profile within a Slack event.
|
||||
|
||||
Events will contain a single, compact authorizations field that shows one
|
||||
installation of your app that the event is visible to.
|
||||
In other words, lists of authorizations will be truncated to one element.
|
||||
|
||||
If there's more than one installing party that your app is keeping track
|
||||
of, it's best not to rely on the single party listed in authorizations to
|
||||
be any particular one.
|
||||
|
||||
To get a full list of who can see events, call the apps.event.
|
||||
authorizations.list method after obtaining an app-level token. Read more on
|
||||
the changes here; they have taken effect for existing apps as of
|
||||
February 24, 2021.
|
||||
|
||||
References:
|
||||
|
||||
- https://api.slack.com/apis/events-api#authorizations
|
||||
- https://api.slack.com/changelog/2020-09-15-events-api-truncate-authed-users#no_context
|
||||
"""
|
||||
|
||||
enterprise_id: Optional[str] = None
|
||||
"""The ID of the enterprise associated with the authorization."""
|
||||
|
||||
team_id: str
|
||||
"""The ID of the team associated with the authorization."""
|
||||
|
||||
user_id: str
|
||||
"""The ID of the user associated with the authorization."""
|
||||
|
||||
is_bot: bool
|
||||
"""Whether the authorized user is a bot."""
|
||||
|
||||
is_enterprise_install: bool
|
||||
"""Whether the authorization is for an enterprise installation."""
|
||||
|
||||
|
||||
class SlackEventProfile(BaseModel):
|
||||
r"""Represents the detailed profile of a Slack event, including user,
|
||||
message, and context data.
|
||||
"""
|
||||
|
||||
user: str
|
||||
"""The ID of the user associated with the event."""
|
||||
|
||||
type: str
|
||||
"""The type of the event (e.g., 'message')."""
|
||||
|
||||
ts: str
|
||||
"""A timestamp representing when the event was triggered."""
|
||||
|
||||
thread_ts: Optional[str] = None
|
||||
"""The timestamp of the parent message in a thread."""
|
||||
|
||||
client_msg_id: str
|
||||
"""A unique ID generated by the client for the message (if available)."""
|
||||
|
||||
text: str
|
||||
"""The message content text."""
|
||||
|
||||
team: str
|
||||
"""The ID of the team that the event is associated with."""
|
||||
|
||||
blocks: list
|
||||
"""The list of message blocks, providing structured information."""
|
||||
|
||||
channel: str
|
||||
"""The ID of the Slack channel where the event happened."""
|
||||
|
||||
event_ts: str
|
||||
"""The event-specific timestamp when it occurred."""
|
||||
|
||||
channel_type: Optional[str]
|
||||
"""The type of Slack channel (e.g., 'channel', 'im')."""
|
||||
|
||||
|
||||
class SlackEventBody(BaseModel):
|
||||
r"""Represents the entire body of a Slack event, including the event
|
||||
profile, authorization, and context.
|
||||
"""
|
||||
|
||||
token: str
|
||||
"""The token to verify the source of the event."""
|
||||
|
||||
team_id: str
|
||||
"""The ID of the team where the event is happening."""
|
||||
|
||||
context_team_id: Optional[str]
|
||||
"""The team ID for the shared channel context, if applicable."""
|
||||
|
||||
context_enterprise_id: Optional[str] = None
|
||||
"""The enterprise ID for the shared channel context, if applicable."""
|
||||
|
||||
api_app_id: str
|
||||
"""The unique identifier for the Slack app that received the event."""
|
||||
|
||||
event: SlackEventProfile
|
||||
"""A detailed profile of the event"""
|
||||
|
||||
type: str
|
||||
"""The overall type of event received (e.g., 'event_callback')."""
|
||||
|
||||
event_id: str
|
||||
"""A unique identifier assigned to this event by Slack."""
|
||||
|
||||
event_time: int
|
||||
"""The timestamp (in seconds) representing when the event was triggered."""
|
||||
|
||||
authorizations: Optional[list[SlackAuthProfile]] = None
|
||||
"""An optional list of authorizations that describe which installation can
|
||||
see the event."""
|
||||
|
||||
is_ext_shared_channel: bool
|
||||
"""Indicates if the event is part of a shared channel between different
|
||||
organizations."""
|
||||
|
||||
event_context: str
|
||||
"""A unique string representing the context of the event."""
|
||||
|
||||
|
||||
class SlackAppMentionEventProfile(SlackEventProfile):
|
||||
r"""Represents the detailed profile of a Slack event where the app was
|
||||
mentioned in a message.
|
||||
"""
|
||||
|
||||
channel_type: Optional[str] = None
|
||||
"""The type of Slack channel. it's None for app mentions."""
|
||||
|
||||
|
||||
class SlackAppMentionEventBody(SlackEventBody):
|
||||
r"""Represents the entire body of a Slack event where the app was mentioned
|
||||
in a message.
|
||||
"""
|
||||
|
||||
context_team_id: Optional[str] = None
|
||||
"""A detailed profile of the event. it's None for app mentions."""
|
||||
|
||||
event: SlackAppMentionEventProfile
|
||||
"""A detailed profile of the event"""
|
||||
@@ -1,255 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
from slack_sdk.oauth.installation_store.async_installation_store import (
|
||||
AsyncInstallationStore,
|
||||
)
|
||||
from starlette import requests, responses
|
||||
|
||||
from camel.bots.slack.models import (
|
||||
SlackAppMentionEventBody,
|
||||
SlackAppMentionEventProfile,
|
||||
SlackEventBody,
|
||||
SlackEventProfile,
|
||||
)
|
||||
from camel.utils import dependencies_required
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from slack_bolt.context.async_context import AsyncBoltContext
|
||||
from slack_bolt.context.say.async_say import AsyncSay
|
||||
from slack_sdk.web.async_client import AsyncWebClient
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SlackApp:
|
||||
r"""Represents a Slack app that is powered by a Slack Bolt `AsyncApp`.
|
||||
|
||||
This class is responsible for initializing and managing the Slack
|
||||
application by setting up event handlers, running the app server, and
|
||||
handling events such as messages and mentions from Slack.
|
||||
|
||||
Args:
|
||||
token (Optional[str]): Slack API token for authentication.
|
||||
scopes (Optional[str]): Slack app scopes for permissions.
|
||||
signing_secret (Optional[str]): Signing secret for verifying Slack
|
||||
requests.
|
||||
client_id (Optional[str]): Slack app client ID.
|
||||
client_secret (Optional[str]): Slack app client secret.
|
||||
redirect_uri_path (str): The URI path for OAuth redirect, defaults to
|
||||
"/slack/oauth_redirect".
|
||||
installation_store (Optional[AsyncInstallationStore]): The installation
|
||||
store for handling OAuth installations.
|
||||
"""
|
||||
|
||||
@dependencies_required('slack_bolt')
|
||||
def __init__(
|
||||
self,
|
||||
token: Optional[str] = None,
|
||||
scopes: Optional[str] = None,
|
||||
signing_secret: Optional[str] = None,
|
||||
client_id: Optional[str] = None,
|
||||
client_secret: Optional[str] = None,
|
||||
redirect_uri_path: str = "/slack/oauth_redirect",
|
||||
installation_store: Optional[AsyncInstallationStore] = None,
|
||||
) -> None:
|
||||
r"""Initializes the SlackApp instance by setting up the Slack Bolt app
|
||||
and configuring event handlers and OAuth settings.
|
||||
|
||||
Args:
|
||||
token (Optional[str]): The Slack API token.
|
||||
scopes (Optional[str]): The scopes for Slack app permissions.
|
||||
signing_secret (Optional[str]): The signing secret for verifying
|
||||
requests.
|
||||
client_id (Optional[str]): The Slack app client ID.
|
||||
client_secret (Optional[str]): The Slack app client secret.
|
||||
redirect_uri_path (str): The URI path for handling OAuth redirects
|
||||
(default is "/slack/oauth_redirect").
|
||||
installation_store (Optional[AsyncInstallationStore]): An optional
|
||||
installation store for OAuth installations.
|
||||
"""
|
||||
from slack_bolt.adapter.starlette.async_handler import (
|
||||
AsyncSlackRequestHandler,
|
||||
)
|
||||
from slack_bolt.app.async_app import AsyncApp
|
||||
from slack_bolt.oauth.async_oauth_settings import AsyncOAuthSettings
|
||||
|
||||
self.token: Optional[str] = token or os.getenv("SLACK_TOKEN")
|
||||
self.scopes: Optional[str] = scopes or os.getenv("SLACK_SCOPES")
|
||||
self.signing_secret: Optional[str] = signing_secret or os.getenv(
|
||||
"SLACK_SIGNING_SECRET"
|
||||
)
|
||||
self.client_id: Optional[str] = client_id or os.getenv(
|
||||
"SLACK_CLIENT_ID"
|
||||
)
|
||||
self.client_secret: Optional[str] = client_secret or os.getenv(
|
||||
"SLACK_CLIENT_SECRET"
|
||||
)
|
||||
|
||||
if not all([self.token, self.scopes, self.signing_secret]):
|
||||
raise ValueError(
|
||||
"`SLACK_TOKEN`, `SLACK_SCOPES`, and `SLACK_SIGNING_SECRET` "
|
||||
"environment variables must be set. Get it here: "
|
||||
"`https://api.slack.com/apps`."
|
||||
)
|
||||
|
||||
# Setup OAuth settings if client ID and secret are provided
|
||||
if self.client_id and self.client_secret:
|
||||
self._app = AsyncApp(
|
||||
oauth_settings=AsyncOAuthSettings(
|
||||
client_id=self.client_id,
|
||||
client_secret=self.client_secret,
|
||||
scopes=self.scopes,
|
||||
redirect_uri_path=redirect_uri_path,
|
||||
),
|
||||
logger=logger,
|
||||
signing_secret=self.signing_secret,
|
||||
installation_store=installation_store,
|
||||
token=self.token,
|
||||
)
|
||||
else:
|
||||
# Initialize Slack Bolt AsyncApp with settings
|
||||
self._app = AsyncApp(
|
||||
logger=logger,
|
||||
signing_secret=self.signing_secret,
|
||||
installation_store=installation_store,
|
||||
token=self.token,
|
||||
)
|
||||
|
||||
self._handler = AsyncSlackRequestHandler(self._app)
|
||||
self.setup_handlers()
|
||||
|
||||
def setup_handlers(self) -> None:
|
||||
r"""Sets up the event handlers for Slack events, such as `app_mention`
|
||||
and `message`.
|
||||
|
||||
This method registers the `app_mention` and `on_message` event handlers
|
||||
with the Slack Bolt app to respond to Slack events.
|
||||
"""
|
||||
self._app.event("app_mention")(self.app_mention)
|
||||
self._app.event("message")(self.on_message)
|
||||
|
||||
def run(
|
||||
self,
|
||||
port: int = 3000,
|
||||
path: str = "/slack/events",
|
||||
host: Optional[str] = None,
|
||||
) -> None:
|
||||
r"""Starts the Slack Bolt app server to listen for incoming Slack
|
||||
events.
|
||||
|
||||
Args:
|
||||
port (int): The port on which the server should run (default is
|
||||
3000).
|
||||
path (str): The endpoint path for receiving Slack events (default
|
||||
is "/slack/events").
|
||||
host (Optional[str]): The hostname to bind the server (default is
|
||||
None).
|
||||
"""
|
||||
self._app.start(port=port, path=path, host=host)
|
||||
|
||||
async def handle_request(
|
||||
self, request: requests.Request
|
||||
) -> responses.Response:
|
||||
r"""Handles incoming requests from Slack through the request handler.
|
||||
|
||||
Args:
|
||||
request (Request): A Starlette request object representing the
|
||||
incoming request.
|
||||
|
||||
Returns:
|
||||
The response generated by the Slack Bolt handler.
|
||||
"""
|
||||
return await self._handler.handle(request)
|
||||
|
||||
async def app_mention(
|
||||
self,
|
||||
context: "AsyncBoltContext",
|
||||
client: "AsyncWebClient",
|
||||
event: Dict[str, Any],
|
||||
body: Dict[str, Any],
|
||||
say: "AsyncSay",
|
||||
) -> None:
|
||||
r"""Event handler for `app_mention` events.
|
||||
|
||||
This method is triggered when someone mentions the app in Slack.
|
||||
|
||||
Args:
|
||||
context (AsyncBoltContext): The Slack Bolt context for the event.
|
||||
client (AsyncWebClient): The Slack Web API client.
|
||||
event (Dict[str, Any]): The event data for the app mention.
|
||||
body (Dict[str, Any]): The full request body from Slack.
|
||||
say (AsyncSay): A function to send a response back to the channel.
|
||||
"""
|
||||
event_profile = SlackAppMentionEventProfile(**event)
|
||||
event_body = SlackAppMentionEventBody(**body)
|
||||
|
||||
logger.info(f"app_mention, context: {context}")
|
||||
logger.info(f"app_mention, client: {client}")
|
||||
logger.info(f"app_mention, event_profile: {event_profile}")
|
||||
logger.info(f"app_mention, event_body: {event_body}")
|
||||
logger.info(f"app_mention, say: {say}")
|
||||
|
||||
async def on_message(
|
||||
self,
|
||||
context: "AsyncBoltContext",
|
||||
client: "AsyncWebClient",
|
||||
event: Dict[str, Any],
|
||||
body: Dict[str, Any],
|
||||
say: "AsyncSay",
|
||||
) -> None:
|
||||
r"""Event handler for `message` events.
|
||||
|
||||
This method is triggered when the app receives a message in Slack.
|
||||
|
||||
Args:
|
||||
context (AsyncBoltContext): The Slack Bolt context for the event.
|
||||
client (AsyncWebClient): The Slack Web API client.
|
||||
event (Dict[str, Any]): The event data for the message.
|
||||
body (Dict[str, Any]): The full request body from Slack.
|
||||
say (AsyncSay): A function to send a response back to the channel.
|
||||
"""
|
||||
await context.ack()
|
||||
|
||||
event_profile = SlackEventProfile(**event)
|
||||
event_body = SlackEventBody(**body)
|
||||
|
||||
logger.info(f"on_message, context: {context}")
|
||||
logger.info(f"on_message, client: {client}")
|
||||
logger.info(f"on_message, event_profile: {event_profile}")
|
||||
logger.info(f"on_message, event_body: {event_body}")
|
||||
logger.info(f"on_message, say: {say}")
|
||||
|
||||
logger.info(f"Received message: {event_profile.text}")
|
||||
|
||||
def mention_me(
|
||||
self, context: "AsyncBoltContext", body: SlackEventBody
|
||||
) -> bool:
|
||||
r"""Check if the bot is mentioned in the message.
|
||||
|
||||
Args:
|
||||
context (AsyncBoltContext): The Slack Bolt context for the event.
|
||||
body (SlackEventBody): The body of the Slack event.
|
||||
|
||||
Returns:
|
||||
bool: True if the bot is mentioned in the message, False otherwise.
|
||||
"""
|
||||
message = body.event.text
|
||||
bot_user_id = context.bot_user_id
|
||||
mention = f"<@{bot_user_id}>"
|
||||
return mention in message
|
||||
@@ -1,82 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from camel.agents import ChatAgent
|
||||
from camel.messages import BaseMessage
|
||||
from camel.utils import dependencies_required
|
||||
|
||||
# Conditionally import telebot types only for type checking
|
||||
if TYPE_CHECKING:
|
||||
from telebot.types import ( # type: ignore[import-untyped]
|
||||
Message,
|
||||
)
|
||||
|
||||
|
||||
class TelegramBot:
|
||||
r"""Represents a Telegram bot that is powered by an agent.
|
||||
|
||||
Attributes:
|
||||
chat_agent (ChatAgent): Chat agent that will power the bot.
|
||||
telegram_token (str, optional): The bot token.
|
||||
"""
|
||||
|
||||
@dependencies_required('telebot')
|
||||
def __init__(
|
||||
self,
|
||||
chat_agent: ChatAgent,
|
||||
telegram_token: Optional[str] = None,
|
||||
) -> None:
|
||||
self.chat_agent = chat_agent
|
||||
|
||||
if not telegram_token:
|
||||
self.token = os.getenv('TELEGRAM_TOKEN')
|
||||
if not self.token:
|
||||
raise ValueError(
|
||||
"`TELEGRAM_TOKEN` not found in environment variables. "
|
||||
"Get it from t.me/BotFather."
|
||||
)
|
||||
else:
|
||||
self.token = telegram_token
|
||||
|
||||
import telebot # type: ignore[import-untyped]
|
||||
|
||||
self.bot = telebot.TeleBot(token=self.token)
|
||||
|
||||
# Register the message handler within the constructor
|
||||
self.bot.message_handler(func=lambda message: True)(self.on_message)
|
||||
|
||||
def run(self) -> None:
|
||||
r"""Start the Telegram bot."""
|
||||
print("Telegram bot is running...")
|
||||
self.bot.infinity_polling()
|
||||
|
||||
def on_message(self, message: 'Message') -> None:
|
||||
r"""Handles incoming messages from the user.
|
||||
|
||||
Args:
|
||||
message (types.Message): The incoming message object.
|
||||
"""
|
||||
self.chat_agent.reset()
|
||||
|
||||
if not message.text:
|
||||
return
|
||||
|
||||
user_msg = BaseMessage.make_user_message(
|
||||
role_name="User", content=message.text
|
||||
)
|
||||
assistant_response = self.chat_agent.step(user_msg)
|
||||
|
||||
self.bot.reply_to(message, assistant_response.msg.content)
|
||||
@@ -1,76 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from .anthropic_config import ANTHROPIC_API_PARAMS, AnthropicConfig
|
||||
from .base_config import BaseConfig
|
||||
from .cohere_config import COHERE_API_PARAMS, CohereConfig
|
||||
from .deepseek_config import DEEPSEEK_API_PARAMS, DeepSeekConfig
|
||||
from .gemini_config import Gemini_API_PARAMS, GeminiConfig
|
||||
from .groq_config import GROQ_API_PARAMS, GroqConfig
|
||||
from .litellm_config import LITELLM_API_PARAMS, LiteLLMConfig
|
||||
from .mistral_config import MISTRAL_API_PARAMS, MistralConfig
|
||||
from .nvidia_config import NVIDIA_API_PARAMS, NvidiaConfig
|
||||
from .ollama_config import OLLAMA_API_PARAMS, OllamaConfig
|
||||
from .openai_config import OPENAI_API_PARAMS, ChatGPTConfig
|
||||
from .qwen_config import QWEN_API_PARAMS, QwenConfig
|
||||
from .reka_config import REKA_API_PARAMS, RekaConfig
|
||||
from .samba_config import (
|
||||
SAMBA_CLOUD_API_PARAMS,
|
||||
SAMBA_VERSE_API_PARAMS,
|
||||
SambaCloudAPIConfig,
|
||||
SambaVerseAPIConfig,
|
||||
)
|
||||
from .togetherai_config import TOGETHERAI_API_PARAMS, TogetherAIConfig
|
||||
from .vllm_config import VLLM_API_PARAMS, VLLMConfig
|
||||
from .yi_config import YI_API_PARAMS, YiConfig
|
||||
from .zhipuai_config import ZHIPUAI_API_PARAMS, ZhipuAIConfig
|
||||
|
||||
__all__ = [
|
||||
'BaseConfig',
|
||||
'ChatGPTConfig',
|
||||
'OPENAI_API_PARAMS',
|
||||
'AnthropicConfig',
|
||||
'ANTHROPIC_API_PARAMS',
|
||||
'GROQ_API_PARAMS',
|
||||
'GroqConfig',
|
||||
'LiteLLMConfig',
|
||||
'LITELLM_API_PARAMS',
|
||||
'NvidiaConfig',
|
||||
'NVIDIA_API_PARAMS',
|
||||
'OllamaConfig',
|
||||
'OLLAMA_API_PARAMS',
|
||||
'ZhipuAIConfig',
|
||||
'ZHIPUAI_API_PARAMS',
|
||||
'GeminiConfig',
|
||||
'Gemini_API_PARAMS',
|
||||
'VLLMConfig',
|
||||
'VLLM_API_PARAMS',
|
||||
'MistralConfig',
|
||||
'MISTRAL_API_PARAMS',
|
||||
'RekaConfig',
|
||||
'REKA_API_PARAMS',
|
||||
'SambaVerseAPIConfig',
|
||||
'SAMBA_VERSE_API_PARAMS',
|
||||
'SambaCloudAPIConfig',
|
||||
'SAMBA_CLOUD_API_PARAMS',
|
||||
'TogetherAIConfig',
|
||||
'TOGETHERAI_API_PARAMS',
|
||||
'CohereConfig',
|
||||
'COHERE_API_PARAMS',
|
||||
'YiConfig',
|
||||
'YI_API_PARAMS',
|
||||
'QwenConfig',
|
||||
'QWEN_API_PARAMS',
|
||||
'DeepSeekConfig',
|
||||
'DEEPSEEK_API_PARAMS',
|
||||
]
|
||||
@@ -1,69 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Union
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
from camel.types import NOT_GIVEN, NotGiven
|
||||
|
||||
|
||||
class AnthropicConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
Anthropic API.
|
||||
|
||||
See: https://docs.anthropic.com/claude/reference/complete_post
|
||||
Args:
|
||||
max_tokens (int, optional): The maximum number of tokens to
|
||||
generate before stopping. Note that Anthropic models may stop
|
||||
before reaching this maximum. This parameter only specifies the
|
||||
absolute maximum number of tokens to generate.
|
||||
(default: :obj:`256`)
|
||||
stop_sequences (List[str], optional): Sequences that will cause the
|
||||
model to stop generating completion text. Anthropic models stop
|
||||
on "\n\nHuman:", and may include additional built-in stop sequences
|
||||
in the future. By providing the stop_sequences parameter, you may
|
||||
include additional strings that will cause the model to stop
|
||||
generating.
|
||||
temperature (float, optional): Amount of randomness injected into the
|
||||
response. Defaults to 1. Ranges from 0 to 1. Use temp closer to 0
|
||||
for analytical / multiple choice, and closer to 1 for creative
|
||||
and generative tasks.
|
||||
(default: :obj:`1`)
|
||||
top_p (float, optional): Use nucleus sampling. In nucleus sampling, we
|
||||
compute the cumulative distribution over all the options for each
|
||||
subsequent token in decreasing probability order and cut it off
|
||||
once it reaches a particular probability specified by `top_p`.
|
||||
You should either alter `temperature` or `top_p`,
|
||||
but not both.
|
||||
(default: :obj:`0.7`)
|
||||
top_k (int, optional): Only sample from the top K options for each
|
||||
subsequent token. Used to remove "long tail" low probability
|
||||
responses.
|
||||
(default: :obj:`5`)
|
||||
metadata: An object describing metadata about the request.
|
||||
stream (bool, optional): Whether to incrementally stream the response
|
||||
using server-sent events. (default: :obj:`False`)
|
||||
"""
|
||||
|
||||
max_tokens: int = 256
|
||||
stop_sequences: Union[List[str], NotGiven] = NOT_GIVEN
|
||||
temperature: float = 1
|
||||
top_p: Union[float, NotGiven] = NOT_GIVEN
|
||||
top_k: Union[int, NotGiven] = NOT_GIVEN
|
||||
metadata: NotGiven = NOT_GIVEN
|
||||
stream: bool = False
|
||||
|
||||
|
||||
ANTHROPIC_API_PARAMS = {param for param in AnthropicConfig.model_fields.keys()}
|
||||
@@ -1,89 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
|
||||
|
||||
class BaseConfig(ABC, BaseModel):
|
||||
r"""Base configuration class for all models.
|
||||
|
||||
This class provides a common interface for all models, ensuring that all
|
||||
models have a consistent set of attributes and methods.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
extra="forbid",
|
||||
frozen=True,
|
||||
# UserWarning: conflict with protected namespace "model_"
|
||||
protected_namespaces=(),
|
||||
)
|
||||
|
||||
tools: Optional[List[Any]] = None
|
||||
"""A list of tools the model may
|
||||
call. Currently, only functions are supported as a tool. Use this
|
||||
to provide a list of functions the model may generate JSON inputs
|
||||
for. A max of 128 functions are supported.
|
||||
"""
|
||||
|
||||
@field_validator("tools", mode="before")
|
||||
@classmethod
|
||||
def fields_type_checking(cls, tools):
|
||||
r"""Validate the type of tools in the configuration.
|
||||
|
||||
This method ensures that the tools provided in the configuration are
|
||||
instances of `FunctionTool`. If any tool is not an instance of
|
||||
`FunctionTool`, it raises a ValueError.
|
||||
"""
|
||||
if tools is not None:
|
||||
from camel.toolkits import FunctionTool
|
||||
|
||||
for tool in tools:
|
||||
if not isinstance(tool, FunctionTool):
|
||||
raise ValueError(
|
||||
f"The tool {tool} should "
|
||||
"be an instance of `FunctionTool`."
|
||||
)
|
||||
return tools
|
||||
|
||||
def as_dict(self) -> dict[str, Any]:
|
||||
r"""Convert the current configuration to a dictionary.
|
||||
|
||||
This method converts the current configuration object to a dictionary
|
||||
representation, which can be used for serialization or other purposes.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: A dictionary representation of the current
|
||||
configuration.
|
||||
"""
|
||||
config_dict = self.model_dump()
|
||||
|
||||
tools_schema = None
|
||||
if self.tools:
|
||||
from camel.toolkits import FunctionTool
|
||||
|
||||
tools_schema = []
|
||||
for tool in self.tools:
|
||||
if not isinstance(tool, FunctionTool):
|
||||
raise ValueError(
|
||||
f"The tool {tool} should "
|
||||
"be an instance of `FunctionTool`."
|
||||
)
|
||||
tools_schema.append(tool.get_openai_tool_schema())
|
||||
config_dict["tools"] = tools_schema
|
||||
return config_dict
|
||||
@@ -1,76 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
|
||||
|
||||
class CohereConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
Cohere API.
|
||||
|
||||
Args:
|
||||
temperature (float, optional): Sampling temperature to use, between
|
||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
||||
while lower values make it more focused and deterministic.
|
||||
(default: :obj:`0.3`)
|
||||
documents (list, optional): A list of relevant documents that the
|
||||
model can cite to generate a more accurate reply. Each document is
|
||||
either a string or document object with content and metadata.
|
||||
(default: :obj:`None`)
|
||||
max_tokens (int, optional): The maximum number of tokens the model
|
||||
will generate as part of the response. (default: :obj:`None`)
|
||||
stop_sequences (List(str), optional): A list of up to 5 strings that
|
||||
the model will use to stop generation. If the model generates a
|
||||
string that matches any of the strings in the list, it will stop
|
||||
generating tokens and return the generated text up to that point
|
||||
not including the stop sequence. (default: :obj:`None`)
|
||||
seed (int, optional): If specified, the backend will make a best
|
||||
effort to sample tokens deterministically, such that repeated
|
||||
requests with the same seed and parameters should return the same
|
||||
result. However, determinism cannot be totally guaranteed.
|
||||
(default: :obj:`None`)
|
||||
frequency_penalty (float, optional): Min value of `0.0`, max value of
|
||||
`1.0`. Used to reduce repetitiveness of generated tokens. The
|
||||
higher the value, the stronger a penalty is applied to previously
|
||||
present tokens, proportional to how many times they have already
|
||||
appeared in the prompt or prior generation. (default: :obj:`0.0`)
|
||||
presence_penalty (float, optional): Min value of `0.0`, max value of
|
||||
`1.0`. Used to reduce repetitiveness of generated tokens. Similar
|
||||
to `frequency_penalty`, except that this penalty is applied
|
||||
equally to all tokens that have already appeared, regardless of
|
||||
their exact frequencies. (default: :obj:`0.0`)
|
||||
k (int, optional): Ensures only the top k most likely tokens are
|
||||
considered for generation at each step. Min value of `0`, max
|
||||
value of `500`. (default: :obj:`0`)
|
||||
p (float, optional): Ensures that only the most likely tokens, with
|
||||
total probability mass of `p`, are considered for generation at
|
||||
each step. If both k and p are enabled, `p` acts after `k`. Min
|
||||
value of `0.01`, max value of `0.99`. (default: :obj:`0.75`)
|
||||
"""
|
||||
|
||||
temperature: Optional[float] = 0.2
|
||||
documents: Optional[list] = None
|
||||
max_tokens: Optional[int] = None
|
||||
stop_sequences: Optional[List[str]] = None
|
||||
seed: Optional[int] = None
|
||||
frequency_penalty: Optional[float] = 0.0
|
||||
presence_penalty: Optional[float] = 0.0
|
||||
k: Optional[int] = 0
|
||||
p: Optional[float] = 0.75
|
||||
|
||||
|
||||
COHERE_API_PARAMS = {param for param in CohereConfig().model_fields.keys()}
|
||||
@@ -1,134 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional, Sequence, Type, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
from camel.types import NOT_GIVEN, NotGiven
|
||||
|
||||
|
||||
class DeepSeekConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
DeepSeek API.
|
||||
|
||||
Args:
|
||||
temperature (float, optional): Sampling temperature to use, between
|
||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
||||
while lower values make it more focused and deterministic.
|
||||
(default: :obj:`0.2`)
|
||||
top_p (float, optional): Controls the diversity and focus of the
|
||||
generated results. Higher values make the output more diverse,
|
||||
while lower values make it more focused. (default: :obj:`1.0`)
|
||||
response_format (object, optional): Specifies the format of the
|
||||
returned content. The available values are `{"type": "text"}` or
|
||||
`{"type": "json_object"}`. Setting it to `{"type": "json_object"}`
|
||||
will output a standard JSON string.
|
||||
(default: :obj:`{"type": "text"}`)
|
||||
stream (bool, optional): If set, partial message deltas will be sent.
|
||||
Tokens will be sent as data-only server-sent events (SSE) as
|
||||
they become available, with the stream terminated by a
|
||||
data: [DONE] message. (default: :obj:`False`)
|
||||
stop (Union[str, list[str]], optional): Up to 16 sequences where
|
||||
the API will stop generating further tokens. (default: :obj:`None`)
|
||||
max_tokens (int, optional): The maximum number of tokens that can
|
||||
be generated in the chat completion. The total length of input
|
||||
tokens and generated tokens is limited by the model's context
|
||||
length. (default: :obj:`None`)
|
||||
presence_penalty (float, optional): Number between -2.0 and 2.0.
|
||||
Positive values penalize new tokens based on whether they
|
||||
appear in the text so far, increasing the model's likelihood
|
||||
to talk about new topics. (default: :obj:`0.0`)
|
||||
frequency_penalty (float, optional): Number between -2.0 and 2.0.
|
||||
Positive values penalize new tokens based on their existing
|
||||
frequency in the text so far, decreasing the model's likelihood
|
||||
to repeat the same line verbatim. (default: :obj:`0`)
|
||||
tools (list[FunctionTool], optional): A list of tools the model may
|
||||
call. Currently, only functions are supported as a tool. Use
|
||||
this to provide a list of functions the model may generate JSON
|
||||
inputs for. A max of 128 functions are supported.
|
||||
(default: :obj:`None`)
|
||||
tool_choice (Union[dict[str, str], str], optional): Controls which
|
||||
(if any) tool is called by the model. "none" means the model
|
||||
will not call any tool and instead generates a message. "auto"
|
||||
means the model can pick between generating a message or calling
|
||||
one or more tools. "required" means the model must call one or
|
||||
more tools. Specifying a particular tool via
|
||||
{"type": "function", "function": {"name": "my_function"}} forces
|
||||
the model to call that tool. "none" is the default when no tools
|
||||
are present. "auto" is the default if tools are present.
|
||||
(default: :obj:`"auto"`)
|
||||
logprobs (bool, optional): Whether to return log probabilities of
|
||||
the output tokens or not. If true, returns the log probabilities
|
||||
of each output token returned in the content of message.
|
||||
(default: :obj:`False`)
|
||||
top_logprobs (int, optional): An integer between 0 and 20 specifying
|
||||
the number of most likely tokens to return at each token
|
||||
position, each with an associated log probability. logprobs
|
||||
must be set to true if this parameter is used.
|
||||
(default: :obj:`None`)
|
||||
include_usage (bool, optional): When streaming, specifies whether to
|
||||
include usage information in `stream_options`. (default:
|
||||
:obj:`True`)
|
||||
"""
|
||||
|
||||
temperature: float = 0.2 # deepseek default: 1.0
|
||||
top_p: float = 1.0
|
||||
stream: bool = False
|
||||
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
|
||||
max_tokens: Union[int, NotGiven] = NOT_GIVEN
|
||||
presence_penalty: float = 0.0
|
||||
response_format: Union[Type[BaseModel], dict, NotGiven] = NOT_GIVEN
|
||||
frequency_penalty: float = 0.0
|
||||
tool_choice: Optional[Union[dict[str, str], str]] = None
|
||||
logprobs: bool = False
|
||||
top_logprobs: Optional[int] = None
|
||||
|
||||
def __init__(self, include_usage: bool = True, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
# Only set stream_options when stream is True
|
||||
# Otherwise, it will raise error when calling the API
|
||||
if self.stream:
|
||||
self.stream_options = {"include_usage": include_usage}
|
||||
|
||||
def as_dict(self) -> dict[str, Any]:
|
||||
r"""Convert the current configuration to a dictionary.
|
||||
|
||||
This method converts the current configuration object to a dictionary
|
||||
representation, which can be used for serialization or other purposes.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: A dictionary representation of the current
|
||||
configuration.
|
||||
"""
|
||||
config_dict = self.model_dump()
|
||||
if self.tools:
|
||||
from camel.toolkits import FunctionTool
|
||||
|
||||
tools_schema = []
|
||||
for tool in self.tools:
|
||||
if not isinstance(tool, FunctionTool):
|
||||
raise ValueError(
|
||||
f"The tool {tool} should "
|
||||
"be an instance of `FunctionTool`."
|
||||
)
|
||||
tools_schema.append(tool.get_openai_tool_schema())
|
||||
config_dict["tools"] = NOT_GIVEN
|
||||
return config_dict
|
||||
|
||||
|
||||
DEEPSEEK_API_PARAMS = {param for param in DeepSeekConfig.model_fields.keys()}
|
||||
@@ -1,114 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional, Sequence, Type, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
from camel.types import NOT_GIVEN, NotGiven
|
||||
|
||||
|
||||
class GeminiConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
Gemini API.
|
||||
|
||||
Args:
|
||||
temperature (float, optional): Sampling temperature to use, between
|
||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
||||
while lower values make it more focused and deterministic.
|
||||
(default: :obj:`0.2`)
|
||||
top_p (float, optional): An alternative to sampling with temperature,
|
||||
called nucleus sampling, where the model considers the results of
|
||||
the tokens with top_p probability mass. So :obj:`0.1` means only
|
||||
the tokens comprising the top 10% probability mass are considered.
|
||||
(default: :obj:`1.0`)
|
||||
n (int, optional): How many chat completion choices to generate for
|
||||
each input message. (default: :obj:`1`)
|
||||
response_format (object, optional): An object specifying the format
|
||||
that the model must output. Compatible with GPT-4 Turbo and all
|
||||
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
|
||||
{"type": "json_object"} enables JSON mode, which guarantees the
|
||||
message the model generates is valid JSON. Important: when using
|
||||
JSON mode, you must also instruct the model to produce JSON
|
||||
yourself via a system or user message. Without this, the model
|
||||
may generate an unending stream of whitespace until the generation
|
||||
reaches the token limit, resulting in a long-running and seemingly
|
||||
"stuck" request. Also note that the message content may be
|
||||
partially cut off if finish_reason="length", which indicates the
|
||||
generation exceeded max_tokens or the conversation exceeded the
|
||||
max context length.
|
||||
stream (bool, optional): If True, partial message deltas will be sent
|
||||
as data-only server-sent events as they become available.
|
||||
(default: :obj:`False`)
|
||||
stop (str or list, optional): Up to :obj:`4` sequences where the API
|
||||
will stop generating further tokens. (default: :obj:`None`)
|
||||
max_tokens (int, optional): The maximum number of tokens to generate
|
||||
in the chat completion. The total length of input tokens and
|
||||
generated tokens is limited by the model's context length.
|
||||
(default: :obj:`None`)
|
||||
tools (list[FunctionTool], optional): A list of tools the model may
|
||||
call. Currently, only functions are supported as a tool. Use this
|
||||
to provide a list of functions the model may generate JSON inputs
|
||||
for. A max of 128 functions are supported.
|
||||
tool_choice (Union[dict[str, str], str], optional): Controls which (if
|
||||
any) tool is called by the model. :obj:`"none"` means the model
|
||||
will not call any tool and instead generates a message.
|
||||
:obj:`"auto"` means the model can pick between generating a
|
||||
message or calling one or more tools. :obj:`"required"` means the
|
||||
model must call one or more tools. Specifying a particular tool
|
||||
via {"type": "function", "function": {"name": "my_function"}}
|
||||
forces the model to call that tool. :obj:`"none"` is the default
|
||||
when no tools are present. :obj:`"auto"` is the default if tools
|
||||
are present.
|
||||
"""
|
||||
|
||||
temperature: float = 0.2 # openai default: 1.0
|
||||
top_p: float = 1.0
|
||||
n: int = 1
|
||||
stream: bool = False
|
||||
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
|
||||
max_tokens: Union[int, NotGiven] = NOT_GIVEN
|
||||
response_format: Union[Type[BaseModel], dict, NotGiven] = NOT_GIVEN
|
||||
tool_choice: Optional[Union[dict[str, str], str]] = None
|
||||
|
||||
def as_dict(self) -> dict[str, Any]:
|
||||
r"""Convert the current configuration to a dictionary.
|
||||
|
||||
This method converts the current configuration object to a dictionary
|
||||
representation, which can be used for serialization or other purposes.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: A dictionary representation of the current
|
||||
configuration.
|
||||
"""
|
||||
config_dict = self.model_dump()
|
||||
if self.tools:
|
||||
from camel.toolkits import FunctionTool
|
||||
|
||||
tools_schema = []
|
||||
for tool in self.tools:
|
||||
if not isinstance(tool, FunctionTool):
|
||||
raise ValueError(
|
||||
f"The tool {tool} should "
|
||||
"be an instance of `FunctionTool`."
|
||||
)
|
||||
tools_schema.append(tool.get_openai_tool_schema())
|
||||
config_dict["tools"] = NOT_GIVEN
|
||||
return config_dict
|
||||
|
||||
|
||||
Gemini_API_PARAMS = {param for param in GeminiConfig.model_fields.keys()}
|
||||
@@ -1,104 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
from camel.types import NOT_GIVEN, NotGiven
|
||||
|
||||
|
||||
class GroqConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using OpenAI
|
||||
compatibility.
|
||||
|
||||
Reference: https://console.groq.com/docs/openai
|
||||
|
||||
Args:
|
||||
temperature (float, optional): Sampling temperature to use, between
|
||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
||||
while lower values make it more focused and deterministic.
|
||||
(default: :obj:`0.2`)
|
||||
top_p (float, optional): An alternative to sampling with temperature,
|
||||
called nucleus sampling, where the model considers the results of
|
||||
the tokens with top_p probability mass. So :obj:`0.1` means only
|
||||
the tokens comprising the top 10% probability mass are considered.
|
||||
(default: :obj:`1.0`)
|
||||
n (int, optional): How many chat completion choices to generate for
|
||||
each input message. (default: :obj:`1`)
|
||||
response_format (object, optional): An object specifying the format
|
||||
that the model must output. Compatible with GPT-4 Turbo and all
|
||||
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
|
||||
{"type": "json_object"} enables JSON mode, which guarantees the
|
||||
message the model generates is valid JSON. Important: when using
|
||||
JSON mode, you must also instruct the model to produce JSON
|
||||
yourself via a system or user message. Without this, the model
|
||||
may generate an unending stream of whitespace until the generation
|
||||
reaches the token limit, resulting in a long-running and seemingly
|
||||
"stuck" request. Also note that the message content may be
|
||||
partially cut off if finish_reason="length", which indicates the
|
||||
generation exceeded max_tokens or the conversation exceeded the
|
||||
max context length.
|
||||
stream (bool, optional): If True, partial message deltas will be sent
|
||||
as data-only server-sent events as they become available.
|
||||
(default: :obj:`False`)
|
||||
stop (str or list, optional): Up to :obj:`4` sequences where the API
|
||||
will stop generating further tokens. (default: :obj:`None`)
|
||||
max_tokens (int, optional): The maximum number of tokens to generate
|
||||
in the chat completion. The total length of input tokens and
|
||||
generated tokens is limited by the model's context length.
|
||||
(default: :obj:`None`)
|
||||
presence_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on whether
|
||||
they appear in the text so far, increasing the model's likelihood
|
||||
to talk about new topics. See more information about frequency and
|
||||
presence penalties. (default: :obj:`0.0`)
|
||||
frequency_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on their
|
||||
existing frequency in the text so far, decreasing the model's
|
||||
likelihood to repeat the same line verbatim. See more information
|
||||
about frequency and presence penalties. (default: :obj:`0.0`)
|
||||
user (str, optional): A unique identifier representing your end-user,
|
||||
which can help OpenAI to monitor and detect abuse.
|
||||
(default: :obj:`""`)
|
||||
tools (list[FunctionTool], optional): A list of tools the model may
|
||||
call. Currently, only functions are supported as a tool. Use this
|
||||
to provide a list of functions the model may generate JSON inputs
|
||||
for. A max of 128 functions are supported.
|
||||
tool_choice (Union[dict[str, str], str], optional): Controls which (if
|
||||
any) tool is called by the model. :obj:`"none"` means the model
|
||||
will not call any tool and instead generates a message.
|
||||
:obj:`"auto"` means the model can pick between generating a
|
||||
message or calling one or more tools. :obj:`"required"` means the
|
||||
model must call one or more tools. Specifying a particular tool
|
||||
via {"type": "function", "function": {"name": "my_function"}}
|
||||
forces the model to call that tool. :obj:`"none"` is the default
|
||||
when no tools are present. :obj:`"auto"` is the default if tools
|
||||
are present.
|
||||
"""
|
||||
|
||||
temperature: float = 0.2 # openai default: 1.0
|
||||
top_p: float = 1.0
|
||||
n: int = 1
|
||||
stream: bool = False
|
||||
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
|
||||
max_tokens: Union[int, NotGiven] = NOT_GIVEN
|
||||
presence_penalty: float = 0.0
|
||||
response_format: Union[dict, NotGiven] = NOT_GIVEN
|
||||
frequency_penalty: float = 0.0
|
||||
user: str = ""
|
||||
tool_choice: Optional[Union[dict[str, str], str]] = "auto"
|
||||
|
||||
|
||||
GROQ_API_PARAMS = {param for param in GroqConfig.model_fields.keys()}
|
||||
@@ -1,97 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
|
||||
|
||||
class LiteLLMConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
LiteLLM API.
|
||||
|
||||
Args:
|
||||
timeout (Optional[Union[float, str]], optional): Request timeout.
|
||||
(default: None)
|
||||
temperature (Optional[float], optional): Temperature parameter for
|
||||
controlling randomness. (default: None)
|
||||
top_p (Optional[float], optional): Top-p parameter for nucleus
|
||||
sampling. (default: None)
|
||||
n (Optional[int], optional): Number of completions to generate.
|
||||
(default: None)
|
||||
stream (Optional[bool], optional): Whether to return a streaming
|
||||
response. (default: None)
|
||||
stream_options (Optional[dict], optional): Options for the streaming
|
||||
response. (default: None)
|
||||
stop (Optional[Union[str, List[str]]], optional): Sequences where the
|
||||
API will stop generating further tokens. (default: None)
|
||||
max_tokens (Optional[int], optional): Maximum number of tokens to
|
||||
generate. (default: None)
|
||||
presence_penalty (Optional[float], optional): Penalize new tokens
|
||||
based on their existence in the text so far. (default: None)
|
||||
frequency_penalty (Optional[float], optional): Penalize new tokens
|
||||
based on their frequency in the text so far. (default: None)
|
||||
logit_bias (Optional[dict], optional): Modify the probability of
|
||||
specific tokens appearing in the completion. (default: None)
|
||||
user (Optional[str], optional): A unique identifier representing the
|
||||
end-user. (default: None)
|
||||
response_format (Optional[dict], optional): Response format
|
||||
parameters. (default: None)
|
||||
seed (Optional[int], optional): Random seed. (default: None)
|
||||
tools (Optional[List], optional): List of tools. (default: None)
|
||||
tool_choice (Optional[Union[str, dict]], optional): Tool choice
|
||||
parameters. (default: None)
|
||||
logprobs (Optional[bool], optional): Whether to return log
|
||||
probabilities of the output tokens. (default: None)
|
||||
top_logprobs (Optional[int], optional): Number of most likely tokens
|
||||
to return at each token position. (default: None)
|
||||
deployment_id (Optional[str], optional): Deployment ID. (default: None)
|
||||
extra_headers (Optional[dict], optional): Additional headers for the
|
||||
request. (default: None)
|
||||
api_version (Optional[str], optional): API version. (default: None)
|
||||
mock_response (Optional[str], optional): Mock completion response for
|
||||
testing or debugging. (default: None)
|
||||
custom_llm_provider (Optional[str], optional): Non-OpenAI LLM
|
||||
provider. (default: None)
|
||||
max_retries (Optional[int], optional): Maximum number of retries.
|
||||
(default: None)
|
||||
"""
|
||||
|
||||
timeout: Optional[Union[float, str]] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
n: Optional[int] = None
|
||||
stream: Optional[bool] = None
|
||||
stream_options: Optional[dict] = None
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
max_tokens: Optional[int] = None
|
||||
presence_penalty: Optional[float] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
logit_bias: Optional[dict] = None
|
||||
user: Optional[str] = None
|
||||
response_format: Optional[dict] = None
|
||||
seed: Optional[int] = None
|
||||
tool_choice: Optional[Union[str, dict]] = None
|
||||
logprobs: Optional[bool] = None
|
||||
top_logprobs: Optional[int] = None
|
||||
deployment_id: Optional[str] = None
|
||||
extra_headers: Optional[dict] = None
|
||||
api_version: Optional[str] = None
|
||||
mock_response: Optional[str] = None
|
||||
custom_llm_provider: Optional[str] = None
|
||||
max_retries: Optional[int] = None
|
||||
|
||||
|
||||
LITELLM_API_PARAMS = {param for param in LiteLLMConfig.model_fields.keys()}
|
||||
@@ -1,79 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from pydantic import field_validator
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
|
||||
|
||||
class MistralConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
Mistral API.
|
||||
|
||||
reference: https://github.com/mistralai/client-python/blob/9d238f88c41689821d7b08570f13b43426f97fd6/src/mistralai/client.py#L195
|
||||
|
||||
#TODO: Support stream mode
|
||||
|
||||
Args:
|
||||
temperature (Optional[float], optional): temperature the temperature
|
||||
to use for sampling, e.g. 0.5.
|
||||
top_p (Optional[float], optional): the cumulative probability of
|
||||
tokens to generate, e.g. 0.9. Defaults to None.
|
||||
max_tokens (Optional[int], optional): the maximum number of tokens to
|
||||
generate, e.g. 100. Defaults to None.
|
||||
stop (Optional[Union[str,list[str]]]): Stop generation if this token
|
||||
is detected. Or if one of these tokens is detected when providing
|
||||
a string list.
|
||||
random_seed (Optional[int], optional): the random seed to use for
|
||||
sampling, e.g. 42. Defaults to None.
|
||||
safe_prompt (bool, optional): whether to use safe prompt, e.g. true.
|
||||
Defaults to False.
|
||||
response_format (Union[Dict[str, str], ResponseFormat): format of the
|
||||
response.
|
||||
tool_choice (str, optional): Controls which (if
|
||||
any) tool is called by the model. :obj:`"none"` means the model
|
||||
will not call any tool and instead generates a message.
|
||||
:obj:`"auto"` means the model can pick between generating a
|
||||
message or calling one or more tools. :obj:`"any"` means the
|
||||
model must call one or more tools. :obj:`"auto"` is the default
|
||||
value.
|
||||
"""
|
||||
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
max_tokens: Optional[int] = None
|
||||
stop: Optional[Union[str, list[str]]] = None
|
||||
random_seed: Optional[int] = None
|
||||
safe_prompt: bool = False
|
||||
response_format: Optional[Union[Dict[str, str], Any]] = None
|
||||
tool_choice: Optional[str] = "auto"
|
||||
|
||||
@field_validator("response_format", mode="before")
|
||||
@classmethod
|
||||
def fields_type_checking(cls, response_format):
|
||||
if response_format and not isinstance(response_format, dict):
|
||||
from mistralai.models import ResponseFormat
|
||||
|
||||
if not isinstance(response_format, ResponseFormat):
|
||||
raise ValueError(
|
||||
f"The tool {response_format} should be an instance "
|
||||
"of `mistralai.models.ResponseFormat`."
|
||||
)
|
||||
return response_format
|
||||
|
||||
|
||||
MISTRAL_API_PARAMS = {param for param in MistralConfig().model_fields.keys()}
|
||||
@@ -1,70 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
from camel.types import NOT_GIVEN, NotGiven
|
||||
|
||||
|
||||
class NvidiaConfig(BaseConfig):
|
||||
r"""Configuration class for NVIDIA API models.
|
||||
|
||||
This class defines the configuration parameters for NVIDIA's language
|
||||
models, including temperature, sampling parameters, and response format
|
||||
settings.
|
||||
|
||||
Args:
|
||||
stream (bool, optional): Whether to stream the response.
|
||||
(default: :obj:`False`)
|
||||
temperature (float, optional): Controls randomness in the response.
|
||||
Higher values make output more random, lower values make it more
|
||||
deterministic. Range: [0.0, 2.0]. (default: :obj:`0.7`)
|
||||
top_p (float, optional): Controls diversity via nucleus sampling.
|
||||
Range: [0.0, 1.0]. (default: :obj:`0.95`)
|
||||
presence_penalty (float, optional): Penalizes new tokens based on
|
||||
whether they appear in the text so far. Range: [-2.0, 2.0].
|
||||
(default: :obj:`0.0`)
|
||||
frequency_penalty (float, optional): Penalizes new tokens based on
|
||||
their frequency in the text so far. Range: [-2.0, 2.0].
|
||||
(default: :obj:`0.0`)
|
||||
max_tokens (Union[int, NotGiven], optional): Maximum number of tokens
|
||||
to generate. If not provided, model will use its default maximum.
|
||||
(default: :obj:`NOT_GIVEN`)
|
||||
seed (Optional[int], optional): Random seed for deterministic sampling.
|
||||
(default: :obj:`None`)
|
||||
tools (Optional[List[Dict]], optional): List of tools available to the
|
||||
model. This includes tools such as a text editor, a calculator, or
|
||||
a search engine. (default: :obj:`None`)
|
||||
tool_choice (Optional[str], optional): Tool choice configuration.
|
||||
(default: :obj:`None`)
|
||||
stop (Optional[List[str]], optional): List of stop sequences.
|
||||
(default: :obj:`None`)
|
||||
"""
|
||||
|
||||
stream: bool = Field(default=False)
|
||||
temperature: float = Field(default=0.7)
|
||||
top_p: float = Field(default=0.95)
|
||||
presence_penalty: float = Field(default=0.0)
|
||||
frequency_penalty: float = Field(default=0.0)
|
||||
max_tokens: Union[int, NotGiven] = Field(default=NOT_GIVEN)
|
||||
seed: Optional[int] = Field(default=None)
|
||||
tool_choice: Optional[str] = Field(default=None)
|
||||
stop: Optional[List[str]] = Field(default=None)
|
||||
|
||||
|
||||
NVIDIA_API_PARAMS = {param for param in NvidiaConfig.model_fields.keys()}
|
||||
@@ -1,82 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
from camel.types import NOT_GIVEN, NotGiven
|
||||
|
||||
|
||||
class OllamaConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using OpenAI
|
||||
compatibility
|
||||
|
||||
Reference: https://github.com/ollama/ollama/blob/main/docs/openai.md
|
||||
|
||||
Args:
|
||||
temperature (float, optional): Sampling temperature to use, between
|
||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
||||
while lower values make it more focused and deterministic.
|
||||
(default: :obj:`0.2`)
|
||||
top_p (float, optional): An alternative to sampling with temperature,
|
||||
called nucleus sampling, where the model considers the results of
|
||||
the tokens with top_p probability mass. So :obj:`0.1` means only
|
||||
the tokens comprising the top 10% probability mass are considered.
|
||||
(default: :obj:`1.0`)
|
||||
response_format (object, optional): An object specifying the format
|
||||
that the model must output. Compatible with GPT-4 Turbo and all
|
||||
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
|
||||
{"type": "json_object"} enables JSON mode, which guarantees the
|
||||
message the model generates is valid JSON. Important: when using
|
||||
JSON mode, you must also instruct the model to produce JSON
|
||||
yourself via a system or user message. Without this, the model
|
||||
may generate an unending stream of whitespace until the generation
|
||||
reaches the token limit, resulting in a long-running and seemingly
|
||||
"stuck" request. Also note that the message content may be
|
||||
partially cut off if finish_reason="length", which indicates the
|
||||
generation exceeded max_tokens or the conversation exceeded the
|
||||
max context length.
|
||||
stream (bool, optional): If True, partial message deltas will be sent
|
||||
as data-only server-sent events as they become available.
|
||||
(default: :obj:`False`)
|
||||
stop (str or list, optional): Up to :obj:`4` sequences where the API
|
||||
will stop generating further tokens. (default: :obj:`None`)
|
||||
max_tokens (int, optional): The maximum number of tokens to generate
|
||||
in the chat completion. The total length of input tokens and
|
||||
generated tokens is limited by the model's context length.
|
||||
(default: :obj:`None`)
|
||||
presence_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on whether
|
||||
they appear in the text so far, increasing the model's likelihood
|
||||
to talk about new topics. See more information about frequency and
|
||||
presence penalties. (default: :obj:`0.0`)
|
||||
frequency_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on their
|
||||
existing frequency in the text so far, decreasing the model's
|
||||
likelihood to repeat the same line verbatim. See more information
|
||||
about frequency and presence penalties. (default: :obj:`0.0`)
|
||||
"""
|
||||
|
||||
temperature: float = 0.2
|
||||
top_p: float = 1.0
|
||||
stream: bool = False
|
||||
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
|
||||
max_tokens: Union[int, NotGiven] = NOT_GIVEN
|
||||
presence_penalty: float = 0.0
|
||||
response_format: Union[dict, NotGiven] = NOT_GIVEN
|
||||
frequency_penalty: float = 0.0
|
||||
|
||||
|
||||
OLLAMA_API_PARAMS = {param for param in OllamaConfig.model_fields.keys()}
|
||||
@@ -1,139 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional, Sequence, Type, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
from camel.types import NOT_GIVEN, NotGiven
|
||||
|
||||
|
||||
class ChatGPTConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
OpenAI API.
|
||||
|
||||
Args:
|
||||
temperature (float, optional): Sampling temperature to use, between
|
||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
||||
while lower values make it more focused and deterministic.
|
||||
(default: :obj:`0.2`)
|
||||
top_p (float, optional): An alternative to sampling with temperature,
|
||||
called nucleus sampling, where the model considers the results of
|
||||
the tokens with top_p probability mass. So :obj:`0.1` means only
|
||||
the tokens comprising the top 10% probability mass are considered.
|
||||
(default: :obj:`1.0`)
|
||||
n (int, optional): How many chat completion choices to generate for
|
||||
each input message. (default: :obj:`1`)
|
||||
response_format (object, optional): An object specifying the format
|
||||
that the model must output. Compatible with GPT-4 Turbo and all
|
||||
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
|
||||
{"type": "json_object"} enables JSON mode, which guarantees the
|
||||
message the model generates is valid JSON. Important: when using
|
||||
JSON mode, you must also instruct the model to produce JSON
|
||||
yourself via a system or user message. Without this, the model
|
||||
may generate an unending stream of whitespace until the generation
|
||||
reaches the token limit, resulting in a long-running and seemingly
|
||||
"stuck" request. Also note that the message content may be
|
||||
partially cut off if finish_reason="length", which indicates the
|
||||
generation exceeded max_tokens or the conversation exceeded the
|
||||
max context length.
|
||||
stream (bool, optional): If True, partial message deltas will be sent
|
||||
as data-only server-sent events as they become available.
|
||||
(default: :obj:`False`)
|
||||
stop (str or list, optional): Up to :obj:`4` sequences where the API
|
||||
will stop generating further tokens. (default: :obj:`None`)
|
||||
max_tokens (int, optional): The maximum number of tokens to generate
|
||||
in the chat completion. The total length of input tokens and
|
||||
generated tokens is limited by the model's context length.
|
||||
(default: :obj:`None`)
|
||||
presence_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on whether
|
||||
they appear in the text so far, increasing the model's likelihood
|
||||
to talk about new topics. See more information about frequency and
|
||||
presence penalties. (default: :obj:`0.0`)
|
||||
frequency_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on their
|
||||
existing frequency in the text so far, decreasing the model's
|
||||
likelihood to repeat the same line verbatim. See more information
|
||||
about frequency and presence penalties. (default: :obj:`0.0`)
|
||||
logit_bias (dict, optional): Modify the likelihood of specified tokens
|
||||
appearing in the completion. Accepts a json object that maps tokens
|
||||
(specified by their token ID in the tokenizer) to an associated
|
||||
bias value from :obj:`-100` to :obj:`100`. Mathematically, the bias
|
||||
is added to the logits generated by the model prior to sampling.
|
||||
The exact effect will vary per model, but values between:obj:` -1`
|
||||
and :obj:`1` should decrease or increase likelihood of selection;
|
||||
values like :obj:`-100` or :obj:`100` should result in a ban or
|
||||
exclusive selection of the relevant token. (default: :obj:`{}`)
|
||||
user (str, optional): A unique identifier representing your end-user,
|
||||
which can help OpenAI to monitor and detect abuse.
|
||||
(default: :obj:`""`)
|
||||
tools (list[FunctionTool], optional): A list of tools the model may
|
||||
call. Currently, only functions are supported as a tool. Use this
|
||||
to provide a list of functions the model may generate JSON inputs
|
||||
for. A max of 128 functions are supported.
|
||||
tool_choice (Union[dict[str, str], str], optional): Controls which (if
|
||||
any) tool is called by the model. :obj:`"none"` means the model
|
||||
will not call any tool and instead generates a message.
|
||||
:obj:`"auto"` means the model can pick between generating a
|
||||
message or calling one or more tools. :obj:`"required"` means the
|
||||
model must call one or more tools. Specifying a particular tool
|
||||
via {"type": "function", "function": {"name": "my_function"}}
|
||||
forces the model to call that tool. :obj:`"none"` is the default
|
||||
when no tools are present. :obj:`"auto"` is the default if tools
|
||||
are present.
|
||||
"""
|
||||
|
||||
temperature: float = 0.2 # openai default: 1.0
|
||||
top_p: float = 1.0
|
||||
n: int = 1
|
||||
stream: bool = False
|
||||
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
|
||||
max_tokens: Union[int, NotGiven] = NOT_GIVEN
|
||||
presence_penalty: float = 0.0
|
||||
response_format: Union[Type[BaseModel], dict, NotGiven] = NOT_GIVEN
|
||||
frequency_penalty: float = 0.0
|
||||
logit_bias: dict = Field(default_factory=dict)
|
||||
user: str = ""
|
||||
tool_choice: Optional[Union[dict[str, str], str]] = None
|
||||
|
||||
def as_dict(self) -> dict[str, Any]:
|
||||
r"""Convert the current configuration to a dictionary.
|
||||
|
||||
This method converts the current configuration object to a dictionary
|
||||
representation, which can be used for serialization or other purposes.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: A dictionary representation of the current
|
||||
configuration.
|
||||
"""
|
||||
config_dict = self.model_dump()
|
||||
if self.tools:
|
||||
from camel.toolkits import FunctionTool
|
||||
|
||||
tools_schema = []
|
||||
for tool in self.tools:
|
||||
if not isinstance(tool, FunctionTool):
|
||||
raise ValueError(
|
||||
f"The tool {tool} should "
|
||||
"be an instance of `FunctionTool`."
|
||||
)
|
||||
tools_schema.append(tool.get_openai_tool_schema())
|
||||
config_dict["tools"] = NOT_GIVEN
|
||||
return config_dict
|
||||
|
||||
|
||||
OPENAI_API_PARAMS = {param for param in ChatGPTConfig.model_fields.keys()}
|
||||
@@ -1,91 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import ClassVar, Optional, Union
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
from camel.types import NOT_GIVEN, NotGiven
|
||||
|
||||
|
||||
class QwenConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
Qwen API. You can refer to the following link for more details:
|
||||
https://help.aliyun.com/zh/model-studio/developer-reference/use-qwen-by-calling-api
|
||||
|
||||
Args:
|
||||
stream (bool, optional): Whether to stream the response.
|
||||
(default: :obj:`False`)
|
||||
temperature (float, optional): Controls the diversity and focus of
|
||||
the generated results. Lower values make the output more focused,
|
||||
while higher values make it more diverse. (default: :obj:`0.3`)
|
||||
top_p (float, optional): Controls the diversity and focus of the
|
||||
generated results. Higher values make the output more diverse,
|
||||
while lower values make it more focused. (default: :obj:`0.9`)
|
||||
presence_penalty (float, optional): Controls the repetition of
|
||||
content in the generated results. Positive values reduce the
|
||||
repetition of content, while negative values increase it.
|
||||
(default: :obj:`0.0`)
|
||||
response_format (object, optional): Specifies the format of the
|
||||
returned content. The available values are `{"type": "text"}` or
|
||||
`{"type": "json_object"}`. Setting it to `{"type": "json_object"}`
|
||||
will output a standard JSON string.
|
||||
(default: :obj:`{"type": "text"}`)
|
||||
max_tokens (Union[int, NotGiven], optional): Allows the model to
|
||||
generate the maximum number of tokens.
|
||||
(default: :obj:`NOT_GIVEN`)
|
||||
seed (int, optional): Sets the seed parameter to make the text
|
||||
generation process more deterministic, typically used to ensure
|
||||
that the results are consistent across model runs. By passing the
|
||||
same seed value (specified by you) in each model call while
|
||||
keeping other parameters unchanged, the model is likely to return
|
||||
the same result.
|
||||
(default: :obj:`None`)
|
||||
stop (str or list, optional): Using the stop parameter, the model will
|
||||
automatically stop generating text when it is about to include the
|
||||
specified string or token_id. You can use the stop parameter to
|
||||
control the output of the model by passing sensitive words.
|
||||
(default: :obj:`None`)
|
||||
tools (list, optional): Specifies an array of tools that the model can
|
||||
call. It can contain one or more tool objects. During a function
|
||||
call process, the model will select one tool from the array.
|
||||
(default: :obj:`None`)
|
||||
extra_body (dict, optional): Additional parameters to be sent to the
|
||||
Qwen API. If you want to enable internet search, you can set this
|
||||
parameter to `{"enable_search": True}`.
|
||||
(default: :obj:`{"enable_search": False}`)
|
||||
include_usage (bool, optional): When streaming, specifies whether to
|
||||
include usage information in `stream_options`. (default:
|
||||
:obj:`True`)
|
||||
"""
|
||||
|
||||
stream: bool = False
|
||||
temperature: float = 0.3
|
||||
top_p: float = 0.9
|
||||
presence_penalty: float = 0.0
|
||||
response_format: ClassVar[dict] = {"type": "text"}
|
||||
max_tokens: Union[int, NotGiven] = NOT_GIVEN
|
||||
seed: Optional[int] = None
|
||||
stop: Optional[Union[str, list]] = None
|
||||
extra_body: ClassVar[dict] = {"enable_search": False}
|
||||
|
||||
def __init__(self, include_usage: bool = True, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
# Only set stream_options when stream is True
|
||||
# Otherwise, it will raise error when calling the API
|
||||
if self.stream:
|
||||
self.stream_options = {"include_usage": include_usage}
|
||||
|
||||
|
||||
QWEN_API_PARAMS = {param for param in QwenConfig.model_fields.keys()}
|
||||
@@ -1,74 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
|
||||
|
||||
class RekaConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
Reka API.
|
||||
|
||||
Reference: https://docs.reka.ai/api-reference/chat/create
|
||||
|
||||
Args:
|
||||
temperature (Optional[float], optional): temperature the temperature
|
||||
to use for sampling, e.g. 0.5.
|
||||
top_p (Optional[float], optional): the cumulative probability of
|
||||
tokens to generate, e.g. 0.9. Defaults to None.
|
||||
top_k (Optional[int], optional): Parameter which forces the model to
|
||||
only consider the tokens with the `top_k` highest probabilities at
|
||||
the next step. Defaults to 1024.
|
||||
max_tokens (Optional[int], optional): the maximum number of tokens to
|
||||
generate, e.g. 100. Defaults to None.
|
||||
stop (Optional[Union[str,list[str]]]): Stop generation if this token
|
||||
is detected. Or if one of these tokens is detected when providing
|
||||
a string list.
|
||||
seed (Optional[int], optional): the random seed to use for sampling, e.
|
||||
g. 42. Defaults to None.
|
||||
presence_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on whether
|
||||
they appear in the text so far, increasing the model's likelihood
|
||||
to talk about new topics. See more information about frequency and
|
||||
presence penalties. (default: :obj:`0.0`)
|
||||
frequency_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on their
|
||||
existing frequency in the text so far, decreasing the model's
|
||||
likelihood to repeat the same line verbatim. See more information
|
||||
about frequency and presence penalties. (default: :obj:`0.0`)
|
||||
use_search_engine (Optional[bool]): Whether to consider using search
|
||||
engine to complete the request. Note that even if this is set to
|
||||
`True`, the model might decide to not use search.
|
||||
"""
|
||||
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
max_tokens: Optional[int] = None
|
||||
stop: Optional[Union[str, list[str]]] = None
|
||||
seed: Optional[int] = None
|
||||
frequency_penalty: float = 0.0
|
||||
presence_penalty: float = 0.0
|
||||
use_search_engine: Optional[bool] = False
|
||||
|
||||
def as_dict(self) -> dict[str, Any]:
|
||||
config_dict = super().as_dict()
|
||||
if "tools" in config_dict:
|
||||
del config_dict["tools"] # Reka does not support tool calling
|
||||
return config_dict
|
||||
|
||||
|
||||
REKA_API_PARAMS = {param for param in RekaConfig().model_fields.keys()}
|
||||
@@ -1,170 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional, Sequence, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
from camel.types import NOT_GIVEN, NotGiven
|
||||
|
||||
|
||||
class SambaVerseAPIConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
SambaVerse API.
|
||||
|
||||
Args:
|
||||
temperature (float, optional): Sampling temperature to use, between
|
||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
||||
while lower values make it more focused and deterministic.
|
||||
(default: :obj:`0.7`)
|
||||
top_p (float, optional): An alternative to sampling with temperature,
|
||||
called nucleus sampling, where the model considers the results of
|
||||
the tokens with top_p probability mass. So :obj:`0.1` means only
|
||||
the tokens comprising the top 10% probability mass are considered.
|
||||
(default: :obj:`0.95`)
|
||||
top_k (int, optional): Only sample from the top K options for each
|
||||
subsequent token. Used to remove "long tail" low probability
|
||||
responses.
|
||||
(default: :obj:`50`)
|
||||
max_tokens (Optional[int], optional): The maximum number of tokens to
|
||||
generate, e.g. 100.
|
||||
(default: :obj:`2048`)
|
||||
repetition_penalty (Optional[float], optional): The parameter for
|
||||
repetition penalty. 1.0 means no penalty.
|
||||
(default: :obj:`1.0`)
|
||||
stop (Optional[Union[str,list[str]]]): Stop generation if this token
|
||||
is detected. Or if one of these tokens is detected when providing
|
||||
a string list.
|
||||
(default: :obj:`""`)
|
||||
stream (Optional[bool]): If True, partial message deltas will be sent
|
||||
as data-only server-sent events as they become available.
|
||||
Currently SambaVerse API doesn't support stream mode.
|
||||
(default: :obj:`False`)
|
||||
"""
|
||||
|
||||
temperature: Optional[float] = 0.7
|
||||
top_p: Optional[float] = 0.95
|
||||
top_k: Optional[int] = 50
|
||||
max_tokens: Optional[int] = 2048
|
||||
repetition_penalty: Optional[float] = 1.0
|
||||
stop: Optional[Union[str, list[str]]] = ""
|
||||
stream: Optional[bool] = False
|
||||
|
||||
def as_dict(self) -> dict[str, Any]:
|
||||
config_dict = super().as_dict()
|
||||
if "tools" in config_dict:
|
||||
del config_dict["tools"] # SambaNova does not support tool calling
|
||||
return config_dict
|
||||
|
||||
|
||||
SAMBA_VERSE_API_PARAMS = {
|
||||
param for param in SambaVerseAPIConfig().model_fields.keys()
|
||||
}
|
||||
|
||||
|
||||
class SambaCloudAPIConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
OpenAI API.
|
||||
|
||||
Args:
|
||||
temperature (float, optional): Sampling temperature to use, between
|
||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
||||
while lower values make it more focused and deterministic.
|
||||
(default: :obj:`0.2`)
|
||||
top_p (float, optional): An alternative to sampling with temperature,
|
||||
called nucleus sampling, where the model considers the results of
|
||||
the tokens with top_p probability mass. So :obj:`0.1` means only
|
||||
the tokens comprising the top 10% probability mass are considered.
|
||||
(default: :obj:`1.0`)
|
||||
n (int, optional): How many chat completion choices to generate for
|
||||
each input message. (default: :obj:`1`)
|
||||
response_format (object, optional): An object specifying the format
|
||||
that the model must output. Compatible with GPT-4 Turbo and all
|
||||
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
|
||||
{"type": "json_object"} enables JSON mode, which guarantees the
|
||||
message the model generates is valid JSON. Important: when using
|
||||
JSON mode, you must also instruct the model to produce JSON
|
||||
yourself via a system or user message. Without this, the model
|
||||
may generate an unending stream of whitespace until the generation
|
||||
reaches the token limit, resulting in a long-running and seemingly
|
||||
"stuck" request. Also note that the message content may be
|
||||
partially cut off if finish_reason="length", which indicates the
|
||||
generation exceeded max_tokens or the conversation exceeded the
|
||||
max context length.
|
||||
stream (bool, optional): If True, partial message deltas will be sent
|
||||
as data-only server-sent events as they become available.
|
||||
(default: :obj:`False`)
|
||||
stop (str or list, optional): Up to :obj:`4` sequences where the API
|
||||
will stop generating further tokens. (default: :obj:`None`)
|
||||
max_tokens (int, optional): The maximum number of tokens to generate
|
||||
in the chat completion. The total length of input tokens and
|
||||
generated tokens is limited by the model's context length.
|
||||
(default: :obj:`None`)
|
||||
presence_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on whether
|
||||
they appear in the text so far, increasing the model's likelihood
|
||||
to talk about new topics. See more information about frequency and
|
||||
presence penalties. (default: :obj:`0.0`)
|
||||
frequency_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on their
|
||||
existing frequency in the text so far, decreasing the model's
|
||||
likelihood to repeat the same line verbatim. See more information
|
||||
about frequency and presence penalties. (default: :obj:`0.0`)
|
||||
logit_bias (dict, optional): Modify the likelihood of specified tokens
|
||||
appearing in the completion. Accepts a json object that maps tokens
|
||||
(specified by their token ID in the tokenizer) to an associated
|
||||
bias value from :obj:`-100` to :obj:`100`. Mathematically, the bias
|
||||
is added to the logits generated by the model prior to sampling.
|
||||
The exact effect will vary per model, but values between:obj:` -1`
|
||||
and :obj:`1` should decrease or increase likelihood of selection;
|
||||
values like :obj:`-100` or :obj:`100` should result in a ban or
|
||||
exclusive selection of the relevant token. (default: :obj:`{}`)
|
||||
user (str, optional): A unique identifier representing your end-user,
|
||||
which can help OpenAI to monitor and detect abuse.
|
||||
(default: :obj:`""`)
|
||||
tools (list[FunctionTool], optional): A list of tools the model may
|
||||
call. Currently, only functions are supported as a tool. Use this
|
||||
to provide a list of functions the model may generate JSON inputs
|
||||
for. A max of 128 functions are supported.
|
||||
tool_choice (Union[dict[str, str], str], optional): Controls which (if
|
||||
any) tool is called by the model. :obj:`"none"` means the model
|
||||
will not call any tool and instead generates a message.
|
||||
:obj:`"auto"` means the model can pick between generating a
|
||||
message or calling one or more tools. :obj:`"required"` means the
|
||||
model must call one or more tools. Specifying a particular tool
|
||||
via {"type": "function", "function": {"name": "my_function"}}
|
||||
forces the model to call that tool. :obj:`"none"` is the default
|
||||
when no tools are present. :obj:`"auto"` is the default if tools
|
||||
are present.
|
||||
"""
|
||||
|
||||
temperature: float = 0.2 # openai default: 1.0
|
||||
top_p: float = 1.0
|
||||
n: int = 1
|
||||
stream: bool = False
|
||||
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
|
||||
max_tokens: Union[int, NotGiven] = NOT_GIVEN
|
||||
presence_penalty: float = 0.0
|
||||
response_format: Union[dict, NotGiven] = NOT_GIVEN
|
||||
frequency_penalty: float = 0.0
|
||||
logit_bias: dict = Field(default_factory=dict)
|
||||
user: str = ""
|
||||
tool_choice: Optional[Union[dict[str, str], str]] = None
|
||||
|
||||
|
||||
SAMBA_CLOUD_API_PARAMS = {
|
||||
param for param in SambaCloudAPIConfig().model_fields.keys()
|
||||
}
|
||||
@@ -1,107 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Sequence, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
from camel.types import NOT_GIVEN, NotGiven
|
||||
|
||||
|
||||
class TogetherAIConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
OpenAI API.
|
||||
|
||||
Args:
|
||||
temperature (float, optional): Sampling temperature to use, between
|
||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
||||
while lower values make it more focused and deterministic.
|
||||
(default: :obj:`0.2`)
|
||||
top_p (float, optional): An alternative to sampling with temperature,
|
||||
called nucleus sampling, where the model considers the results of
|
||||
the tokens with top_p probability mass. So :obj:`0.1` means only
|
||||
the tokens comprising the top 10% probability mass are considered.
|
||||
(default: :obj:`1.0`)
|
||||
n (int, optional): How many chat completion choices to generate for
|
||||
each input message. (default: :obj:`1`)
|
||||
response_format (object, optional): An object specifying the format
|
||||
that the model must output. Compatible with GPT-4 Turbo and all
|
||||
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
|
||||
{"type": "json_object"} enables JSON mode, which guarantees the
|
||||
message the model generates is valid JSON. Important: when using
|
||||
JSON mode, you must also instruct the model to produce JSON
|
||||
yourself via a system or user message. Without this, the model
|
||||
may generate an unending stream of whitespace until the generation
|
||||
reaches the token limit, resulting in a long-running and seemingly
|
||||
"stuck" request. Also note that the message content may be
|
||||
partially cut off if finish_reason="length", which indicates the
|
||||
generation exceeded max_tokens or the conversation exceeded the
|
||||
max context length.
|
||||
stream (bool, optional): If True, partial message deltas will be sent
|
||||
as data-only server-sent events as they become available.
|
||||
(default: :obj:`False`)
|
||||
stop (str or list, optional): Up to :obj:`4` sequences where the API
|
||||
will stop generating further tokens. (default: :obj:`None`)
|
||||
max_tokens (int, optional): The maximum number of tokens to generate
|
||||
in the chat completion. The total length of input tokens and
|
||||
generated tokens is limited by the model's context length.
|
||||
(default: :obj:`None`)
|
||||
presence_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on whether
|
||||
they appear in the text so far, increasing the model's likelihood
|
||||
to talk about new topics. See more information about frequency and
|
||||
presence penalties. (default: :obj:`0.0`)
|
||||
frequency_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on their
|
||||
existing frequency in the text so far, decreasing the model's
|
||||
likelihood to repeat the same line verbatim. See more information
|
||||
about frequency and presence penalties. (default: :obj:`0.0`)
|
||||
logit_bias (dict, optional): Modify the likelihood of specified tokens
|
||||
appearing in the completion. Accepts a json object that maps tokens
|
||||
(specified by their token ID in the tokenizer) to an associated
|
||||
bias value from :obj:`-100` to :obj:`100`. Mathematically, the bias
|
||||
is added to the logits generated by the model prior to sampling.
|
||||
The exact effect will vary per model, but values between:obj:` -1`
|
||||
and :obj:`1` should decrease or increase likelihood of selection;
|
||||
values like :obj:`-100` or :obj:`100` should result in a ban or
|
||||
exclusive selection of the relevant token. (default: :obj:`{}`)
|
||||
user (str, optional): A unique identifier representing your end-user,
|
||||
which can help OpenAI to monitor and detect abuse.
|
||||
(default: :obj:`""`)
|
||||
"""
|
||||
|
||||
temperature: float = 0.2 # openai default: 1.0
|
||||
top_p: float = 1.0
|
||||
n: int = 1
|
||||
stream: bool = False
|
||||
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
|
||||
max_tokens: Union[int, NotGiven] = NOT_GIVEN
|
||||
presence_penalty: float = 0.0
|
||||
response_format: Union[dict, NotGiven] = NOT_GIVEN
|
||||
frequency_penalty: float = 0.0
|
||||
logit_bias: dict = Field(default_factory=dict)
|
||||
user: str = ""
|
||||
|
||||
def as_dict(self) -> dict[str, Any]:
|
||||
config_dict = super().as_dict()
|
||||
if "tools" in config_dict:
|
||||
del config_dict["tools"] # Currently does not support tool calling
|
||||
return config_dict
|
||||
|
||||
|
||||
TOGETHERAI_API_PARAMS = {
|
||||
param for param in TogetherAIConfig.model_fields.keys()
|
||||
}
|
||||
@@ -1,111 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
from camel.types import NOT_GIVEN, NotGiven
|
||||
|
||||
|
||||
# flake8: noqa: E501
|
||||
class VLLMConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
OpenAI API.
|
||||
|
||||
Reference: https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html
|
||||
|
||||
Args:
|
||||
temperature (float, optional): Sampling temperature to use, between
|
||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
||||
while lower values make it more focused and deterministic.
|
||||
(default: :obj:`0.2`)
|
||||
top_p (float, optional): An alternative to sampling with temperature,
|
||||
called nucleus sampling, where the model considers the results of
|
||||
the tokens with top_p probability mass. So :obj:`0.1` means only
|
||||
the tokens comprising the top 10% probability mass are considered.
|
||||
(default: :obj:`1.0`)
|
||||
n (int, optional): How many chat completion choices to generate for
|
||||
each input message. (default: :obj:`1`)
|
||||
response_format (object, optional): An object specifying the format
|
||||
that the model must output. Compatible with GPT-4 Turbo and all
|
||||
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
|
||||
{"type": "json_object"} enables JSON mode, which guarantees the
|
||||
message the model generates is valid JSON. Important: when using
|
||||
JSON mode, you must also instruct the model to produce JSON
|
||||
yourself via a system or user message. Without this, the model
|
||||
may generate an unending stream of whitespace until the generation
|
||||
reaches the token limit, resulting in a long-running and seemingly
|
||||
"stuck" request. Also note that the message content may be
|
||||
partially cut off if finish_reason="length", which indicates the
|
||||
generation exceeded max_tokens or the conversation exceeded the
|
||||
max context length.
|
||||
stream (bool, optional): If True, partial message deltas will be sent
|
||||
as data-only server-sent events as they become available.
|
||||
(default: :obj:`False`)
|
||||
stop (str or list, optional): Up to :obj:`4` sequences where the API
|
||||
will stop generating further tokens. (default: :obj:`None`)
|
||||
max_tokens (int, optional): The maximum number of tokens to generate
|
||||
in the chat completion. The total length of input tokens and
|
||||
generated tokens is limited by the model's context length.
|
||||
(default: :obj:`None`)
|
||||
presence_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on whether
|
||||
they appear in the text so far, increasing the model's likelihood
|
||||
to talk about new topics. See more information about frequency and
|
||||
presence penalties. (default: :obj:`0.0`)
|
||||
frequency_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on their
|
||||
existing frequency in the text so far, decreasing the model's
|
||||
likelihood to repeat the same line verbatim. See more information
|
||||
about frequency and presence penalties. (default: :obj:`0.0`)
|
||||
logit_bias (dict, optional): Modify the likelihood of specified tokens
|
||||
appearing in the completion. Accepts a json object that maps tokens
|
||||
(specified by their token ID in the tokenizer) to an associated
|
||||
bias value from :obj:`-100` to :obj:`100`. Mathematically, the bias
|
||||
is added to the logits generated by the model prior to sampling.
|
||||
The exact effect will vary per model, but values between:obj:` -1`
|
||||
and :obj:`1` should decrease or increase likelihood of selection;
|
||||
values like :obj:`-100` or :obj:`100` should result in a ban or
|
||||
exclusive selection of the relevant token. (default: :obj:`{}`)
|
||||
user (str, optional): A unique identifier representing your end-user,
|
||||
which can help OpenAI to monitor and detect abuse.
|
||||
(default: :obj:`""`)
|
||||
logprobs: Whether to return log probabilities of the output tokens or
|
||||
not. If true, returns the log probabilities of each output token
|
||||
returned in the `logits` of `message`. (default: :obj:`None`)
|
||||
top_logprobs: An integer between 0 and 20 specifying the number of
|
||||
most likely tokens to return at each token position, each with an
|
||||
associated log probability. `logprobs` must be set to `true` if
|
||||
this parameter is used. (default: :obj:`None`)
|
||||
"""
|
||||
|
||||
temperature: float = 0.2 # openai default: 1.0
|
||||
top_p: float = 1.0
|
||||
n: int = 1
|
||||
stream: bool = False
|
||||
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
|
||||
max_tokens: Union[int, NotGiven] = NOT_GIVEN
|
||||
presence_penalty: float = 0.0
|
||||
response_format: Union[dict, NotGiven] = NOT_GIVEN
|
||||
frequency_penalty: float = 0.0
|
||||
logit_bias: dict = Field(default_factory=dict)
|
||||
user: str = ""
|
||||
logprobs: Optional[bool] = None
|
||||
top_logprobs: Optional[int] = None
|
||||
|
||||
|
||||
VLLM_API_PARAMS = {param for param in VLLMConfig.model_fields.keys()}
|
||||
@@ -1,58 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
from camel.types import NOT_GIVEN, NotGiven
|
||||
|
||||
|
||||
class YiConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
Yi API. You can refer to the following link for more details:
|
||||
https://platform.lingyiwanwu.com/docs/api-reference
|
||||
|
||||
Args:
|
||||
tool_choice (Union[dict[str, str], str], optional): Controls which (if
|
||||
any) tool is called by the model. :obj:`"none"` means the model
|
||||
will not call any tool and instead generates a message.
|
||||
:obj:`"auto"` means the model can pick between generating a
|
||||
message or calling one or more tools. :obj:`"required"` or
|
||||
specifying a particular tool via
|
||||
{"type": "function", "function": {"name": "some_function"}}
|
||||
can be used to guide the model to use tools more strongly.
|
||||
(default: :obj:`None`)
|
||||
max_tokens (int, optional): Specifies the maximum number of tokens
|
||||
the model can generate. This sets an upper limit, but does not
|
||||
guarantee that this number will always be reached.
|
||||
(default: :obj:`5000`)
|
||||
top_p (float, optional): Controls the randomness of the generated
|
||||
results. Lower values lead to less randomness, while higher
|
||||
values increase randomness. (default: :obj:`0.9`)
|
||||
temperature (float, optional): Controls the diversity and focus of
|
||||
the generated results. Lower values make the output more focused,
|
||||
while higher values make it more diverse. (default: :obj:`0.3`)
|
||||
stream (bool, optional): If True, enables streaming output.
|
||||
(default: :obj:`False`)
|
||||
"""
|
||||
|
||||
tool_choice: Optional[Union[dict[str, str], str]] = None
|
||||
max_tokens: Union[int, NotGiven] = NOT_GIVEN
|
||||
top_p: float = 0.9
|
||||
temperature: float = 0.3
|
||||
stream: bool = False
|
||||
|
||||
|
||||
YI_API_PARAMS = {param for param in YiConfig.model_fields.keys()}
|
||||
@@ -1,71 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
from camel.types import NOT_GIVEN, NotGiven
|
||||
|
||||
|
||||
class ZhipuAIConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using OpenAI
|
||||
compatibility
|
||||
|
||||
Reference: https://open.bigmodel.cn/dev/api#glm-4v
|
||||
|
||||
Args:
|
||||
temperature (float, optional): Sampling temperature to use, between
|
||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
||||
while lower values make it more focused and deterministic.
|
||||
(default: :obj:`0.2`)
|
||||
top_p (float, optional): An alternative to sampling with temperature,
|
||||
called nucleus sampling, where the model considers the results of
|
||||
the tokens with top_p probability mass. So :obj:`0.1` means only
|
||||
the tokens comprising the top 10% probability mass are considered.
|
||||
(default: :obj:`0.6`)
|
||||
stream (bool, optional): If True, partial message deltas will be sent
|
||||
as data-only server-sent events as they become available.
|
||||
(default: :obj:`False`)
|
||||
stop (str or list, optional): Up to :obj:`4` sequences where the API
|
||||
will stop generating further tokens. (default: :obj:`None`)
|
||||
max_tokens (int, optional): The maximum number of tokens to generate
|
||||
in the chat completion. The total length of input tokens and
|
||||
generated tokens is limited by the model's context length.
|
||||
(default: :obj:`None`)
|
||||
tools (list[FunctionTool], optional): A list of tools the model may
|
||||
call. Currently, only functions are supported as a tool. Use this
|
||||
to provide a list of functions the model may generate JSON inputs
|
||||
for. A max of 128 functions are supported.
|
||||
tool_choice (Union[dict[str, str], str], optional): Controls which (if
|
||||
any) tool is called by the model. :obj:`"none"` means the model
|
||||
will not call any tool and instead generates a message.
|
||||
:obj:`"auto"` means the model can pick between generating a
|
||||
message or calling one or more tools. :obj:`"required"` means the
|
||||
model must call one or more tools. Specifying a particular tool
|
||||
via {"type": "function", "function": {"name": "my_function"}}
|
||||
forces the model to call that tool. :obj:`"none"` is the default
|
||||
when no tools are present. :obj:`"auto"` is the default if tools
|
||||
are present.
|
||||
"""
|
||||
|
||||
temperature: float = 0.2
|
||||
top_p: float = 0.6
|
||||
stream: bool = False
|
||||
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
|
||||
max_tokens: Union[int, NotGiven] = NOT_GIVEN
|
||||
tool_choice: Optional[Union[dict[str, str], str]] = None
|
||||
|
||||
|
||||
ZHIPUAI_API_PARAMS = {param for param in ZhipuAIConfig.model_fields.keys()}
|
||||
@@ -1,23 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
from .base import BaseDatasetManager
|
||||
from .huggingface import HuggingFaceDatasetManager
|
||||
from .models import Record
|
||||
|
||||
__all__ = [
|
||||
"BaseDatasetManager",
|
||||
"Record",
|
||||
"HuggingFaceDatasetManager",
|
||||
]
|
||||
@@ -1,136 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List
|
||||
|
||||
from camel.datahubs.models import Record
|
||||
|
||||
|
||||
class BaseDatasetManager(ABC):
|
||||
r"""Abstract base class for dataset managers."""
|
||||
|
||||
@abstractmethod
|
||||
def create_dataset(self, name: str, **kwargs: Any) -> str:
|
||||
r"""Creates a new dataset.
|
||||
|
||||
Args:
|
||||
name (str): The name of the dataset.
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
str: The URL of the created dataset.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_datasets(
|
||||
self, username: str, limit: int = 100, **kwargs: Any
|
||||
) -> List[str]:
|
||||
r"""Lists all datasets for the current user.
|
||||
|
||||
Args:
|
||||
username (str): The username of the user whose datasets to list.
|
||||
limit (int): The maximum number of datasets to list.
|
||||
(default::obj:`100`)
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of dataset ids.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_dataset(self, dataset_name: str, **kwargs: Any) -> None:
|
||||
r"""Deletes a dataset.
|
||||
|
||||
Args:
|
||||
dataset_name (str): The name of the dataset to delete.
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_records(
|
||||
self,
|
||||
dataset_name: str,
|
||||
records: List[Record],
|
||||
filepath: str = "records/records.json",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
r"""Adds records to a dataset.
|
||||
|
||||
Args:
|
||||
dataset_name (str): The name of the dataset.
|
||||
records (List[Record]): A list of records to add to the dataset.
|
||||
filepath (str): The path to the file containing the records.
|
||||
(default::obj:`"records/records.json"`)
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_records(
|
||||
self,
|
||||
dataset_name: str,
|
||||
records: List[Record],
|
||||
filepath: str = "records/records.json",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
r"""Updates records in a dataset.
|
||||
|
||||
Args:
|
||||
dataset_name (str): The name of the dataset.
|
||||
records (List[Record]): A list of records to update in the dataset.
|
||||
filepath (str): The path to the file containing the records.
|
||||
(default::obj:`"records/records.json"`)
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_records(
|
||||
self,
|
||||
dataset_name: str,
|
||||
filepath: str = "records/records.json",
|
||||
**kwargs: Any,
|
||||
) -> List[Record]:
|
||||
r"""Lists records in a dataset.
|
||||
|
||||
Args:
|
||||
dataset_name (str): The name of the dataset.
|
||||
filepath (str): The path to the file containing the records.
|
||||
(default::obj:`"records/records.json"`)
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
"""
|
||||
pass
|
||||
|
||||
# New method for record deletion
|
||||
@abstractmethod
|
||||
def delete_record(
|
||||
self,
|
||||
dataset_name: str,
|
||||
record_id: str,
|
||||
filepath: str = "records/records.json",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
r"""Deletes a record from the dataset.
|
||||
|
||||
Args:
|
||||
dataset_name (str): The name of the dataset.
|
||||
record_id (str): The ID of the record to delete.
|
||||
filepath (str): The path to the file containing the records.
|
||||
(default::obj:`"records/records.json"`)
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
"""
|
||||
pass
|
||||
@@ -1,433 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from camel.datahubs.base import BaseDatasetManager
|
||||
from camel.datahubs.models import Record
|
||||
from camel.logger import get_logger
|
||||
from camel.types import HuggingFaceRepoType
|
||||
from camel.utils import api_keys_required, dependencies_required
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class HuggingFaceDatasetManager(BaseDatasetManager):
|
||||
r"""A dataset manager for Hugging Face datasets. This class provides
|
||||
methods to create, add, update, delete, and list records in a dataset on
|
||||
the Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
token (str): The Hugging Face API token. If not provided, the token
|
||||
will be read from the environment variable `HUGGING_FACE_TOKEN`.
|
||||
"""
|
||||
|
||||
@api_keys_required("HUGGING_FACE_TOKEN")
|
||||
@dependencies_required('huggingface_hub')
|
||||
def __init__(self, token: Optional[str] = None):
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
self._api_key = token or os.getenv("HUGGING_FACE_TOKEN")
|
||||
self.api = HfApi(token=self._api_key)
|
||||
|
||||
def create_dataset_card(
|
||||
self,
|
||||
dataset_name: str,
|
||||
description: str,
|
||||
license: Optional[str] = None,
|
||||
version: Optional[str] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
authors: Optional[List[str]] = None,
|
||||
size_category: Optional[List[str]] = None,
|
||||
language: Optional[List[str]] = None,
|
||||
task_categories: Optional[List[str]] = None,
|
||||
content: Optional[str] = None,
|
||||
) -> None:
|
||||
r"""Creates and uploads a dataset card to the Hugging Face Hub in YAML
|
||||
format.
|
||||
|
||||
Args:
|
||||
dataset_name (str): The name of the dataset.
|
||||
description (str): A description of the dataset.
|
||||
license (str): The license of the dataset. (default: :obj:`None`)
|
||||
version (str): The version of the dataset. (default: :obj:`None`)
|
||||
tags (list): A list of tags for the dataset.(default: :obj:`None`)
|
||||
authors (list): A list of authors of the dataset. (default:
|
||||
:obj:`None`)
|
||||
size_category (list): A size category for the dataset. (default:
|
||||
:obj:`None`)
|
||||
language (list): A list of languages the dataset is in. (default:
|
||||
:obj:`None`)
|
||||
task_categories (list): A list of task categories. (default:
|
||||
:obj:`None`)
|
||||
content (str): Custom markdown content that the user wants to add
|
||||
to the dataset card. (default: :obj:`None`)
|
||||
"""
|
||||
import yaml
|
||||
|
||||
metadata = {
|
||||
"license": license,
|
||||
"authors": authors,
|
||||
"task_categories": task_categories,
|
||||
"language": language,
|
||||
"tags": tags,
|
||||
"pretty_name": dataset_name,
|
||||
"size_categories": size_category,
|
||||
"version": version,
|
||||
"description": description,
|
||||
}
|
||||
|
||||
# Remove keys with None values
|
||||
metadata = {k: v for k, v in metadata.items() if v}
|
||||
|
||||
card_content = (
|
||||
"---\n"
|
||||
+ yaml.dump(metadata, default_flow_style=False, allow_unicode=True)
|
||||
+ "\n---"
|
||||
)
|
||||
|
||||
if content:
|
||||
card_content += f"\n\n# Additional Information\n{content}\n"
|
||||
|
||||
self._upload_file(
|
||||
file_content=card_content,
|
||||
dataset_name=dataset_name,
|
||||
filepath="README.md",
|
||||
file_type="md",
|
||||
)
|
||||
|
||||
def create_dataset(
|
||||
self, name: str, private: bool = False, **kwargs: Any
|
||||
) -> str:
|
||||
r"""Creates a new dataset on the Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
name (str): The name of the dataset.
|
||||
private (bool): Whether the dataset should be private. defaults to
|
||||
False.
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
str: The URL of the created dataset.
|
||||
"""
|
||||
from huggingface_hub.errors import RepositoryNotFoundError
|
||||
|
||||
try:
|
||||
self.api.repo_info(
|
||||
repo_id=name,
|
||||
repo_type=HuggingFaceRepoType.DATASET.value,
|
||||
**kwargs,
|
||||
)
|
||||
except RepositoryNotFoundError:
|
||||
self.api.create_repo(
|
||||
repo_id=name,
|
||||
repo_type=HuggingFaceRepoType.DATASET.value,
|
||||
private=private,
|
||||
)
|
||||
|
||||
return f"https://huggingface.co/datasets/{name}"
|
||||
|
||||
def list_datasets(
|
||||
self, username: str, limit: int = 100, **kwargs: Any
|
||||
) -> List[str]:
|
||||
r"""Lists all datasets for the current user.
|
||||
|
||||
Args:
|
||||
username (str): The username of the user whose datasets to list.
|
||||
limit (int): The maximum number of datasets to list.
|
||||
(default: :obj:`100`)
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of dataset ids.
|
||||
"""
|
||||
try:
|
||||
return [
|
||||
dataset.id
|
||||
for dataset in self.api.list_datasets(
|
||||
author=username, limit=limit, **kwargs
|
||||
)
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing datasets: {e}")
|
||||
return []
|
||||
|
||||
def delete_dataset(self, dataset_name: str, **kwargs: Any) -> None:
|
||||
r"""Deletes a dataset from the Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
dataset_name (str): The name of the dataset to delete.
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
"""
|
||||
try:
|
||||
self.api.delete_repo(
|
||||
repo_id=dataset_name,
|
||||
repo_type=HuggingFaceRepoType.DATASET.value,
|
||||
**kwargs,
|
||||
)
|
||||
logger.info(f"Dataset '{dataset_name}' deleted successfully.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting dataset '{dataset_name}': {e}")
|
||||
raise
|
||||
|
||||
def add_records(
|
||||
self,
|
||||
dataset_name: str,
|
||||
records: List[Record],
|
||||
filepath: str = "records/records.json",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
r"""Adds records to a dataset on the Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
dataset_name (str): The name of the dataset.
|
||||
records (List[Record]): A list of records to add to the dataset.
|
||||
filepath (str): The path to the file containing the records.
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
|
||||
Raises:
|
||||
ValueError: If the dataset already has a records file.
|
||||
"""
|
||||
existing_records = self._download_records(
|
||||
dataset_name=dataset_name, filepath=filepath, **kwargs
|
||||
)
|
||||
|
||||
if existing_records:
|
||||
raise ValueError(
|
||||
f"Dataset '{filepath}' already exists. "
|
||||
f"Use `update_records` to modify."
|
||||
)
|
||||
|
||||
self._upload_records(
|
||||
records=records,
|
||||
dataset_name=dataset_name,
|
||||
filepath=filepath,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def update_records(
|
||||
self,
|
||||
dataset_name: str,
|
||||
records: List[Record],
|
||||
filepath: str = "records/records.json",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
r"""Updates records in a dataset on the Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
dataset_name (str): The name of the dataset.
|
||||
records (List[Record]): A list of records to update in the dataset.
|
||||
filepath (str): The path to the file containing the records.
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
|
||||
Raises:
|
||||
ValueError: If the dataset does not have an existing file to update
|
||||
records in.
|
||||
"""
|
||||
existing_records = self._download_records(
|
||||
dataset_name=dataset_name, filepath=filepath, **kwargs
|
||||
)
|
||||
|
||||
if not existing_records:
|
||||
logger.warning(
|
||||
f"Dataset '{dataset_name}' does not have existing "
|
||||
"records. Adding new records."
|
||||
)
|
||||
self._upload_records(
|
||||
records=records,
|
||||
dataset_name=dataset_name,
|
||||
filepath=filepath,
|
||||
**kwargs,
|
||||
)
|
||||
return
|
||||
|
||||
old_dict = {record.id: record for record in existing_records}
|
||||
new_dict = {record.id: record for record in records}
|
||||
merged_dict = old_dict.copy()
|
||||
merged_dict.update(new_dict)
|
||||
|
||||
self._upload_records(
|
||||
records=list(merged_dict.values()),
|
||||
dataset_name=dataset_name,
|
||||
filepath=filepath,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def delete_record(
|
||||
self,
|
||||
dataset_name: str,
|
||||
record_id: str,
|
||||
filepath: str = "records/records.json",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
r"""Deletes a record from the dataset.
|
||||
|
||||
Args:
|
||||
dataset_name (str): The name of the dataset.
|
||||
record_id (str): The ID of the record to delete.
|
||||
filepath (str): The path to the file containing the records.
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
|
||||
Raises:
|
||||
ValueError: If the dataset does not have an existing file to delete
|
||||
records from.
|
||||
"""
|
||||
existing_records = self._download_records(
|
||||
dataset_name=dataset_name, filepath=filepath, **kwargs
|
||||
)
|
||||
|
||||
if not existing_records:
|
||||
raise ValueError(
|
||||
f"Dataset '{dataset_name}' does not have an existing file to "
|
||||
f"delete records from."
|
||||
)
|
||||
|
||||
filtered_records = [
|
||||
record for record in existing_records if record.id != record_id
|
||||
]
|
||||
|
||||
self._upload_records(
|
||||
records=filtered_records,
|
||||
dataset_name=dataset_name,
|
||||
filepath=filepath,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def list_records(
|
||||
self,
|
||||
dataset_name: str,
|
||||
filepath: str = "records/records.json",
|
||||
**kwargs: Any,
|
||||
) -> List[Record]:
|
||||
r"""Lists all records in a dataset.
|
||||
|
||||
Args:
|
||||
dataset_name (str): The name of the dataset.
|
||||
filepath (str): The path to the file containing the records.
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
List[Record]: A list of records in the dataset.
|
||||
"""
|
||||
return self._download_records(
|
||||
dataset_name=dataset_name, filepath=filepath, **kwargs
|
||||
)
|
||||
|
||||
def _download_records(
|
||||
self, dataset_name: str, filepath: str, **kwargs: Any
|
||||
) -> List[Record]:
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.errors import EntryNotFoundError
|
||||
|
||||
try:
|
||||
downloaded_file_path = hf_hub_download(
|
||||
repo_id=dataset_name,
|
||||
filename=filepath,
|
||||
repo_type=HuggingFaceRepoType.DATASET.value,
|
||||
token=self._api_key,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
with open(downloaded_file_path, "r") as f:
|
||||
records_data = json.load(f)
|
||||
|
||||
return [Record(**record) for record in records_data]
|
||||
except EntryNotFoundError:
|
||||
logger.info(f"No records found for dataset '{dataset_name}'.")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error downloading or processing records: {e}")
|
||||
raise e
|
||||
|
||||
def _upload_records(
|
||||
self,
|
||||
records: List[Record],
|
||||
dataset_name: str,
|
||||
filepath: str,
|
||||
**kwargs: Any,
|
||||
):
|
||||
with tempfile.NamedTemporaryFile(
|
||||
delete=False, mode="w", newline="", encoding="utf-8"
|
||||
) as f:
|
||||
json.dump([record.model_dump() for record in records], f)
|
||||
temp_file_path = f.name
|
||||
|
||||
try:
|
||||
self.api.upload_file(
|
||||
path_or_fileobj=temp_file_path,
|
||||
path_in_repo=filepath,
|
||||
repo_id=dataset_name,
|
||||
repo_type=HuggingFaceRepoType.DATASET.value,
|
||||
**kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error uploading records file: {e}")
|
||||
raise
|
||||
finally:
|
||||
if os.path.exists(temp_file_path):
|
||||
os.remove(temp_file_path)
|
||||
|
||||
def _upload_file(
|
||||
self,
|
||||
file_content: str,
|
||||
dataset_name: str,
|
||||
filepath: str,
|
||||
file_type: str = "json",
|
||||
**kwargs: Any,
|
||||
):
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", delete=False, suffix=f".{file_type}"
|
||||
) as f:
|
||||
if file_type == "json":
|
||||
if isinstance(file_content, str):
|
||||
try:
|
||||
json_content = json.loads(file_content)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(
|
||||
"Invalid JSON string provided for file_content."
|
||||
)
|
||||
else:
|
||||
try:
|
||||
json.dumps(file_content)
|
||||
json_content = file_content
|
||||
except (TypeError, ValueError):
|
||||
raise ValueError(
|
||||
"file_content is not JSON serializable."
|
||||
)
|
||||
|
||||
json.dump(json_content, f)
|
||||
elif file_type == "md" or file_type == "txt":
|
||||
f.write(file_content)
|
||||
else:
|
||||
raise ValueError(f"Unsupported file type: {file_type}")
|
||||
|
||||
temp_file_path = f.name
|
||||
|
||||
try:
|
||||
self.api.upload_file(
|
||||
path_or_fileobj=temp_file_path,
|
||||
path_in_repo=filepath,
|
||||
repo_id=dataset_name,
|
||||
repo_type=HuggingFaceRepoType.DATASET.value,
|
||||
**kwargs,
|
||||
)
|
||||
logger.info(f"File uploaded successfully: {filepath}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error uploading file: {e}")
|
||||
raise
|
||||
|
||||
if os.path.exists(temp_file_path):
|
||||
os.remove(temp_file_path)
|
||||
@@ -1,22 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Record(BaseModel):
|
||||
id: Optional[str] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
content: Dict[str, Any]
|
||||
@@ -1,28 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from .base import BaseEmbedding
|
||||
from .mistral_embedding import MistralEmbedding
|
||||
from .openai_compatible_embedding import OpenAICompatibleEmbedding
|
||||
from .openai_embedding import OpenAIEmbedding
|
||||
from .sentence_transformers_embeddings import SentenceTransformerEncoder
|
||||
from .vlm_embedding import VisionLanguageEmbedding
|
||||
|
||||
__all__ = [
|
||||
"BaseEmbedding",
|
||||
"OpenAIEmbedding",
|
||||
"SentenceTransformerEncoder",
|
||||
"VisionLanguageEmbedding",
|
||||
"MistralEmbedding",
|
||||
"OpenAICompatibleEmbedding",
|
||||
]
|
||||
@@ -1,67 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class BaseEmbedding(ABC, Generic[T]):
|
||||
r"""Abstract base class for text embedding functionalities."""
|
||||
|
||||
@abstractmethod
|
||||
def embed_list(
|
||||
self,
|
||||
objs: list[T],
|
||||
**kwargs: Any,
|
||||
) -> list[list[float]]:
|
||||
r"""Generates embeddings for the given texts.
|
||||
|
||||
Args:
|
||||
objs (list[T]): The objects for which to generate the embeddings.
|
||||
**kwargs (Any): Extra kwargs passed to the embedding API.
|
||||
|
||||
Returns:
|
||||
list[list[float]]: A list that represents the
|
||||
generated embedding as a list of floating-point numbers.
|
||||
"""
|
||||
pass
|
||||
|
||||
def embed(
|
||||
self,
|
||||
obj: T,
|
||||
**kwargs: Any,
|
||||
) -> list[float]:
|
||||
r"""Generates an embedding for the given text.
|
||||
|
||||
Args:
|
||||
obj (T): The object for which to generate the embedding.
|
||||
**kwargs (Any): Extra kwargs passed to the embedding API.
|
||||
|
||||
Returns:
|
||||
list[float]: A list of floating-point numbers representing the
|
||||
generated embedding.
|
||||
"""
|
||||
return self.embed_list([obj], **kwargs)[0]
|
||||
|
||||
@abstractmethod
|
||||
def get_output_dim(self) -> int:
|
||||
r"""Returns the output dimension of the embeddings.
|
||||
|
||||
Returns:
|
||||
int: The dimensionality of the embedding for the current model.
|
||||
"""
|
||||
pass
|
||||
@@ -1,89 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from camel.embeddings.base import BaseEmbedding
|
||||
from camel.types import EmbeddingModelType
|
||||
from camel.utils import api_keys_required
|
||||
|
||||
|
||||
class MistralEmbedding(BaseEmbedding[str]):
|
||||
r"""Provides text embedding functionalities using Mistral's models.
|
||||
|
||||
Args:
|
||||
model_type (EmbeddingModelType, optional): The model type to be
|
||||
used for text embeddings.
|
||||
(default: :obj:`MISTRAL_EMBED`)
|
||||
api_key (str, optional): The API key for authenticating with the
|
||||
Mistral service. (default: :obj:`None`)
|
||||
dimensions (int, optional): The text embedding output dimensions.
|
||||
(default: :obj:`None`)
|
||||
|
||||
Raises:
|
||||
RuntimeError: If an unsupported model type is specified.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_type: EmbeddingModelType = (EmbeddingModelType.MISTRAL_EMBED),
|
||||
api_key: str | None = None,
|
||||
dimensions: int | None = None,
|
||||
) -> None:
|
||||
from mistralai import Mistral
|
||||
|
||||
if not model_type.is_mistral:
|
||||
raise ValueError("Invalid Mistral embedding model type.")
|
||||
self.model_type = model_type
|
||||
if dimensions is None:
|
||||
self.output_dim = model_type.output_dim
|
||||
else:
|
||||
assert isinstance(dimensions, int)
|
||||
self.output_dim = dimensions
|
||||
self._api_key = api_key or os.environ.get("MISTRAL_API_KEY")
|
||||
self._client = Mistral(api_key=self._api_key)
|
||||
|
||||
@api_keys_required("MISTRAL_API_KEY")
|
||||
def embed_list(
|
||||
self,
|
||||
objs: list[str],
|
||||
**kwargs: Any,
|
||||
) -> list[list[float]]:
|
||||
r"""Generates embeddings for the given texts.
|
||||
|
||||
Args:
|
||||
objs (list[str]): The texts for which to generate the embeddings.
|
||||
**kwargs (Any): Extra kwargs passed to the embedding API.
|
||||
|
||||
Returns:
|
||||
list[list[float]]: A list that represents the generated embedding
|
||||
as a list of floating-point numbers.
|
||||
"""
|
||||
# TODO: count tokens
|
||||
response = self._client.embeddings.create(
|
||||
inputs=objs,
|
||||
model=self.model_type.value,
|
||||
**kwargs,
|
||||
)
|
||||
return [data.embedding for data in response.data] # type: ignore[misc,union-attr]
|
||||
|
||||
def get_output_dim(self) -> int:
|
||||
r"""Returns the output dimension of the embeddings.
|
||||
|
||||
Returns:
|
||||
int: The dimensionality of the embedding for the current model.
|
||||
"""
|
||||
return self.output_dim
|
||||
@@ -1,91 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from camel.embeddings.base import BaseEmbedding
|
||||
from camel.utils import api_keys_required
|
||||
|
||||
|
||||
class OpenAICompatibleEmbedding(BaseEmbedding[str]):
|
||||
r"""Provides text embedding functionalities supporting OpenAI
|
||||
compatibility.
|
||||
|
||||
Args:
|
||||
model_type (str): The model type to be used for text embeddings.
|
||||
api_key (str): The API key for authenticating with the model service.
|
||||
url (str): The url to the model service.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_type: str,
|
||||
api_key: Optional[str] = None,
|
||||
url: Optional[str] = None,
|
||||
) -> None:
|
||||
self.model_type = model_type
|
||||
self.output_dim: Optional[int] = None
|
||||
|
||||
self._api_key = api_key or os.environ.get(
|
||||
"OPENAI_COMPATIBILIY_API_KEY"
|
||||
)
|
||||
self._url = url or os.environ.get("OPENAI_COMPATIBILIY_API_BASE_URL")
|
||||
self._client = OpenAI(
|
||||
timeout=60,
|
||||
max_retries=3,
|
||||
api_key=self._api_key,
|
||||
base_url=self._url,
|
||||
)
|
||||
|
||||
@api_keys_required("OPENAI_COMPATIBILIY_API_KEY")
|
||||
def embed_list(
|
||||
self,
|
||||
objs: list[str],
|
||||
**kwargs: Any,
|
||||
) -> list[list[float]]:
|
||||
r"""Generates embeddings for the given texts.
|
||||
|
||||
Args:
|
||||
objs (list[str]): The texts for which to generate the embeddings.
|
||||
**kwargs (Any): Extra kwargs passed to the embedding API.
|
||||
|
||||
Returns:
|
||||
list[list[float]]: A list that represents the generated embedding
|
||||
as a list of floating-point numbers.
|
||||
"""
|
||||
|
||||
response = self._client.embeddings.create(
|
||||
input=objs,
|
||||
model=self.model_type,
|
||||
**kwargs,
|
||||
)
|
||||
self.output_dim = len(response.data[0].embedding)
|
||||
return [data.embedding for data in response.data]
|
||||
|
||||
def get_output_dim(self) -> int:
|
||||
r"""Returns the output dimension of the embeddings.
|
||||
|
||||
Returns:
|
||||
int: The dimensionality of the embedding for the current model.
|
||||
"""
|
||||
if self.output_dim is None:
|
||||
raise ValueError(
|
||||
"Output dimension is not yet determined. Call "
|
||||
"'embed_list' first."
|
||||
)
|
||||
return self.output_dim
|
||||
@@ -1,99 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from camel.embeddings.base import BaseEmbedding
|
||||
from camel.types import NOT_GIVEN, EmbeddingModelType, NotGiven
|
||||
from camel.utils import api_keys_required
|
||||
|
||||
|
||||
class OpenAIEmbedding(BaseEmbedding[str]):
|
||||
r"""Provides text embedding functionalities using OpenAI's models.
|
||||
|
||||
Args:
|
||||
model_type (EmbeddingModelType, optional): The model type to be
|
||||
used for text embeddings.
|
||||
(default: :obj:`TEXT_EMBEDDING_3_SMALL`)
|
||||
api_key (str, optional): The API key for authenticating with the
|
||||
OpenAI service. (default: :obj:`None`)
|
||||
dimensions (int, optional): The text embedding output dimensions.
|
||||
(default: :obj:`NOT_GIVEN`)
|
||||
|
||||
Raises:
|
||||
RuntimeError: If an unsupported model type is specified.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_type: EmbeddingModelType = (
|
||||
EmbeddingModelType.TEXT_EMBEDDING_3_SMALL
|
||||
),
|
||||
api_key: str | None = None,
|
||||
dimensions: int | NotGiven = NOT_GIVEN,
|
||||
) -> None:
|
||||
if not model_type.is_openai:
|
||||
raise ValueError("Invalid OpenAI embedding model type.")
|
||||
self.model_type = model_type
|
||||
if dimensions == NOT_GIVEN:
|
||||
self.output_dim = model_type.output_dim
|
||||
else:
|
||||
assert isinstance(dimensions, int)
|
||||
self.output_dim = dimensions
|
||||
self._api_key = api_key or os.environ.get("OPENAI_API_KEY")
|
||||
self.client = OpenAI(timeout=60, max_retries=3, api_key=self._api_key)
|
||||
|
||||
@api_keys_required("OPENAI_API_KEY")
|
||||
def embed_list(
|
||||
self,
|
||||
objs: list[str],
|
||||
**kwargs: Any,
|
||||
) -> list[list[float]]:
|
||||
r"""Generates embeddings for the given texts.
|
||||
|
||||
Args:
|
||||
objs (list[str]): The texts for which to generate the embeddings.
|
||||
**kwargs (Any): Extra kwargs passed to the embedding API.
|
||||
|
||||
Returns:
|
||||
list[list[float]]: A list that represents the generated embedding
|
||||
as a list of floating-point numbers.
|
||||
"""
|
||||
# TODO: count tokens
|
||||
if self.model_type == EmbeddingModelType.TEXT_EMBEDDING_ADA_2:
|
||||
response = self.client.embeddings.create(
|
||||
input=objs,
|
||||
model=self.model_type.value,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
response = self.client.embeddings.create(
|
||||
input=objs,
|
||||
model=self.model_type.value,
|
||||
dimensions=self.output_dim,
|
||||
**kwargs,
|
||||
)
|
||||
return [data.embedding for data in response.data]
|
||||
|
||||
def get_output_dim(self) -> int:
|
||||
r"""Returns the output dimension of the embeddings.
|
||||
|
||||
Returns:
|
||||
int: The dimensionality of the embedding for the current model.
|
||||
"""
|
||||
return self.output_dim
|
||||
@@ -1,80 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from numpy import ndarray
|
||||
|
||||
from camel.embeddings.base import BaseEmbedding
|
||||
|
||||
|
||||
class SentenceTransformerEncoder(BaseEmbedding[str]):
|
||||
r"""This class provides functionalities to generate text
|
||||
embeddings using `Sentence Transformers`.
|
||||
|
||||
References:
|
||||
https://www.sbert.net/
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "intfloat/e5-large-v2",
|
||||
**kwargs,
|
||||
):
|
||||
r"""Initializes the: obj: `SentenceTransformerEmbedding` class
|
||||
with the specified transformer model.
|
||||
|
||||
Args:
|
||||
model_name (str, optional): The name of the model to use.
|
||||
(default: :obj:`intfloat/e5-large-v2`)
|
||||
**kwargs (optional): Additional arguments of
|
||||
:class:`SentenceTransformer`, such as :obj:`prompts` etc.
|
||||
"""
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
self.model = SentenceTransformer(model_name, **kwargs)
|
||||
|
||||
def embed_list(
|
||||
self,
|
||||
objs: list[str],
|
||||
**kwargs: Any,
|
||||
) -> list[list[float]]:
|
||||
r"""Generates embeddings for the given texts using the model.
|
||||
|
||||
Args:
|
||||
objs (list[str]): The texts for which to generate the
|
||||
embeddings.
|
||||
|
||||
Returns:
|
||||
list[list[float]]: A list that represents the generated embedding
|
||||
as a list of floating-point numbers.
|
||||
"""
|
||||
if not objs:
|
||||
raise ValueError("Input text list is empty")
|
||||
embeddings = self.model.encode(
|
||||
objs, normalize_embeddings=True, **kwargs
|
||||
)
|
||||
assert isinstance(embeddings, ndarray)
|
||||
return embeddings.tolist()
|
||||
|
||||
def get_output_dim(self) -> int:
|
||||
r"""Returns the output dimension of the embeddings.
|
||||
|
||||
Returns:
|
||||
int: The dimensionality of the embeddings.
|
||||
"""
|
||||
output_dim = self.model.get_sentence_embedding_dimension()
|
||||
assert isinstance(output_dim, int)
|
||||
return output_dim
|
||||
@@ -1,149 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from camel.embeddings import BaseEmbedding
|
||||
from camel.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class VisionLanguageEmbedding(BaseEmbedding[Union[str, Image.Image]]):
|
||||
r"""Provides image embedding functionalities using multimodal model.
|
||||
|
||||
Args:
|
||||
model_name : The model type to be used for generating embeddings.
|
||||
And the default value is: obj:`openai/clip-vit-base-patch32`.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If an unsupported model type is specified.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, model_name: str = "openai/clip-vit-base-patch32"
|
||||
) -> None:
|
||||
r"""Initializes the: obj: `VisionLanguageEmbedding` class with a
|
||||
specified model and return the dimension of embeddings.
|
||||
|
||||
Args:
|
||||
model_name (str, optional): The version name of the model to use.
|
||||
(default: :obj:`openai/clip-vit-base-patch32`)
|
||||
"""
|
||||
from transformers import AutoModel, AutoProcessor
|
||||
|
||||
try:
|
||||
self.model = AutoModel.from_pretrained(model_name)
|
||||
self.processor = AutoProcessor.from_pretrained(model_name)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load model '{model_name}': {e}")
|
||||
|
||||
self.valid_processor_kwargs = []
|
||||
self.valid_model_kwargs = []
|
||||
|
||||
try:
|
||||
self.valid_processor_kwargs = (
|
||||
self.processor.image_processor._valid_processor_keys
|
||||
)
|
||||
self.valid_model_kwargs = [
|
||||
"pixel_values",
|
||||
"return_dict",
|
||||
"interpolate_pos_encoding",
|
||||
]
|
||||
except Exception:
|
||||
logger.warning("not typically processor and model structure")
|
||||
pass
|
||||
self.dim: Optional[int] = None
|
||||
|
||||
def embed_list(
|
||||
self, objs: List[Union[Image.Image, str]], **kwargs: Any
|
||||
) -> List[List[float]]:
|
||||
"""Generates embeddings for the given images or texts.
|
||||
|
||||
Args:
|
||||
objs (List[Image.Image|str]): The list of images or texts for
|
||||
which to generate the embeddings.
|
||||
image_processor_kwargs: Extra kwargs passed to the image processor.
|
||||
tokenizer_kwargs: Extra kwargs passed to the text tokenizer
|
||||
(processor).
|
||||
model_kwargs: Extra kwargs passed to the main model.
|
||||
|
||||
Returns:
|
||||
List[List[float]]: A list that represents the generated embedding
|
||||
as a list of floating-point numbers.
|
||||
|
||||
Raises:
|
||||
ValueError: If the input type is not `Image.Image` or `str`.
|
||||
"""
|
||||
if not objs:
|
||||
raise ValueError("Input objs list is empty.")
|
||||
|
||||
image_processor_kwargs: Optional[dict] = kwargs.get(
|
||||
'image_processor_kwargs', {}
|
||||
)
|
||||
tokenizer_kwargs: Optional[dict] = kwargs.get('tokenizer_kwargs', {})
|
||||
model_kwargs: Optional[dict] = kwargs.get('model_kwargs', {})
|
||||
|
||||
result_list = []
|
||||
for obj in objs:
|
||||
if isinstance(obj, Image.Image):
|
||||
image_input = self.processor(
|
||||
images=obj,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
**image_processor_kwargs,
|
||||
)
|
||||
image_feature = (
|
||||
self.model.get_image_features(
|
||||
**image_input, **model_kwargs
|
||||
)
|
||||
.squeeze(dim=0)
|
||||
.tolist()
|
||||
)
|
||||
result_list.append(image_feature)
|
||||
elif isinstance(obj, str):
|
||||
text_input = self.processor(
|
||||
text=obj,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
**tokenizer_kwargs,
|
||||
)
|
||||
text_feature = (
|
||||
self.model.get_text_features(**text_input, **model_kwargs)
|
||||
.squeeze(dim=0)
|
||||
.tolist()
|
||||
)
|
||||
result_list.append(text_feature)
|
||||
else:
|
||||
raise ValueError("Input type is not image nor text.")
|
||||
|
||||
self.dim = len(result_list[0])
|
||||
|
||||
if any(len(result) != self.dim for result in result_list):
|
||||
raise ValueError("Dimensionality is not consistent.")
|
||||
|
||||
return result_list
|
||||
|
||||
def get_output_dim(self) -> int:
|
||||
r"""Returns the output dimension of the embeddings.
|
||||
|
||||
Returns:
|
||||
int: The dimensionality of the embedding for the current model.
|
||||
"""
|
||||
if self.dim is None:
|
||||
text = 'dimension'
|
||||
inputs = self.processor(text=[text], return_tensors="pt")
|
||||
self.dim = self.model.get_text_features(**inputs).shape[1]
|
||||
return self.dim
|
||||
@@ -1,375 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from typing import Dict, Generator, List, Optional, Set, Tuple
|
||||
|
||||
from camel.messages import BaseMessage
|
||||
from camel.prompts import PromptTemplateGenerator, TextPrompt
|
||||
from camel.types import RoleType, TaskType
|
||||
|
||||
|
||||
class SystemMessageGenerator:
|
||||
r"""System message generator for agents.
|
||||
|
||||
Args:
|
||||
task_type (TaskType, optional): The task type.
|
||||
(default: :obj:`TaskType.AI_SOCIETY`)
|
||||
sys_prompts (Optional[Dict[RoleType, str]], optional): The prompts of
|
||||
the system messages for each role type. (default: :obj:`None`)
|
||||
sys_msg_meta_dict_keys (Optional[Set[str]], optional): The set of keys
|
||||
of the meta dictionary used to fill the prompts.
|
||||
(default: :obj:`None`)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task_type: TaskType = TaskType.AI_SOCIETY,
|
||||
sys_prompts: Optional[Dict[RoleType, str]] = None,
|
||||
sys_msg_meta_dict_keys: Optional[Set[str]] = None,
|
||||
) -> None:
|
||||
self.sys_prompts: Dict[RoleType, str]
|
||||
|
||||
if sys_prompts is not None:
|
||||
self.sys_prompts = sys_prompts
|
||||
self.sys_msg_meta_dict_keys = sys_msg_meta_dict_keys or set()
|
||||
else:
|
||||
assistant_prompt_template = (
|
||||
PromptTemplateGenerator().get_system_prompt(
|
||||
task_type,
|
||||
RoleType.ASSISTANT,
|
||||
)
|
||||
)
|
||||
user_prompt_template = PromptTemplateGenerator().get_system_prompt(
|
||||
task_type,
|
||||
RoleType.USER,
|
||||
)
|
||||
critic_prompt_template = (
|
||||
PromptTemplateGenerator().get_system_prompt(
|
||||
task_type,
|
||||
RoleType.CRITIC,
|
||||
)
|
||||
)
|
||||
embodiment_prompt_template = (
|
||||
PromptTemplateGenerator().get_system_prompt(
|
||||
task_type,
|
||||
RoleType.EMBODIMENT,
|
||||
)
|
||||
)
|
||||
|
||||
self.sys_prompts = dict()
|
||||
self.sys_prompts[RoleType.ASSISTANT] = assistant_prompt_template
|
||||
self.sys_prompts[RoleType.USER] = user_prompt_template
|
||||
self.sys_prompts[RoleType.CRITIC] = critic_prompt_template
|
||||
self.sys_prompts[RoleType.EMBODIMENT] = embodiment_prompt_template
|
||||
|
||||
self.sys_msg_meta_dict_keys = (
|
||||
assistant_prompt_template.key_words
|
||||
| user_prompt_template.key_words
|
||||
| critic_prompt_template.key_words
|
||||
| embodiment_prompt_template.key_words
|
||||
)
|
||||
|
||||
if RoleType.DEFAULT not in self.sys_prompts:
|
||||
self.sys_prompts[RoleType.DEFAULT] = "You are a helpful assistant."
|
||||
|
||||
def validate_meta_dict_keys(self, meta_dict: Dict[str, str]) -> None:
|
||||
r"""Validates the keys of the meta_dict.
|
||||
|
||||
Args:
|
||||
meta_dict (Dict[str, str]): The dictionary to validate.
|
||||
"""
|
||||
if not set(meta_dict.keys()).issubset(self.sys_msg_meta_dict_keys):
|
||||
raise ValueError(
|
||||
"The keys of the meta_dict should be in "
|
||||
f"{self.sys_msg_meta_dict_keys}. "
|
||||
f"Got {set(meta_dict.keys())} instead."
|
||||
)
|
||||
|
||||
def from_dict(
|
||||
self,
|
||||
meta_dict: Dict[str, str],
|
||||
role_tuple: Tuple[str, RoleType] = ("", RoleType.DEFAULT),
|
||||
) -> BaseMessage:
|
||||
r"""Generates a system message from a dictionary.
|
||||
|
||||
Args:
|
||||
meta_dict (Dict[str, str]): The dictionary containing the
|
||||
information to generate the system message.
|
||||
role_tuple (Tuple[str, RoleType], optional): The tuple containing
|
||||
the role name and role type. (default: ("", RoleType.DEFAULT))
|
||||
|
||||
Returns:
|
||||
BaseMessage: The generated system message.
|
||||
"""
|
||||
self.validate_meta_dict_keys(meta_dict)
|
||||
role_name, role_type = role_tuple
|
||||
sys_prompt = self.sys_prompts[role_type]
|
||||
sys_prompt = sys_prompt.format(**meta_dict)
|
||||
return BaseMessage(
|
||||
role_name=role_name,
|
||||
role_type=role_type,
|
||||
meta_dict=meta_dict,
|
||||
content=sys_prompt,
|
||||
)
|
||||
|
||||
def from_dicts(
|
||||
self,
|
||||
meta_dicts: List[Dict[str, str]],
|
||||
role_tuples: List[Tuple[str, RoleType]],
|
||||
) -> List[BaseMessage]:
|
||||
r"""Generates a list of system messages from a list of dictionaries.
|
||||
|
||||
Args:
|
||||
meta_dicts (List[Dict[str, str]]): A list of dictionaries
|
||||
containing the information to generate the system messages.
|
||||
role_tuples (List[Tuple[str, RoleType]]): A list of tuples
|
||||
containing the role name and role type for each system message.
|
||||
|
||||
Returns:
|
||||
List[BaseMessage]: A list of generated system messages.
|
||||
|
||||
Raises:
|
||||
ValueError: If the number of meta_dicts and role_tuples are
|
||||
different.
|
||||
"""
|
||||
if len(meta_dicts) != len(role_tuples):
|
||||
raise ValueError(
|
||||
"The number of meta_dicts and role_types should be the same."
|
||||
)
|
||||
|
||||
return [
|
||||
self.from_dict(meta_dict, role_tuple)
|
||||
for meta_dict, role_tuple in zip(meta_dicts, role_tuples)
|
||||
]
|
||||
|
||||
|
||||
class RoleNameGenerator:
|
||||
r"""Role name generator for role-playing workers.
|
||||
|
||||
Args:
|
||||
assistant_role_names_path (str, optional): The path to the file
|
||||
containing the assistant role names.
|
||||
(default: :obj:`"data/ai_society/assistant_roles.txt"`)
|
||||
user_role_names_path (str, optional): The path to the file
|
||||
containing the user role names.
|
||||
(default: :obj:`"data/ai_society/user_roles.txt"`)
|
||||
assistant_role_names (Optional[List[str]], optional): The list of
|
||||
assistant role names. (default: :obj:`None`)
|
||||
user_role_names (Optional[List[str]], optional): The list of user role
|
||||
names. (default: :obj:`None`)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
assistant_role_names_path: str = "data/ai_society/assistant_roles.txt",
|
||||
user_role_names_path: str = "data/ai_society/user_roles.txt",
|
||||
assistant_role_names: Optional[List[str]] = None,
|
||||
user_role_names: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
if assistant_role_names is None:
|
||||
with open(assistant_role_names_path, "r") as f:
|
||||
assistant_role_names_: List[str] = f.read().splitlines()
|
||||
self.assistant_role_names = [
|
||||
" ".join(name.split(" ")[1:])
|
||||
for name in assistant_role_names_
|
||||
]
|
||||
else:
|
||||
self.assistant_role_names = assistant_role_names
|
||||
|
||||
if user_role_names is None:
|
||||
with open(user_role_names_path, "r") as f:
|
||||
user_role_names_: List[str] = f.read().splitlines()
|
||||
self.user_role_names = [
|
||||
" ".join(name.split(" ")[1:]) for name in user_role_names_
|
||||
]
|
||||
else:
|
||||
self.user_role_names = user_role_names
|
||||
|
||||
def from_role_files(self) -> Generator[Tuple, None, None]:
|
||||
r"""Generate role names from the file.
|
||||
|
||||
Returns:
|
||||
Generator[Tuple, None, None]: A generator that yields tuples of
|
||||
assistant role names and user role names.
|
||||
"""
|
||||
for assistant_role_name in self.assistant_role_names:
|
||||
for user_role_name in self.user_role_names:
|
||||
yield (assistant_role_name, user_role_name)
|
||||
|
||||
|
||||
class AISocietyTaskPromptGenerator:
|
||||
r"""Task prompt generator for AI society tasks.
|
||||
|
||||
Args:
|
||||
num_tasks (int, optional): The number of tasks to generate.
|
||||
(default: :obj:`10`)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_tasks: int = 10,
|
||||
) -> None:
|
||||
self.generate_tasks_prompt = (
|
||||
PromptTemplateGenerator().get_generate_tasks_prompt(
|
||||
TaskType.AI_SOCIETY
|
||||
)
|
||||
)
|
||||
|
||||
self.num_tasks = num_tasks
|
||||
|
||||
# TODO: Return role names for user and assistant with the generator.
|
||||
def from_role_files(
|
||||
self,
|
||||
assistant_role_names_path: str = "data/ai_society/assistant_roles.txt",
|
||||
user_role_names_path: str = "data/ai_society/user_roles.txt",
|
||||
) -> Generator[Tuple[str, Tuple[str, str]], None, None]:
|
||||
r"""Generate tasks from role files.
|
||||
|
||||
Args:
|
||||
assistant_role_names_path (str, optional): The path to the file
|
||||
containing the assistant role names.
|
||||
(default: :obj:`"data/ai_society/assistant_roles.txt"`)
|
||||
user_role_names_path (str, optional): The path to the file
|
||||
containing the user role names.
|
||||
(default: :obj:`"data/ai_society/user_roles.txt"`)
|
||||
|
||||
Returns:
|
||||
Generator[Tuple[str, Tuple[str, str]], None, None]: A generator
|
||||
that yields tuples of task prompts and role names.
|
||||
"""
|
||||
roles_generator = RoleNameGenerator(
|
||||
assistant_role_names_path, user_role_names_path
|
||||
).from_role_files()
|
||||
for role_1, role_2 in roles_generator:
|
||||
generate_tasks_prompt = self.generate_tasks_prompt.format(
|
||||
assistant_role=role_1,
|
||||
user_role=role_2,
|
||||
num_tasks=self.num_tasks,
|
||||
)
|
||||
|
||||
yield (generate_tasks_prompt, (role_1, role_2))
|
||||
|
||||
def from_role_generator(
|
||||
self, role_generator: Generator[Tuple, None, None]
|
||||
) -> Generator[Tuple[str, Tuple[str, str]], None, None]:
|
||||
r"""Generate tasks from a role generator.
|
||||
|
||||
Args:
|
||||
role_generator (Generator[Tuple, None, None]): A generator that
|
||||
yields tuples of role names.
|
||||
|
||||
Returns:
|
||||
Generator[Tuple[str, Tuple[str, str]], None, None]: A generator
|
||||
that yields tuples of task prompts and role names.
|
||||
"""
|
||||
for role_1, role_2 in role_generator:
|
||||
generate_tasks_prompt = self.generate_tasks_prompt.format(
|
||||
assistant_role=role_1,
|
||||
user_role=role_2,
|
||||
num_tasks=self.num_tasks,
|
||||
)
|
||||
|
||||
yield (generate_tasks_prompt, (role_1, role_2))
|
||||
|
||||
|
||||
class SingleTxtGenerator:
|
||||
r"""Single text generator for role-playing workers.
|
||||
|
||||
Args:
|
||||
text_file_path (str): The path to the file containing the text data.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_file_path: str,
|
||||
) -> None:
|
||||
with open(text_file_path, "r") as f:
|
||||
data_list: List[str] = f.read().splitlines()
|
||||
self.data_list = [
|
||||
" ".join(name.split(" ")[1:]) for name in data_list
|
||||
]
|
||||
|
||||
def from_role_files(self) -> Generator[str, None, None]:
|
||||
r"""Generate text from the file.
|
||||
|
||||
Returns:
|
||||
Generator[str, None, None]: A generator that yields the text data.
|
||||
"""
|
||||
for data in self.data_list:
|
||||
yield data
|
||||
|
||||
|
||||
class CodeTaskPromptGenerator:
|
||||
r"""Code task prompt generator for code tasks.
|
||||
|
||||
Args:
|
||||
num_tasks (int, optional): The number of tasks to generate.
|
||||
(default: :obj:`50`)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_tasks: int = 50,
|
||||
) -> None:
|
||||
self.generate_tasks_prompt = (
|
||||
PromptTemplateGenerator().get_generate_tasks_prompt(TaskType.CODE)
|
||||
)
|
||||
|
||||
self.num_tasks = num_tasks
|
||||
|
||||
def from_role_files(
|
||||
self,
|
||||
languages_path: str = "data/code/languages.txt",
|
||||
domains_path: str = "data/code/domains.txt",
|
||||
) -> Generator[Tuple[TextPrompt, str, str], None, None]:
|
||||
r"""Generate tasks from role files.
|
||||
|
||||
Args:
|
||||
languages_path (str, optional): The path to the file containing
|
||||
the language names. (default: :obj:`"data/code/languages.txt"`)
|
||||
domains_path (str, optional): The path to the file containing
|
||||
the domain names. (default: :obj:`"data/code/domains.txt"`)
|
||||
|
||||
Returns:
|
||||
Generator[Tuple[TextPrompt, str, str], None, None]: A generator
|
||||
that yields tuples of task prompts, language names, and domain
|
||||
names.
|
||||
"""
|
||||
language_generator = SingleTxtGenerator(
|
||||
languages_path
|
||||
).from_role_files()
|
||||
|
||||
for language in language_generator:
|
||||
domains_generator = SingleTxtGenerator(
|
||||
domains_path
|
||||
).from_role_files()
|
||||
for domain in domains_generator:
|
||||
generated_tasks_prompt = self.generate_tasks_prompt.format(
|
||||
language=language, domain=domain, num_tasks=self.num_tasks
|
||||
)
|
||||
yield generated_tasks_prompt, language, domain
|
||||
|
||||
def from_role_generator(
|
||||
self, role_generator: Generator[Tuple, None, None]
|
||||
) -> Generator[str, None, None]:
|
||||
r"""Generate tasks from a role generator.
|
||||
|
||||
Args:
|
||||
role_generator (Generator[Tuple, None, None]): A generator that
|
||||
yields tuples of role names.
|
||||
|
||||
Returns:
|
||||
Generator[str, None, None]: A generator that yields the task
|
||||
prompts.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -1,138 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from typing import Any, Dict, Sequence
|
||||
|
||||
from colorama import Fore
|
||||
|
||||
from camel.messages import BaseMessage
|
||||
from camel.responses import ChatAgentResponse
|
||||
from camel.utils import print_text_animated
|
||||
|
||||
|
||||
class Human:
|
||||
r"""A class representing a human user.
|
||||
|
||||
Args:
|
||||
name (str): The name of the human user.
|
||||
(default: :obj:`"Kill Switch Engineer"`).
|
||||
logger_color (Any): The color of the menu options displayed to the
|
||||
user. (default: :obj:`Fore.MAGENTA`)
|
||||
|
||||
Attributes:
|
||||
name (str): The name of the human user.
|
||||
logger_color (Any): The color of the menu options displayed to the
|
||||
user.
|
||||
input_button (str): The text displayed for the input button.
|
||||
kill_button (str): The text displayed for the kill button.
|
||||
options_dict (Dict[str, str]): A dictionary containing the options
|
||||
displayed to the user.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str = "Kill Switch Engineer",
|
||||
logger_color: Any = Fore.MAGENTA,
|
||||
) -> None:
|
||||
self.name = name
|
||||
self.logger_color = logger_color
|
||||
self.input_button = f"Input by {self.name}."
|
||||
self.kill_button = "Stop!!!"
|
||||
self.options_dict: Dict[str, str] = dict()
|
||||
|
||||
def display_options(self, messages: Sequence[BaseMessage]) -> None:
|
||||
r"""Displays the options to the user.
|
||||
|
||||
Args:
|
||||
messages (Sequence[BaseMessage]): A list of `BaseMessage` objects.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
options = [message.content for message in messages]
|
||||
options.append(self.input_button)
|
||||
options.append(self.kill_button)
|
||||
print_text_animated(
|
||||
self.logger_color + "\n> Proposals from "
|
||||
f"{messages[0].role_name} ({messages[0].role_type}). "
|
||||
"Please choose an option:\n"
|
||||
)
|
||||
for index, option in enumerate(options):
|
||||
print_text_animated(
|
||||
self.logger_color
|
||||
+ f"\x1b[3mOption {index + 1}:\n{option}\x1b[0m\n"
|
||||
)
|
||||
self.options_dict[str(index + 1)] = option
|
||||
|
||||
def get_input(self) -> str:
|
||||
r"""Gets the input from the user.
|
||||
|
||||
Returns:
|
||||
str: The user's input.
|
||||
"""
|
||||
while True:
|
||||
human_input = input(
|
||||
self.logger_color
|
||||
+ f"Please enter your choice ([1-{len(self.options_dict)}]): "
|
||||
)
|
||||
print("\n")
|
||||
if human_input in self.options_dict:
|
||||
break
|
||||
print_text_animated(
|
||||
self.logger_color + "\n> Invalid choice. Please try again.\n"
|
||||
)
|
||||
|
||||
return human_input
|
||||
|
||||
def parse_input(self, human_input: str) -> str:
|
||||
r"""Parses the user's input and returns a `BaseMessage` object.
|
||||
|
||||
Args:
|
||||
human_input (str): The user's input.
|
||||
|
||||
Returns:
|
||||
content: A `str` object representing the user's input.
|
||||
"""
|
||||
if self.options_dict[human_input] == self.input_button:
|
||||
content = input(self.logger_color + "Please enter your message: ")
|
||||
elif self.options_dict[human_input] == self.kill_button:
|
||||
exit(self.logger_color + f"Killed by {self.name}.")
|
||||
else:
|
||||
content = self.options_dict[human_input]
|
||||
|
||||
return content
|
||||
|
||||
def reduce_step(
|
||||
self, messages: Sequence[BaseMessage]
|
||||
) -> ChatAgentResponse:
|
||||
r"""Performs one step of the conversation by displaying options to the
|
||||
user, getting their input, and parsing their choice.
|
||||
|
||||
Args:
|
||||
messages (Sequence[BaseMessage]): A list of BaseMessage objects.
|
||||
|
||||
Returns:
|
||||
ChatAgentResponse: A `ChatAgentResponse` object representing the
|
||||
user's choice.
|
||||
"""
|
||||
meta_chat_message = BaseMessage(
|
||||
role_name=messages[0].role_name,
|
||||
role_type=messages[0].role_type,
|
||||
meta_dict=messages[0].meta_dict,
|
||||
content="",
|
||||
)
|
||||
self.display_options(messages)
|
||||
human_input = self.get_input()
|
||||
content = self.parse_input(human_input)
|
||||
message = meta_chat_message.create_new_instance(content)
|
||||
return ChatAgentResponse(msgs=[message], terminated=False, info={})
|
||||
@@ -1,29 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
from .base import BaseInterpreter
|
||||
from .docker_interpreter import DockerInterpreter
|
||||
from .internal_python_interpreter import InternalPythonInterpreter
|
||||
from .interpreter_error import InterpreterError
|
||||
from .ipython_interpreter import JupyterKernelInterpreter
|
||||
from .subprocess_interpreter import SubprocessInterpreter
|
||||
|
||||
__all__ = [
|
||||
'BaseInterpreter',
|
||||
'InterpreterError',
|
||||
'InternalPythonInterpreter',
|
||||
'SubprocessInterpreter',
|
||||
'DockerInterpreter',
|
||||
'JupyterKernelInterpreter',
|
||||
]
|
||||
@@ -1,49 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
class BaseInterpreter(ABC):
|
||||
r"""An abstract base class for code interpreters."""
|
||||
|
||||
@abstractmethod
|
||||
def run(self, code: str, code_type: str) -> str:
|
||||
r"""Executes the given code based on its type.
|
||||
|
||||
Args:
|
||||
code (str): The code to be executed.
|
||||
code_type (str): The type of the code, which must be one of the
|
||||
types returned by `supported_code_types()`.
|
||||
|
||||
Returns:
|
||||
str: The result of the code execution. If the execution fails, this
|
||||
should include sufficient information to diagnose and correct
|
||||
the issue.
|
||||
|
||||
Raises:
|
||||
InterpreterError: If the code execution encounters errors that
|
||||
could be resolved by modifying or regenerating the code.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def supported_code_types(self) -> List[str]:
|
||||
r"""Provides supported code types by the interpreter."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_action_space(self, action_space: Dict[str, Any]) -> None:
|
||||
r"""Updates action space for *python* interpreter"""
|
||||
pass
|
||||
@@ -1,245 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
import io
|
||||
import shlex
|
||||
import tarfile
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional
|
||||
|
||||
from colorama import Fore
|
||||
|
||||
from camel.interpreters.base import BaseInterpreter
|
||||
from camel.interpreters.interpreter_error import InterpreterError
|
||||
from camel.logger import get_logger
|
||||
from camel.utils import is_docker_running
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from docker.models.containers import Container
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DockerInterpreter(BaseInterpreter):
|
||||
r"""A class for executing code files or code strings in a docker container.
|
||||
|
||||
This class handles the execution of code in different scripting languages
|
||||
(currently Python and Bash) within a docker container, capturing their
|
||||
stdout and stderr streams, and allowing user checking before executing code
|
||||
strings.
|
||||
|
||||
Args:
|
||||
require_confirm (bool, optional): If `True`, prompt user before
|
||||
running code strings for security. Defaults to `True`.
|
||||
print_stdout (bool, optional): If `True`, print the standard
|
||||
output of the executed code. Defaults to `False`.
|
||||
print_stderr (bool, optional): If `True`, print the standard error
|
||||
of the executed code. Defaults to `True`.
|
||||
"""
|
||||
|
||||
_CODE_EXECUTE_CMD_MAPPING: ClassVar[Dict[str, str]] = {
|
||||
"python": "python {file_name}",
|
||||
"bash": "bash {file_name}",
|
||||
}
|
||||
|
||||
_CODE_EXTENSION_MAPPING: ClassVar[Dict[str, str]] = {
|
||||
"python": "py",
|
||||
"bash": "sh",
|
||||
}
|
||||
|
||||
_CODE_TYPE_MAPPING: ClassVar[Dict[str, str]] = {
|
||||
"python": "python",
|
||||
"py3": "python",
|
||||
"python3": "python",
|
||||
"py": "python",
|
||||
"shell": "bash",
|
||||
"bash": "bash",
|
||||
"sh": "bash",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
require_confirm: bool = True,
|
||||
print_stdout: bool = False,
|
||||
print_stderr: bool = True,
|
||||
) -> None:
|
||||
self.require_confirm = require_confirm
|
||||
self.print_stdout = print_stdout
|
||||
self.print_stderr = print_stderr
|
||||
|
||||
# lazy initialization of container
|
||||
self._container: Optional[Container] = None
|
||||
|
||||
def __del__(self) -> None:
|
||||
r"""Destructor for the DockerInterpreter class.
|
||||
|
||||
This method ensures that the Docker container is removed when the
|
||||
interpreter is deleted.
|
||||
"""
|
||||
if self._container is not None:
|
||||
self._container.remove(force=True)
|
||||
|
||||
def _initialize_if_needed(self) -> None:
|
||||
if self._container is not None:
|
||||
return
|
||||
|
||||
if not is_docker_running():
|
||||
raise InterpreterError(
|
||||
"Docker daemon is not running. Please install/start docker "
|
||||
"and try again."
|
||||
)
|
||||
|
||||
import docker
|
||||
|
||||
client = docker.from_env()
|
||||
self._container = client.containers.run(
|
||||
"python:3.10",
|
||||
detach=True,
|
||||
name=f"camel-interpreter-{uuid.uuid4()}",
|
||||
command="tail -f /dev/null",
|
||||
)
|
||||
|
||||
def _create_file_in_container(self, content: str) -> Path:
|
||||
# get a random name for the file
|
||||
filename = str(uuid.uuid4())
|
||||
# create a tar in memory
|
||||
tar_stream = io.BytesIO()
|
||||
with tarfile.open(fileobj=tar_stream, mode='w') as tar:
|
||||
tarinfo = tarfile.TarInfo(name=filename)
|
||||
tarinfo.size = len(content)
|
||||
tar.addfile(tarinfo, io.BytesIO(content.encode('utf-8')))
|
||||
tar_stream.seek(0)
|
||||
|
||||
# copy the tar into the container
|
||||
if self._container is None:
|
||||
raise InterpreterError(
|
||||
"Container is not initialized. Try running the code again."
|
||||
)
|
||||
self._container.put_archive("/tmp", tar_stream)
|
||||
return Path(f"/tmp/{filename}")
|
||||
|
||||
def _run_file_in_container(
|
||||
self,
|
||||
file: Path,
|
||||
code_type: str,
|
||||
) -> str:
|
||||
code_type = self._check_code_type(code_type)
|
||||
commands = shlex.split(
|
||||
self._CODE_EXECUTE_CMD_MAPPING[code_type].format(
|
||||
file_name=file.as_posix()
|
||||
)
|
||||
)
|
||||
if self._container is None:
|
||||
raise InterpreterError(
|
||||
"Container is not initialized. Try running the code again."
|
||||
)
|
||||
stdout, stderr = self._container.exec_run(
|
||||
commands,
|
||||
demux=True,
|
||||
).output
|
||||
|
||||
if self.print_stdout and stdout:
|
||||
print("======stdout======")
|
||||
print(Fore.GREEN + stdout.decode() + Fore.RESET)
|
||||
print("==================")
|
||||
if self.print_stderr and stderr:
|
||||
print("======stderr======")
|
||||
print(Fore.RED + stderr.decode() + Fore.RESET)
|
||||
print("==================")
|
||||
exec_result = f"{stdout.decode()}" if stdout else ""
|
||||
exec_result += f"(stderr: {stderr.decode()})" if stderr else ""
|
||||
return exec_result
|
||||
|
||||
def run(
|
||||
self,
|
||||
code: str,
|
||||
code_type: str,
|
||||
) -> str:
|
||||
r"""Executes the given code in the conatiner attached to the
|
||||
interpreter, and captures the stdout and stderr streams.
|
||||
|
||||
Args:
|
||||
code (str): The code string to execute.
|
||||
code_type (str): The type of code to execute (e.g., 'python',
|
||||
'bash').
|
||||
|
||||
Returns:
|
||||
str: A string containing the captured stdout and stderr of the
|
||||
executed code.
|
||||
|
||||
Raises:
|
||||
InterpreterError: If the user declines to run the code, or the
|
||||
code type is unsupported, or there is an error in the docker
|
||||
API/container
|
||||
"""
|
||||
import docker.errors
|
||||
|
||||
code_type = self._check_code_type(code_type)
|
||||
|
||||
# Print code for security checking
|
||||
if self.require_confirm:
|
||||
logger.info(
|
||||
f"The following {code_type} code will run on your "
|
||||
"computer: {code}"
|
||||
)
|
||||
while True:
|
||||
choice = input("Running code? [Y/n]:").lower()
|
||||
if choice in ["y", "yes", "ye", ""]:
|
||||
break
|
||||
elif choice not in ["no", "n"]:
|
||||
continue
|
||||
raise InterpreterError(
|
||||
"Execution halted: User opted not to run the code. "
|
||||
"This choice stops the current operation and any "
|
||||
"further code execution."
|
||||
)
|
||||
|
||||
self._initialize_if_needed()
|
||||
|
||||
try:
|
||||
temp_file_path = self._create_file_in_container(code)
|
||||
result = self._run_file_in_container(temp_file_path, code_type)
|
||||
except docker.errors.APIError as e:
|
||||
raise InterpreterError(
|
||||
f"Execution halted due to docker API error: {e.explanation}. "
|
||||
"This choice stops the current operation and any "
|
||||
"further code execution."
|
||||
) from e
|
||||
except docker.errors.DockerException as e:
|
||||
raise InterpreterError(
|
||||
f"Execution halted due to docker exceptoin: {e}. "
|
||||
"This choice stops the current operation and any "
|
||||
"further code execution."
|
||||
) from e
|
||||
return result
|
||||
|
||||
def _check_code_type(self, code_type: str) -> str:
|
||||
if code_type not in self._CODE_TYPE_MAPPING:
|
||||
raise InterpreterError(
|
||||
f"Unsupported code type {code_type}. Currently "
|
||||
f"`{self.__class__.__name__}` only supports "
|
||||
f"{', '.join(self._CODE_EXTENSION_MAPPING.keys())}."
|
||||
)
|
||||
return self._CODE_TYPE_MAPPING[code_type]
|
||||
|
||||
def supported_code_types(self) -> List[str]:
|
||||
r"""Provides supported code types by the interpreter."""
|
||||
return list(self._CODE_EXTENSION_MAPPING.keys())
|
||||
|
||||
def update_action_space(self, action_space: Dict[str, Any]) -> None:
|
||||
r"""Updates action space for *python* interpreter"""
|
||||
raise RuntimeError(
|
||||
"SubprocessInterpreter doesn't support " "`action_space`."
|
||||
)
|
||||
@@ -1,516 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
import ast
|
||||
import difflib
|
||||
import importlib
|
||||
import typing
|
||||
from typing import Any, ClassVar, Dict, List, Optional
|
||||
|
||||
from camel.interpreters.base import BaseInterpreter
|
||||
from camel.interpreters.interpreter_error import InterpreterError
|
||||
|
||||
|
||||
class InternalPythonInterpreter(BaseInterpreter):
|
||||
r"""A customized python interpreter to control the execution of
|
||||
LLM-generated codes. The interpreter makes sure the code can only execute
|
||||
functions given in action space and import white list. It also supports
|
||||
fuzzy variable matching to retrieve uncertain input variable name.
|
||||
|
||||
.. highlight:: none
|
||||
|
||||
This class is adapted from the hugging face implementation
|
||||
`python_interpreter.py <https://github.com/huggingface/transformers/blob/8f
|
||||
093fb799246f7dd9104ff44728da0c53a9f67a/src/transformers/tools/python_interp
|
||||
reter.py>`_. The original license applies::
|
||||
|
||||
Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied. See the License for the specific language governing
|
||||
permissions and limitations under the License.
|
||||
|
||||
We have modified the original code to suit our requirements. We have
|
||||
encapsulated the original functions within a class and saved the
|
||||
interpreter state after execution. We have added support for "import"
|
||||
statements, "for" statements, and several binary and unary operators. We
|
||||
have added import white list to keep `import` statement safe. Additionally,
|
||||
we have modified the variable matching logic and introduced the
|
||||
:obj:`fuzz_state` for fuzzy matching.
|
||||
|
||||
Modifications copyright (C) 2023 CAMEL-AI.org
|
||||
|
||||
Args:
|
||||
action_space (Dict[str, Any], optional): A dictionary that maps action
|
||||
names to their corresponding functions or objects. The interpreter
|
||||
can only execute functions that are either directly listed in this
|
||||
dictionary or are member functions of objects listed in this
|
||||
dictionary. The concept of :obj:`action_space` is derived from
|
||||
EmbodiedAgent, representing the actions that an agent is capable of
|
||||
performing. If `None`, set to empty dict. (default: :obj:`None`)
|
||||
import_white_list (List[str], optional): A list that stores
|
||||
the Python modules or functions that can be imported in the code.
|
||||
All submodules and functions of the modules listed in this list are
|
||||
importable. Any other import statements will be rejected. The
|
||||
module and its submodule or function name are separated by a period
|
||||
(:obj:`.`). (default: :obj:`None`)
|
||||
unsafe_mode (bool, optional): If `True`, the interpreter runs the code
|
||||
by `eval()` without any security check. (default: :obj:`False`)
|
||||
raise_error (bool, optional): Raise error if the interpreter fails.
|
||||
(default: :obj:`False`)
|
||||
"""
|
||||
|
||||
_CODE_TYPES: ClassVar[List[str]] = ["python", "py", "python3", "python2"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
action_space: Optional[Dict[str, Any]] = None,
|
||||
import_white_list: Optional[List[str]] = None,
|
||||
unsafe_mode: bool = False,
|
||||
raise_error: bool = False,
|
||||
) -> None:
|
||||
self.action_space = action_space or dict()
|
||||
self.state = self.action_space.copy()
|
||||
self.fuzz_state: Dict[str, Any] = dict()
|
||||
self.import_white_list = import_white_list or list()
|
||||
self.raise_error = raise_error
|
||||
self.unsafe_mode = unsafe_mode
|
||||
|
||||
def run(self, code: str, code_type: str) -> str:
|
||||
r"""Executes the given code with specified code type in the
|
||||
interpreter.
|
||||
|
||||
This method takes a string of code and its type, checks if the code
|
||||
type is supported, and then executes the code. If `unsafe_mode` is
|
||||
set to `False`, the code is executed in a controlled environment using
|
||||
the `execute` method. If `unsafe_mode` is `True`, the code is executed
|
||||
using `eval()` with the action space as the global context. An
|
||||
`InterpreterError` is raised if the code type is unsupported or if any
|
||||
runtime error occurs during execution.
|
||||
|
||||
Args:
|
||||
code (str): The python code to be executed.
|
||||
code_type (str): The type of the code, which should be one of the
|
||||
supported code types (`python`, `py`, `python3`, `python2`).
|
||||
|
||||
|
||||
Returns:
|
||||
str: The string representation of the output of the executed code.
|
||||
|
||||
Raises:
|
||||
InterpreterError: If the `code_type` is not supported or if any
|
||||
runtime error occurs during the execution of the code.
|
||||
"""
|
||||
if code_type not in self._CODE_TYPES:
|
||||
raise InterpreterError(
|
||||
f"Unsupported code type {code_type}. "
|
||||
f"`{self.__class__.__name__}` only supports "
|
||||
f"{', '.join(self._CODE_TYPES)}."
|
||||
)
|
||||
if not self.unsafe_mode:
|
||||
return str(self.execute(code))
|
||||
else:
|
||||
return str(eval(code, self.action_space))
|
||||
|
||||
def update_action_space(self, action_space: Dict[str, Any]) -> None:
|
||||
r"""Updates action space for *python* interpreter."""
|
||||
self.action_space.update(action_space)
|
||||
|
||||
def supported_code_types(self) -> List[str]:
|
||||
r"""Provides supported code types by the interpreter."""
|
||||
return self._CODE_TYPES
|
||||
|
||||
def execute(
|
||||
self,
|
||||
code: str,
|
||||
state: Optional[Dict[str, Any]] = None,
|
||||
fuzz_state: Optional[Dict[str, Any]] = None,
|
||||
keep_state: bool = True,
|
||||
) -> Any:
|
||||
r"""Execute the input python codes in a security environment.
|
||||
|
||||
Args:
|
||||
code (str): Generated python code to be executed.
|
||||
state (Optional[Dict[str, Any]], optional): External variables that
|
||||
may be used in the generated code. (default: :obj:`None`)
|
||||
fuzz_state (Optional[Dict[str, Any]], optional): External variables
|
||||
that do not have certain variable names. The interpreter will
|
||||
use fuzzy matching to access these variables. For example, if
|
||||
:obj:`fuzz_state` has a variable :obj:`image`, the generated
|
||||
code can use :obj:`input_image` to access it. (default:
|
||||
:obj:`None`)
|
||||
keep_state (bool, optional): If :obj:`True`, :obj:`state` and
|
||||
:obj:`fuzz_state` will be kept for later execution. Otherwise,
|
||||
they will be cleared. (default: :obj:`True`)
|
||||
|
||||
Returns:
|
||||
Any: The value of the last statement (excluding "import") in the
|
||||
code. For this interpreter, the value of an expression is its
|
||||
value, the value of an "assign" statement is the assigned
|
||||
value, and the value of an "if" and "for" block statement is
|
||||
the value of the last statement in the block.
|
||||
"""
|
||||
if state is not None:
|
||||
self.state.update(state)
|
||||
if fuzz_state is not None:
|
||||
self.fuzz_state.update(fuzz_state)
|
||||
|
||||
try:
|
||||
expression = ast.parse(code)
|
||||
except SyntaxError as e:
|
||||
if self.raise_error:
|
||||
raise InterpreterError(f"Syntax error in code: {e}")
|
||||
else:
|
||||
import traceback
|
||||
|
||||
return traceback.format_exc()
|
||||
|
||||
result = None
|
||||
for idx, node in enumerate(expression.body):
|
||||
try:
|
||||
line_result = self._execute_ast(node)
|
||||
except InterpreterError as e:
|
||||
if not keep_state:
|
||||
self.clear_state()
|
||||
msg = (
|
||||
f"Evaluation of the code stopped at node {idx}. "
|
||||
f"See:\n{e}"
|
||||
)
|
||||
# More information can be provided by `ast.unparse()`,
|
||||
# which is new in python 3.9.
|
||||
if self.raise_error:
|
||||
raise InterpreterError(msg)
|
||||
else:
|
||||
import traceback
|
||||
|
||||
return traceback.format_exc()
|
||||
if line_result is not None:
|
||||
result = line_result
|
||||
|
||||
if not keep_state:
|
||||
self.clear_state()
|
||||
|
||||
return result
|
||||
|
||||
def clear_state(self) -> None:
|
||||
r"""Initialize :obj:`state` and :obj:`fuzz_state`."""
|
||||
self.state = self.action_space.copy()
|
||||
self.fuzz_state = {}
|
||||
|
||||
# ast.Index is deprecated after python 3.9, which cannot pass type check,
|
||||
# but is still necessary for older versions.
|
||||
@typing.no_type_check
|
||||
def _execute_ast(self, expression: ast.AST) -> Any:
|
||||
if isinstance(expression, ast.Assign):
|
||||
# Assignment -> evaluate the assignment which should
|
||||
# update the state. We return the variable assigned as it may
|
||||
# be used to determine the final result.
|
||||
return self._execute_assign(expression)
|
||||
elif isinstance(expression, ast.Attribute):
|
||||
value = self._execute_ast(expression.value)
|
||||
return getattr(value, expression.attr)
|
||||
elif isinstance(expression, ast.BinOp):
|
||||
# Binary Operator -> return the result value
|
||||
return self._execute_binop(expression)
|
||||
elif isinstance(expression, ast.Call):
|
||||
# Function call -> return the value of the function call
|
||||
return self._execute_call(expression)
|
||||
elif isinstance(expression, ast.Compare):
|
||||
# Compare -> return True or False
|
||||
return self._execute_condition(expression)
|
||||
elif isinstance(expression, ast.Constant):
|
||||
# Constant -> just return the value
|
||||
return expression.value
|
||||
elif isinstance(expression, ast.Dict):
|
||||
# Dict -> evaluate all keys and values
|
||||
result: Dict = {}
|
||||
for k, v in zip(expression.keys, expression.values):
|
||||
if k is not None:
|
||||
result[self._execute_ast(k)] = self._execute_ast(v)
|
||||
else:
|
||||
result.update(self._execute_ast(v))
|
||||
return result
|
||||
elif isinstance(expression, ast.Expr):
|
||||
# Expression -> evaluate the content
|
||||
return self._execute_ast(expression.value)
|
||||
elif isinstance(expression, ast.For):
|
||||
return self._execute_for(expression)
|
||||
elif isinstance(expression, ast.FormattedValue):
|
||||
# Formatted value (part of f-string) -> evaluate the content
|
||||
# and return
|
||||
return self._execute_ast(expression.value)
|
||||
elif isinstance(expression, ast.If):
|
||||
# If -> execute the right branch
|
||||
return self._execute_if(expression)
|
||||
elif isinstance(expression, ast.Import):
|
||||
# Import -> add imported names in self.state and return None.
|
||||
self._execute_import(expression)
|
||||
return None
|
||||
elif isinstance(expression, ast.ImportFrom):
|
||||
self._execute_import_from(expression)
|
||||
return None
|
||||
elif hasattr(ast, "Index") and isinstance(expression, ast.Index):
|
||||
# cannot pass type check
|
||||
return self._execute_ast(expression.value)
|
||||
elif isinstance(expression, ast.JoinedStr):
|
||||
return "".join(
|
||||
[str(self._execute_ast(v)) for v in expression.values]
|
||||
)
|
||||
elif isinstance(expression, ast.List):
|
||||
# List -> evaluate all elements
|
||||
return [self._execute_ast(elt) for elt in expression.elts]
|
||||
elif isinstance(expression, ast.Name):
|
||||
# Name -> pick up the value in the state
|
||||
return self._execute_name(expression)
|
||||
elif isinstance(expression, ast.Subscript):
|
||||
# Subscript -> return the value of the indexing
|
||||
return self._execute_subscript(expression)
|
||||
elif isinstance(expression, ast.Tuple):
|
||||
return tuple([self._execute_ast(elt) for elt in expression.elts])
|
||||
elif isinstance(expression, ast.UnaryOp):
|
||||
# Binary Operator -> return the result value
|
||||
return self._execute_unaryop(expression)
|
||||
else:
|
||||
# For now we refuse anything else. Let's add things as we need
|
||||
# them.
|
||||
raise InterpreterError(
|
||||
f"{expression.__class__.__name__} is not supported."
|
||||
)
|
||||
|
||||
def _execute_assign(self, assign: ast.Assign) -> Any:
|
||||
targets = assign.targets
|
||||
result = self._execute_ast(assign.value)
|
||||
|
||||
for target in targets:
|
||||
self._assign(target, result)
|
||||
return result
|
||||
|
||||
def _assign(self, target: ast.expr, value: Any):
|
||||
if isinstance(target, ast.Name):
|
||||
self.state[target.id] = value
|
||||
elif isinstance(target, ast.Tuple):
|
||||
if not isinstance(value, tuple):
|
||||
raise InterpreterError(
|
||||
f"Expected type tuple, but got"
|
||||
f"{value.__class__.__name__} instead."
|
||||
)
|
||||
if len(target.elts) != len(value):
|
||||
raise InterpreterError(
|
||||
f"Expected {len(target.elts)} values but got"
|
||||
f" {len(value)}."
|
||||
)
|
||||
for t, v in zip(target.elts, value):
|
||||
self.state[self._execute_ast(t)] = v
|
||||
else:
|
||||
raise InterpreterError(
|
||||
f"Unsupported variable type. Expected "
|
||||
f"ast.Name or ast.Tuple, got "
|
||||
f"{target.__class__.__name__} instead."
|
||||
)
|
||||
|
||||
def _execute_call(self, call: ast.Call) -> Any:
|
||||
callable_func = self._execute_ast(call.func)
|
||||
|
||||
# Todo deal with args
|
||||
args = [self._execute_ast(arg) for arg in call.args]
|
||||
kwargs = {
|
||||
keyword.arg: self._execute_ast(keyword.value)
|
||||
for keyword in call.keywords
|
||||
}
|
||||
return callable_func(*args, **kwargs)
|
||||
|
||||
def _execute_subscript(self, subscript: ast.Subscript):
|
||||
index = self._execute_ast(subscript.slice)
|
||||
value = self._execute_ast(subscript.value)
|
||||
if not isinstance(subscript.ctx, ast.Load):
|
||||
raise InterpreterError(
|
||||
f"{subscript.ctx.__class__.__name__} is not supported for "
|
||||
"subscript."
|
||||
)
|
||||
if isinstance(value, (list, tuple)):
|
||||
return value[int(index)]
|
||||
if index in value:
|
||||
return value[index]
|
||||
if isinstance(index, str) and isinstance(value, dict):
|
||||
close_matches = difflib.get_close_matches(
|
||||
index,
|
||||
[key for key in list(value.keys()) if isinstance(key, str)],
|
||||
)
|
||||
if len(close_matches) > 0:
|
||||
return value[close_matches[0]]
|
||||
|
||||
raise InterpreterError(f"Could not index {value} with '{index}'.")
|
||||
|
||||
def _execute_name(self, name: ast.Name):
|
||||
if isinstance(name.ctx, ast.Store):
|
||||
return name.id
|
||||
elif isinstance(name.ctx, ast.Load):
|
||||
return self._get_value_from_state(name.id)
|
||||
else:
|
||||
raise InterpreterError(f"{name.ctx} is not supported.")
|
||||
|
||||
def _execute_condition(self, condition: ast.Compare):
|
||||
if len(condition.ops) > 1:
|
||||
raise InterpreterError(
|
||||
"Cannot evaluate conditions with multiple operators"
|
||||
)
|
||||
|
||||
left = self._execute_ast(condition.left)
|
||||
comparator = condition.ops[0]
|
||||
right = self._execute_ast(condition.comparators[0])
|
||||
|
||||
if isinstance(comparator, ast.Eq):
|
||||
return left == right
|
||||
elif isinstance(comparator, ast.NotEq):
|
||||
return left != right
|
||||
elif isinstance(comparator, ast.Lt):
|
||||
return left < right
|
||||
elif isinstance(comparator, ast.LtE):
|
||||
return left <= right
|
||||
elif isinstance(comparator, ast.Gt):
|
||||
return left > right
|
||||
elif isinstance(comparator, ast.GtE):
|
||||
return left >= right
|
||||
elif isinstance(comparator, ast.Is):
|
||||
return left is right
|
||||
elif isinstance(comparator, ast.IsNot):
|
||||
return left is not right
|
||||
elif isinstance(comparator, ast.In):
|
||||
return left in right
|
||||
elif isinstance(comparator, ast.NotIn):
|
||||
return left not in right
|
||||
else:
|
||||
raise InterpreterError(f"Unsupported operator: {comparator}")
|
||||
|
||||
def _execute_if(self, if_statement: ast.If):
|
||||
result = None
|
||||
if not isinstance(if_statement.test, ast.Compare):
|
||||
raise InterpreterError(
|
||||
"Only Campare expr supported in if statement, get"
|
||||
f" {if_statement.test.__class__.__name__}"
|
||||
)
|
||||
if self._execute_condition(if_statement.test):
|
||||
for line in if_statement.body:
|
||||
line_result = self._execute_ast(line)
|
||||
if line_result is not None:
|
||||
result = line_result
|
||||
else:
|
||||
for line in if_statement.orelse:
|
||||
line_result = self._execute_ast(line)
|
||||
if line_result is not None:
|
||||
result = line_result
|
||||
return result
|
||||
|
||||
def _execute_for(self, for_statement: ast.For):
|
||||
result = None
|
||||
for value in self._execute_ast(for_statement.iter):
|
||||
self._assign(for_statement.target, value)
|
||||
for line in for_statement.body:
|
||||
line_result = self._execute_ast(line)
|
||||
if line_result is not None:
|
||||
result = line_result
|
||||
|
||||
return result
|
||||
|
||||
def _execute_import(self, import_module: ast.Import) -> None:
|
||||
for module in import_module.names:
|
||||
self._validate_import(module.name)
|
||||
alias = module.asname or module.name
|
||||
self.state[alias] = importlib.import_module(module.name)
|
||||
|
||||
def _execute_import_from(self, import_from: ast.ImportFrom):
|
||||
if import_from.module is None:
|
||||
raise InterpreterError("\"from . import\" is not supported.")
|
||||
for import_name in import_from.names:
|
||||
full_name = import_from.module + f".{import_name.name}"
|
||||
self._validate_import(full_name)
|
||||
imported_module = importlib.import_module(import_from.module)
|
||||
alias = import_name.asname or import_name.name
|
||||
self.state[alias] = getattr(imported_module, import_name.name)
|
||||
|
||||
def _validate_import(self, full_name: str):
|
||||
tmp_name = ""
|
||||
found_name = False
|
||||
for name in full_name.split("."):
|
||||
tmp_name += name if tmp_name == "" else f".{name}"
|
||||
if tmp_name in self.import_white_list:
|
||||
found_name = True
|
||||
return
|
||||
|
||||
if not found_name:
|
||||
raise InterpreterError(
|
||||
f"It is not permitted to import modules "
|
||||
f"than module white list (try to import "
|
||||
f"{full_name})."
|
||||
)
|
||||
|
||||
def _execute_binop(self, binop: ast.BinOp):
|
||||
left = self._execute_ast(binop.left)
|
||||
operator = binop.op
|
||||
right = self._execute_ast(binop.right)
|
||||
|
||||
if isinstance(operator, ast.Add):
|
||||
return left + right
|
||||
elif isinstance(operator, ast.Sub):
|
||||
return left - right
|
||||
elif isinstance(operator, ast.Mult):
|
||||
return left * right
|
||||
elif isinstance(operator, ast.Div):
|
||||
return left / right
|
||||
elif isinstance(operator, ast.FloorDiv):
|
||||
return left // right
|
||||
elif isinstance(operator, ast.Mod):
|
||||
return left % right
|
||||
elif isinstance(operator, ast.Pow):
|
||||
return left**right
|
||||
elif isinstance(operator, ast.LShift):
|
||||
return left << right
|
||||
elif isinstance(operator, ast.RShift):
|
||||
return left >> right
|
||||
elif isinstance(operator, ast.MatMult):
|
||||
return left @ right
|
||||
else:
|
||||
raise InterpreterError(f"Operator not supported: {operator}")
|
||||
|
||||
def _execute_unaryop(self, unaryop: ast.UnaryOp):
|
||||
operand = self._execute_ast(unaryop.operand)
|
||||
operator = unaryop.op
|
||||
|
||||
if isinstance(operator, ast.UAdd):
|
||||
return +operand
|
||||
elif isinstance(operator, ast.USub):
|
||||
return -operand
|
||||
elif isinstance(operator, ast.Not):
|
||||
return not operand
|
||||
else:
|
||||
raise InterpreterError(f"Operator not supported: {operator}")
|
||||
|
||||
def _get_value_from_state(self, key: str) -> Any:
|
||||
if key in self.state:
|
||||
return self.state[key]
|
||||
else:
|
||||
close_matches = difflib.get_close_matches(
|
||||
key, list(self.fuzz_state.keys()), n=1
|
||||
)
|
||||
if close_matches:
|
||||
return self.fuzz_state[close_matches[0]]
|
||||
else:
|
||||
raise InterpreterError(f"The variable `{key}` is not defined.")
|
||||
@@ -1,19 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
# TODO: Do we need a file to store this error class?
|
||||
class InterpreterError(Exception):
|
||||
r"""Exception raised for errors that can be solved by regenerating code"""
|
||||
|
||||
pass
|
||||
@@ -1,168 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
import queue
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from camel.interpreters.base import BaseInterpreter
|
||||
from camel.interpreters.interpreter_error import InterpreterError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from jupyter_client import BlockingKernelClient, KernelManager
|
||||
|
||||
TIMEOUT = 30
|
||||
|
||||
|
||||
class JupyterKernelInterpreter(BaseInterpreter):
|
||||
r"""A class for executing code strings in a Jupyter Kernel.
|
||||
|
||||
Args:
|
||||
require_confirm (bool, optional): If `True`, prompt user before
|
||||
running code strings for security. Defaults to `True`.
|
||||
print_stdout (bool, optional): If `True`, print the standard
|
||||
output of the executed code. Defaults to `False`.
|
||||
print_stderr (bool, optional): If `True`, print the standard error
|
||||
of the executed code. Defaults to `True`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
require_confirm: bool = True,
|
||||
print_stdout: bool = False,
|
||||
print_stderr: bool = True,
|
||||
) -> None:
|
||||
self.require_confirm = require_confirm
|
||||
self.print_stdout = print_stdout
|
||||
self.print_stderr = print_stderr
|
||||
|
||||
self.kernel_manager: Optional[KernelManager] = None
|
||||
self.client: Optional[BlockingKernelClient] = None
|
||||
|
||||
def __del__(self) -> None:
|
||||
r"""Clean up the kernel and client."""
|
||||
|
||||
if self.kernel_manager:
|
||||
self.kernel_manager.shutdown_kernel()
|
||||
if self.client:
|
||||
self.client.stop_channels()
|
||||
|
||||
def _initialize_if_needed(self) -> None:
|
||||
r"""Initialize the kernel manager and client if they are not already
|
||||
initialized.
|
||||
"""
|
||||
|
||||
if self.kernel_manager is not None:
|
||||
return
|
||||
|
||||
from jupyter_client.manager import start_new_kernel
|
||||
|
||||
self.kernel_manager, self.client = start_new_kernel()
|
||||
|
||||
@staticmethod
|
||||
def _clean_ipython_output(output: str) -> str:
|
||||
r"""Remove ANSI escape sequences from the output."""
|
||||
|
||||
ansi_escape = re.compile(r'\x1B[@-_][0-?]*[ -/]*[@-~]')
|
||||
return ansi_escape.sub('', output)
|
||||
|
||||
def _execute(self, code: str, timeout: float) -> str:
|
||||
r"""Execute the code in the Jupyter kernel and return the result."""
|
||||
|
||||
if not self.kernel_manager or not self.client:
|
||||
raise InterpreterError("Jupyter client is not initialized.")
|
||||
|
||||
self.client.execute(code)
|
||||
outputs = []
|
||||
while True:
|
||||
try:
|
||||
msg = self.client.get_iopub_msg(timeout=timeout)
|
||||
msg_content = msg["content"]
|
||||
msg_type = msg.get("msg_type", None)
|
||||
|
||||
if msg_content.get("execution_state", None) == "idle":
|
||||
break
|
||||
|
||||
if msg_type == "error":
|
||||
print(msg_content.keys())
|
||||
print(msg_content)
|
||||
traceback = "\n".join(msg_content["traceback"])
|
||||
outputs.append(traceback)
|
||||
elif msg_type == "stream":
|
||||
outputs.append(msg_content["text"])
|
||||
elif msg_type in ["execute_result", "display_data"]:
|
||||
outputs.append(msg_content["data"]["text/plain"])
|
||||
if "image/png" in msg_content["data"]:
|
||||
outputs.append(
|
||||
f"\n\n"
|
||||
)
|
||||
except queue.Empty:
|
||||
outputs.append("Time out")
|
||||
break
|
||||
except Exception as e:
|
||||
outputs.append(f"Exception occurred: {e!s}")
|
||||
break
|
||||
|
||||
exec_result = "\n".join(outputs)
|
||||
return self._clean_ipython_output(exec_result)
|
||||
|
||||
def run(self, code: str, code_type: str) -> str:
|
||||
r"""Executes the given code in the Jupyter kernel.
|
||||
|
||||
Args:
|
||||
code (str): The code string to execute.
|
||||
code_type (str): The type of code to execute (e.g., 'python',
|
||||
'bash').
|
||||
|
||||
Returns:
|
||||
str: A string containing the captured result of the
|
||||
executed code.
|
||||
|
||||
Raises:
|
||||
InterpreterError: If there is an error when doing code execution.
|
||||
"""
|
||||
self._initialize_if_needed()
|
||||
|
||||
if code_type == "bash":
|
||||
code = f"%%bash\n({code})"
|
||||
try:
|
||||
result = self._execute(code, timeout=TIMEOUT)
|
||||
except Exception as e:
|
||||
raise InterpreterError(f"Execution failed: {e!s}")
|
||||
|
||||
return result
|
||||
|
||||
def supported_code_types(self) -> List[str]:
|
||||
r"""Provides supported code types by the interpreter.
|
||||
|
||||
Returns:
|
||||
List[str]: Supported code types.
|
||||
"""
|
||||
return ["python", "bash"]
|
||||
|
||||
def update_action_space(self, action_space: Dict[str, Any]) -> None:
|
||||
r"""Updates the action space for the interpreter.
|
||||
|
||||
Args:
|
||||
action_space (Dict[str, Any]): A dictionary representing the
|
||||
new or updated action space.
|
||||
|
||||
Raises:
|
||||
RuntimeError: Always raised because `JupyterKernelInterpreter`
|
||||
does not support updating the action space.
|
||||
"""
|
||||
raise RuntimeError(
|
||||
"SubprocessInterpreter doesn't support " "`action_space`."
|
||||
)
|
||||
@@ -1,212 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# You may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
import shlex
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, ClassVar, Dict, List
|
||||
|
||||
from colorama import Fore
|
||||
|
||||
from camel.interpreters.base import BaseInterpreter
|
||||
from camel.interpreters.interpreter_error import InterpreterError
|
||||
from camel.logger import get_logger
|
||||
import os
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class SubprocessInterpreter(BaseInterpreter):
|
||||
r"""SubprocessInterpreter is a class for executing code files or code
|
||||
strings in a subprocess.
|
||||
|
||||
This class handles the execution of code in different scripting languages
|
||||
(currently Python, Bash, and Node.js) within a subprocess, capturing their
|
||||
stdout and stderr streams, and allowing user checking before executing code
|
||||
strings.
|
||||
|
||||
Args:
|
||||
require_confirm (bool, optional): If True, prompt user before running
|
||||
code strings for security. (default: :obj:`True`)
|
||||
print_stdout (bool, optional): If True, print the standard output of
|
||||
the executed code. (default: :obj:`False`)
|
||||
print_stderr (bool, optional): If True, print the standard error of the
|
||||
executed code. (default: :obj:`True`)
|
||||
"""
|
||||
|
||||
_CODE_EXECUTE_CMD_MAPPING: ClassVar[Dict[str, str]] = {
|
||||
"python": "python {file_name}",
|
||||
"bash": "bash {file_name}",
|
||||
"node": "node {file_name}",
|
||||
}
|
||||
|
||||
_CODE_EXTENSION_MAPPING: ClassVar[Dict[str, str]] = {
|
||||
"python": "py",
|
||||
"bash": "sh",
|
||||
"node": "js",
|
||||
}
|
||||
|
||||
_CODE_TYPE_MAPPING: ClassVar[Dict[str, str]] = {
|
||||
"python": "python",
|
||||
"py3": "python",
|
||||
"python3": "python",
|
||||
"py": "python",
|
||||
"shell": "bash",
|
||||
"bash": "bash",
|
||||
"sh": "bash",
|
||||
"node": "node",
|
||||
"javascript": "node",
|
||||
"js": "node",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
require_confirm: bool = True,
|
||||
print_stdout: bool = False,
|
||||
print_stderr: bool = True,
|
||||
) -> None:
|
||||
self.require_confirm = require_confirm
|
||||
self.print_stdout = print_stdout
|
||||
self.print_stderr = print_stderr
|
||||
|
||||
def run_file(
|
||||
self,
|
||||
file: Path,
|
||||
code_type: str,
|
||||
) -> str:
|
||||
r"""Executes a code file in a subprocess and captures its output.
|
||||
|
||||
Args:
|
||||
file (Path): The path object of the file to run.
|
||||
code_type (str): The type of code to execute (e.g., 'python',
|
||||
'bash', 'node').
|
||||
|
||||
Returns:
|
||||
str: A string containing the captured stdout and stderr of the
|
||||
executed code.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the provided file path does not point to a file.
|
||||
InterpreterError: If the code type provided is not supported.
|
||||
"""
|
||||
if not file.is_file():
|
||||
raise RuntimeError(f"{file} is not a file.")
|
||||
code_type = self._check_code_type(code_type)
|
||||
cmd = shlex.split(
|
||||
self._CODE_EXECUTE_CMD_MAPPING[code_type].format(file_name=str(file))
|
||||
)
|
||||
|
||||
proc = subprocess.Popen(
|
||||
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
|
||||
)
|
||||
stdout, stderr = proc.communicate()
|
||||
if self.print_stdout and stdout:
|
||||
print("======stdout======")
|
||||
print(Fore.GREEN + stdout + Fore.RESET)
|
||||
print("==================")
|
||||
if self.print_stderr and stderr:
|
||||
print("======stderr======")
|
||||
print(Fore.RED + stderr + Fore.RESET)
|
||||
print("==================")
|
||||
exec_result = f"{stdout}"
|
||||
exec_result += f"(stderr: {stderr})" if stderr else ""
|
||||
return exec_result
|
||||
|
||||
def run(
|
||||
self,
|
||||
code: str,
|
||||
code_type: str,
|
||||
) -> str:
|
||||
r"""Generates a temporary file with the given code, executes it, and
|
||||
deletes the file afterward.
|
||||
|
||||
Args:
|
||||
code (str): The code string to execute.
|
||||
code_type (str): The type of code to execute (e.g., 'python',
|
||||
'bash', 'node').
|
||||
|
||||
Returns:
|
||||
str: A string containing the captured stdout and stderr of the
|
||||
executed code.
|
||||
|
||||
Raises:
|
||||
InterpreterError: If the user declines to run the code or if the
|
||||
code type is unsupported.
|
||||
"""
|
||||
code_type = self._check_code_type(code_type)
|
||||
|
||||
if self.require_confirm:
|
||||
logger.info(
|
||||
f"The following {code_type} code will run on your " "computer: {code}"
|
||||
)
|
||||
while True:
|
||||
choice = input("Running code? [Y/n]:").lower()
|
||||
if choice in ["y", "yes", "ye", ""]:
|
||||
break
|
||||
elif choice in ["no", "n"]:
|
||||
raise InterpreterError(
|
||||
"Execution halted: User opted not to run the code. "
|
||||
"This choice stops the current operation and any "
|
||||
"further code execution."
|
||||
)
|
||||
|
||||
temp_file_path = self._create_temp_file(
|
||||
code=code, extension=self._CODE_EXTENSION_MAPPING[code_type]
|
||||
)
|
||||
|
||||
result = self.run_file(temp_file_path, code_type)
|
||||
|
||||
temp_file_path.unlink()
|
||||
return result
|
||||
|
||||
def _create_temp_file(self, code: str, extension: str) -> Path:
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", delete=False, suffix=f".{extension}"
|
||||
) as f:
|
||||
f.write(code)
|
||||
name = f.name
|
||||
return Path(name)
|
||||
|
||||
# def _create_temp_file(self, code: str, extension: str) -> Path:
|
||||
# # generate a random file name
|
||||
# import datetime
|
||||
|
||||
# current_time = datetime.datetime.now().strftime("%d%H%M%S")
|
||||
|
||||
# temp_file_path = os.path.join("tmp", f"{current_time}.{extension}")
|
||||
# with open(temp_file_path, "w", encoding='utf-8') as f:
|
||||
# f.write(code)
|
||||
# f.close()
|
||||
# f.flush()
|
||||
# breakpoint()
|
||||
# return Path(temp_file_path)
|
||||
|
||||
|
||||
def _check_code_type(self, code_type: str) -> str:
|
||||
if code_type not in self._CODE_TYPE_MAPPING:
|
||||
raise InterpreterError(
|
||||
f"Unsupported code type {code_type}. Currently "
|
||||
f"`{self.__class__.__name__}` only supports "
|
||||
f"{', '.join(self._CODE_EXTENSION_MAPPING.keys())}."
|
||||
)
|
||||
return self._CODE_TYPE_MAPPING[code_type]
|
||||
|
||||
def supported_code_types(self) -> List[str]:
|
||||
r"""Provides supported code types by the interpreter."""
|
||||
return list(self._CODE_EXTENSION_MAPPING.keys())
|
||||
|
||||
def update_action_space(self, action_space: Dict[str, Any]) -> None:
|
||||
r"""Updates action space for *python* interpreter"""
|
||||
raise RuntimeError("SubprocessInterpreter doesn't support " "`action_space`.")
|
||||
@@ -1,29 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
from .apify_reader import Apify
|
||||
from .base_io import File
|
||||
from .chunkr_reader import ChunkrReader
|
||||
from .firecrawl_reader import Firecrawl
|
||||
from .jina_url_reader import JinaURLReader
|
||||
from .unstructured_io import UnstructuredIO
|
||||
|
||||
__all__ = [
|
||||
'File',
|
||||
'UnstructuredIO',
|
||||
'JinaURLReader',
|
||||
'Firecrawl',
|
||||
'Apify',
|
||||
'ChunkrReader',
|
||||
]
|
||||
@@ -1,223 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
import os
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from apify_client.clients import DatasetClient
|
||||
|
||||
from camel.utils import api_keys_required
|
||||
|
||||
|
||||
class Apify:
|
||||
r"""Apify is a platform that allows you to automate any web workflow.
|
||||
|
||||
Args:
|
||||
api_key (Optional[str]): API key for authenticating with the Apify API.
|
||||
"""
|
||||
|
||||
@api_keys_required("APIFY_API_KEY")
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
) -> None:
|
||||
from apify_client import ApifyClient
|
||||
|
||||
self._api_key = api_key or os.environ.get("APIFY_API_KEY")
|
||||
self.client = ApifyClient(token=self._api_key)
|
||||
|
||||
def run_actor(
|
||||
self,
|
||||
actor_id: str,
|
||||
run_input: Optional[dict] = None,
|
||||
content_type: Optional[str] = None,
|
||||
build: Optional[str] = None,
|
||||
max_items: Optional[int] = None,
|
||||
memory_mbytes: Optional[int] = None,
|
||||
timeout_secs: Optional[int] = None,
|
||||
webhooks: Optional[list] = None,
|
||||
wait_secs: Optional[int] = None,
|
||||
) -> Optional[dict]:
|
||||
r"""Run an actor on the Apify platform.
|
||||
|
||||
Args:
|
||||
actor_id (str): The ID of the actor to run.
|
||||
run_input (Optional[dict]): The input data for the actor. Defaults
|
||||
to `None`.
|
||||
content_type (str, optional): The content type of the input.
|
||||
build (str, optional): Specifies the Actor build to run. It can be
|
||||
either a build tag or build number. By default, the run uses
|
||||
the build specified in the default run configuration for the
|
||||
Actor (typically latest).
|
||||
max_items (int, optional): Maximum number of results that will be
|
||||
returned by this run. If the Actor is charged per result, you
|
||||
will not be charged for more results than the given limit.
|
||||
memory_mbytes (int, optional): Memory limit for the run, in
|
||||
megabytes. By default, the run uses a memory limit specified in
|
||||
the default run configuration for the Actor.
|
||||
timeout_secs (int, optional): Optional timeout for the run, in
|
||||
seconds. By default, the run uses timeout specified in the
|
||||
default run configuration for the Actor.
|
||||
webhooks (list, optional): Optional webhooks
|
||||
(https://docs.apify.com/webhooks) associated with the Actor
|
||||
run, which can be used to receive a notification, e.g. when the
|
||||
Actor finished or failed. If you already have a webhook set up
|
||||
for the Actor, you do not have to add it again here.
|
||||
wait_secs (int, optional): The maximum number of seconds the server
|
||||
waits for finish. If not provided, waits indefinitely.
|
||||
|
||||
Returns:
|
||||
Optional[dict]: The output data from the actor if successful.
|
||||
# please use the 'defaultDatasetId' to get the dataset
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the actor fails to run.
|
||||
"""
|
||||
try:
|
||||
return self.client.actor(actor_id).call(
|
||||
run_input=run_input,
|
||||
content_type=content_type,
|
||||
build=build,
|
||||
max_items=max_items,
|
||||
memory_mbytes=memory_mbytes,
|
||||
timeout_secs=timeout_secs,
|
||||
webhooks=webhooks,
|
||||
wait_secs=wait_secs,
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to run actor {actor_id}: {e}") from e
|
||||
|
||||
def get_dataset_client(
|
||||
self,
|
||||
dataset_id: str,
|
||||
) -> "DatasetClient":
|
||||
r"""Get a dataset client from the Apify platform.
|
||||
|
||||
Args:
|
||||
dataset_id (str): The ID of the dataset to get the client for.
|
||||
|
||||
Returns:
|
||||
DatasetClient: The dataset client.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the dataset client fails to be retrieved.
|
||||
"""
|
||||
try:
|
||||
return self.client.dataset(dataset_id)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to get dataset {dataset_id}: {e}"
|
||||
) from e
|
||||
|
||||
def get_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
) -> Optional[dict]:
|
||||
r"""Get a dataset from the Apify platform.
|
||||
|
||||
Args:
|
||||
dataset_id (str): The ID of the dataset to get.
|
||||
|
||||
Returns:
|
||||
dict: The dataset.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the dataset fails to be retrieved.
|
||||
"""
|
||||
try:
|
||||
return self.get_dataset_client(dataset_id).get()
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to get dataset {dataset_id}: {e}"
|
||||
) from e
|
||||
|
||||
def update_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
name: str,
|
||||
) -> dict:
|
||||
r"""Update a dataset on the Apify platform.
|
||||
|
||||
Args:
|
||||
dataset_id (str): The ID of the dataset to update.
|
||||
name (str): The new name for the dataset.
|
||||
|
||||
Returns:
|
||||
dict: The updated dataset.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the dataset fails to be updated.
|
||||
"""
|
||||
try:
|
||||
return self.get_dataset_client(dataset_id).update(name=name)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to update dataset {dataset_id}: {e}"
|
||||
) from e
|
||||
|
||||
def get_dataset_items(
|
||||
self,
|
||||
dataset_id: str,
|
||||
) -> List:
|
||||
r"""Get items from a dataset on the Apify platform.
|
||||
|
||||
Args:
|
||||
dataset_id (str): The ID of the dataset to get items from.
|
||||
|
||||
Returns:
|
||||
list: The items in the dataset.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the items fail to be retrieved.
|
||||
"""
|
||||
try:
|
||||
items = self.get_dataset_client(dataset_id).list_items().items
|
||||
return items
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to get dataset items {dataset_id}: {e}"
|
||||
) from e
|
||||
|
||||
def get_datasets(
|
||||
self,
|
||||
unnamed: Optional[bool] = None,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
desc: Optional[bool] = None,
|
||||
) -> List[dict]:
|
||||
r"""Get all named datasets from the Apify platform.
|
||||
|
||||
Args:
|
||||
unnamed (bool, optional): Whether to include unnamed key-value
|
||||
stores in the list
|
||||
limit (int, optional): How many key-value stores to retrieve
|
||||
offset (int, optional): What key-value store to include as first
|
||||
when retrieving the list
|
||||
desc (bool, optional): Whether to sort the key-value stores in
|
||||
descending order based on their modification date
|
||||
|
||||
Returns:
|
||||
List[dict]: The datasets.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the datasets fail to be retrieved.
|
||||
"""
|
||||
try:
|
||||
return (
|
||||
self.client.datasets()
|
||||
.list(unnamed=unnamed, limit=limit, offset=offset, desc=desc)
|
||||
.items
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to get datasets: {e}") from e
|
||||
@@ -1,328 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from copy import deepcopy
|
||||
from hashlib import md5
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from camel.utils import dependencies_required
|
||||
|
||||
|
||||
class File(ABC):
|
||||
r"""Represents an uploaded file comprised of Documents.
|
||||
|
||||
Args:
|
||||
name (str): The name of the file.
|
||||
file_id (str): The unique identifier of the file.
|
||||
metadata (Dict[str, Any], optional): Additional metadata
|
||||
associated with the file. Defaults to None.
|
||||
docs (List[Dict[str, Any]], optional): A list of documents
|
||||
contained within the file. Defaults to None.
|
||||
raw_bytes (bytes, optional): The raw bytes content of the file.
|
||||
Defaults to b"".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
file_id: str,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
docs: Optional[List[Dict[str, Any]]] = None,
|
||||
raw_bytes: bytes = b"",
|
||||
):
|
||||
self.name = name
|
||||
self.file_id = file_id
|
||||
self.metadata = metadata or {}
|
||||
self.docs = docs or []
|
||||
self.raw_bytes = raw_bytes
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_bytes(cls, file: BytesIO, filename: str) -> "File":
|
||||
r"""Creates a File object from a BytesIO object.
|
||||
|
||||
Args:
|
||||
file (BytesIO): A BytesIO object representing the contents of the
|
||||
file.
|
||||
filename (str): The name of the file.
|
||||
|
||||
Returns:
|
||||
File: A File object.
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_raw_bytes(cls, raw_bytes: bytes, filename: str) -> "File":
|
||||
r"""Creates a File object from raw bytes.
|
||||
|
||||
Args:
|
||||
raw_bytes (bytes): The raw bytes content of the file.
|
||||
filename (str): The name of the file.
|
||||
|
||||
Returns:
|
||||
File: A File object.
|
||||
"""
|
||||
file = BytesIO(raw_bytes)
|
||||
return cls.from_bytes(file, filename)
|
||||
|
||||
@staticmethod
|
||||
def create_file(file: BytesIO, filename: str) -> "File":
|
||||
r"""Reads an uploaded file and returns a File object.
|
||||
|
||||
Args:
|
||||
file (BytesIO): A BytesIO object representing the contents of the
|
||||
file.
|
||||
filename (str): The name of the file.
|
||||
|
||||
Returns:
|
||||
File: A File object.
|
||||
"""
|
||||
ext_to_cls = {
|
||||
"docx": DocxFile,
|
||||
"pdf": PdfFile,
|
||||
"txt": TxtFile,
|
||||
"json": JsonFile,
|
||||
"html": HtmlFile,
|
||||
}
|
||||
|
||||
ext = filename.split(".")[-1].lower()
|
||||
if ext not in ext_to_cls:
|
||||
raise NotImplementedError(f"File type {ext} not supported")
|
||||
|
||||
out_file = ext_to_cls[ext].from_bytes(file, filename)
|
||||
return out_file
|
||||
|
||||
@staticmethod
|
||||
def create_file_from_raw_bytes(raw_bytes: bytes, filename: str) -> "File":
|
||||
r"""Reads raw bytes and returns a File object.
|
||||
|
||||
Args:
|
||||
raw_bytes (bytes): The raw bytes content of the file.
|
||||
filename (str): The name of the file.
|
||||
|
||||
Returns:
|
||||
File: A File object.
|
||||
"""
|
||||
file = BytesIO(raw_bytes)
|
||||
return File.create_file(file, filename)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"File(name={self.name}, id={self.file_id}, "
|
||||
f"metadata={self.metadata}, docs={self.docs})"
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f"File(name={self.name}, id={self.file_id}, metadata="
|
||||
f"{self.metadata})"
|
||||
)
|
||||
|
||||
def copy(self) -> "File":
|
||||
r"""Create a deep copy of this File"""
|
||||
|
||||
return self.__class__(
|
||||
name=self.name,
|
||||
file_id=self.file_id,
|
||||
metadata=deepcopy(self.metadata),
|
||||
docs=deepcopy(self.docs),
|
||||
raw_bytes=self.raw_bytes,
|
||||
)
|
||||
|
||||
|
||||
def strip_consecutive_newlines(text: str) -> str:
|
||||
r"""Strips consecutive newlines from a string.
|
||||
|
||||
Args:
|
||||
text (str): The string to strip.
|
||||
|
||||
Returns:
|
||||
str: The string with consecutive newlines stripped.
|
||||
"""
|
||||
return re.sub(r"\s*\n\s*", "\n", text)
|
||||
|
||||
|
||||
class DocxFile(File):
|
||||
@classmethod
|
||||
@dependencies_required('docx2txt')
|
||||
def from_bytes(cls, file: BytesIO, filename: str) -> "DocxFile":
|
||||
r"""Creates a DocxFile object from a BytesIO object.
|
||||
|
||||
Args:
|
||||
file (BytesIO): A BytesIO object representing the contents of the
|
||||
docx file.
|
||||
filename (str): The name of the file.
|
||||
|
||||
Returns:
|
||||
DocxFile: A DocxFile object.
|
||||
"""
|
||||
import docx2txt
|
||||
|
||||
text = docx2txt.process(file)
|
||||
text = strip_consecutive_newlines(text)
|
||||
# Create a dictionary with the extracted text
|
||||
doc = {"page_content": text.strip()}
|
||||
# Calculate a unique identifier for the file
|
||||
file_id = md5(file.getvalue()).hexdigest()
|
||||
# Reset the file pointer to the beginning
|
||||
file.seek(0)
|
||||
return cls(
|
||||
name=filename,
|
||||
file_id=file_id,
|
||||
docs=[doc],
|
||||
raw_bytes=file.getvalue(),
|
||||
)
|
||||
|
||||
|
||||
class PdfFile(File):
|
||||
@classmethod
|
||||
def from_bytes(cls, file: BytesIO, filename: str) -> "PdfFile":
|
||||
r"""Creates a PdfFile object from a BytesIO object.
|
||||
|
||||
Args:
|
||||
file (BytesIO): A BytesIO object representing the contents of the
|
||||
pdf file.
|
||||
filename (str): The name of the file.
|
||||
|
||||
Returns:
|
||||
PdfFile: A PdfFile object.
|
||||
"""
|
||||
# Use fitz to extract text from pdf files
|
||||
try:
|
||||
import fitz
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install `PyMuPDF` first. "
|
||||
"You can install it by running "
|
||||
"`pip install PyMuPDF`."
|
||||
)
|
||||
pdf = fitz.open(stream=file.read(), filetype="pdf")
|
||||
docs = []
|
||||
for i, page in enumerate(pdf):
|
||||
text = page.get_text(sort=True)
|
||||
text = strip_consecutive_newlines(text)
|
||||
# Create a dictionary with the extracted text
|
||||
doc = {"page_content": text.strip(), "page": i + 1}
|
||||
docs.append(doc)
|
||||
# Calculate a unique identifier for the file
|
||||
file_id = md5(file.getvalue()).hexdigest()
|
||||
# Reset the file pointer to the beginning
|
||||
file.seek(0)
|
||||
return cls(
|
||||
name=filename,
|
||||
file_id=file_id,
|
||||
docs=docs,
|
||||
raw_bytes=file.getvalue(),
|
||||
)
|
||||
|
||||
|
||||
class TxtFile(File):
|
||||
@classmethod
|
||||
def from_bytes(cls, file: BytesIO, filename: str) -> "TxtFile":
|
||||
r"""Creates a TxtFile object from a BytesIO object.
|
||||
|
||||
Args:
|
||||
file (BytesIO): A BytesIO object representing the contents of the
|
||||
txt file.
|
||||
filename (str): The name of the file.
|
||||
|
||||
Returns:
|
||||
TxtFile: A TxtFile object.
|
||||
"""
|
||||
# Read the text from the file
|
||||
text = file.read().decode("utf-8")
|
||||
text = strip_consecutive_newlines(text)
|
||||
# Create a dictionary with the extracted text
|
||||
doc = {"page_content": text.strip()}
|
||||
# Calculate a unique identifier for the file
|
||||
file_id = md5(file.getvalue()).hexdigest()
|
||||
# Reset the file pointer to the beginning
|
||||
file.seek(0)
|
||||
return cls(
|
||||
name=filename,
|
||||
file_id=file_id,
|
||||
docs=[doc],
|
||||
raw_bytes=file.getvalue(),
|
||||
)
|
||||
|
||||
|
||||
class JsonFile(File):
|
||||
@classmethod
|
||||
def from_bytes(cls, file: BytesIO, filename: str) -> "JsonFile":
|
||||
r"""Creates a JsonFile object from a BytesIO object.
|
||||
|
||||
Args:
|
||||
file (BytesIO): A BytesIO object representing the contents of the
|
||||
json file.
|
||||
filename (str): The name of the file.
|
||||
|
||||
Returns:
|
||||
JsonFile: A JsonFile object.
|
||||
"""
|
||||
# Parse the JSON data from the file
|
||||
data = json.load(file)
|
||||
# Create a dictionary with the parsed data
|
||||
doc = {"page_content": json.dumps(data)}
|
||||
# Calculate a unique identifier for the file
|
||||
file_id = md5(file.getvalue()).hexdigest()
|
||||
# Reset the file pointer to the beginning
|
||||
file.seek(0)
|
||||
return cls(
|
||||
name=filename,
|
||||
file_id=file_id,
|
||||
docs=[doc],
|
||||
raw_bytes=file.getvalue(),
|
||||
)
|
||||
|
||||
|
||||
class HtmlFile(File):
|
||||
@classmethod
|
||||
def from_bytes(cls, file: BytesIO, filename: str) -> "HtmlFile":
|
||||
r"""Creates a HtmlFile object from a BytesIO object.
|
||||
|
||||
Args:
|
||||
file (BytesIO): A BytesIO object representing the contents of the
|
||||
html file.
|
||||
filename (str): The name of the file.
|
||||
|
||||
Returns:
|
||||
HtmlFile: A HtmlFile object.
|
||||
"""
|
||||
# Parse the HTML data from the file
|
||||
try:
|
||||
from bs4 import BeautifulSoup
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install `beautifulsoup4` first. "
|
||||
"You can install it by running "
|
||||
"`pip install beautifulsoup4`."
|
||||
)
|
||||
soup = BeautifulSoup(file, "html.parser")
|
||||
text = soup.get_text()
|
||||
text = strip_consecutive_newlines(text)
|
||||
# Create a dictionary with the parsed data
|
||||
doc = {"page_content": text.strip()}
|
||||
# Calculate a unique identifier for the file
|
||||
file_id = md5(file.getvalue()).hexdigest()
|
||||
# Reset the file pointer to the beginning
|
||||
file.seek(0)
|
||||
return cls(
|
||||
name=filename,
|
||||
file_id=file_id,
|
||||
docs=[doc],
|
||||
raw_bytes=file.getvalue(),
|
||||
)
|
||||
@@ -1,162 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import IO, Any, Optional, Union
|
||||
|
||||
import requests
|
||||
|
||||
from camel.utils import api_keys_required
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChunkrReader:
|
||||
r"""Chunkr Reader for processing documents and returning content
|
||||
in various formats.
|
||||
|
||||
Args:
|
||||
api_key (Optional[str], optional): The API key for Chunkr API. If not
|
||||
provided, it will be retrieved from the environment variable
|
||||
`CHUNKR_API_KEY`. (default: :obj:`None`)
|
||||
url (Optional[str], optional): The url to the Chunkr service.
|
||||
(default: :obj:`https://api.chunkr.ai/api/v1/task`)
|
||||
timeout (int, optional): The maximum time in seconds to wait for the
|
||||
API responses. (default: :obj:`30`)
|
||||
**kwargs (Any): Additional keyword arguments for request headers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
url: Optional[str] = "https://api.chunkr.ai/api/v1/task",
|
||||
timeout: int = 30,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self._api_key = api_key or os.getenv('CHUNKR_API_KEY')
|
||||
self._url = os.getenv('CHUNKR_API_URL') or url
|
||||
self._headers = {
|
||||
"Authorization": f"{self._api_key}",
|
||||
**kwargs,
|
||||
}
|
||||
self.timeout = timeout
|
||||
|
||||
def submit_task(
|
||||
self,
|
||||
file_path: str,
|
||||
model: str = "Fast",
|
||||
ocr_strategy: str = "Auto",
|
||||
target_chunk_length: str = "512",
|
||||
) -> str:
|
||||
r"""Submits a file to the Chunkr API and returns the task ID.
|
||||
|
||||
Args:
|
||||
file_path (str): The path to the file to be uploaded.
|
||||
model (str, optional): The model to be used for the task.
|
||||
(default: :obj:`Fast`)
|
||||
ocr_strategy (str, optional): The OCR strategy. Defaults to 'Auto'.
|
||||
target_chunk_length (str, optional): The target chunk length.
|
||||
(default: :obj:`512`)
|
||||
|
||||
Returns:
|
||||
str: The task ID.
|
||||
"""
|
||||
with open(file_path, 'rb') as file:
|
||||
files: dict[
|
||||
str, Union[tuple[None, IO[bytes]], tuple[None, str]]
|
||||
] = {
|
||||
'file': (
|
||||
None,
|
||||
file,
|
||||
), # Properly pass the file as a binary stream
|
||||
'model': (None, model),
|
||||
'ocr_strategy': (None, ocr_strategy),
|
||||
'target_chunk_length': (None, target_chunk_length),
|
||||
}
|
||||
try:
|
||||
response = requests.post(
|
||||
self._url, # type: ignore[arg-type]
|
||||
headers=self._headers,
|
||||
files=files,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
task_id = response.json().get('task_id')
|
||||
if not task_id:
|
||||
raise ValueError("Task ID not returned in the response.")
|
||||
logger.info(f"Task submitted successfully. Task ID: {task_id}")
|
||||
return task_id
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to submit task: {e}")
|
||||
raise ValueError(f"Failed to submit task: {e}") from e
|
||||
|
||||
def get_task_output(self, task_id: str, max_retries: int = 5) -> str:
|
||||
r"""Polls the Chunkr API to check the task status and returns the task
|
||||
result.
|
||||
|
||||
Args:
|
||||
task_id (str): The task ID to check the status for.
|
||||
max_retries (int, optional): Maximum number of retry attempts.
|
||||
(default: :obj:`5`)
|
||||
|
||||
Returns:
|
||||
str: The formatted task result in JSON format.
|
||||
|
||||
Raises:
|
||||
ValueError: If the task status cannot be retrieved.
|
||||
RuntimeError: If the maximum number of retries is reached without
|
||||
a successful task completion.
|
||||
"""
|
||||
url_get = f"{self._url}/{task_id}"
|
||||
attempts = 0
|
||||
|
||||
while attempts < max_retries:
|
||||
try:
|
||||
response = requests.get(
|
||||
url_get, headers=self._headers, timeout=self.timeout
|
||||
)
|
||||
response.raise_for_status()
|
||||
task_status = response.json().get('status')
|
||||
|
||||
if task_status == "Succeeded":
|
||||
logger.info(f"Task {task_id} completed successfully.")
|
||||
return self._pretty_print_response(response.json())
|
||||
else:
|
||||
logger.info(
|
||||
f"Task {task_id} is still {task_status}. Retrying "
|
||||
"in 5 seconds..."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to retrieve task status: {e}")
|
||||
raise ValueError(f"Failed to retrieve task status: {e}") from e
|
||||
|
||||
attempts += 1
|
||||
time.sleep(5)
|
||||
|
||||
logger.error(f"Max retries reached for task {task_id}.")
|
||||
raise RuntimeError(f"Max retries reached for task {task_id}.")
|
||||
|
||||
def _pretty_print_response(self, response_json: dict) -> str:
|
||||
r"""Pretty prints the JSON response.
|
||||
|
||||
Args:
|
||||
response_json (dict): The response JSON to pretty print.
|
||||
|
||||
Returns:
|
||||
str: Formatted JSON as a string.
|
||||
"""
|
||||
return json.dumps(response_json, indent=4)
|
||||
@@ -1,202 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Firecrawl:
|
||||
r"""Firecrawl allows you to turn entire websites into LLM-ready markdown.
|
||||
|
||||
Args:
|
||||
api_key (Optional[str]): API key for authenticating with the Firecrawl
|
||||
API.
|
||||
api_url (Optional[str]): Base URL for the Firecrawl API.
|
||||
|
||||
References:
|
||||
https://docs.firecrawl.dev/introduction
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
api_url: Optional[str] = None,
|
||||
) -> None:
|
||||
from firecrawl import FirecrawlApp
|
||||
|
||||
self._api_key = api_key or os.environ.get("FIRECRAWL_API_KEY")
|
||||
self._api_url = api_url or os.environ.get("FIRECRAWL_API_URL")
|
||||
|
||||
self.app = FirecrawlApp(api_key=self._api_key, api_url=self._api_url)
|
||||
|
||||
def crawl(
|
||||
self,
|
||||
url: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
r"""Crawl a URL and all accessible subpages. Customize the crawl by
|
||||
setting different parameters, and receive the full response or a job
|
||||
ID based on the specified options.
|
||||
|
||||
Args:
|
||||
url (str): The URL to crawl.
|
||||
params (Optional[Dict[str, Any]]): Additional parameters for the
|
||||
crawl request. Defaults to `None`.
|
||||
**kwargs (Any): Additional keyword arguments, such as
|
||||
`poll_interval`, `idempotency_key`.
|
||||
|
||||
Returns:
|
||||
Any: The crawl job ID or the crawl results if waiting until
|
||||
completion.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the crawling process fails.
|
||||
"""
|
||||
|
||||
try:
|
||||
crawl_response = self.app.crawl_url(
|
||||
url=url,
|
||||
params=params,
|
||||
**kwargs,
|
||||
)
|
||||
return crawl_response
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to crawl the URL: {e}")
|
||||
|
||||
def markdown_crawl(self, url: str) -> str:
|
||||
r"""Crawl a URL and all accessible subpages and return the content in
|
||||
Markdown format.
|
||||
|
||||
Args:
|
||||
url (str): The URL to crawl.
|
||||
|
||||
Returns:
|
||||
str: The content of the URL in Markdown format.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the crawling process fails.
|
||||
"""
|
||||
|
||||
try:
|
||||
crawl_result = self.app.crawl_url(
|
||||
url,
|
||||
{'formats': ['markdown']},
|
||||
)
|
||||
if not isinstance(crawl_result, list):
|
||||
raise ValueError("Unexpected response format")
|
||||
markdown_contents = [
|
||||
result.get('markdown', '') for result in crawl_result
|
||||
]
|
||||
return '\n'.join(markdown_contents)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to crawl the URL and retrieve markdown: {e}"
|
||||
)
|
||||
|
||||
def check_crawl_job(self, job_id: str) -> Dict:
|
||||
r"""Check the status of a crawl job.
|
||||
|
||||
Args:
|
||||
job_id (str): The ID of the crawl job.
|
||||
|
||||
Returns:
|
||||
Dict: The response including status of the crawl job.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the check process fails.
|
||||
"""
|
||||
|
||||
try:
|
||||
return self.app.check_crawl_status(job_id)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to check the crawl job status: {e}")
|
||||
|
||||
def scrape(
|
||||
self,
|
||||
url: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict:
|
||||
r"""To scrape a single URL. This function supports advanced scraping
|
||||
by setting different parameters and returns the full scraped data as a
|
||||
dictionary.
|
||||
|
||||
Reference: https://docs.firecrawl.dev/advanced-scraping-guide
|
||||
|
||||
Args:
|
||||
url (str): The URL to read.
|
||||
params (Optional[Dict[str, Any]]): Additional parameters for the
|
||||
scrape request.
|
||||
|
||||
Returns:
|
||||
Dict: The scraped data.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the scrape process fails.
|
||||
"""
|
||||
try:
|
||||
return self.app.scrape_url(url=url, params=params)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to scrape the URL: {e}")
|
||||
|
||||
def structured_scrape(self, url: str, response_format: BaseModel) -> Dict:
|
||||
r"""Use LLM to extract structured data from given URL.
|
||||
|
||||
Args:
|
||||
url (str): The URL to read.
|
||||
response_format (BaseModel): A pydantic model
|
||||
that includes value types and field descriptions used to
|
||||
generate a structured response by LLM. This schema helps
|
||||
in defining the expected output format.
|
||||
|
||||
Returns:
|
||||
Dict: The content of the URL.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the scrape process fails.
|
||||
"""
|
||||
try:
|
||||
data = self.app.scrape_url(
|
||||
url,
|
||||
{
|
||||
'formats': ['extract'],
|
||||
'extract': {'schema': response_format.model_json_schema()},
|
||||
},
|
||||
)
|
||||
return data.get("extract", {})
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to perform structured scrape: {e}")
|
||||
|
||||
def map_site(
|
||||
self, url: str, params: Optional[Dict[str, Any]] = None
|
||||
) -> list:
|
||||
r"""Map a website to retrieve all accessible URLs.
|
||||
|
||||
Args:
|
||||
url (str): The URL of the site to map.
|
||||
params (Optional[Dict[str, Any]]): Additional parameters for the
|
||||
map request. Defaults to `None`.
|
||||
|
||||
Returns:
|
||||
list: A list containing the URLs found on the site.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the mapping process fails.
|
||||
"""
|
||||
try:
|
||||
return self.app.map_url(url=url, params=params)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to map the site: {e}")
|
||||
@@ -1,99 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
from warnings import warn
|
||||
|
||||
from camel.types.enums import JinaReturnFormat
|
||||
|
||||
JINA_ENDPOINT = "https://r.jina.ai/"
|
||||
|
||||
|
||||
class JinaURLReader:
|
||||
r"""URL Reader provided by Jina AI. The output is cleaner and more
|
||||
LLM-friendly than the URL Reader of UnstructuredIO. Can be configured to
|
||||
replace the UnstructuredIO URL Reader in the pipeline.
|
||||
|
||||
Args:
|
||||
api_key (Optional[str], optional): The API key for Jina AI. If not
|
||||
provided, the reader will have a lower rate limit. Defaults to
|
||||
None.
|
||||
return_format (ReturnFormat, optional): The level of detail
|
||||
of the returned content, which is optimized for LLMs. For
|
||||
now screenshots are not supported. Defaults to
|
||||
ReturnFormat.DEFAULT.
|
||||
json_response (bool, optional): Whether to return the response
|
||||
in JSON format. Defaults to False.
|
||||
timeout (int, optional): The maximum time in seconds to wait for
|
||||
the page to be rendered. Defaults to 30.
|
||||
**kwargs (Any): Additional keyword arguments, including proxies,
|
||||
cookies, etc. It should align with the HTTP Header field and
|
||||
value pairs listed in the reference.
|
||||
|
||||
References:
|
||||
https://jina.ai/reader
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
return_format: JinaReturnFormat = JinaReturnFormat.DEFAULT,
|
||||
json_response: bool = False,
|
||||
timeout: int = 30,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
api_key = api_key or os.getenv('JINA_API_KEY')
|
||||
if not api_key:
|
||||
warn(
|
||||
"JINA_API_KEY not set. This will result in a low rate limit "
|
||||
"of Jina URL Reader. Get API key here: https://jina.ai/reader."
|
||||
)
|
||||
|
||||
# if the following field not provided, it will be None
|
||||
api_field = f"Bearer {api_key}" if api_key else None
|
||||
json_field = "application/json" if json_response else None
|
||||
|
||||
raw_headers = {
|
||||
"Authorization": api_field,
|
||||
"X-Return-Format": return_format.value,
|
||||
"Accept": json_field,
|
||||
"X-Timeout": str(timeout),
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
# eliminate None values
|
||||
self._headers = {k: v for k, v in raw_headers.items() if v}
|
||||
|
||||
def read_content(self, url: str) -> str:
|
||||
r"""Reads the content of a URL and returns it as a string with
|
||||
given form.
|
||||
|
||||
Args:
|
||||
url (str): The URL to read.
|
||||
|
||||
Returns:
|
||||
str: The content of the URL.
|
||||
"""
|
||||
|
||||
import requests
|
||||
|
||||
full_url = f"{JINA_ENDPOINT}{url}"
|
||||
try:
|
||||
resp = requests.get(full_url, headers=self._headers)
|
||||
resp.raise_for_status()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to read content from {url}: {e}") from e
|
||||
|
||||
return resp.text
|
||||
@@ -1,471 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
import uuid
|
||||
import warnings
|
||||
from typing import (
|
||||
IO,
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
if TYPE_CHECKING:
|
||||
from unstructured.documents.elements import Element
|
||||
import pdb
|
||||
|
||||
class UnstructuredIO:
|
||||
r"""A class to handle various functionalities provided by the
|
||||
Unstructured library, including version checking, parsing, cleaning,
|
||||
extracting, staging, chunking data, and integrating with cloud
|
||||
services like S3 and Azure for data connection.
|
||||
|
||||
References:
|
||||
https://docs.unstructured.io/
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create_element_from_text(
|
||||
text: str,
|
||||
element_id: Optional[str] = None,
|
||||
embeddings: Optional[List[float]] = None,
|
||||
filename: Optional[str] = None,
|
||||
file_directory: Optional[str] = None,
|
||||
last_modified: Optional[str] = None,
|
||||
filetype: Optional[str] = None,
|
||||
parent_id: Optional[str] = None,
|
||||
) -> "Element":
|
||||
r"""Creates a Text element from a given text input, with optional
|
||||
metadata and embeddings.
|
||||
|
||||
Args:
|
||||
text (str): The text content for the element.
|
||||
element_id (Optional[str], optional): Unique identifier for the
|
||||
element. (default: :obj:`None`)
|
||||
embeddings (List[float], optional): A list of float
|
||||
numbers representing the text embeddings.
|
||||
(default: :obj:`None`)
|
||||
filename (Optional[str], optional): The name of the file the
|
||||
element is associated with. (default: :obj:`None`)
|
||||
file_directory (Optional[str], optional): The directory path where
|
||||
the file is located. (default: :obj:`None`)
|
||||
last_modified (Optional[str], optional): The last modified date of
|
||||
the file. (default: :obj:`None`)
|
||||
filetype (Optional[str], optional): The type of the file.
|
||||
(default: :obj:`None`)
|
||||
parent_id (Optional[str], optional): The identifier of the parent
|
||||
element. (default: :obj:`None`)
|
||||
|
||||
Returns:
|
||||
Element: An instance of Text with the provided content and
|
||||
metadata.
|
||||
"""
|
||||
from unstructured.documents.elements import ElementMetadata, Text
|
||||
|
||||
metadata = ElementMetadata(
|
||||
filename=filename,
|
||||
file_directory=file_directory,
|
||||
last_modified=last_modified,
|
||||
filetype=filetype,
|
||||
parent_id=parent_id,
|
||||
)
|
||||
|
||||
return Text(
|
||||
text=text,
|
||||
element_id=element_id or str(uuid.uuid4()),
|
||||
metadata=metadata,
|
||||
embeddings=embeddings,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def parse_file_or_url(
|
||||
input_path: str,
|
||||
**kwargs: Any,
|
||||
) -> Union[List["Element"], None]:
|
||||
r"""Loads a file or a URL and parses its contents into elements.
|
||||
|
||||
Args:
|
||||
input_path (str): Path to the file or URL to be parsed.
|
||||
**kwargs: Extra kwargs passed to the partition function.
|
||||
|
||||
Returns:
|
||||
Union[List[Element],None]: List of elements after parsing the file
|
||||
or URL if success.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the file does not exist at the path
|
||||
specified.
|
||||
|
||||
Notes:
|
||||
Supported file types:
|
||||
"csv", "doc", "docx", "epub", "image", "md", "msg", "odt",
|
||||
"org", "pdf", "ppt", "pptx", "rtf", "rst", "tsv", "xlsx".
|
||||
|
||||
References:
|
||||
https://unstructured-io.github.io/unstructured/
|
||||
"""
|
||||
import os
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from unstructured.partition.auto import partition
|
||||
# Check if the input is a URL
|
||||
parsed_url = urlparse(input_path)
|
||||
# pdb.set_trace()
|
||||
is_url = all([parsed_url.scheme, parsed_url.netloc])
|
||||
|
||||
# Handling URL
|
||||
if is_url:
|
||||
try:
|
||||
elements = partition(url=input_path, **kwargs)
|
||||
return elements
|
||||
except Exception:
|
||||
warnings.warn(f"Failed to parse the URL: {input_path}")
|
||||
return None
|
||||
|
||||
# Handling file
|
||||
else:
|
||||
# Check if the file exists
|
||||
if not os.path.exists(input_path):
|
||||
raise FileNotFoundError(
|
||||
f"The file {input_path} was not found."
|
||||
)
|
||||
|
||||
# Read the file
|
||||
try:
|
||||
with open(input_path, "rb") as f:
|
||||
elements = partition(file=f, **kwargs)
|
||||
return elements
|
||||
except Exception:
|
||||
warnings.warn(f"Failed to partition the file: {input_path}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def parse_bytes(
|
||||
file: IO[bytes], **kwargs: Any
|
||||
) -> Union[List["Element"], None]:
|
||||
r"""Parses a bytes stream and converts its contents into elements.
|
||||
|
||||
Args:
|
||||
file (IO[bytes]): The file in bytes format to be parsed.
|
||||
**kwargs: Extra kwargs passed to the partition function.
|
||||
|
||||
Returns:
|
||||
Union[List[Element], None]: List of elements after parsing the file
|
||||
if successful, otherwise `None`.
|
||||
|
||||
Notes:
|
||||
Supported file types:
|
||||
"csv", "doc", "docx", "epub", "image", "md", "msg", "odt",
|
||||
"org", "pdf", "ppt", "pptx", "rtf", "rst", "tsv", "xlsx".
|
||||
|
||||
References:
|
||||
https://docs.unstructured.io/open-source/core-functionality/partitioning
|
||||
"""
|
||||
|
||||
from unstructured.partition.auto import partition
|
||||
|
||||
try:
|
||||
# Use partition to process the bytes stream
|
||||
elements = partition(file=file, **kwargs)
|
||||
return elements
|
||||
except Exception as e:
|
||||
warnings.warn(f"Failed to partition the file stream: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def clean_text_data(
|
||||
text: str,
|
||||
clean_options: Optional[List[Tuple[str, Dict[str, Any]]]] = None,
|
||||
) -> str:
|
||||
r"""Cleans text data using a variety of cleaning functions provided by
|
||||
the `unstructured` library.
|
||||
|
||||
This function applies multiple text cleaning utilities by calling the
|
||||
`unstructured` library's cleaning bricks for operations like
|
||||
replacing Unicode quotes, removing extra whitespace, dashes, non-ascii
|
||||
characters, and more.
|
||||
|
||||
If no cleaning options are provided, a default set of cleaning
|
||||
operations is applied. These defaults including operations
|
||||
"replace_unicode_quotes", "clean_non_ascii_chars",
|
||||
"group_broken_paragraphs", and "clean_extra_whitespace".
|
||||
|
||||
Args:
|
||||
text (str): The text to be cleaned.
|
||||
clean_options (dict): A dictionary specifying which cleaning
|
||||
options to apply. The keys should match the names of the
|
||||
cleaning functions, and the values should be dictionaries
|
||||
containing the parameters for each function. Supported types:
|
||||
'clean_extra_whitespace', 'clean_bullets',
|
||||
'clean_ordered_bullets', 'clean_postfix', 'clean_prefix',
|
||||
'clean_dashes', 'clean_trailing_punctuation',
|
||||
'clean_non_ascii_chars', 'group_broken_paragraphs',
|
||||
'remove_punctuation', 'replace_unicode_quotes',
|
||||
'bytes_string_to_string', 'translate_text'.
|
||||
|
||||
Returns:
|
||||
str: The cleaned text.
|
||||
|
||||
Raises:
|
||||
AttributeError: If a cleaning option does not correspond to a
|
||||
valid cleaning function in `unstructured`.
|
||||
|
||||
Notes:
|
||||
The 'options' dictionary keys must correspond to valid cleaning
|
||||
brick names from the `unstructured` library.
|
||||
Each brick's parameters must be provided in a nested dictionary
|
||||
as the value for the key.
|
||||
|
||||
References:
|
||||
https://unstructured-io.github.io/unstructured/
|
||||
"""
|
||||
|
||||
from unstructured.cleaners.core import (
|
||||
bytes_string_to_string,
|
||||
clean_bullets,
|
||||
clean_dashes,
|
||||
clean_extra_whitespace,
|
||||
clean_non_ascii_chars,
|
||||
clean_ordered_bullets,
|
||||
clean_postfix,
|
||||
clean_prefix,
|
||||
clean_trailing_punctuation,
|
||||
group_broken_paragraphs,
|
||||
remove_punctuation,
|
||||
replace_unicode_quotes,
|
||||
)
|
||||
from unstructured.cleaners.translate import translate_text
|
||||
|
||||
cleaning_functions: Any = {
|
||||
"clean_extra_whitespace": clean_extra_whitespace,
|
||||
"clean_bullets": clean_bullets,
|
||||
"clean_ordered_bullets": clean_ordered_bullets,
|
||||
"clean_postfix": clean_postfix,
|
||||
"clean_prefix": clean_prefix,
|
||||
"clean_dashes": clean_dashes,
|
||||
"clean_trailing_punctuation": clean_trailing_punctuation,
|
||||
"clean_non_ascii_chars": clean_non_ascii_chars,
|
||||
"group_broken_paragraphs": group_broken_paragraphs,
|
||||
"remove_punctuation": remove_punctuation,
|
||||
"replace_unicode_quotes": replace_unicode_quotes,
|
||||
"bytes_string_to_string": bytes_string_to_string,
|
||||
"translate_text": translate_text,
|
||||
}
|
||||
|
||||
# Define default clean options if none are provided
|
||||
if clean_options is None:
|
||||
clean_options = [
|
||||
("replace_unicode_quotes", {}),
|
||||
("clean_non_ascii_chars", {}),
|
||||
("group_broken_paragraphs", {}),
|
||||
("clean_extra_whitespace", {}),
|
||||
]
|
||||
|
||||
cleaned_text = text
|
||||
for func_name, params in clean_options:
|
||||
if func_name in cleaning_functions:
|
||||
cleaned_text = cleaning_functions[func_name](
|
||||
cleaned_text, **params
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"'{func_name}' is not a valid function in "
|
||||
"`Unstructured IO`."
|
||||
)
|
||||
|
||||
return cleaned_text
|
||||
|
||||
@staticmethod
|
||||
def extract_data_from_text(
|
||||
text: str,
|
||||
extract_type: Literal[
|
||||
'extract_datetimetz',
|
||||
'extract_email_address',
|
||||
'extract_ip_address',
|
||||
'extract_ip_address_name',
|
||||
'extract_mapi_id',
|
||||
'extract_ordered_bullets',
|
||||
'extract_text_after',
|
||||
'extract_text_before',
|
||||
'extract_us_phone_number',
|
||||
],
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
r"""Extracts various types of data from text using functions from
|
||||
unstructured.cleaners.extract.
|
||||
|
||||
Args:
|
||||
text (str): Text to extract data from.
|
||||
extract_type (Literal['extract_datetimetz',
|
||||
'extract_email_address', 'extract_ip_address',
|
||||
'extract_ip_address_name', 'extract_mapi_id',
|
||||
'extract_ordered_bullets', 'extract_text_after',
|
||||
'extract_text_before', 'extract_us_phone_number']): Type of
|
||||
data to extract.
|
||||
**kwargs: Additional keyword arguments for specific
|
||||
extraction functions.
|
||||
|
||||
Returns:
|
||||
Any: The extracted data, type depends on extract_type.
|
||||
|
||||
References:
|
||||
https://unstructured-io.github.io/unstructured/
|
||||
"""
|
||||
|
||||
from unstructured.cleaners.extract import (
|
||||
extract_datetimetz,
|
||||
extract_email_address,
|
||||
extract_ip_address,
|
||||
extract_ip_address_name,
|
||||
extract_mapi_id,
|
||||
extract_ordered_bullets,
|
||||
extract_text_after,
|
||||
extract_text_before,
|
||||
extract_us_phone_number,
|
||||
)
|
||||
|
||||
extraction_functions: Any = {
|
||||
"extract_datetimetz": extract_datetimetz,
|
||||
"extract_email_address": extract_email_address,
|
||||
"extract_ip_address": extract_ip_address,
|
||||
"extract_ip_address_name": extract_ip_address_name,
|
||||
"extract_mapi_id": extract_mapi_id,
|
||||
"extract_ordered_bullets": extract_ordered_bullets,
|
||||
"extract_text_after": extract_text_after,
|
||||
"extract_text_before": extract_text_before,
|
||||
"extract_us_phone_number": extract_us_phone_number,
|
||||
}
|
||||
|
||||
if extract_type not in extraction_functions:
|
||||
raise ValueError(f"Unsupported extract_type: {extract_type}")
|
||||
|
||||
return extraction_functions[extract_type](text, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def stage_elements(
|
||||
elements: List[Any],
|
||||
stage_type: Literal[
|
||||
'convert_to_csv',
|
||||
'convert_to_dataframe',
|
||||
'convert_to_dict',
|
||||
'dict_to_elements',
|
||||
'stage_csv_for_prodigy',
|
||||
'stage_for_prodigy',
|
||||
'stage_for_baseplate',
|
||||
'stage_for_datasaur',
|
||||
'stage_for_label_box',
|
||||
'stage_for_label_studio',
|
||||
'stage_for_weaviate',
|
||||
],
|
||||
**kwargs,
|
||||
) -> Union[str, List[Dict], Any]:
|
||||
r"""Stages elements for various platforms based on the
|
||||
specified staging type.
|
||||
|
||||
This function applies multiple staging utilities to format data
|
||||
for different NLP annotation and machine learning tools. It uses
|
||||
the 'unstructured.staging' module's functions for operations like
|
||||
converting to CSV, DataFrame, dictionary, or formatting for
|
||||
specific platforms like Prodigy, etc.
|
||||
|
||||
Args:
|
||||
elements (List[Any]): List of Element objects to be staged.
|
||||
stage_type (Literal['convert_to_csv', 'convert_to_dataframe',
|
||||
'convert_to_dict', 'dict_to_elements',
|
||||
'stage_csv_for_prodigy', 'stage_for_prodigy',
|
||||
'stage_for_baseplate', 'stage_for_datasaur',
|
||||
'stage_for_label_box', 'stage_for_label_studio',
|
||||
'stage_for_weaviate']): Type of staging to perform.
|
||||
**kwargs: Additional keyword arguments specific to
|
||||
the staging type.
|
||||
|
||||
Returns:
|
||||
Union[str, List[Dict], Any]: Staged data in the
|
||||
format appropriate for the specified staging type.
|
||||
|
||||
Raises:
|
||||
ValueError: If the staging type is not supported or a required
|
||||
argument is missing.
|
||||
References:
|
||||
https://unstructured-io.github.io/unstructured/
|
||||
"""
|
||||
|
||||
from unstructured.staging import (
|
||||
base,
|
||||
baseplate,
|
||||
datasaur,
|
||||
label_box,
|
||||
label_studio,
|
||||
prodigy,
|
||||
weaviate,
|
||||
)
|
||||
|
||||
staging_functions: Any = {
|
||||
"convert_to_csv": base.convert_to_csv,
|
||||
"convert_to_dataframe": base.convert_to_dataframe,
|
||||
"convert_to_dict": base.convert_to_dict,
|
||||
"dict_to_elements": base.dict_to_elements,
|
||||
"stage_csv_for_prodigy": lambda els,
|
||||
**kw: prodigy.stage_csv_for_prodigy(els, kw.get('metadata', [])),
|
||||
"stage_for_prodigy": lambda els, **kw: prodigy.stage_for_prodigy(
|
||||
els, kw.get('metadata', [])
|
||||
),
|
||||
"stage_for_baseplate": baseplate.stage_for_baseplate,
|
||||
"stage_for_datasaur": lambda els,
|
||||
**kw: datasaur.stage_for_datasaur(els, kw.get('entities', [])),
|
||||
"stage_for_label_box": lambda els,
|
||||
**kw: label_box.stage_for_label_box(els, **kw),
|
||||
"stage_for_label_studio": lambda els,
|
||||
**kw: label_studio.stage_for_label_studio(els, **kw),
|
||||
"stage_for_weaviate": weaviate.stage_for_weaviate,
|
||||
}
|
||||
|
||||
if stage_type not in staging_functions:
|
||||
raise ValueError(f"Unsupported stage type: {stage_type}")
|
||||
|
||||
return staging_functions[stage_type](elements, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def chunk_elements(
|
||||
elements: List["Element"], chunk_type: str, **kwargs
|
||||
) -> List["Element"]:
|
||||
r"""Chunks elements by titles.
|
||||
|
||||
Args:
|
||||
elements (List[Element]): List of Element objects to be chunked.
|
||||
chunk_type (str): Type chunk going to apply. Supported types:
|
||||
'chunk_by_title'.
|
||||
**kwargs: Additional keyword arguments for chunking.
|
||||
|
||||
Returns:
|
||||
List[Dict]: List of chunked sections.
|
||||
|
||||
References:
|
||||
https://unstructured-io.github.io/unstructured/
|
||||
"""
|
||||
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
|
||||
chunking_functions = {
|
||||
"chunk_by_title": chunk_by_title,
|
||||
}
|
||||
|
||||
if chunk_type not in chunking_functions:
|
||||
raise ValueError(f"Unsupported chunk type: {chunk_type}")
|
||||
|
||||
# Format chunks into a list of dictionaries (or your preferred format)
|
||||
return chunking_functions[chunk_type](elements, **kwargs)
|
||||
@@ -1,112 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Create a private logger
|
||||
_logger = logging.getLogger('camel')
|
||||
|
||||
|
||||
def _configure_library_logging():
|
||||
if os.environ.get('CAMEL_LOGGING_DISABLED', 'False').lower() == 'true':
|
||||
return
|
||||
|
||||
if not logging.root.handlers and not _logger.handlers:
|
||||
logging.basicConfig(
|
||||
level=os.environ.get('LOGLEVEL', 'INFO').upper(),
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
stream=sys.stdout,
|
||||
)
|
||||
logging.setLoggerClass(logging.Logger)
|
||||
_logger.info("Camel library logging has been configured.")
|
||||
else:
|
||||
_logger.debug("Existing logger configuration found, using that.")
|
||||
|
||||
|
||||
def disable_logging():
|
||||
r"""Disable all logging for the Camel library.
|
||||
|
||||
This function sets the log level to a value higher than CRITICAL,
|
||||
effectively disabling all log messages, and adds a NullHandler to
|
||||
suppress any potential warnings about no handlers being found.
|
||||
"""
|
||||
os.environ['CAMEL_LOGGING_DISABLED'] = 'true'
|
||||
_logger.setLevel(logging.CRITICAL + 1)
|
||||
# Avoid adding multiple NullHandlers
|
||||
if not any(
|
||||
isinstance(handler, logging.NullHandler)
|
||||
for handler in _logger.handlers
|
||||
):
|
||||
_logger.addHandler(logging.NullHandler())
|
||||
_logger.debug("Logging has been disabled.")
|
||||
|
||||
|
||||
def enable_logging():
|
||||
r"""Enable logging for the Camel library.
|
||||
|
||||
This function re-enables logging if it was previously disabled,
|
||||
and configures the library logging using the default settings.
|
||||
If the logging is already configured,
|
||||
this function does not change its configuration.
|
||||
"""
|
||||
os.environ['CAMEL_LOGGING_DISABLED'] = 'false'
|
||||
_configure_library_logging()
|
||||
|
||||
|
||||
def set_log_level(level):
|
||||
r"""Set the logging level for the Camel library.
|
||||
|
||||
Args:
|
||||
level (Union[str, int]): The logging level to set. This can be a string
|
||||
(e.g., 'INFO') or a logging level constant (e.g., logging.INFO,
|
||||
logging.DEBUG).
|
||||
See https://docs.python.org/3/library/logging.html#levels
|
||||
|
||||
Raises:
|
||||
ValueError: If the provided level is not a valid logging level.
|
||||
"""
|
||||
valid_levels = ['NOTSET', 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']
|
||||
if isinstance(level, str):
|
||||
if level.upper() not in valid_levels:
|
||||
raise ValueError(
|
||||
f"Invalid logging level."
|
||||
f" Choose from: {', '.join(valid_levels)}"
|
||||
)
|
||||
level = level.upper()
|
||||
elif not isinstance(level, int):
|
||||
raise ValueError(
|
||||
"Logging level must be an option from the logging module."
|
||||
)
|
||||
|
||||
_logger.setLevel(level)
|
||||
_logger.debug(f"Logging level set to: {logging.getLevelName(level)}")
|
||||
|
||||
|
||||
def get_logger(name):
|
||||
r"""Get a logger with the specified name, prefixed with 'camel.'.
|
||||
|
||||
Args:
|
||||
name (str): The name to be appended to 'camel.' to create the logger.
|
||||
|
||||
Returns:
|
||||
logging.Logger: A logger instance with the name 'camel.{name}'.
|
||||
"""
|
||||
return logging.getLogger(f'camel.{name}')
|
||||
|
||||
|
||||
# Lazy configuration: Only configure logging if explicitly enabled.
|
||||
if os.environ.get('CAMEL_LOGGING_DISABLED', 'False').strip().lower() != 'true':
|
||||
_configure_library_logging()
|
||||
@@ -1,38 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
from .agent_memories import (
|
||||
ChatHistoryMemory,
|
||||
LongtermAgentMemory,
|
||||
VectorDBMemory,
|
||||
)
|
||||
from .base import AgentMemory, BaseContextCreator, MemoryBlock
|
||||
from .blocks.chat_history_block import ChatHistoryBlock
|
||||
from .blocks.vectordb_block import VectorDBBlock
|
||||
from .context_creators.score_based import ScoreBasedContextCreator
|
||||
from .records import ContextRecord, MemoryRecord
|
||||
|
||||
__all__ = [
|
||||
'MemoryRecord',
|
||||
'ContextRecord',
|
||||
'MemoryBlock',
|
||||
"AgentMemory",
|
||||
'BaseContextCreator',
|
||||
'ScoreBasedContextCreator',
|
||||
'ChatHistoryMemory',
|
||||
'VectorDBMemory',
|
||||
'ChatHistoryBlock',
|
||||
'VectorDBBlock',
|
||||
'LongtermAgentMemory',
|
||||
]
|
||||
@@ -1,176 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from camel.memories.base import AgentMemory, BaseContextCreator
|
||||
from camel.memories.blocks import ChatHistoryBlock, VectorDBBlock
|
||||
from camel.memories.records import ContextRecord, MemoryRecord
|
||||
from camel.storages import BaseKeyValueStorage, BaseVectorStorage
|
||||
from camel.types import OpenAIBackendRole
|
||||
|
||||
|
||||
class ChatHistoryMemory(AgentMemory):
|
||||
r"""An agent memory wrapper of :obj:`ChatHistoryBlock`.
|
||||
|
||||
Args:
|
||||
context_creator (BaseContextCreator): A model context creator.
|
||||
storage (BaseKeyValueStorage, optional): A storage backend for storing
|
||||
chat history. If `None`, an :obj:`InMemoryKeyValueStorage`
|
||||
will be used. (default: :obj:`None`)
|
||||
window_size (int, optional): The number of recent chat messages to
|
||||
retrieve. If not provided, the entire chat history will be
|
||||
retrieved. (default: :obj:`None`)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context_creator: BaseContextCreator,
|
||||
storage: Optional[BaseKeyValueStorage] = None,
|
||||
window_size: Optional[int] = None,
|
||||
) -> None:
|
||||
if window_size is not None and not isinstance(window_size, int):
|
||||
raise TypeError("`window_size` must be an integer or None.")
|
||||
if window_size is not None and window_size < 0:
|
||||
raise ValueError("`window_size` must be non-negative.")
|
||||
self._context_creator = context_creator
|
||||
self._window_size = window_size
|
||||
self._chat_history_block = ChatHistoryBlock(storage=storage)
|
||||
|
||||
def retrieve(self) -> List[ContextRecord]:
|
||||
return self._chat_history_block.retrieve(self._window_size)
|
||||
|
||||
def write_records(self, records: List[MemoryRecord]) -> None:
|
||||
self._chat_history_block.write_records(records)
|
||||
|
||||
def get_context_creator(self) -> BaseContextCreator:
|
||||
return self._context_creator
|
||||
|
||||
def clear(self) -> None:
|
||||
self._chat_history_block.clear()
|
||||
|
||||
|
||||
class VectorDBMemory(AgentMemory):
|
||||
r"""An agent memory wrapper of :obj:`VectorDBBlock`. This memory queries
|
||||
messages stored in the vector database. Notice that the most recent
|
||||
messages will not be added to the context.
|
||||
|
||||
Args:
|
||||
context_creator (BaseContextCreator): A model context creator.
|
||||
storage (BaseVectorStorage, optional): A vector storage storage. If
|
||||
`None`, an :obj:`QdrantStorage` will be used.
|
||||
(default: :obj:`None`)
|
||||
retrieve_limit (int, optional): The maximum number of messages
|
||||
to be added into the context. (default: :obj:`3`)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context_creator: BaseContextCreator,
|
||||
storage: Optional[BaseVectorStorage] = None,
|
||||
retrieve_limit: int = 3,
|
||||
) -> None:
|
||||
self._context_creator = context_creator
|
||||
self._retrieve_limit = retrieve_limit
|
||||
self._vectordb_block = VectorDBBlock(storage=storage)
|
||||
|
||||
self._current_topic: str = ""
|
||||
|
||||
def retrieve(self) -> List[ContextRecord]:
|
||||
return self._vectordb_block.retrieve(
|
||||
self._current_topic,
|
||||
limit=self._retrieve_limit,
|
||||
)
|
||||
|
||||
def write_records(self, records: List[MemoryRecord]) -> None:
|
||||
# Assume the last user input is the current topic.
|
||||
for record in records:
|
||||
if record.role_at_backend == OpenAIBackendRole.USER:
|
||||
self._current_topic = record.message.content
|
||||
self._vectordb_block.write_records(records)
|
||||
|
||||
def get_context_creator(self) -> BaseContextCreator:
|
||||
return self._context_creator
|
||||
|
||||
|
||||
class LongtermAgentMemory(AgentMemory):
|
||||
r"""An implementation of the :obj:`AgentMemory` abstract base class for
|
||||
augmenting ChatHistoryMemory with VectorDBMemory.
|
||||
|
||||
Args:
|
||||
context_creator (BaseContextCreator): A model context creator.
|
||||
chat_history_block (Optional[ChatHistoryBlock], optional): A chat
|
||||
history block. If `None`, a :obj:`ChatHistoryBlock` will be used.
|
||||
(default: :obj:`None`)
|
||||
vector_db_block (Optional[VectorDBBlock], optional): A vector database
|
||||
block. If `None`, a :obj:`VectorDBBlock` will be used.
|
||||
(default: :obj:`None`)
|
||||
retrieve_limit (int, optional): The maximum number of messages
|
||||
to be added into the context. (default: :obj:`3`)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context_creator: BaseContextCreator,
|
||||
chat_history_block: Optional[ChatHistoryBlock] = None,
|
||||
vector_db_block: Optional[VectorDBBlock] = None,
|
||||
retrieve_limit: int = 3,
|
||||
) -> None:
|
||||
self.chat_history_block = chat_history_block or ChatHistoryBlock()
|
||||
self.vector_db_block = vector_db_block or VectorDBBlock()
|
||||
self.retrieve_limit = retrieve_limit
|
||||
self._context_creator = context_creator
|
||||
self._current_topic: str = ""
|
||||
|
||||
def get_context_creator(self) -> BaseContextCreator:
|
||||
r"""Returns the context creator used by the memory.
|
||||
|
||||
Returns:
|
||||
BaseContextCreator: The context creator used by the memory.
|
||||
"""
|
||||
return self._context_creator
|
||||
|
||||
def retrieve(self) -> List[ContextRecord]:
|
||||
r"""Retrieves context records from both the chat history and the vector
|
||||
database.
|
||||
|
||||
Returns:
|
||||
List[ContextRecord]: A list of context records retrieved from both
|
||||
the chat history and the vector database.
|
||||
"""
|
||||
chat_history = self.chat_history_block.retrieve()
|
||||
vector_db_retrieve = self.vector_db_block.retrieve(
|
||||
self._current_topic, self.retrieve_limit
|
||||
)
|
||||
return chat_history[:1] + vector_db_retrieve + chat_history[1:]
|
||||
|
||||
def write_records(self, records: List[MemoryRecord]) -> None:
|
||||
r"""Converts the provided chat messages into vector representations and
|
||||
writes them to the vector database.
|
||||
|
||||
Args:
|
||||
records (List[MemoryRecord]): Messages to be added to the vector
|
||||
database.
|
||||
"""
|
||||
self.vector_db_block.write_records(records)
|
||||
self.chat_history_block.write_records(records)
|
||||
|
||||
for record in records:
|
||||
if record.role_at_backend == OpenAIBackendRole.USER:
|
||||
self._current_topic = record.message.content
|
||||
|
||||
def clear(self) -> None:
|
||||
r"""Removes all records from the memory."""
|
||||
self.chat_history_block.clear()
|
||||
self.vector_db_block.clear()
|
||||
@@ -1,140 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Tuple
|
||||
|
||||
from camel.memories.records import ContextRecord, MemoryRecord
|
||||
from camel.messages import OpenAIMessage
|
||||
from camel.utils import BaseTokenCounter
|
||||
|
||||
|
||||
class MemoryBlock(ABC):
|
||||
r"""An abstract class serves as the fundamental component within the agent
|
||||
memory system. This class is equipped with "write" and "clear" functions.
|
||||
However, it intentionally does not define a retrieval interface, as the
|
||||
structure of the data to be retrieved may vary in different types of
|
||||
memory blocks.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def write_records(self, records: List[MemoryRecord]) -> None:
|
||||
r"""Writes records to the memory, appending them to existing ones.
|
||||
|
||||
Args:
|
||||
records (List[MemoryRecord]): Records to be added to the memory.
|
||||
"""
|
||||
pass
|
||||
|
||||
def write_record(self, record: MemoryRecord) -> None:
|
||||
r"""Writes a record to the memory, appending it to existing ones.
|
||||
|
||||
Args:
|
||||
record (MemoryRecord): Record to be added to the memory.
|
||||
"""
|
||||
self.write_records([record])
|
||||
|
||||
@abstractmethod
|
||||
def clear(self) -> None:
|
||||
r"""Clears all messages from the memory."""
|
||||
pass
|
||||
|
||||
|
||||
class BaseContextCreator(ABC):
|
||||
r"""An abstract base class defining the interface for context creation
|
||||
strategies.
|
||||
|
||||
This class provides a foundational structure for different strategies to
|
||||
generate conversational context from a list of context records. The
|
||||
primary goal is to create a context that is aligned with a specified token
|
||||
count limit, allowing subclasses to define their specific approach.
|
||||
|
||||
Subclasses should implement the :obj:`token_counter`,:obj: `token_limit`,
|
||||
and :obj:`create_context` methods to provide specific context creation
|
||||
logic.
|
||||
|
||||
Attributes:
|
||||
token_counter (BaseTokenCounter): A token counter instance responsible
|
||||
for counting tokens in a message.
|
||||
token_limit (int): The maximum number of tokens allowed in the
|
||||
generated context.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def token_counter(self) -> BaseTokenCounter:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def token_limit(self) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_context(
|
||||
self,
|
||||
records: List[ContextRecord],
|
||||
) -> Tuple[List[OpenAIMessage], int]:
|
||||
r"""An abstract method to create conversational context from the chat
|
||||
history.
|
||||
|
||||
Constructs the context from provided records. The specifics of how this
|
||||
is done and how the token count is managed should be provided by
|
||||
subclasses implementing this method. The output messages order
|
||||
should keep same as the input order.
|
||||
|
||||
Args:
|
||||
records (List[ContextRecord]): A list of context records from
|
||||
which to generate the context.
|
||||
|
||||
Returns:
|
||||
Tuple[List[OpenAIMessage], int]: A tuple containing the constructed
|
||||
context in OpenAIMessage format and the total token count.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class AgentMemory(MemoryBlock, ABC):
|
||||
r"""Represents a specialized form of `MemoryBlock`, uniquely designed for
|
||||
direct integration with an agent. Two key abstract functions, "retrieve"
|
||||
and "get_context_creator", are used for generating model context based on
|
||||
the memory records stored within the AgentMemory.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def retrieve(self) -> List[ContextRecord]:
|
||||
r"""Get a record list from the memory for creating model context.
|
||||
|
||||
Returns:
|
||||
List[ContextRecord]: A record list for creating model context.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_context_creator(self) -> BaseContextCreator:
|
||||
r"""Gets context creator.
|
||||
|
||||
Returns:
|
||||
BaseContextCreator: A model context creator.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_context(self) -> Tuple[List[OpenAIMessage], int]:
|
||||
r"""Gets chat context with a proper size for the agent from the memory.
|
||||
|
||||
Returns:
|
||||
(List[OpenAIMessage], int): A tuple containing the constructed
|
||||
context in OpenAIMessage format and the total token count.
|
||||
"""
|
||||
return self.get_context_creator().create_context(self.retrieve())
|
||||
@@ -1,21 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
from .chat_history_block import ChatHistoryBlock
|
||||
from .vectordb_block import VectorDBBlock
|
||||
|
||||
__all__ = [
|
||||
'ChatHistoryBlock',
|
||||
'VectorDBBlock',
|
||||
]
|
||||
@@ -1,115 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
import warnings
|
||||
from typing import List, Optional
|
||||
|
||||
from camel.memories.base import MemoryBlock
|
||||
from camel.memories.records import ContextRecord, MemoryRecord
|
||||
from camel.storages import BaseKeyValueStorage, InMemoryKeyValueStorage
|
||||
from camel.types import OpenAIBackendRole
|
||||
|
||||
|
||||
class ChatHistoryBlock(MemoryBlock):
|
||||
r"""An implementation of the :obj:`MemoryBlock` abstract base class for
|
||||
maintaining a record of chat histories.
|
||||
|
||||
This memory block helps manage conversation histories with a key-value
|
||||
storage backend, either provided by the user or using a default
|
||||
in-memory storage. It offers a windowed approach to retrieving chat
|
||||
histories, allowing users to specify how many recent messages they'd
|
||||
like to fetch.
|
||||
|
||||
Args:
|
||||
storage (BaseKeyValueStorage, optional): A storage mechanism for
|
||||
storing chat history. If `None`, an :obj:`InMemoryKeyValueStorage`
|
||||
will be used. (default: :obj:`None`)
|
||||
keep_rate (float, optional): In historical messages, the score of the
|
||||
last message is 1.0, and with each step taken backward, the score
|
||||
of the message is multiplied by the `keep_rate`. Higher `keep_rate`
|
||||
leads to high possiblity to keep history messages during context
|
||||
creation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage: Optional[BaseKeyValueStorage] = None,
|
||||
keep_rate: float = 0.9,
|
||||
) -> None:
|
||||
if keep_rate > 1 or keep_rate < 0:
|
||||
raise ValueError("`keep_rate` should be in [0,1]")
|
||||
self.storage = storage or InMemoryKeyValueStorage()
|
||||
self.keep_rate = keep_rate
|
||||
|
||||
def retrieve(
|
||||
self,
|
||||
window_size: Optional[int] = None,
|
||||
) -> List[ContextRecord]:
|
||||
r"""Retrieves records with a proper size for the agent from the memory
|
||||
based on the window size or fetches the entire chat history if no
|
||||
window size is specified.
|
||||
|
||||
Args:
|
||||
window_size (int, optional): Specifies the number of recent chat
|
||||
messages to retrieve. If not provided, the entire chat history
|
||||
will be retrieved. (default: :obj:`None`)
|
||||
|
||||
Returns:
|
||||
List[ContextRecord]: A list of retrieved records.
|
||||
"""
|
||||
record_dicts = self.storage.load()
|
||||
if len(record_dicts) == 0:
|
||||
warnings.warn("The `ChatHistoryMemory` is empty.")
|
||||
return list()
|
||||
|
||||
chat_records: List[MemoryRecord] = []
|
||||
truncate_idx = -window_size if window_size is not None else 0
|
||||
for record_dict in record_dicts[truncate_idx:]:
|
||||
chat_records.append(MemoryRecord.from_dict(record_dict))
|
||||
|
||||
# We assume that, in the chat history memory, the closer the record is
|
||||
# to the current message, the more score it will be.
|
||||
output_records = []
|
||||
score = 1.0
|
||||
for record in reversed(chat_records):
|
||||
if record.role_at_backend == OpenAIBackendRole.SYSTEM:
|
||||
# System messages are always kept.
|
||||
output_records.append(
|
||||
ContextRecord(memory_record=record, score=1.0)
|
||||
)
|
||||
else:
|
||||
# Other messages' score drops down gradually
|
||||
score *= self.keep_rate
|
||||
output_records.append(
|
||||
ContextRecord(memory_record=record, score=score)
|
||||
)
|
||||
|
||||
output_records.reverse()
|
||||
return output_records
|
||||
|
||||
def write_records(self, records: List[MemoryRecord]) -> None:
|
||||
r"""Writes memory records to the memory. Additionally, performs
|
||||
validation checks on the messages.
|
||||
|
||||
Args:
|
||||
records (List[MemoryRecord]): Memory records to be added to the
|
||||
memory.
|
||||
"""
|
||||
stored_records = []
|
||||
for record in records:
|
||||
stored_records.append(record.to_dict())
|
||||
self.storage.save(stored_records)
|
||||
|
||||
def clear(self) -> None:
|
||||
r"""Clears all chat messages from the memory."""
|
||||
self.storage.clear()
|
||||
@@ -1,103 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from camel.embeddings import BaseEmbedding, OpenAIEmbedding
|
||||
from camel.memories.base import MemoryBlock
|
||||
from camel.memories.records import ContextRecord, MemoryRecord
|
||||
from camel.storages.vectordb_storages import (
|
||||
BaseVectorStorage,
|
||||
QdrantStorage,
|
||||
VectorDBQuery,
|
||||
VectorRecord,
|
||||
)
|
||||
|
||||
|
||||
class VectorDBBlock(MemoryBlock):
|
||||
r"""An implementation of the :obj:`MemoryBlock` abstract base class for
|
||||
maintaining and retrieving information using vector embeddings within a
|
||||
vector database.
|
||||
|
||||
Args:
|
||||
storage (Optional[BaseVectorStorage], optional): The storage mechanism
|
||||
for the vector database. Defaults to in-memory :obj:`Qdrant` if not
|
||||
provided. (default: :obj:`None`)
|
||||
embedding (Optional[BaseEmbedding], optional): Embedding mechanism to
|
||||
convert chat messages into vector representations. Defaults to
|
||||
:obj:`OpenAiEmbedding` if not provided. (default: :obj:`None`)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage: Optional[BaseVectorStorage] = None,
|
||||
embedding: Optional[BaseEmbedding] = None,
|
||||
) -> None:
|
||||
self.embedding = embedding or OpenAIEmbedding()
|
||||
self.vector_dim = self.embedding.get_output_dim()
|
||||
self.storage = storage or QdrantStorage(vector_dim=self.vector_dim)
|
||||
|
||||
def retrieve(
|
||||
self,
|
||||
keyword: str,
|
||||
limit: int = 3,
|
||||
) -> List[ContextRecord]:
|
||||
r"""Retrieves similar records from the vector database based on the
|
||||
content of the keyword.
|
||||
|
||||
Args:
|
||||
keyword (str): This string will be converted into a vector
|
||||
representation to query the database.
|
||||
limit (int, optional): The maximum number of similar messages to
|
||||
retrieve. (default: :obj:`3`).
|
||||
|
||||
Returns:
|
||||
List[ContextRecord]: A list of memory records retrieved from the
|
||||
vector database based on similarity to :obj:`current_state`.
|
||||
"""
|
||||
query_vector = self.embedding.embed(keyword)
|
||||
results = self.storage.query(
|
||||
VectorDBQuery(query_vector=query_vector, top_k=limit)
|
||||
)
|
||||
return [
|
||||
ContextRecord(
|
||||
memory_record=MemoryRecord.from_dict(result.record.payload),
|
||||
score=result.similarity,
|
||||
)
|
||||
for result in results
|
||||
if result.record.payload is not None
|
||||
]
|
||||
|
||||
def write_records(self, records: List[MemoryRecord]) -> None:
|
||||
"""
|
||||
Converts the provided chat messages into vector representations and
|
||||
writes them to the vector database.
|
||||
|
||||
Args:
|
||||
records (List[MemoryRecord]): Memory records to be added to the
|
||||
memory.
|
||||
"""
|
||||
v_records = [
|
||||
VectorRecord(
|
||||
vector=self.embedding.embed(record.message.content),
|
||||
payload=record.to_dict(),
|
||||
id=str(record.uuid),
|
||||
)
|
||||
for record in records
|
||||
]
|
||||
self.storage.add(v_records)
|
||||
|
||||
def clear(self) -> None:
|
||||
r"""Removes all records from the vector database memory."""
|
||||
self.storage.clear()
|
||||
@@ -1,19 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
from .score_based import ScoreBasedContextCreator
|
||||
|
||||
__all__ = [
|
||||
'ScoreBasedContextCreator',
|
||||
]
|
||||
@@ -1,142 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from typing import List, Tuple
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from camel.memories.base import BaseContextCreator
|
||||
from camel.memories.records import ContextRecord
|
||||
from camel.messages import OpenAIMessage
|
||||
from camel.utils import BaseTokenCounter
|
||||
|
||||
|
||||
class _ContextUnit(BaseModel):
|
||||
idx: int
|
||||
record: ContextRecord
|
||||
num_tokens: int
|
||||
|
||||
|
||||
class ScoreBasedContextCreator(BaseContextCreator):
|
||||
r"""A default implementation of context creation strategy, which inherits
|
||||
from :obj:`BaseContextCreator`.
|
||||
|
||||
This class provides a strategy to generate a conversational context from
|
||||
a list of chat history records while ensuring the total token count of
|
||||
the context does not exceed a specified limit. It prunes messages based
|
||||
on their score if the total token count exceeds the limit.
|
||||
|
||||
Args:
|
||||
token_counter (BaseTokenCounter): An instance responsible for counting
|
||||
tokens in a message.
|
||||
token_limit (int): The maximum number of tokens allowed in the
|
||||
generated context.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, token_counter: BaseTokenCounter, token_limit: int
|
||||
) -> None:
|
||||
self._token_counter = token_counter
|
||||
self._token_limit = token_limit
|
||||
|
||||
@property
|
||||
def token_counter(self) -> BaseTokenCounter:
|
||||
return self._token_counter
|
||||
|
||||
@property
|
||||
def token_limit(self) -> int:
|
||||
return self._token_limit
|
||||
|
||||
def create_context(
|
||||
self,
|
||||
records: List[ContextRecord],
|
||||
) -> Tuple[List[OpenAIMessage], int]:
|
||||
r"""Creates conversational context from chat history while respecting
|
||||
token limits.
|
||||
|
||||
Constructs the context from provided records and ensures that the total
|
||||
token count does not exceed the specified limit by pruning the least
|
||||
score messages if necessary.
|
||||
|
||||
Args:
|
||||
records (List[ContextRecord]): A list of message records from which
|
||||
to generate the context.
|
||||
|
||||
Returns:
|
||||
Tuple[List[OpenAIMessage], int]: A tuple containing the constructed
|
||||
context in OpenAIMessage format and the total token count.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If it's impossible to create a valid context without
|
||||
exceeding the token limit.
|
||||
"""
|
||||
# Create unique context units list
|
||||
uuid_set = set()
|
||||
context_units = []
|
||||
for idx, record in enumerate(records):
|
||||
if record.memory_record.uuid not in uuid_set:
|
||||
uuid_set.add(record.memory_record.uuid)
|
||||
context_units.append(
|
||||
_ContextUnit(
|
||||
idx=idx,
|
||||
record=record,
|
||||
num_tokens=self.token_counter.count_tokens_from_messages(
|
||||
[record.memory_record.to_openai_message()]
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# TODO: optimize the process, may give information back to memory
|
||||
|
||||
# If not exceed token limit, simply return
|
||||
total_tokens = sum([unit.num_tokens for unit in context_units])
|
||||
if total_tokens <= self.token_limit:
|
||||
return self._create_output(context_units)
|
||||
|
||||
# Sort by score
|
||||
context_units = sorted(
|
||||
context_units, key=lambda unit: unit.record.score
|
||||
)
|
||||
|
||||
# Remove the least score messages until total token number is smaller
|
||||
# than token limit
|
||||
truncate_idx = None
|
||||
for i, unit in enumerate(context_units):
|
||||
if unit.record.score == 1:
|
||||
raise RuntimeError(
|
||||
"Cannot create context: exceed token limit.", total_tokens
|
||||
)
|
||||
total_tokens -= unit.num_tokens
|
||||
if total_tokens <= self.token_limit:
|
||||
truncate_idx = i
|
||||
break
|
||||
if truncate_idx is None:
|
||||
raise RuntimeError(
|
||||
"Cannot create context: exceed token limit.", total_tokens
|
||||
)
|
||||
return self._create_output(context_units[truncate_idx + 1 :])
|
||||
|
||||
def _create_output(
|
||||
self, context_units: List[_ContextUnit]
|
||||
) -> Tuple[List[OpenAIMessage], int]:
|
||||
r"""Helper method to generate output from context units.
|
||||
|
||||
This method converts the provided context units into a format suitable
|
||||
for output, specifically a list of OpenAIMessages and an integer
|
||||
representing the total token count.
|
||||
"""
|
||||
context_units = sorted(context_units, key=lambda unit: unit.idx)
|
||||
return [
|
||||
unit.record.memory_record.to_openai_message()
|
||||
for unit in context_units
|
||||
], sum([unit.num_tokens for unit in context_units])
|
||||
@@ -1,95 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
from dataclasses import asdict
|
||||
from typing import Any, ClassVar, Dict
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from camel.messages import BaseMessage, FunctionCallingMessage, OpenAIMessage
|
||||
from camel.types import OpenAIBackendRole
|
||||
|
||||
|
||||
class MemoryRecord(BaseModel):
|
||||
r"""The basic message storing unit in the CAMEL memory system.
|
||||
|
||||
Attributes:
|
||||
message (BaseMessage): The main content of the record.
|
||||
role_at_backend (OpenAIBackendRole): An enumeration value representing
|
||||
the role this message played at the OpenAI backend. Note that this
|
||||
value is different from the :obj:`RoleType` used in the CAMEL role
|
||||
playing system.
|
||||
uuid (UUID, optional): A universally unique identifier for this record.
|
||||
This is used to uniquely identify this record in the memory system.
|
||||
If not given, it will be assigned with a random UUID.
|
||||
extra_info (Dict[str, str], optional): A dictionary of additional
|
||||
key-value pairs that provide more information. If not given, it
|
||||
will be an empty `Dict`.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
message: BaseMessage
|
||||
role_at_backend: OpenAIBackendRole
|
||||
uuid: UUID = Field(default_factory=uuid4)
|
||||
extra_info: Dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
_MESSAGE_TYPES: ClassVar[dict] = {
|
||||
"BaseMessage": BaseMessage,
|
||||
"FunctionCallingMessage": FunctionCallingMessage,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, record_dict: Dict[str, Any]) -> "MemoryRecord":
|
||||
r"""Reconstruct a :obj:`MemoryRecord` from the input dict.
|
||||
|
||||
Args:
|
||||
record_dict(Dict[str, Any]): A dict generated by :meth:`to_dict`.
|
||||
"""
|
||||
message_cls = cls._MESSAGE_TYPES[record_dict["message"]["__class__"]]
|
||||
kwargs: Dict = record_dict["message"].copy()
|
||||
kwargs.pop("__class__")
|
||||
reconstructed_message = message_cls(**kwargs)
|
||||
return cls(
|
||||
uuid=UUID(record_dict["uuid"]),
|
||||
message=reconstructed_message,
|
||||
role_at_backend=record_dict["role_at_backend"],
|
||||
extra_info=record_dict["extra_info"],
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
r"""Convert the :obj:`MemoryRecord` to a dict for serialization
|
||||
purposes.
|
||||
"""
|
||||
return {
|
||||
"uuid": str(self.uuid),
|
||||
"message": {
|
||||
"__class__": self.message.__class__.__name__,
|
||||
**asdict(self.message),
|
||||
},
|
||||
"role_at_backend": self.role_at_backend,
|
||||
"extra_info": self.extra_info,
|
||||
}
|
||||
|
||||
def to_openai_message(self) -> OpenAIMessage:
|
||||
r"""Converts the record to an :obj:`OpenAIMessage` object."""
|
||||
return self.message.to_openai_message(self.role_at_backend)
|
||||
|
||||
|
||||
class ContextRecord(BaseModel):
|
||||
r"""The result of memory retrieving."""
|
||||
|
||||
memory_record: MemoryRecord
|
||||
score: float
|
||||
@@ -1,63 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from typing import Union
|
||||
|
||||
from camel.types import (
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionSystemMessageParam,
|
||||
ChatCompletionToolMessageParam,
|
||||
ChatCompletionUserMessageParam,
|
||||
)
|
||||
|
||||
from .conversion import (
|
||||
AlpacaItem,
|
||||
HermesFunctionFormatter,
|
||||
ShareGPTMessage,
|
||||
)
|
||||
from .conversion.conversation_models import (
|
||||
ShareGPTConversation,
|
||||
)
|
||||
from .conversion.sharegpt.function_call_formatter import (
|
||||
FunctionCallFormatter,
|
||||
)
|
||||
|
||||
OpenAISystemMessage = ChatCompletionSystemMessageParam
|
||||
OpenAIAssistantMessage = Union[
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionToolMessageParam,
|
||||
]
|
||||
OpenAIUserMessage = ChatCompletionUserMessageParam
|
||||
OpenAIToolMessageParam = ChatCompletionToolMessageParam
|
||||
|
||||
OpenAIMessage = ChatCompletionMessageParam
|
||||
|
||||
|
||||
from .base import BaseMessage # noqa: E402
|
||||
from .func_message import FunctionCallingMessage # noqa: E402
|
||||
|
||||
__all__ = [
|
||||
'OpenAISystemMessage',
|
||||
'OpenAIAssistantMessage',
|
||||
'OpenAIUserMessage',
|
||||
'OpenAIToolMessageParam',
|
||||
'OpenAIMessage',
|
||||
'FunctionCallFormatter',
|
||||
'HermesFunctionFormatter',
|
||||
'ShareGPTConversation',
|
||||
'ShareGPTMessage',
|
||||
'BaseMessage',
|
||||
'FunctionCallingMessage',
|
||||
'AlpacaItem',
|
||||
]
|
||||
@@ -1,541 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
import base64
|
||||
import io
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel
|
||||
|
||||
from camel.messages import (
|
||||
FunctionCallFormatter,
|
||||
HermesFunctionFormatter,
|
||||
OpenAIAssistantMessage,
|
||||
OpenAIMessage,
|
||||
OpenAISystemMessage,
|
||||
OpenAIUserMessage,
|
||||
)
|
||||
from camel.messages.conversion import ShareGPTMessage
|
||||
from camel.prompts import CodePrompt, TextPrompt
|
||||
from camel.types import (
|
||||
OpenAIBackendRole,
|
||||
OpenAIImageType,
|
||||
OpenAIVisionDetailType,
|
||||
RoleType,
|
||||
)
|
||||
from camel.utils import Constants
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseMessage:
|
||||
r"""Base class for message objects used in CAMEL chat system.
|
||||
|
||||
Args:
|
||||
role_name (str): The name of the user or assistant role.
|
||||
role_type (RoleType): The type of role, either :obj:`RoleType.
|
||||
ASSISTANT` or :obj:`RoleType.USER`.
|
||||
meta_dict (Optional[Dict[str, str]]): Additional metadata dictionary
|
||||
for the message.
|
||||
content (str): The content of the message.
|
||||
video_bytes (Optional[bytes]): Optional bytes of a video associated
|
||||
with the message. (default: :obj:`None`)
|
||||
image_list (Optional[List[Image.Image]]): Optional list of PIL Image
|
||||
objects associated with the message. (default: :obj:`None`)
|
||||
image_detail (Literal["auto", "low", "high"]): Detail level of the
|
||||
images associated with the message. (default: :obj:`auto`)
|
||||
video_detail (Literal["auto", "low", "high"]): Detail level of the
|
||||
videos associated with the message. (default: :obj:`low`)
|
||||
parsed: Optional[Union[Type[BaseModel], dict]]: Optional object which
|
||||
is parsed from the content. (default: :obj:`None`)
|
||||
"""
|
||||
|
||||
role_name: str
|
||||
role_type: RoleType
|
||||
meta_dict: Optional[Dict[str, Any]]
|
||||
content: str
|
||||
|
||||
video_bytes: Optional[bytes] = None
|
||||
image_list: Optional[List[Image.Image]] = None
|
||||
image_detail: Literal["auto", "low", "high"] = "auto"
|
||||
video_detail: Literal["auto", "low", "high"] = "low"
|
||||
parsed: Optional[Union[Type[BaseModel], dict]] = None
|
||||
|
||||
@classmethod
|
||||
def make_user_message(
|
||||
cls,
|
||||
role_name: str,
|
||||
content: str,
|
||||
meta_dict: Optional[Dict[str, str]] = None,
|
||||
video_bytes: Optional[bytes] = None,
|
||||
image_list: Optional[List[Image.Image]] = None,
|
||||
image_detail: Union[
|
||||
OpenAIVisionDetailType, str
|
||||
] = OpenAIVisionDetailType.AUTO,
|
||||
video_detail: Union[
|
||||
OpenAIVisionDetailType, str
|
||||
] = OpenAIVisionDetailType.LOW,
|
||||
) -> "BaseMessage":
|
||||
r"""Create a new user message.
|
||||
|
||||
Args:
|
||||
role_name (str): The name of the user role.
|
||||
content (str): The content of the message.
|
||||
meta_dict (Optional[Dict[str, str]]): Additional metadata
|
||||
dictionary for the message.
|
||||
video_bytes (Optional[bytes]): Optional bytes of a video
|
||||
associated with the message.
|
||||
image_list (Optional[List[Image.Image]]): Optional list of PIL
|
||||
Image objects associated with the message.
|
||||
image_detail (Union[OpenAIVisionDetailType, str]): Detail level of
|
||||
the images associated with the message.
|
||||
video_detail (Union[OpenAIVisionDetailType, str]): Detail level of
|
||||
the videos associated with the message.
|
||||
|
||||
Returns:
|
||||
BaseMessage: The new user message.
|
||||
"""
|
||||
return cls(
|
||||
role_name,
|
||||
RoleType.USER,
|
||||
meta_dict,
|
||||
content,
|
||||
video_bytes,
|
||||
image_list,
|
||||
OpenAIVisionDetailType(image_detail).value,
|
||||
OpenAIVisionDetailType(video_detail).value,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def make_assistant_message(
|
||||
cls,
|
||||
role_name: str,
|
||||
content: str,
|
||||
meta_dict: Optional[Dict[str, str]] = None,
|
||||
video_bytes: Optional[bytes] = None,
|
||||
image_list: Optional[List[Image.Image]] = None,
|
||||
image_detail: Union[
|
||||
OpenAIVisionDetailType, str
|
||||
] = OpenAIVisionDetailType.AUTO,
|
||||
video_detail: Union[
|
||||
OpenAIVisionDetailType, str
|
||||
] = OpenAIVisionDetailType.LOW,
|
||||
) -> "BaseMessage":
|
||||
r"""Create a new assistant message.
|
||||
|
||||
Args:
|
||||
role_name (str): The name of the assistant role.
|
||||
content (str): The content of the message.
|
||||
meta_dict (Optional[Dict[str, str]]): Additional metadata
|
||||
dictionary for the message.
|
||||
video_bytes (Optional[bytes]): Optional bytes of a video
|
||||
associated with the message.
|
||||
image_list (Optional[List[Image.Image]]): Optional list of PIL
|
||||
Image objects associated with the message.
|
||||
image_detail (Union[OpenAIVisionDetailType, str]): Detail level of
|
||||
the images associated with the message.
|
||||
video_detail (Union[OpenAIVisionDetailType, str]): Detail level of
|
||||
the videos associated with the message.
|
||||
|
||||
Returns:
|
||||
BaseMessage: The new assistant message.
|
||||
"""
|
||||
return cls(
|
||||
role_name,
|
||||
RoleType.ASSISTANT,
|
||||
meta_dict,
|
||||
content,
|
||||
video_bytes,
|
||||
image_list,
|
||||
OpenAIVisionDetailType(image_detail).value,
|
||||
OpenAIVisionDetailType(video_detail).value,
|
||||
)
|
||||
|
||||
def create_new_instance(self, content: str) -> "BaseMessage":
|
||||
r"""Create a new instance of the :obj:`BaseMessage` with updated
|
||||
content.
|
||||
|
||||
Args:
|
||||
content (str): The new content value.
|
||||
|
||||
Returns:
|
||||
BaseMessage: The new instance of :obj:`BaseMessage`.
|
||||
"""
|
||||
return self.__class__(
|
||||
role_name=self.role_name,
|
||||
role_type=self.role_type,
|
||||
meta_dict=self.meta_dict,
|
||||
content=content,
|
||||
)
|
||||
|
||||
def __add__(self, other: Any) -> Union["BaseMessage", Any]:
|
||||
r"""Addition operator override for :obj:`BaseMessage`.
|
||||
|
||||
Args:
|
||||
other (Any): The value to be added with.
|
||||
|
||||
Returns:
|
||||
Union[BaseMessage, Any]: The result of the addition.
|
||||
"""
|
||||
if isinstance(other, BaseMessage):
|
||||
combined_content = self.content.__add__(other.content)
|
||||
elif isinstance(other, str):
|
||||
combined_content = self.content.__add__(other)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Unsupported operand type(s) for +: '{type(self)}' and "
|
||||
f"'{type(other)}'"
|
||||
)
|
||||
return self.create_new_instance(combined_content)
|
||||
|
||||
def __mul__(self, other: Any) -> Union["BaseMessage", Any]:
|
||||
r"""Multiplication operator override for :obj:`BaseMessage`.
|
||||
|
||||
Args:
|
||||
other (Any): The value to be multiplied with.
|
||||
|
||||
Returns:
|
||||
Union[BaseMessage, Any]: The result of the multiplication.
|
||||
"""
|
||||
if isinstance(other, int):
|
||||
multiplied_content = self.content.__mul__(other)
|
||||
return self.create_new_instance(multiplied_content)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Unsupported operand type(s) for *: '{type(self)}' and "
|
||||
f"'{type(other)}'"
|
||||
)
|
||||
|
||||
def __len__(self) -> int:
|
||||
r"""Length operator override for :obj:`BaseMessage`.
|
||||
|
||||
Returns:
|
||||
int: The length of the content.
|
||||
"""
|
||||
return len(self.content)
|
||||
|
||||
def __contains__(self, item: str) -> bool:
|
||||
r"""Contains operator override for :obj:`BaseMessage`.
|
||||
|
||||
Args:
|
||||
item (str): The item to check for containment.
|
||||
|
||||
Returns:
|
||||
bool: :obj:`True` if the item is contained in the content,
|
||||
:obj:`False` otherwise.
|
||||
"""
|
||||
return item in self.content
|
||||
|
||||
def extract_text_and_code_prompts(
|
||||
self,
|
||||
) -> Tuple[List[TextPrompt], List[CodePrompt]]:
|
||||
r"""Extract text and code prompts from the message content.
|
||||
|
||||
Returns:
|
||||
Tuple[List[TextPrompt], List[CodePrompt]]: A tuple containing a
|
||||
list of text prompts and a list of code prompts extracted
|
||||
from the content.
|
||||
"""
|
||||
text_prompts: List[TextPrompt] = []
|
||||
code_prompts: List[CodePrompt] = []
|
||||
|
||||
lines = self.content.split("\n")
|
||||
idx = 0
|
||||
start_idx = 0
|
||||
while idx < len(lines):
|
||||
while idx < len(lines) and (
|
||||
not lines[idx].lstrip().startswith("```")
|
||||
):
|
||||
idx += 1
|
||||
text = "\n".join(lines[start_idx:idx]).strip()
|
||||
text_prompts.append(TextPrompt(text))
|
||||
|
||||
if idx >= len(lines):
|
||||
break
|
||||
|
||||
code_type = lines[idx].strip()[3:].strip()
|
||||
idx += 1
|
||||
start_idx = idx
|
||||
while not lines[idx].lstrip().startswith("```"):
|
||||
idx += 1
|
||||
code = "\n".join(lines[start_idx:idx]).strip()
|
||||
code_prompts.append(CodePrompt(code, code_type=code_type))
|
||||
|
||||
idx += 1
|
||||
start_idx = idx
|
||||
|
||||
return text_prompts, code_prompts
|
||||
|
||||
@classmethod
|
||||
def from_sharegpt(
|
||||
cls,
|
||||
message: ShareGPTMessage,
|
||||
function_format: Optional[FunctionCallFormatter[Any, Any]] = None,
|
||||
role_mapping=None,
|
||||
) -> "BaseMessage":
|
||||
r"""Convert ShareGPT message to BaseMessage or FunctionCallingMessage.
|
||||
Note tool calls and responses have an 'assistant' role in CAMEL
|
||||
|
||||
Args:
|
||||
message (ShareGPTMessage): ShareGPT message to convert.
|
||||
function_format (FunctionCallFormatter, optional): Function call
|
||||
formatter to use. (default: :obj:`HermesFunctionFormatter()`.
|
||||
role_mapping (Dict[str, List[str, RoleType]], optional): Role
|
||||
mapping to use. Defaults to a CAMEL specific mapping.
|
||||
|
||||
Returns:
|
||||
BaseMessage: Converted message.
|
||||
"""
|
||||
from camel.messages import FunctionCallingMessage
|
||||
|
||||
if role_mapping is None:
|
||||
role_mapping = {
|
||||
"system": ["system", RoleType.USER],
|
||||
"human": ["user", RoleType.USER],
|
||||
"gpt": ["assistant", RoleType.ASSISTANT],
|
||||
"tool": ["assistant", RoleType.ASSISTANT],
|
||||
}
|
||||
role_name, role_type = role_mapping[message.from_]
|
||||
|
||||
if function_format is None:
|
||||
function_format = HermesFunctionFormatter()
|
||||
|
||||
# Check if this is a function-related message
|
||||
if message.from_ == "gpt":
|
||||
func_info = function_format.extract_tool_calls(message.value)
|
||||
if (
|
||||
func_info and len(func_info) == 1
|
||||
): # TODO: Handle multiple tool calls
|
||||
# Including cleaned content is useful to
|
||||
# remind consumers of non-considered content
|
||||
clean_content = re.sub(
|
||||
r"<tool_call>.*?</tool_call>",
|
||||
"",
|
||||
message.value,
|
||||
flags=re.DOTALL,
|
||||
).strip()
|
||||
|
||||
return FunctionCallingMessage(
|
||||
role_name=role_name,
|
||||
role_type=role_type,
|
||||
meta_dict=None,
|
||||
content=clean_content,
|
||||
func_name=func_info[0].__dict__["name"],
|
||||
args=func_info[0].__dict__["arguments"],
|
||||
)
|
||||
elif message.from_ == "tool":
|
||||
func_r_info = function_format.extract_tool_response(message.value)
|
||||
if func_r_info:
|
||||
return FunctionCallingMessage(
|
||||
role_name=role_name,
|
||||
role_type=role_type,
|
||||
meta_dict=None,
|
||||
content="",
|
||||
func_name=func_r_info.__dict__["name"],
|
||||
result=func_r_info.__dict__["content"],
|
||||
)
|
||||
|
||||
# Regular message
|
||||
return cls(
|
||||
role_name=role_name,
|
||||
role_type=role_type,
|
||||
meta_dict=None,
|
||||
content=message.value,
|
||||
)
|
||||
|
||||
def to_sharegpt(
|
||||
self,
|
||||
function_format: Optional[FunctionCallFormatter] = None,
|
||||
) -> ShareGPTMessage:
|
||||
r"""Convert BaseMessage to ShareGPT message
|
||||
|
||||
Args:
|
||||
function_format (FunctionCallFormatter): Function call formatter
|
||||
to use. Defaults to Hermes.
|
||||
"""
|
||||
|
||||
if function_format is None:
|
||||
function_format = HermesFunctionFormatter()
|
||||
|
||||
# Convert role type to ShareGPT 'from' field
|
||||
if self.role_type == RoleType.USER:
|
||||
from_ = "system" if self.role_name == "system" else "human"
|
||||
else: # RoleType.ASSISTANT
|
||||
from_ = "gpt"
|
||||
|
||||
# Function conversion code in FunctionCallingMessage
|
||||
return ShareGPTMessage(from_=from_, value=self.content) # type: ignore[call-arg]
|
||||
|
||||
def to_openai_message(
|
||||
self,
|
||||
role_at_backend: OpenAIBackendRole,
|
||||
) -> OpenAIMessage:
|
||||
r"""Converts the message to an :obj:`OpenAIMessage` object.
|
||||
|
||||
Args:
|
||||
role_at_backend (OpenAIBackendRole): The role of the message in
|
||||
OpenAI chat system.
|
||||
|
||||
Returns:
|
||||
OpenAIMessage: The converted :obj:`OpenAIMessage` object.
|
||||
"""
|
||||
if role_at_backend == OpenAIBackendRole.SYSTEM:
|
||||
return self.to_openai_system_message()
|
||||
elif role_at_backend == OpenAIBackendRole.USER:
|
||||
return self.to_openai_user_message()
|
||||
elif role_at_backend == OpenAIBackendRole.ASSISTANT:
|
||||
return self.to_openai_assistant_message()
|
||||
else:
|
||||
raise ValueError(f"Unsupported role: {role_at_backend}.")
|
||||
|
||||
def to_openai_system_message(self) -> OpenAISystemMessage:
|
||||
r"""Converts the message to an :obj:`OpenAISystemMessage` object.
|
||||
|
||||
Returns:
|
||||
OpenAISystemMessage: The converted :obj:`OpenAISystemMessage`
|
||||
object.
|
||||
"""
|
||||
return {"role": "system", "content": self.content}
|
||||
|
||||
def to_openai_user_message(self) -> OpenAIUserMessage:
|
||||
r"""Converts the message to an :obj:`OpenAIUserMessage` object.
|
||||
|
||||
Returns:
|
||||
OpenAIUserMessage: The converted :obj:`OpenAIUserMessage` object.
|
||||
"""
|
||||
hybird_content: List[Any] = []
|
||||
hybird_content.append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": self.content,
|
||||
}
|
||||
)
|
||||
if self.image_list and len(self.image_list) > 0:
|
||||
for image in self.image_list:
|
||||
if image.format is None:
|
||||
raise ValueError(
|
||||
f"Image's `format` is `None`, please "
|
||||
f"transform the `PIL.Image.Image` to one of "
|
||||
f"following supported formats, such as "
|
||||
f"{list(OpenAIImageType)}"
|
||||
)
|
||||
|
||||
image_type: str = image.format.lower()
|
||||
if image_type not in OpenAIImageType:
|
||||
raise ValueError(
|
||||
f"Image type {image.format} "
|
||||
f"is not supported by OpenAI vision model"
|
||||
)
|
||||
with io.BytesIO() as buffer:
|
||||
image.save(fp=buffer, format=image.format)
|
||||
encoded_image = base64.b64encode(buffer.getvalue()).decode(
|
||||
"utf-8"
|
||||
)
|
||||
image_prefix = f"data:image/{image_type};base64,"
|
||||
hybird_content.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"{image_prefix}{encoded_image}",
|
||||
"detail": self.image_detail,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
if self.video_bytes:
|
||||
import imageio.v3 as iio
|
||||
|
||||
base64Frames: List[str] = []
|
||||
frame_count = 0
|
||||
# read video bytes
|
||||
video = iio.imiter(
|
||||
self.video_bytes, plugin=Constants.VIDEO_DEFAULT_PLUG_PYAV
|
||||
)
|
||||
|
||||
for frame in video:
|
||||
frame_count += 1
|
||||
if (
|
||||
frame_count % Constants.VIDEO_IMAGE_EXTRACTION_INTERVAL
|
||||
== 0
|
||||
):
|
||||
# convert frame to numpy array
|
||||
frame_array = np.asarray(frame)
|
||||
frame_image = Image.fromarray(frame_array)
|
||||
|
||||
# Get the dimensions of the frame
|
||||
width, height = frame_image.size
|
||||
|
||||
# resize the frame to the default image size
|
||||
new_width = Constants.VIDEO_DEFAULT_IMAGE_SIZE
|
||||
aspect_ratio = width / height
|
||||
new_height = int(new_width / aspect_ratio)
|
||||
resized_img = frame_image.resize((new_width, new_height))
|
||||
|
||||
# encode the image to base64
|
||||
with io.BytesIO() as buffer:
|
||||
image_format = OpenAIImageType.JPEG.value
|
||||
image_format = image_format.upper()
|
||||
resized_img.save(fp=buffer, format=image_format)
|
||||
encoded_image = base64.b64encode(
|
||||
buffer.getvalue()
|
||||
).decode("utf-8")
|
||||
|
||||
base64Frames.append(encoded_image)
|
||||
|
||||
for encoded_image in base64Frames:
|
||||
item = {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{encoded_image}",
|
||||
"detail": self.video_detail,
|
||||
},
|
||||
}
|
||||
|
||||
hybird_content.append(item)
|
||||
|
||||
if len(hybird_content) > 1:
|
||||
return {
|
||||
"role": "user",
|
||||
"content": hybird_content,
|
||||
}
|
||||
# This return just for str message
|
||||
else:
|
||||
return {
|
||||
"role": "user",
|
||||
"content": self.content,
|
||||
}
|
||||
|
||||
def to_openai_assistant_message(self) -> OpenAIAssistantMessage:
|
||||
r"""Converts the message to an :obj:`OpenAIAssistantMessage` object.
|
||||
|
||||
Returns:
|
||||
OpenAIAssistantMessage: The converted :obj:`OpenAIAssistantMessage`
|
||||
object.
|
||||
"""
|
||||
return {"role": "assistant", "content": self.content}
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
r"""Converts the message to a dictionary.
|
||||
|
||||
Returns:
|
||||
dict: The converted dictionary.
|
||||
"""
|
||||
return {
|
||||
"role_name": self.role_name,
|
||||
"role_type": self.role_type.name,
|
||||
**(self.meta_dict or {}),
|
||||
"content": self.content,
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
from .alpaca import AlpacaItem
|
||||
from .conversation_models import (
|
||||
ShareGPTConversation,
|
||||
ShareGPTMessage,
|
||||
ToolCall,
|
||||
ToolResponse,
|
||||
)
|
||||
from .sharegpt import HermesFunctionFormatter
|
||||
|
||||
__all__ = [
|
||||
'ShareGPTMessage',
|
||||
'ShareGPTConversation',
|
||||
'HermesFunctionFormatter',
|
||||
'AlpacaItem',
|
||||
'ToolCall',
|
||||
'ToolResponse',
|
||||
]
|
||||
@@ -1,122 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
import re
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class AlpacaItem(BaseModel):
|
||||
r"""Represents an instruction-response item in the Alpaca format.
|
||||
|
||||
Appropripate for both cases where input field is empty, or populated.
|
||||
Provides parsing from string format using the class method from_string().
|
||||
|
||||
Args:
|
||||
instruction (str): The instruction/question/prompt
|
||||
input (str): Input context or examples (put empty string if none)
|
||||
output (str): The response/answer to the instruction
|
||||
"""
|
||||
|
||||
instruction: str = Field(description="The instruction/question/prompt")
|
||||
input: str = Field(
|
||||
description="Optional context or input for the task."
|
||||
" For example, when the instruction is \"Summarize the "
|
||||
"following article\", the input is the article."
|
||||
)
|
||||
output: str = Field(description="The response/answer to the instruction")
|
||||
|
||||
@field_validator('instruction', 'output')
|
||||
def no_section_markers(cls, value: str) -> str:
|
||||
r"""Ensures fields don't contain section markers like '###
|
||||
Response:'
|
||||
"""
|
||||
if (
|
||||
'### Response' in value
|
||||
or '### Instruction' in value
|
||||
or '### Input' in value
|
||||
):
|
||||
raise ValueError("Field cannot contain section markers")
|
||||
return value.strip()
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, text: str) -> "AlpacaItem":
|
||||
r"""Creates an AlpacaItem from a formatted string.
|
||||
|
||||
Args:
|
||||
text: String in either of these formats:
|
||||
With input:
|
||||
### Instruction:
|
||||
{instruction}
|
||||
### Input:
|
||||
{input}
|
||||
### Response:
|
||||
{response}
|
||||
|
||||
Without input:
|
||||
### Instruction:
|
||||
{instruction}
|
||||
### Response:
|
||||
{response}
|
||||
|
||||
Returns:
|
||||
AlpacaItem: Parsed instance
|
||||
|
||||
Raises:
|
||||
ValueError: text doesn't match expected format or sections missing
|
||||
"""
|
||||
# Strip and standardize newlines
|
||||
text = text.strip().replace('\r\n', '\n')
|
||||
|
||||
# Try to extract sections using regex
|
||||
instruction_match = re.search(
|
||||
r'###\s*Instruction:\s*\n(.+?)(?=\n###|\Z)', text, re.DOTALL
|
||||
)
|
||||
input_match = re.search(
|
||||
r'###\s*Input:\s*\n(.+?)(?=\n###|\Z)', text, re.DOTALL
|
||||
)
|
||||
response_match = re.search(
|
||||
r'###\s*Response:\s*\n(.+?)(?=\n###|\Z)', text, re.DOTALL
|
||||
)
|
||||
|
||||
if not instruction_match or not response_match:
|
||||
raise ValueError(
|
||||
"Text must contain '### Instruction:'"
|
||||
" and '### Response:' sections"
|
||||
)
|
||||
|
||||
return cls(
|
||||
instruction=instruction_match.group(1).strip(),
|
||||
input=input_match.group(1).strip() if input_match else "",
|
||||
output=response_match.group(1).strip(),
|
||||
)
|
||||
|
||||
def to_string(self) -> str:
|
||||
r"""Converts the AlpacaItem to its string representation.
|
||||
|
||||
Returns:
|
||||
str: Formatted string representation with sections markers
|
||||
"""
|
||||
return "\n".join(
|
||||
[
|
||||
"### Instruction:",
|
||||
self.instruction,
|
||||
"",
|
||||
"### Input:",
|
||||
self.input,
|
||||
"",
|
||||
"### Response:",
|
||||
self.output,
|
||||
]
|
||||
)
|
||||
@@ -1,178 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Literal
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Field,
|
||||
RootModel,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
|
||||
class ShareGPTMessage(BaseModel):
|
||||
r"""A single message in ShareGPT format with enhanced validation"""
|
||||
|
||||
from_: Literal["human", "gpt", "system", "tool"] = Field(
|
||||
alias="from", description="The role of the message sender"
|
||||
)
|
||||
value: str = Field(
|
||||
min_length=0,
|
||||
max_length=100000,
|
||||
description="The content of the message",
|
||||
)
|
||||
|
||||
model_config = {
|
||||
"populate_by_name": True,
|
||||
"extra": "forbid",
|
||||
"json_schema_extra": {
|
||||
"examples": [
|
||||
{"from": "human", "value": "What's the weather like today?"}
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class ShareGPTConversation(RootModel):
|
||||
r"""A full conversation in ShareGPT format with validation"""
|
||||
|
||||
root: List[ShareGPTMessage]
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_conversation_flow(self) -> 'ShareGPTConversation':
|
||||
r"""Validate the conversation follows logical message order"""
|
||||
messages = self.root
|
||||
|
||||
if not messages:
|
||||
raise ValueError("Conversation cannot be empty")
|
||||
|
||||
if messages[0].from_ not in ("system", "human"):
|
||||
raise ValueError(
|
||||
"Conversation must start with either system or human message"
|
||||
)
|
||||
|
||||
# Validate message sequence
|
||||
for i in range(1, len(messages)):
|
||||
curr, prev = messages[i], messages[i - 1]
|
||||
|
||||
if curr.from_ == "tool":
|
||||
if prev.from_ != "gpt" or "<tool_call>" not in prev.value:
|
||||
raise ValueError(
|
||||
f"Tool response at position {i} "
|
||||
f"must follow an gpt message with a tool call"
|
||||
)
|
||||
|
||||
if curr.from_ == "gpt" and prev.from_ not in (
|
||||
"human",
|
||||
"tool",
|
||||
):
|
||||
raise ValueError(
|
||||
f"Assistant message at position {i} "
|
||||
f"must follow a human or tool message"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def model_dump(self, **kwargs):
|
||||
return self.root
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.root)
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
r"""Represents a single tool/function call with validation"""
|
||||
|
||||
name: str = Field(
|
||||
min_length=1,
|
||||
max_length=256,
|
||||
description="The name of the tool to call",
|
||||
)
|
||||
arguments: Dict[str, Any] = Field(
|
||||
description="The arguments to pass to the tool"
|
||||
)
|
||||
|
||||
@field_validator('arguments')
|
||||
@classmethod
|
||||
def validate_arguments(cls, v: Dict[str, Any]) -> Dict[str, Any]:
|
||||
r"""Validate argument structure and content"""
|
||||
|
||||
# Try to serialize arguments to ensure they're JSON-compatible
|
||||
try:
|
||||
json.dumps(v)
|
||||
except (TypeError, ValueError):
|
||||
raise ValueError("Arguments must be JSON-serializable")
|
||||
|
||||
return v
|
||||
|
||||
model_config = {
|
||||
"extra": "forbid",
|
||||
"json_schema_extra": {
|
||||
"examples": [
|
||||
{
|
||||
"name": "get_weather",
|
||||
"arguments": {"city": "London", "units": "celsius"},
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class ToolResponse(BaseModel):
|
||||
r"""Represents a tool/function response with validation. This is a
|
||||
base class and default implementation for tool responses, for the purpose
|
||||
of converting between different formats.
|
||||
"""
|
||||
|
||||
name: str = Field(
|
||||
min_length=1,
|
||||
max_length=256,
|
||||
description="The name of the tool that was called",
|
||||
)
|
||||
content: Any = Field(
|
||||
description="The response content from the tool."
|
||||
" Must be JSON serializable literal or object"
|
||||
)
|
||||
|
||||
@field_validator('content')
|
||||
@classmethod
|
||||
def validate_content(cls, v: Dict[str, Any]) -> Dict[str, Any]:
|
||||
r"""Validate response content structure"""
|
||||
|
||||
# Ensure content is JSON-serializable
|
||||
try:
|
||||
json.dumps(v)
|
||||
except (TypeError, ValueError):
|
||||
raise ValueError("Response content must be JSON-serializable")
|
||||
|
||||
return v
|
||||
|
||||
model_config = {
|
||||
"extra": "forbid",
|
||||
"json_schema_extra": {
|
||||
"examples": [
|
||||
{
|
||||
"name": "get_weather",
|
||||
"content": {
|
||||
"temperature": 20,
|
||||
"conditions": "sunny",
|
||||
"humidity": 65,
|
||||
},
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
@@ -1,20 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
|
||||
from .hermes import HermesFunctionFormatter
|
||||
|
||||
__all__ = [
|
||||
'HermesFunctionFormatter',
|
||||
]
|
||||
@@ -1,49 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Generic, List, Optional, TypeVar
|
||||
|
||||
from camel.messages.conversion import (
|
||||
ToolCall,
|
||||
ToolResponse,
|
||||
)
|
||||
|
||||
CallT = TypeVar('CallT', bound=ToolCall, covariant=True)
|
||||
ResponseT = TypeVar('ResponseT', bound=ToolResponse, covariant=True)
|
||||
|
||||
|
||||
class FunctionCallFormatter(ABC, Generic[CallT, ResponseT]):
|
||||
r"""Abstract base class for function calling formats"""
|
||||
|
||||
@abstractmethod
|
||||
def extract_tool_calls(self, message: str) -> List[CallT]:
|
||||
r"""Extract function call info from a message string"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def extract_tool_response(self, message: str) -> Optional[ResponseT]:
|
||||
r"""Extract function response info from a message string"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def format_tool_call(
|
||||
self, content: str, func_name: str, args: Dict[str, Any]
|
||||
) -> str:
|
||||
r"""Format a function call into a message string"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def format_tool_response(self, func_name: str, result: Any) -> str:
|
||||
r"""Format a function response into a message string"""
|
||||
pass
|
||||
@@ -1,19 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
from .hermes_function_formatter import HermesFunctionFormatter
|
||||
|
||||
__all__ = [
|
||||
'HermesFunctionFormatter',
|
||||
]
|
||||
@@ -1,128 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from camel.messages.conversion import (
|
||||
ToolCall,
|
||||
ToolResponse,
|
||||
)
|
||||
from camel.messages.conversion.sharegpt.function_call_formatter import (
|
||||
FunctionCallFormatter,
|
||||
)
|
||||
|
||||
|
||||
class HermesToolResponse(ToolResponse):
|
||||
r"""Represents a single tool/function call with validation"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class HermesToolCall(ToolCall):
|
||||
r"""Represents a single tool/function call with validation"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class HermesFunctionFormatter(
|
||||
FunctionCallFormatter[HermesToolCall, HermesToolResponse]
|
||||
):
|
||||
r"""Hermes-style function calling format implementation with validation"""
|
||||
|
||||
def extract_tool_calls(self, message: str) -> List[HermesToolCall]:
|
||||
r"""Extracts all tool calls from the provided message string.
|
||||
|
||||
Args:
|
||||
message (str): The input message string containing potential tool
|
||||
calls.
|
||||
|
||||
Returns:
|
||||
List[HermesToolCall]: A list of parsed HermesToolCall objects.
|
||||
"""
|
||||
tool_calls = []
|
||||
pattern = r"<tool_call>\s*({.*?})\s*</tool_call>"
|
||||
matches = re.finditer(pattern, message, re.DOTALL)
|
||||
|
||||
for match in matches:
|
||||
try:
|
||||
call_dict = json.loads(match.group(1).replace("'", '"'))
|
||||
tool_calls.append(HermesToolCall.model_validate(call_dict))
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to parse tool call: {e}")
|
||||
continue
|
||||
|
||||
return tool_calls
|
||||
|
||||
def extract_tool_response(
|
||||
self, message: str
|
||||
) -> Optional[HermesToolResponse]:
|
||||
r"""Extracts a single tool response from the provided message string.
|
||||
|
||||
Args:
|
||||
message (str): The input message string containing a potential
|
||||
tool response.
|
||||
|
||||
Returns:
|
||||
Optional[HermesToolResponse]: A parsed HermesToolResponse object,
|
||||
or None if no valid response is found.
|
||||
"""
|
||||
pattern = r"<tool_response>\s*({.*?})\s*</tool_response>"
|
||||
match = re.search(pattern, message, re.DOTALL)
|
||||
|
||||
if match:
|
||||
try:
|
||||
response_json = match.group(1)
|
||||
response_dict = json.loads(response_json.replace("'", '"'))
|
||||
return HermesToolResponse.model_validate(response_dict)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to parse tool response: {e}")
|
||||
return None
|
||||
return None
|
||||
|
||||
def format_tool_call(
|
||||
self, content: str, func_name: str, args: Dict[str, Any]
|
||||
) -> str:
|
||||
r"""Formats a tool call message with the given content, function name,
|
||||
and arguments.
|
||||
|
||||
Args:
|
||||
content (str): The content or message to be included in the tool
|
||||
call.
|
||||
func_name (str): The name of the function being called.
|
||||
args (Dict[str, Any]): A dictionary of arguments to be passed to
|
||||
the function.
|
||||
|
||||
Returns:
|
||||
str: A formatted string representing the tool call in Hermes
|
||||
format.
|
||||
"""
|
||||
tool_call_dict = {"name": func_name, "arguments": args}
|
||||
return f"{content}\n<tool_call>\n{tool_call_dict}\n</tool_call>"
|
||||
|
||||
def format_tool_response(self, func_name: str, result: Any) -> str:
|
||||
r"""Formats a tool response message with the given function name and
|
||||
result.
|
||||
|
||||
Args:
|
||||
func_name (str): The name of the function whose result is being
|
||||
returned.
|
||||
result (Any): The result to be included in the tool response.
|
||||
|
||||
Returns:
|
||||
str: A formatted string representing the tool response in Hermes
|
||||
format.
|
||||
"""
|
||||
response_dict = {"name": func_name, "content": result}
|
||||
return f"<tool_response>\n{response_dict}\n</tool_response>"
|
||||
@@ -1,163 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from camel.messages import (
|
||||
BaseMessage,
|
||||
HermesFunctionFormatter,
|
||||
OpenAIAssistantMessage,
|
||||
OpenAIMessage,
|
||||
OpenAIToolMessageParam,
|
||||
)
|
||||
from camel.messages.conversion import (
|
||||
ShareGPTMessage,
|
||||
ToolCall,
|
||||
ToolResponse,
|
||||
)
|
||||
from camel.messages.conversion.sharegpt.function_call_formatter import (
|
||||
FunctionCallFormatter,
|
||||
)
|
||||
from camel.types import OpenAIBackendRole
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionCallingMessage(BaseMessage):
|
||||
r"""Class for message objects used specifically for
|
||||
function-related messages.
|
||||
|
||||
Args:
|
||||
func_name (Optional[str]): The name of the function used.
|
||||
(default: :obj:`None`)
|
||||
args (Optional[Dict]): The dictionary of arguments passed to the
|
||||
function. (default: :obj:`None`)
|
||||
result (Optional[Any]): The result of function execution.
|
||||
(default: :obj:`None`)
|
||||
tool_call_id (Optional[str]): The ID of the tool call, if available.
|
||||
(default: :obj:`None`)
|
||||
"""
|
||||
|
||||
func_name: Optional[str] = None
|
||||
args: Optional[Dict] = None
|
||||
result: Optional[Any] = None
|
||||
tool_call_id: Optional[str] = None
|
||||
|
||||
def to_openai_message(
|
||||
self,
|
||||
role_at_backend: OpenAIBackendRole,
|
||||
) -> OpenAIMessage:
|
||||
r"""Converts the message to an :obj:`OpenAIMessage` object.
|
||||
|
||||
Args:
|
||||
role_at_backend (OpenAIBackendRole): The role of the message in
|
||||
OpenAI chat system.
|
||||
|
||||
Returns:
|
||||
OpenAIMessage: The converted :obj:`OpenAIMessage` object.
|
||||
"""
|
||||
if role_at_backend == OpenAIBackendRole.ASSISTANT:
|
||||
return self.to_openai_assistant_message()
|
||||
elif role_at_backend == OpenAIBackendRole.FUNCTION:
|
||||
return self.to_openai_tool_message()
|
||||
else:
|
||||
raise ValueError(f"Unsupported role: {role_at_backend}.")
|
||||
|
||||
def to_sharegpt(
|
||||
self,
|
||||
function_format: Optional[
|
||||
FunctionCallFormatter[ToolCall, ToolResponse]
|
||||
] = None,
|
||||
) -> ShareGPTMessage:
|
||||
r"""Convert FunctionCallingMessage to ShareGPT message.
|
||||
|
||||
Args:
|
||||
function_format (FunctionCallFormatter[ToolCall, ToolResponse],
|
||||
optional): The function formatter to use. Defaults to None.
|
||||
"""
|
||||
|
||||
if function_format is None:
|
||||
function_format = HermesFunctionFormatter()
|
||||
# The role of the message is an unreliable indicator of whether
|
||||
# it is a function call or response, so use result
|
||||
if self.result is None:
|
||||
# This is a function call
|
||||
# TODO: split the incoming types to be more specific
|
||||
# and remove the type ignores
|
||||
content = function_format.format_tool_call(
|
||||
self.content or "", # type: ignore[arg-type]
|
||||
self.func_name, # type: ignore[arg-type]
|
||||
self.args, # type: ignore[arg-type]
|
||||
)
|
||||
return ShareGPTMessage(from_="gpt", value=content) # type: ignore[call-arg]
|
||||
else:
|
||||
# This is a function response
|
||||
# TODO: Allow for more flexible setting of tool role,
|
||||
# optionally to be the same as assistant messages
|
||||
content = function_format.format_tool_response(
|
||||
self.func_name, # type: ignore[arg-type]
|
||||
self.result, # type: ignore[arg-type]
|
||||
)
|
||||
return ShareGPTMessage(from_="tool", value=content) # type: ignore[call-arg]
|
||||
|
||||
def to_openai_assistant_message(self) -> OpenAIAssistantMessage:
|
||||
r"""Converts the message to an :obj:`OpenAIAssistantMessage` object.
|
||||
|
||||
Returns:
|
||||
OpenAIAssistantMessage: The converted :obj:`OpenAIAssistantMessage`
|
||||
object.
|
||||
"""
|
||||
if (not self.func_name) or (self.args is None):
|
||||
raise ValueError(
|
||||
"Invalid request for converting into assistant message"
|
||||
" due to missing function name or arguments."
|
||||
)
|
||||
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": self.content or "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": self.tool_call_id or "null",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.func_name,
|
||||
"arguments": json.dumps(self.args),
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
def to_openai_tool_message(self) -> OpenAIToolMessageParam:
|
||||
r"""Converts the message to an :obj:`OpenAIToolMessageParam` object
|
||||
with the role being "tool".
|
||||
|
||||
Returns:
|
||||
OpenAIToolMessageParam: The converted
|
||||
:obj:`OpenAIToolMessageParam` object with its role being
|
||||
"tool".
|
||||
"""
|
||||
if not self.func_name:
|
||||
raise ValueError(
|
||||
"Invalid request for converting into function message"
|
||||
" due to missing function name."
|
||||
)
|
||||
|
||||
result_content = str(self.result)
|
||||
|
||||
return {
|
||||
"role": "tool",
|
||||
"content": result_content,
|
||||
"tool_call_id": self.tool_call_id or "null",
|
||||
}
|
||||
@@ -1,68 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from .anthropic_model import AnthropicModel
|
||||
from .azure_openai_model import AzureOpenAIModel
|
||||
from .base_model import BaseModelBackend
|
||||
from .cohere_model import CohereModel
|
||||
from .deepseek_model import DeepSeekModel
|
||||
from .gemini_model import GeminiModel
|
||||
from .groq_model import GroqModel
|
||||
from .litellm_model import LiteLLMModel
|
||||
from .mistral_model import MistralModel
|
||||
from .model_factory import ModelFactory
|
||||
from .model_manager import ModelManager, ModelProcessingError
|
||||
from .nemotron_model import NemotronModel
|
||||
from .nvidia_model import NvidiaModel
|
||||
from .ollama_model import OllamaModel
|
||||
from .openai_audio_models import OpenAIAudioModels
|
||||
from .openai_compatible_model import OpenAICompatibleModel
|
||||
from .openai_model import OpenAIModel
|
||||
from .qwen_model import QwenModel
|
||||
from .reka_model import RekaModel
|
||||
from .samba_model import SambaModel
|
||||
from .stub_model import StubModel
|
||||
from .togetherai_model import TogetherAIModel
|
||||
from .vllm_model import VLLMModel
|
||||
from .yi_model import YiModel
|
||||
from .zhipuai_model import ZhipuAIModel
|
||||
from .fish_audio_model import FishAudioModel
|
||||
|
||||
__all__ = [
|
||||
'BaseModelBackend',
|
||||
'OpenAIModel',
|
||||
'AzureOpenAIModel',
|
||||
'AnthropicModel',
|
||||
'MistralModel',
|
||||
'GroqModel',
|
||||
'StubModel',
|
||||
'ZhipuAIModel',
|
||||
'CohereModel',
|
||||
'ModelFactory',
|
||||
'ModelManager',
|
||||
'LiteLLMModel',
|
||||
'OpenAIAudioModels',
|
||||
'NemotronModel',
|
||||
'NvidiaModel',
|
||||
'OllamaModel',
|
||||
'VLLMModel',
|
||||
'GeminiModel',
|
||||
'OpenAICompatibleModel',
|
||||
'RekaModel',
|
||||
'SambaModel',
|
||||
'TogetherAIModel',
|
||||
'YiModel',
|
||||
'QwenModel',
|
||||
'ModelProcessingError',
|
||||
'DeepSeekModel',
|
||||
]
|
||||
@@ -1,167 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from camel.configs import ANTHROPIC_API_PARAMS, AnthropicConfig
|
||||
from camel.messages import OpenAIMessage
|
||||
from camel.models.base_model import BaseModelBackend
|
||||
from camel.types import ChatCompletion, ModelType
|
||||
from camel.utils import (
|
||||
AnthropicTokenCounter,
|
||||
BaseTokenCounter,
|
||||
api_keys_required,
|
||||
dependencies_required,
|
||||
)
|
||||
|
||||
|
||||
class AnthropicModel(BaseModelBackend):
|
||||
r"""Anthropic API in a unified BaseModelBackend interface.
|
||||
|
||||
Args:
|
||||
model_type (Union[ModelType, str]): Model for which a backend is
|
||||
created, one of CLAUDE_* series.
|
||||
model_config_dict (Optional[Dict[str, Any]], optional): A dictionary
|
||||
that will be fed into Anthropic.messages.create(). If
|
||||
:obj:`None`, :obj:`AnthropicConfig().as_dict()` will be used.
|
||||
(default::obj:`None`)
|
||||
api_key (Optional[str], optional): The API key for authenticating with
|
||||
the Anthropic service. (default: :obj:`None`)
|
||||
url (Optional[str], optional): The url to the Anthropic service.
|
||||
(default: :obj:`None`)
|
||||
token_counter (Optional[BaseTokenCounter], optional): Token counter to
|
||||
use for the model. If not provided, :obj:`AnthropicTokenCounter`
|
||||
will be used. (default: :obj:`None`)
|
||||
"""
|
||||
|
||||
@dependencies_required('anthropic')
|
||||
def __init__(
|
||||
self,
|
||||
model_type: Union[ModelType, str],
|
||||
model_config_dict: Optional[Dict[str, Any]] = None,
|
||||
api_key: Optional[str] = None,
|
||||
url: Optional[str] = None,
|
||||
token_counter: Optional[BaseTokenCounter] = None,
|
||||
) -> None:
|
||||
from anthropic import Anthropic
|
||||
|
||||
if model_config_dict is None:
|
||||
model_config_dict = AnthropicConfig().as_dict()
|
||||
api_key = api_key or os.environ.get("ANTHROPIC_API_KEY")
|
||||
url = url or os.environ.get("ANTHROPIC_API_BASE_URL")
|
||||
super().__init__(
|
||||
model_type, model_config_dict, api_key, url, token_counter
|
||||
)
|
||||
self.client = Anthropic(api_key=self._api_key, base_url=self._url)
|
||||
|
||||
def _convert_response_from_anthropic_to_openai(self, response):
|
||||
# openai ^1.0.0 format, reference openai/types/chat/chat_completion.py
|
||||
obj = ChatCompletion.construct(
|
||||
id=None,
|
||||
choices=[
|
||||
dict(
|
||||
index=0,
|
||||
message={
|
||||
"role": "assistant",
|
||||
"content": response.content[0].text,
|
||||
},
|
||||
finish_reason=response.stop_reason,
|
||||
)
|
||||
],
|
||||
created=None,
|
||||
model=response.model,
|
||||
object="chat.completion",
|
||||
)
|
||||
return obj
|
||||
|
||||
@property
|
||||
def token_counter(self) -> BaseTokenCounter:
|
||||
r"""Initialize the token counter for the model backend.
|
||||
|
||||
Returns:
|
||||
BaseTokenCounter: The token counter following the model's
|
||||
tokenization style.
|
||||
"""
|
||||
if not self._token_counter:
|
||||
self._token_counter = AnthropicTokenCounter()
|
||||
return self._token_counter
|
||||
|
||||
def count_tokens_from_prompt(self, prompt: str) -> int:
|
||||
r"""Count the number of tokens from a prompt.
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt string.
|
||||
|
||||
Returns:
|
||||
int: The number of tokens in the prompt.
|
||||
"""
|
||||
return self.client.count_tokens(prompt)
|
||||
|
||||
@api_keys_required("ANTHROPIC_API_KEY")
|
||||
def run(
|
||||
self,
|
||||
messages: List[OpenAIMessage],
|
||||
):
|
||||
r"""Run inference of Anthropic chat completion.
|
||||
|
||||
Args:
|
||||
messages (List[OpenAIMessage]): Message list with the chat history
|
||||
in OpenAI API format.
|
||||
|
||||
Returns:
|
||||
ChatCompletion: Response in the OpenAI API format.
|
||||
"""
|
||||
from anthropic import NOT_GIVEN
|
||||
|
||||
if messages[0]["role"] == "system":
|
||||
sys_msg = str(messages.pop(0)["content"])
|
||||
else:
|
||||
sys_msg = NOT_GIVEN # type: ignore[assignment]
|
||||
response = self.client.messages.create(
|
||||
model=self.model_type,
|
||||
system=sys_msg,
|
||||
messages=messages, # type: ignore[arg-type]
|
||||
**self.model_config_dict,
|
||||
)
|
||||
|
||||
# format response to openai format
|
||||
response = self._convert_response_from_anthropic_to_openai(response)
|
||||
|
||||
return response
|
||||
|
||||
def check_model_config(self):
|
||||
r"""Check whether the model configuration is valid for anthropic
|
||||
model backends.
|
||||
|
||||
Raises:
|
||||
ValueError: If the model configuration dictionary contains any
|
||||
unexpected arguments to OpenAI API, or it does not contain
|
||||
:obj:`model_path` or :obj:`server_url`.
|
||||
"""
|
||||
for param in self.model_config_dict:
|
||||
if param not in ANTHROPIC_API_PARAMS:
|
||||
raise ValueError(
|
||||
f"Unexpected argument `{param}` is "
|
||||
"input into Anthropic model backend."
|
||||
)
|
||||
|
||||
@property
|
||||
def stream(self) -> bool:
|
||||
r"""Returns whether the model is in stream mode, which sends partial
|
||||
results each time.
|
||||
|
||||
Returns:
|
||||
bool: Whether the model is in stream mode.
|
||||
"""
|
||||
return self.model_config_dict.get("stream", False)
|
||||
@@ -1,155 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from openai import AzureOpenAI, Stream
|
||||
|
||||
from camel.configs import OPENAI_API_PARAMS, ChatGPTConfig
|
||||
from camel.messages import OpenAIMessage
|
||||
from camel.models.base_model import BaseModelBackend
|
||||
from camel.types import (
|
||||
ChatCompletion,
|
||||
ChatCompletionChunk,
|
||||
ModelType,
|
||||
)
|
||||
from camel.utils import BaseTokenCounter, OpenAITokenCounter, api_keys_required
|
||||
|
||||
|
||||
class AzureOpenAIModel(BaseModelBackend):
|
||||
r"""Azure OpenAI API in a unified BaseModelBackend interface.
|
||||
|
||||
Args:
|
||||
model_type (Union[ModelType, str]): Model for which a backend is
|
||||
created, one of GPT_* series.
|
||||
model_config_dict (Optional[Dict[str, Any]], optional): A dictionary
|
||||
that will be fed into:obj:`openai.ChatCompletion.create()`. If
|
||||
:obj:`None`, :obj:`ChatGPTConfig().as_dict()` will be used.
|
||||
(default: :obj:`None`)
|
||||
api_key (Optional[str], optional): The API key for authenticating with
|
||||
the OpenAI service. (default: :obj:`None`)
|
||||
url (Optional[str], optional): The url to the OpenAI service.
|
||||
(default: :obj:`None`)
|
||||
api_version (Optional[str], optional): The api version for the model.
|
||||
(default: :obj:`None`)
|
||||
azure_deployment_name (Optional[str], optional): The deployment name
|
||||
you chose when you deployed an azure model. (default: :obj:`None`)
|
||||
token_counter (Optional[BaseTokenCounter], optional): Token counter to
|
||||
use for the model. If not provided, :obj:`OpenAITokenCounter`
|
||||
will be used. (default: :obj:`None`)
|
||||
|
||||
References:
|
||||
https://learn.microsoft.com/en-us/azure/ai-services/openai/
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_type: Union[ModelType, str],
|
||||
model_config_dict: Optional[Dict[str, Any]] = None,
|
||||
api_key: Optional[str] = None,
|
||||
url: Optional[str] = None,
|
||||
token_counter: Optional[BaseTokenCounter] = None,
|
||||
api_version: Optional[str] = None,
|
||||
azure_deployment_name: Optional[str] = None,
|
||||
) -> None:
|
||||
if model_config_dict is None:
|
||||
model_config_dict = ChatGPTConfig().as_dict()
|
||||
api_key = api_key or os.environ.get("AZURE_OPENAI_API_KEY")
|
||||
url = url or os.environ.get("AZURE_OPENAI_BASE_URL")
|
||||
super().__init__(
|
||||
model_type, model_config_dict, api_key, url, token_counter
|
||||
)
|
||||
|
||||
self.api_version = api_version or os.environ.get("AZURE_API_VERSION")
|
||||
self.azure_deployment_name = azure_deployment_name or os.environ.get(
|
||||
"AZURE_DEPLOYMENT_NAME"
|
||||
)
|
||||
if self.api_version is None:
|
||||
raise ValueError(
|
||||
"Must provide either the `api_version` argument "
|
||||
"or `AZURE_API_VERSION` environment variable."
|
||||
)
|
||||
if self.azure_deployment_name is None:
|
||||
raise ValueError(
|
||||
"Must provide either the `azure_deployment_name` argument "
|
||||
"or `AZURE_DEPLOYMENT_NAME` environment variable."
|
||||
)
|
||||
|
||||
self._client = AzureOpenAI(
|
||||
azure_endpoint=str(self._url),
|
||||
azure_deployment=self.azure_deployment_name,
|
||||
api_version=self.api_version,
|
||||
api_key=self._api_key,
|
||||
timeout=60,
|
||||
max_retries=3,
|
||||
)
|
||||
|
||||
@property
|
||||
def token_counter(self) -> BaseTokenCounter:
|
||||
r"""Initialize the token counter for the model backend.
|
||||
|
||||
Returns:
|
||||
BaseTokenCounter: The token counter following the model's
|
||||
tokenization style.
|
||||
"""
|
||||
if not self._token_counter:
|
||||
self._token_counter = OpenAITokenCounter(self.model_type)
|
||||
return self._token_counter
|
||||
|
||||
@api_keys_required("AZURE_OPENAI_API_KEY", "AZURE_API_VERSION")
|
||||
def run(
|
||||
self,
|
||||
messages: List[OpenAIMessage],
|
||||
) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]:
|
||||
r"""Runs inference of Azure OpenAI chat completion.
|
||||
|
||||
Args:
|
||||
messages (List[OpenAIMessage]): Message list with the chat history
|
||||
in OpenAI API format.
|
||||
|
||||
Returns:
|
||||
Union[ChatCompletion, Stream[ChatCompletionChunk]]:
|
||||
`ChatCompletion` in the non-stream mode, or
|
||||
`Stream[ChatCompletionChunk]` in the stream mode.
|
||||
"""
|
||||
response = self._client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=self.azure_deployment_name, # type:ignore[arg-type]
|
||||
**self.model_config_dict,
|
||||
)
|
||||
return response
|
||||
|
||||
def check_model_config(self):
|
||||
r"""Check whether the model configuration contains any
|
||||
unexpected arguments to Azure OpenAI API.
|
||||
|
||||
Raises:
|
||||
ValueError: If the model configuration dictionary contains any
|
||||
unexpected arguments to Azure OpenAI API.
|
||||
"""
|
||||
for param in self.model_config_dict:
|
||||
if param not in OPENAI_API_PARAMS:
|
||||
raise ValueError(
|
||||
f"Unexpected argument `{param}` is "
|
||||
"input into Azure OpenAI model backend."
|
||||
)
|
||||
|
||||
@property
|
||||
def stream(self) -> bool:
|
||||
r"""Returns whether the model is in stream mode,
|
||||
which sends partial results each time.
|
||||
Returns:
|
||||
bool: Whether the model is in stream mode.
|
||||
"""
|
||||
return self.model_config_dict.get("stream", False)
|
||||
@@ -1,140 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from openai import Stream
|
||||
|
||||
from camel.messages import OpenAIMessage
|
||||
from camel.types import (
|
||||
ChatCompletion,
|
||||
ChatCompletionChunk,
|
||||
ModelType,
|
||||
UnifiedModelType,
|
||||
)
|
||||
from camel.utils import BaseTokenCounter
|
||||
|
||||
|
||||
class BaseModelBackend(ABC):
|
||||
r"""Base class for different model backends.
|
||||
It may be OpenAI API, a local LLM, a stub for unit tests, etc.
|
||||
|
||||
Args:
|
||||
model_type (Union[ModelType, str]): Model for which a backend is
|
||||
created.
|
||||
model_config_dict (Optional[Dict[str, Any]], optional): A config
|
||||
dictionary. (default: :obj:`{}`)
|
||||
api_key (Optional[str], optional): The API key for authenticating
|
||||
with the model service. (default: :obj:`None`)
|
||||
url (Optional[str], optional): The url to the model service.
|
||||
(default: :obj:`None`)
|
||||
token_counter (Optional[BaseTokenCounter], optional): Token
|
||||
counter to use for the model. If not provided,
|
||||
:obj:`OpenAITokenCounter` will be used. (default: :obj:`None`)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_type: Union[ModelType, str],
|
||||
model_config_dict: Optional[Dict[str, Any]] = None,
|
||||
api_key: Optional[str] = None,
|
||||
url: Optional[str] = None,
|
||||
token_counter: Optional[BaseTokenCounter] = None,
|
||||
) -> None:
|
||||
self.model_type: UnifiedModelType = UnifiedModelType(model_type)
|
||||
if model_config_dict is None:
|
||||
model_config_dict = {}
|
||||
self.model_config_dict = model_config_dict
|
||||
self._api_key = api_key
|
||||
self._url = url
|
||||
self._token_counter = token_counter
|
||||
self.check_model_config()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def token_counter(self) -> BaseTokenCounter:
|
||||
r"""Initialize the token counter for the model backend.
|
||||
|
||||
Returns:
|
||||
BaseTokenCounter: The token counter following the model's
|
||||
tokenization style.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def run(
|
||||
self,
|
||||
messages: List[OpenAIMessage],
|
||||
) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]:
|
||||
r"""Runs the query to the backend model.
|
||||
|
||||
Args:
|
||||
messages (List[OpenAIMessage]): Message list with the chat history
|
||||
in OpenAI API format.
|
||||
|
||||
Returns:
|
||||
Union[ChatCompletion, Stream[ChatCompletionChunk]]:
|
||||
`ChatCompletion` in the non-stream mode, or
|
||||
`Stream[ChatCompletionChunk]` in the stream mode.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def check_model_config(self):
|
||||
r"""Check whether the input model configuration contains unexpected
|
||||
arguments
|
||||
|
||||
Raises:
|
||||
ValueError: If the model configuration dictionary contains any
|
||||
unexpected argument for this model class.
|
||||
"""
|
||||
pass
|
||||
|
||||
def count_tokens_from_messages(self, messages: List[OpenAIMessage]) -> int:
|
||||
r"""Count the number of tokens in the messages using the specific
|
||||
tokenizer.
|
||||
|
||||
Args:
|
||||
messages (List[Dict]): message list with the chat history
|
||||
in OpenAI API format.
|
||||
|
||||
Returns:
|
||||
int: Number of tokens in the messages.
|
||||
"""
|
||||
return self.token_counter.count_tokens_from_messages(messages)
|
||||
|
||||
@property
|
||||
def token_limit(self) -> int:
|
||||
r"""Returns the maximum token limit for a given model.
|
||||
|
||||
This method retrieves the maximum token limit either from the
|
||||
`model_config_dict` or from the model's default token limit.
|
||||
|
||||
Returns:
|
||||
int: The maximum token limit for the given model.
|
||||
"""
|
||||
return (
|
||||
self.model_config_dict.get("max_tokens")
|
||||
or self.model_type.token_limit
|
||||
)
|
||||
|
||||
@property
|
||||
def stream(self) -> bool:
|
||||
r"""Returns whether the model is in stream mode, which sends partial
|
||||
results each time.
|
||||
|
||||
Returns:
|
||||
bool: Whether the model is in stream mode.
|
||||
"""
|
||||
return False
|
||||
@@ -1,282 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
import ast
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from cohere.types import ChatMessageV2, ChatResponse
|
||||
|
||||
from camel.configs import COHERE_API_PARAMS, CohereConfig
|
||||
from camel.messages import OpenAIMessage
|
||||
from camel.models import BaseModelBackend
|
||||
from camel.types import ChatCompletion, ModelType
|
||||
from camel.utils import (
|
||||
BaseTokenCounter,
|
||||
OpenAITokenCounter,
|
||||
api_keys_required,
|
||||
)
|
||||
|
||||
try:
|
||||
if os.getenv("AGENTOPS_API_KEY") is not None:
|
||||
from agentops import LLMEvent, record
|
||||
else:
|
||||
raise ImportError
|
||||
except (ImportError, AttributeError):
|
||||
LLMEvent = None
|
||||
|
||||
|
||||
class CohereModel(BaseModelBackend):
|
||||
r"""Cohere API in a unified BaseModelBackend interface."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_type: Union[ModelType, str],
|
||||
model_config_dict: Optional[Dict[str, Any]] = None,
|
||||
api_key: Optional[str] = None,
|
||||
url: Optional[str] = None,
|
||||
token_counter: Optional[BaseTokenCounter] = None,
|
||||
):
|
||||
import cohere
|
||||
|
||||
if model_config_dict is None:
|
||||
model_config_dict = CohereConfig().as_dict()
|
||||
|
||||
api_key = api_key or os.environ.get("COHERE_API_KEY")
|
||||
url = url or os.environ.get("COHERE_API_BASE_URL")
|
||||
super().__init__(
|
||||
model_type, model_config_dict, api_key, url, token_counter
|
||||
)
|
||||
self._client = cohere.ClientV2(api_key=self._api_key)
|
||||
|
||||
def _to_openai_response(self, response: 'ChatResponse') -> ChatCompletion:
|
||||
if response.usage and response.usage.tokens:
|
||||
input_tokens = response.usage.tokens.input_tokens or 0
|
||||
output_tokens = response.usage.tokens.output_tokens or 0
|
||||
usage = {
|
||||
"prompt_tokens": input_tokens,
|
||||
"completion_tokens": output_tokens,
|
||||
"total_tokens": input_tokens + output_tokens,
|
||||
}
|
||||
else:
|
||||
usage = {}
|
||||
|
||||
tool_calls = response.message.tool_calls
|
||||
choices = []
|
||||
if tool_calls:
|
||||
for tool_call in tool_calls:
|
||||
openai_tool_calls = [
|
||||
dict(
|
||||
id=tool_call.id,
|
||||
function={
|
||||
"name": tool_call.function.name,
|
||||
"arguments": tool_call.function.arguments,
|
||||
}
|
||||
if tool_call.function
|
||||
else {},
|
||||
type=tool_call.type,
|
||||
)
|
||||
]
|
||||
|
||||
choice = dict(
|
||||
index=None,
|
||||
message={
|
||||
"role": "assistant",
|
||||
"content": response.message.tool_plan,
|
||||
"tool_calls": openai_tool_calls,
|
||||
},
|
||||
finish_reason=response.finish_reason
|
||||
if response.finish_reason
|
||||
else None,
|
||||
)
|
||||
choices.append(choice)
|
||||
|
||||
else:
|
||||
openai_tool_calls = None
|
||||
|
||||
choice = dict(
|
||||
index=None,
|
||||
message={
|
||||
"role": "assistant",
|
||||
"content": response.message.content[0].text, # type: ignore[union-attr,index]
|
||||
"tool_calls": openai_tool_calls,
|
||||
},
|
||||
finish_reason=response.finish_reason
|
||||
if response.finish_reason
|
||||
else None,
|
||||
)
|
||||
choices.append(choice)
|
||||
|
||||
obj = ChatCompletion.construct(
|
||||
id=response.id,
|
||||
choices=choices,
|
||||
created=None,
|
||||
model=self.model_type,
|
||||
object="chat.completion",
|
||||
usage=usage,
|
||||
)
|
||||
return obj
|
||||
|
||||
def _to_cohere_chatmessage(
|
||||
self, messages: List[OpenAIMessage]
|
||||
) -> List["ChatMessageV2"]:
|
||||
from cohere.types import ToolCallV2Function
|
||||
from cohere.types.chat_message_v2 import (
|
||||
AssistantChatMessageV2,
|
||||
SystemChatMessageV2,
|
||||
ToolCallV2,
|
||||
ToolChatMessageV2,
|
||||
UserChatMessageV2,
|
||||
)
|
||||
|
||||
tool_call_id = None
|
||||
new_messages = []
|
||||
for msg in messages:
|
||||
role = msg.get("role")
|
||||
content = msg.get("content")
|
||||
function_call = msg.get("function_call")
|
||||
|
||||
if role == "user":
|
||||
new_message = UserChatMessageV2(role="user", content=content) # type: ignore[arg-type]
|
||||
elif role in {"tool", "function"}:
|
||||
new_message = ToolChatMessageV2(
|
||||
role="tool",
|
||||
tool_call_id=tool_call_id, # type: ignore[arg-type]
|
||||
content=content, # type: ignore[assignment,arg-type]
|
||||
)
|
||||
elif role == "assistant":
|
||||
if not function_call:
|
||||
new_message = AssistantChatMessageV2( # type: ignore[assignment]
|
||||
role="assistant",
|
||||
content=content, # type: ignore[arg-type]
|
||||
)
|
||||
else:
|
||||
arguments = function_call.get("arguments") # type: ignore[attr-defined]
|
||||
arguments_dict = ast.literal_eval(arguments)
|
||||
arguments_json = json.dumps(arguments_dict)
|
||||
|
||||
assis_tool_call_id = str(uuid.uuid4())
|
||||
tool_call_id = assis_tool_call_id
|
||||
new_message = AssistantChatMessageV2( # type: ignore[assignment]
|
||||
role="assistant",
|
||||
tool_calls=[
|
||||
ToolCallV2(
|
||||
id=assis_tool_call_id,
|
||||
type="function",
|
||||
function=ToolCallV2Function(
|
||||
name=function_call.get("name"), # type: ignore[attr-defined]
|
||||
arguments=arguments_json, # type: ignore[attr-defined]
|
||||
),
|
||||
)
|
||||
],
|
||||
content=content, # type: ignore[arg-type]
|
||||
)
|
||||
elif role == "system":
|
||||
new_message = SystemChatMessageV2( # type: ignore[assignment]
|
||||
role="system",
|
||||
content=content, # type: ignore[arg-type]
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported message role: {role}")
|
||||
|
||||
new_messages.append(new_message)
|
||||
return new_messages # type: ignore[return-value]
|
||||
|
||||
@property
|
||||
def token_counter(self) -> BaseTokenCounter:
|
||||
r"""Initialize the token counter for the model backend.
|
||||
|
||||
Returns:
|
||||
BaseTokenCounter: The token counter following the model's
|
||||
tokenization style.
|
||||
"""
|
||||
if not self._token_counter:
|
||||
self._token_counter = OpenAITokenCounter(
|
||||
model=ModelType.GPT_4O_MINI
|
||||
)
|
||||
return self._token_counter
|
||||
|
||||
@api_keys_required("COHERE_API_KEY")
|
||||
def run(self, messages: List[OpenAIMessage]) -> ChatCompletion:
|
||||
r"""Runs inference of Cohere chat completion.
|
||||
|
||||
Args:
|
||||
messages (List[OpenAIMessage]): Message list with the chat history
|
||||
in OpenAI API format.
|
||||
Returns:
|
||||
ChatCompletion.
|
||||
"""
|
||||
from cohere.core.api_error import ApiError
|
||||
|
||||
cohere_messages = self._to_cohere_chatmessage(messages)
|
||||
|
||||
try:
|
||||
response = self._client.chat(
|
||||
messages=cohere_messages,
|
||||
model=self.model_type,
|
||||
**self.model_config_dict,
|
||||
)
|
||||
except ApiError as e:
|
||||
logging.error(f"Cohere API Error: {e.status_code}")
|
||||
logging.error(f"Error body: {e.body}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logging.error(f"Unexpected error when calling Cohere API: {e!s}")
|
||||
raise
|
||||
|
||||
openai_response = self._to_openai_response(response)
|
||||
|
||||
# Add AgentOps LLM Event tracking
|
||||
if LLMEvent:
|
||||
llm_event = LLMEvent(
|
||||
thread_id=openai_response.id,
|
||||
prompt=" ".join(
|
||||
[message.get("content") for message in messages] # type: ignore[misc]
|
||||
),
|
||||
prompt_tokens=openai_response.usage.prompt_tokens, # type: ignore[union-attr]
|
||||
completion=openai_response.choices[0].message.content,
|
||||
completion_tokens=openai_response.usage.completion_tokens, # type: ignore[union-attr]
|
||||
model=self.model_type,
|
||||
)
|
||||
record(llm_event)
|
||||
|
||||
return openai_response
|
||||
|
||||
def check_model_config(self):
|
||||
r"""Check whether the model configuration contains any unexpected
|
||||
arguments to Cohere API.
|
||||
|
||||
Raises:
|
||||
ValueError: If the model configuration dictionary contains any
|
||||
unexpected arguments to Cohere API.
|
||||
"""
|
||||
for param in self.model_config_dict:
|
||||
if param not in COHERE_API_PARAMS:
|
||||
raise ValueError(
|
||||
f"Unexpected argument `{param}` is "
|
||||
"input into Cohere model backend."
|
||||
)
|
||||
|
||||
@property
|
||||
def stream(self) -> bool:
|
||||
r"""Returns whether the model is in stream mode, which sends partial
|
||||
results each time. Current it's not supported.
|
||||
|
||||
Returns:
|
||||
bool: Whether the model is in stream mode.
|
||||
"""
|
||||
return False
|
||||
@@ -1,225 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from openai import OpenAI, Stream
|
||||
|
||||
from camel.configs import DEEPSEEK_API_PARAMS, DeepSeekConfig
|
||||
from camel.logger import get_logger
|
||||
from camel.messages import OpenAIMessage
|
||||
from camel.models.base_model import BaseModelBackend
|
||||
from camel.types import (
|
||||
ChatCompletion,
|
||||
ChatCompletionChunk,
|
||||
ModelType,
|
||||
)
|
||||
from camel.utils import BaseTokenCounter, OpenAITokenCounter, api_keys_required
|
||||
from retry import retry
|
||||
import json
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DeepSeekModel(BaseModelBackend):
|
||||
r"""DeepSeek API in a unified BaseModelBackend interface.
|
||||
|
||||
Args:
|
||||
model_type (Union[ModelType, str]): Model for which a backend is
|
||||
created.
|
||||
model_config_dict (Optional[Dict[str, Any]], optional): A dictionary
|
||||
that will be fed into:obj:`openai.ChatCompletion.create()`. If
|
||||
:obj:`None`, :obj:`DeepSeekConfig().as_dict()` will be used.
|
||||
(default: :obj:`None`)
|
||||
api_key (Optional[str], optional): The API key for authenticating with
|
||||
the DeepSeek service. (default: :obj:`None`)
|
||||
url (Optional[str], optional): The url to the DeepSeek service.
|
||||
(default: :obj:`https://api.deepseek.com`)
|
||||
token_counter (Optional[BaseTokenCounter], optional): Token counter to
|
||||
use for the model. If not provided, :obj:`OpenAITokenCounter`
|
||||
will be used. (default: :obj:`None`)
|
||||
|
||||
References:
|
||||
https://api-docs.deepseek.com/
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_type: Union[ModelType, str],
|
||||
model_config_dict: Optional[Dict[str, Any]] = None,
|
||||
api_key: Optional[str] = None,
|
||||
url: Optional[str] = None,
|
||||
token_counter: Optional[BaseTokenCounter] = None,
|
||||
) -> None:
|
||||
if model_config_dict is None:
|
||||
model_config_dict = DeepSeekConfig().as_dict()
|
||||
api_key = api_key or os.environ.get("DEEPSEEK_API_KEY")
|
||||
url = url or os.environ.get(
|
||||
"DEEPSEEK_API_BASE_URL",
|
||||
"https://api.deepseek.com",
|
||||
)
|
||||
super().__init__(
|
||||
model_type, model_config_dict, api_key, url, token_counter
|
||||
)
|
||||
|
||||
self._client = OpenAI(
|
||||
timeout=180,
|
||||
max_retries=3,
|
||||
api_key=self._api_key,
|
||||
base_url=self._url,
|
||||
)
|
||||
|
||||
@property
|
||||
def token_counter(self) -> BaseTokenCounter:
|
||||
r"""Initialize the token counter for the model backend.
|
||||
|
||||
Returns:
|
||||
BaseTokenCounter: The token counter following the model's
|
||||
tokenization style.
|
||||
"""
|
||||
if not self._token_counter:
|
||||
self._token_counter = OpenAITokenCounter(
|
||||
model=ModelType.GPT_4O_MINI
|
||||
)
|
||||
return self._token_counter
|
||||
|
||||
@retry((ValueError, TypeError, json.decoder.JSONDecodeError), delay=10, logger=logger)
|
||||
def run(
|
||||
self,
|
||||
messages: List[OpenAIMessage],
|
||||
) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]:
|
||||
r"""Runs inference of DeepSeek chat completion.
|
||||
|
||||
Args:
|
||||
messages (List[OpenAIMessage]): Message list with the chat history
|
||||
in OpenAI API format.
|
||||
|
||||
Returns:
|
||||
Union[ChatCompletion, Stream[ChatCompletionChunk]]:
|
||||
`ChatCompletion` in the non-stream mode, or
|
||||
`Stream[ChatCompletionChunk]` in the stream mode.
|
||||
"""
|
||||
# deepseek reasoner has limitations
|
||||
# reference: https://api-docs.deepseek.com/guides/reasoning_model#api-parameters
|
||||
if self.model_type in [
|
||||
ModelType.DEEPSEEK_REASONER,
|
||||
]:
|
||||
import re
|
||||
|
||||
logger.warning(
|
||||
"You are using a DeepSeek Reasoner model, "
|
||||
"which has certain limitations, reference: "
|
||||
"`https://api-docs.deepseek.com/guides/reasoning_model#api-parameters`"
|
||||
)
|
||||
|
||||
# Check and remove unsupported parameters and reset the fixed
|
||||
# parameters
|
||||
unsupported_keys = [
|
||||
"temperature",
|
||||
"top_p",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"tools",
|
||||
]
|
||||
for key in unsupported_keys:
|
||||
if key in self.model_config_dict:
|
||||
del self.model_config_dict[key]
|
||||
# Remove thinking content from messages before sending to API
|
||||
# This ensures only the final response is sent, excluding
|
||||
# intermediate thought processes
|
||||
messages = [
|
||||
{ # type: ignore[misc]
|
||||
**msg,
|
||||
'content': re.sub(
|
||||
r'<think>.*?</think>',
|
||||
'',
|
||||
msg['content'], # type: ignore[arg-type]
|
||||
flags=re.DOTALL,
|
||||
).strip(),
|
||||
}
|
||||
for msg in messages
|
||||
]
|
||||
|
||||
response = self._client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=self.model_type,
|
||||
**self.model_config_dict,
|
||||
)
|
||||
|
||||
# Handle reasoning content with <think> tags at the beginning
|
||||
if (
|
||||
self.model_type
|
||||
in [
|
||||
ModelType.DEEPSEEK_REASONER,
|
||||
]
|
||||
and os.environ.get("GET_REASONING_CONTENT", "false").lower()
|
||||
== "true"
|
||||
):
|
||||
reasoning_content = response.choices[0].message.reasoning_content
|
||||
combined_content = (
|
||||
f"<think>\n{reasoning_content}\n</think>\n"
|
||||
if reasoning_content
|
||||
else ""
|
||||
) + response.choices[0].message.content
|
||||
|
||||
response = ChatCompletion.construct(
|
||||
id=response.id,
|
||||
choices=[
|
||||
dict(
|
||||
index=response.choices[0].index,
|
||||
message={
|
||||
"role": response.choices[0].message.role,
|
||||
"content": combined_content,
|
||||
"tool_calls": None,
|
||||
},
|
||||
finish_reason=response.choices[0].finish_reason
|
||||
if response.choices[0].finish_reason
|
||||
else None,
|
||||
)
|
||||
],
|
||||
created=response.created,
|
||||
model=response.model,
|
||||
object="chat.completion",
|
||||
usage=response.usage,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def check_model_config(self):
|
||||
r"""Check whether the model configuration contains any
|
||||
unexpected arguments to DeepSeek API.
|
||||
|
||||
Raises:
|
||||
ValueError: If the model configuration dictionary contains any
|
||||
unexpected arguments to DeepSeek API.
|
||||
"""
|
||||
for param in self.model_config_dict:
|
||||
if param not in DEEPSEEK_API_PARAMS:
|
||||
raise ValueError(
|
||||
f"Unexpected argument `{param}` is "
|
||||
"input into DeepSeek model backend."
|
||||
)
|
||||
|
||||
@property
|
||||
def stream(self) -> bool:
|
||||
r"""Returns whether the model is in stream mode, which sends partial
|
||||
results each time.
|
||||
|
||||
Returns:
|
||||
bool: Whether the model is in stream mode.
|
||||
"""
|
||||
return self.model_config_dict.get("stream", False)
|
||||
@@ -1,147 +0,0 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
class FishAudioModel:
|
||||
r"""Provides access to FishAudio's Text-to-Speech (TTS) and Speech_to_Text
|
||||
(STT) models.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
url: Optional[str] = None,
|
||||
) -> None:
|
||||
r"""Initialize an instance of FishAudioModel.
|
||||
|
||||
Args:
|
||||
api_key (Optional[str]): API key for FishAudio service. If not
|
||||
provided, the environment variable `FISHAUDIO_API_KEY` will be
|
||||
used.
|
||||
url (Optional[str]): Base URL for FishAudio API. If not provided,
|
||||
the environment variable `FISHAUDIO_API_BASE_URL` will be used.
|
||||
"""
|
||||
from fish_audio_sdk import Session
|
||||
|
||||
self._api_key = api_key or os.environ.get("FISHAUDIO_API_KEY")
|
||||
self._url = url or os.environ.get(
|
||||
"FISHAUDIO_API_BASE_URL", "https://api.fish.audio"
|
||||
)
|
||||
self.session = Session(apikey=self._api_key, base_url=self._url)
|
||||
|
||||
|
||||
def text_to_speech(
|
||||
self,
|
||||
input: str,
|
||||
storage_path: str,
|
||||
reference_id: Optional[str] = None,
|
||||
reference_audio: Optional[str] = None,
|
||||
reference_audio_text: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
r"""Convert text to speech and save the output to a file.
|
||||
|
||||
Args:
|
||||
input_text (str): The text to convert to speech.
|
||||
storage_path (str): The file path where the resulting speech will
|
||||
be saved.
|
||||
reference_id (Optional[str]): An optional reference ID to
|
||||
associate with the request. (default: :obj:`None`)
|
||||
reference_audio (Optional[str]): Path to an audio file for
|
||||
reference speech. (default: :obj:`None`)
|
||||
reference_audio_text (Optional[str]): Text for the reference audio.
|
||||
(default: :obj:`None`)
|
||||
**kwargs (Any): Additional parameters to pass to the TTS request.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the reference audio file cannot be found.
|
||||
"""
|
||||
from fish_audio_sdk import ReferenceAudio, TTSRequest
|
||||
|
||||
directory = os.path.dirname(storage_path)
|
||||
if directory and not os.path.exists(directory):
|
||||
os.makedirs(directory)
|
||||
|
||||
if not reference_audio:
|
||||
with open(f"{storage_path}", "wb") as f:
|
||||
for chunk in self.session.tts(
|
||||
TTSRequest(reference_id=reference_id, text=input, **kwargs)
|
||||
):
|
||||
f.write(chunk)
|
||||
else:
|
||||
if not os.path.exists(reference_audio):
|
||||
raise FileNotFoundError(
|
||||
f"Reference audio file not found: {reference_audio}"
|
||||
)
|
||||
if not reference_audio_text:
|
||||
raise ValueError("reference_audio_text should be provided")
|
||||
with open(f"{reference_audio}", "rb") as audio_file:
|
||||
with open(f"{storage_path}", "wb") as f:
|
||||
for chunk in self.session.tts(
|
||||
TTSRequest(
|
||||
text=input,
|
||||
references=[
|
||||
ReferenceAudio(
|
||||
audio=audio_file.read(),
|
||||
text=reference_audio_text,
|
||||
)
|
||||
],
|
||||
**kwargs,
|
||||
)
|
||||
):
|
||||
f.write(chunk)
|
||||
|
||||
def speech_to_text(
|
||||
self,
|
||||
audio_file_path: str,
|
||||
language: Optional[str] = None,
|
||||
ignore_timestamps: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
r"""Convert speech to text from an audio file.
|
||||
|
||||
Args:
|
||||
audio_file_path (str): The path to the audio file to transcribe.
|
||||
language (Optional[str]): The language of the audio. (default:
|
||||
:obj:`None`)
|
||||
ignore_timestamps (Optional[bool]): Whether to ignore timestamps.
|
||||
(default: :obj:`None`)
|
||||
**kwargs (Any): Additional parameters to pass to the STT request.
|
||||
|
||||
Returns:
|
||||
str: The transcribed text from the audio.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the audio file cannot be found.
|
||||
"""
|
||||
from fish_audio_sdk import ASRRequest
|
||||
|
||||
if not os.path.exists(audio_file_path):
|
||||
raise FileNotFoundError(f"Audio file not found: {audio_file_path}")
|
||||
|
||||
with open(f"{audio_file_path}", "rb") as audio_file:
|
||||
audio_data = audio_file.read()
|
||||
|
||||
response = self.session.asr(
|
||||
ASRRequest(
|
||||
audio=audio_data,
|
||||
language=language,
|
||||
ignore_timestamps=ignore_timestamps,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
return response.text
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user