mirror of
https://github.com/camel-ai/owl.git
synced 2025-12-26 10:07:51 +08:00
initial update
This commit is contained in:
commit
62da328e7b
5
.gitignore
vendored
Normal file
5
.gitignore
vendored
Normal file
@ -0,0 +1,5 @@
|
||||
.dist
|
||||
deep-swarm/data
|
||||
deep-swarm/tmp
|
||||
deep-swarm/.env
|
||||
deep-swarm/utils/__pycache__/
|
||||
45
README.md
Normal file
45
README.md
Normal file
@ -0,0 +1,45 @@
|
||||
# DeepSwarm
|
||||
|
||||
## Overview
|
||||
|
||||
DeepSwarm is a multi-agent framework based from [camel](https://github.com/camel-ai/camel/). It achieved open-source state-of-the-art performance on the [GAIA](https://huggingface.co/datasets/gaia-benchmark/GAIA) benchmark.
|
||||
|
||||
## Quickstart
|
||||
|
||||
It is recommended to run the code in linux environment.
|
||||
To get started, follow these steps:
|
||||
|
||||
1. **Clone the Github repository:**
|
||||
|
||||
```bash
|
||||
$ git clone xxx
|
||||
```
|
||||
|
||||
2. **Set up Python Environment:**
|
||||
|
||||
```bash
|
||||
$ conda create -n deepswarm python=3.11
|
||||
$ conda activate deepswarm
|
||||
```
|
||||
|
||||
3. **Install Dependencies:**
|
||||
|
||||
```bash
|
||||
$ pip install -r requirements.txt
|
||||
```
|
||||
|
||||
4. **Set API Keys:** We use `dotenv` to manage API keys. Please copy and check the `.env.example` file to `.env` and fill in the necessary API keys.
|
||||
|
||||
5. **Run the Demo Code:**
|
||||
|
||||
```bash
|
||||
$ python run.py
|
||||
```
|
||||
|
||||
## Reproduce the Results in GAIA
|
||||
|
||||
We have provided a script to reproduce the results in GAIA. You can check the `run_gaia_roleplaying.py` file and run the following command:
|
||||
|
||||
```bash
|
||||
$ python run_gaia_roleplaying.py
|
||||
```
|
||||
25
deep-swarm/.env_template
Normal file
25
deep-swarm/.env_template
Normal file
@ -0,0 +1,25 @@
|
||||
|
||||
# OPENAI API
|
||||
OPENAI_API_KEY = ""
|
||||
|
||||
# Hugging Face API (https://huggingface.co/join)
|
||||
HF_TOKEN=""
|
||||
|
||||
# Qwen API (https://help.aliyun.com/document_detail/611472.html)
|
||||
QWEN_API_KEY=""
|
||||
|
||||
|
||||
#===========================================
|
||||
# Tools & Services API
|
||||
#===========================================
|
||||
|
||||
# Google Search API (https://developers.google.com/custom-search/v1/overview)
|
||||
GOOGLE_API_KEY=""
|
||||
SEARCH_ENGINE_ID=""
|
||||
|
||||
# Chunkr API (https://chunkr.ai/)
|
||||
CHUNKR_API_KEY=""
|
||||
|
||||
# Firecrawl API (https://www.firecrawl.dev/)
|
||||
FIRECRAWL_API_KEY=""
|
||||
|
||||
25
deep-swarm/camel/__init__.py
Normal file
25
deep-swarm/camel/__init__.py
Normal file
@ -0,0 +1,25 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
from camel.logger import disable_logging, enable_logging, set_log_level
|
||||
|
||||
__version__ = '0.2.11'
|
||||
|
||||
__all__ = [
|
||||
'__version__',
|
||||
'camel',
|
||||
'disable_logging',
|
||||
'enable_logging',
|
||||
'set_log_level',
|
||||
]
|
||||
BIN
deep-swarm/camel/__pycache__/__init__.cpython-311.pyc
Normal file
BIN
deep-swarm/camel/__pycache__/__init__.cpython-311.pyc
Normal file
Binary file not shown.
BIN
deep-swarm/camel/__pycache__/generators.cpython-311.pyc
Normal file
BIN
deep-swarm/camel/__pycache__/generators.cpython-311.pyc
Normal file
Binary file not shown.
BIN
deep-swarm/camel/__pycache__/human.cpython-311.pyc
Normal file
BIN
deep-swarm/camel/__pycache__/human.cpython-311.pyc
Normal file
Binary file not shown.
BIN
deep-swarm/camel/__pycache__/logger.cpython-311.pyc
Normal file
BIN
deep-swarm/camel/__pycache__/logger.cpython-311.pyc
Normal file
Binary file not shown.
44
deep-swarm/camel/agents/__init__.py
Normal file
44
deep-swarm/camel/agents/__init__.py
Normal file
@ -0,0 +1,44 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from .base import BaseAgent
|
||||
from .chat_agent import ChatAgent
|
||||
from .critic_agent import CriticAgent
|
||||
from .embodied_agent import EmbodiedAgent
|
||||
from .knowledge_graph_agent import KnowledgeGraphAgent
|
||||
from .role_assignment_agent import RoleAssignmentAgent
|
||||
from .search_agent import SearchAgent
|
||||
from .task_agent import (
|
||||
TaskCreationAgent,
|
||||
TaskPlannerAgent,
|
||||
TaskPrioritizationAgent,
|
||||
TaskSpecifyAgent,
|
||||
)
|
||||
from .tool_agents.base import BaseToolAgent
|
||||
from .tool_agents.hugging_face_tool_agent import HuggingFaceToolAgent
|
||||
|
||||
__all__ = [
|
||||
'BaseAgent',
|
||||
'ChatAgent',
|
||||
'TaskSpecifyAgent',
|
||||
'TaskPlannerAgent',
|
||||
'TaskCreationAgent',
|
||||
'TaskPrioritizationAgent',
|
||||
'CriticAgent',
|
||||
'BaseToolAgent',
|
||||
'HuggingFaceToolAgent',
|
||||
'EmbodiedAgent',
|
||||
'RoleAssignmentAgent',
|
||||
'SearchAgent',
|
||||
'KnowledgeGraphAgent',
|
||||
]
|
||||
BIN
deep-swarm/camel/agents/__pycache__/__init__.cpython-311.pyc
Normal file
BIN
deep-swarm/camel/agents/__pycache__/__init__.cpython-311.pyc
Normal file
Binary file not shown.
BIN
deep-swarm/camel/agents/__pycache__/base.cpython-311.pyc
Normal file
BIN
deep-swarm/camel/agents/__pycache__/base.cpython-311.pyc
Normal file
Binary file not shown.
BIN
deep-swarm/camel/agents/__pycache__/chat_agent.cpython-311.pyc
Normal file
BIN
deep-swarm/camel/agents/__pycache__/chat_agent.cpython-311.pyc
Normal file
Binary file not shown.
BIN
deep-swarm/camel/agents/__pycache__/critic_agent.cpython-311.pyc
Normal file
BIN
deep-swarm/camel/agents/__pycache__/critic_agent.cpython-311.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
deep-swarm/camel/agents/__pycache__/search_agent.cpython-311.pyc
Normal file
BIN
deep-swarm/camel/agents/__pycache__/search_agent.cpython-311.pyc
Normal file
Binary file not shown.
BIN
deep-swarm/camel/agents/__pycache__/task_agent.cpython-311.pyc
Normal file
BIN
deep-swarm/camel/agents/__pycache__/task_agent.cpython-311.pyc
Normal file
Binary file not shown.
29
deep-swarm/camel/agents/base.py
Normal file
29
deep-swarm/camel/agents/base.py
Normal file
@ -0,0 +1,29 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
|
||||
class BaseAgent(ABC):
|
||||
r"""An abstract base class for all CAMEL agents."""
|
||||
|
||||
@abstractmethod
|
||||
def reset(self, *args: Any, **kwargs: Any) -> Any:
|
||||
r"""Resets the agent to its initial state."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def step(self, *args: Any, **kwargs: Any) -> Any:
|
||||
r"""Performs a single step of the agent."""
|
||||
pass
|
||||
1423
deep-swarm/camel/agents/chat_agent.py
Normal file
1423
deep-swarm/camel/agents/chat_agent.py
Normal file
File diff suppressed because it is too large
Load Diff
202
deep-swarm/camel/agents/critic_agent.py
Normal file
202
deep-swarm/camel/agents/critic_agent.py
Normal file
@ -0,0 +1,202 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
import random
|
||||
import warnings
|
||||
from typing import Any, Dict, Optional, Sequence
|
||||
|
||||
from colorama import Fore
|
||||
|
||||
from camel.agents.chat_agent import ChatAgent
|
||||
from camel.memories import AgentMemory
|
||||
from camel.messages import BaseMessage
|
||||
from camel.models import BaseModelBackend
|
||||
from camel.responses import ChatAgentResponse
|
||||
from camel.utils import get_first_int, print_text_animated
|
||||
|
||||
# AgentOps decorator setting
|
||||
try:
|
||||
import os
|
||||
|
||||
if os.getenv("AGENTOPS_API_KEY") is not None:
|
||||
from agentops import track_agent
|
||||
else:
|
||||
raise ImportError
|
||||
except (ImportError, AttributeError):
|
||||
from camel.utils import track_agent
|
||||
|
||||
|
||||
@track_agent(name="CriticAgent")
|
||||
class CriticAgent(ChatAgent):
|
||||
r"""A class for the critic agent that assists in selecting an option.
|
||||
|
||||
Args:
|
||||
system_message (BaseMessage): The system message for the critic
|
||||
agent.
|
||||
model (BaseModelBackend, optional): The model backend to use for
|
||||
generating responses. (default: :obj:`OpenAIModel` with
|
||||
`GPT_4O_MINI`)
|
||||
message_window_size (int, optional): The maximum number of previous
|
||||
messages to include in the context window. If `None`, no windowing
|
||||
is performed. (default: :obj:`6`)
|
||||
retry_attempts (int, optional): The number of retry attempts if the
|
||||
critic fails to return a valid option. (default: :obj:`2`)
|
||||
verbose (bool, optional): Whether to print the critic's messages.
|
||||
logger_color (Any): The color of the menu options displayed to the
|
||||
user. (default: :obj:`Fore.MAGENTA`)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
system_message: BaseMessage,
|
||||
model: Optional[BaseModelBackend] = None,
|
||||
memory: Optional[AgentMemory] = None,
|
||||
message_window_size: int = 6,
|
||||
retry_attempts: int = 2,
|
||||
verbose: bool = False,
|
||||
logger_color: Any = Fore.MAGENTA,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
system_message,
|
||||
model=model,
|
||||
memory=memory,
|
||||
message_window_size=message_window_size,
|
||||
)
|
||||
self.options_dict: Dict[str, str] = dict()
|
||||
self.retry_attempts = retry_attempts
|
||||
self.verbose = verbose
|
||||
self.logger_color = logger_color
|
||||
|
||||
def flatten_options(self, messages: Sequence[BaseMessage]) -> str:
|
||||
r"""Flattens the options to the critic.
|
||||
|
||||
Args:
|
||||
messages (Sequence[BaseMessage]): A list of `BaseMessage` objects.
|
||||
|
||||
Returns:
|
||||
str: A string containing the flattened options to the critic.
|
||||
"""
|
||||
options = [message.content for message in messages]
|
||||
flatten_options = (
|
||||
f"> Proposals from "
|
||||
f"{messages[0].role_name} ({messages[0].role_type}). "
|
||||
"Please choose an option:\n"
|
||||
)
|
||||
for index, option in enumerate(options):
|
||||
flatten_options += f"Option {index + 1}:\n{option}\n\n"
|
||||
self.options_dict[str(index + 1)] = option
|
||||
format = (
|
||||
f"Please first enter your choice ([1-{len(self.options_dict)}]) "
|
||||
"and then your explanation and comparison: "
|
||||
)
|
||||
return flatten_options + format
|
||||
|
||||
def get_option(self, input_message: BaseMessage) -> str:
|
||||
r"""Gets the option selected by the critic.
|
||||
|
||||
Args:
|
||||
input_message (BaseMessage): A `BaseMessage` object representing
|
||||
the input message.
|
||||
|
||||
Returns:
|
||||
str: The option selected by the critic.
|
||||
"""
|
||||
# TODO: Add support for editing options by the critic.
|
||||
msg_content = input_message.content
|
||||
i = 0
|
||||
while i < self.retry_attempts:
|
||||
critic_response = self.step(input_message)
|
||||
|
||||
if critic_response.msgs is None or len(critic_response.msgs) == 0:
|
||||
raise RuntimeError("Got None critic messages.")
|
||||
if critic_response.terminated:
|
||||
raise RuntimeError("Critic step failed.")
|
||||
|
||||
critic_msg = critic_response.msg
|
||||
if self.verbose:
|
||||
print_text_animated(
|
||||
self.logger_color + "\n> Critic response: "
|
||||
f"\x1b[3m{critic_msg.content}\x1b[0m\n"
|
||||
)
|
||||
choice = self.parse_critic(critic_msg)
|
||||
|
||||
if choice in self.options_dict:
|
||||
return self.options_dict[choice]
|
||||
else:
|
||||
input_message = BaseMessage(
|
||||
role_name=input_message.role_name,
|
||||
role_type=input_message.role_type,
|
||||
meta_dict=input_message.meta_dict,
|
||||
content="> Invalid choice. Please choose again.\n"
|
||||
+ msg_content,
|
||||
)
|
||||
i += 1
|
||||
warnings.warn(
|
||||
"Critic failed to get a valid option. "
|
||||
f"After {self.retry_attempts} attempts. "
|
||||
"Returning a random option."
|
||||
)
|
||||
return random.choice(list(self.options_dict.values()))
|
||||
|
||||
def parse_critic(self, critic_msg: BaseMessage) -> Optional[str]:
|
||||
r"""Parses the critic's message and extracts the choice.
|
||||
|
||||
Args:
|
||||
critic_msg (BaseMessage): A `BaseMessage` object representing the
|
||||
critic's response.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The critic's choice as a string, or None if the
|
||||
message could not be parsed.
|
||||
"""
|
||||
choice = str(get_first_int(critic_msg.content))
|
||||
return choice
|
||||
|
||||
def reduce_step(
|
||||
self,
|
||||
input_messages: Sequence[BaseMessage],
|
||||
) -> ChatAgentResponse:
|
||||
r"""Performs one step of the conversation by flattening options to the
|
||||
critic, getting the option, and parsing the choice.
|
||||
|
||||
Args:
|
||||
input_messages (Sequence[BaseMessage]): A list of BaseMessage
|
||||
objects.
|
||||
|
||||
Returns:
|
||||
ChatAgentResponse: A `ChatAgentResponse` object includes the
|
||||
critic's choice.
|
||||
"""
|
||||
meta_chat_message = BaseMessage(
|
||||
role_name=input_messages[0].role_name,
|
||||
role_type=input_messages[0].role_type,
|
||||
meta_dict=input_messages[0].meta_dict,
|
||||
content="",
|
||||
)
|
||||
|
||||
flatten_options = self.flatten_options(input_messages)
|
||||
if self.verbose:
|
||||
print_text_animated(
|
||||
self.logger_color + f"\x1b[3m{flatten_options}\x1b[0m\n"
|
||||
)
|
||||
input_msg = meta_chat_message.create_new_instance(flatten_options)
|
||||
|
||||
option = self.get_option(input_msg)
|
||||
output_msg = meta_chat_message.create_new_instance(option)
|
||||
|
||||
# TODO: The return `info` can be improved.
|
||||
return ChatAgentResponse(
|
||||
msgs=[output_msg],
|
||||
terminated=False,
|
||||
info={},
|
||||
)
|
||||
303
deep-swarm/camel/agents/deductive_reasoner_agent.py
Normal file
303
deep-swarm/camel/agents/deductive_reasoner_agent.py
Normal file
@ -0,0 +1,303 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
import re
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from camel.agents.chat_agent import ChatAgent
|
||||
from camel.logger import get_logger
|
||||
from camel.messages import BaseMessage
|
||||
from camel.models import BaseModelBackend
|
||||
from camel.prompts import TextPrompt
|
||||
from camel.types import RoleType
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# AgentOps decorator setting
|
||||
try:
|
||||
import os
|
||||
|
||||
if os.getenv("AGENTOPS_API_KEY") is not None:
|
||||
from agentops import track_agent
|
||||
else:
|
||||
raise ImportError
|
||||
except (ImportError, AttributeError):
|
||||
from camel.utils import track_agent
|
||||
|
||||
|
||||
@track_agent(name="DeductiveReasonerAgent")
|
||||
class DeductiveReasonerAgent(ChatAgent):
|
||||
r"""An agent responsible for deductive reasoning. Model of deductive
|
||||
reasoning:
|
||||
- L: A ⊕ C -> q * B
|
||||
- A represents the known starting state.
|
||||
- B represents the known target state.
|
||||
- C represents the conditions required to transition from A to B.
|
||||
- Q represents the quality or effectiveness of the transition from
|
||||
A to B.
|
||||
- L represents the path or process from A to B.
|
||||
|
||||
Args:
|
||||
model (BaseModelBackend, optional): The model backend to use for
|
||||
generating responses. (default: :obj:`OpenAIModel` with
|
||||
`GPT_4O_MINI`)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[BaseModelBackend] = None,
|
||||
) -> None:
|
||||
system_message = BaseMessage(
|
||||
role_name="Insight Agent",
|
||||
role_type=RoleType.ASSISTANT,
|
||||
meta_dict=None,
|
||||
content="You assign roles based on tasks.",
|
||||
)
|
||||
super().__init__(system_message, model=model)
|
||||
|
||||
def deduce_conditions_and_quality(
|
||||
self,
|
||||
starting_state: str,
|
||||
target_state: str,
|
||||
role_descriptions_dict: Optional[Dict[str, str]] = None,
|
||||
) -> Dict[str, Union[List[str], Dict[str, str]]]:
|
||||
r"""Derives the conditions and quality from the starting state and the
|
||||
target state based on the model of the deductive reasoning and the
|
||||
knowledge base. It can optionally consider the roles involved in the
|
||||
scenario, which allows tailoring the output more closely to the AI
|
||||
agent's environment.
|
||||
|
||||
Args:
|
||||
starting_state (str): The initial or starting state from which
|
||||
conditions are deduced.
|
||||
target_state (str): The target state of the task.
|
||||
role_descriptions_dict (Optional[Dict[str, str]], optional): The
|
||||
descriptions of the roles. (default: :obj:`None`)
|
||||
role_descriptions_dict (Optional[Dict[str, str]], optional): A
|
||||
dictionary describing the roles involved in the scenario. This
|
||||
is optional and can be used to provide a context for the
|
||||
CAMEL's role-playing, enabling the generation of more relevant
|
||||
and tailored conditions and quality assessments. This could be
|
||||
generated using a `RoleAssignmentAgent()` or defined manually
|
||||
by the user.
|
||||
|
||||
Returns:
|
||||
Dict[str, Union[List[str], Dict[str, str]]]: A dictionary with the
|
||||
extracted data from the message. The dictionary contains three
|
||||
keys:
|
||||
- 'conditions': A list where each key is a condition ID and
|
||||
each value is the corresponding condition text.
|
||||
- 'labels': A list of label strings extracted from the message.
|
||||
- 'quality': A string of quality assessment strings extracted
|
||||
from the message.
|
||||
"""
|
||||
self.reset()
|
||||
|
||||
deduce_prompt = """You are a deductive reasoner. You are tasked to
|
||||
complete the TASK based on the THOUGHT OF DEDUCTIVE REASONING, the
|
||||
STARTING STATE A and the TARGET STATE B. You are given the CONTEXT
|
||||
CONTENT to help you complete the TASK.
|
||||
Your answer MUST strictly adhere to the structure of ANSWER TEMPLATE, ONLY
|
||||
fill in the BLANKs, and DO NOT alter or modify any other part of the template
|
||||
|
||||
===== MODELING OF DEDUCTIVE REASONING =====
|
||||
You are tasked with understanding a mathematical model based on the components
|
||||
${A, B, C, Q, L}$. In this model: ``L: A ⊕ C -> q * B``.
|
||||
- $A$ represents the known starting state.
|
||||
- $B$ represents the known target state.
|
||||
- $C$ represents the conditions required to transition from $A$ to $B$.
|
||||
- $Q$ represents the quality or effectiveness of the transition from $A$ to
|
||||
$B$.
|
||||
- $L$ represents the path or process from $A$ to $B$.
|
||||
|
||||
===== THOUGHT OF DEDUCTIVE REASONING =====
|
||||
1. Define the Parameters of A and B:
|
||||
- Characterization: Before delving into transitions, thoroughly understand
|
||||
the nature and boundaries of both $A$ and $B$. This includes the type,
|
||||
properties, constraints, and possible interactions between the two.
|
||||
- Contrast and Compare: Highlight the similarities and differences between
|
||||
$A$ and $B$. This comparative analysis will give an insight into what
|
||||
needs changing and what remains constant.
|
||||
2. Historical & Empirical Analysis:
|
||||
- Previous Transitions according to the Knowledge Base of GPT: (if
|
||||
applicable) Extract conditions and patterns from the historical instances
|
||||
where a similar transition from a state comparable to $A$ moved towards
|
||||
$B$.
|
||||
- Scientific Principles: (if applicable) Consider the underlying
|
||||
scientific principles governing or related to the states and their
|
||||
transition. For example, if $A$ and $B$ are physical states, laws of
|
||||
physics might apply.
|
||||
3. Logical Deduction of Conditions ($C$):
|
||||
- Direct Path Analysis: What are the immediate and direct conditions
|
||||
required to move from $A$ to $B$?
|
||||
- Intermediate States: Are there states between $A$ and $B$ that must be
|
||||
traversed or can be used to make the transition smoother or more
|
||||
efficient? If yes, what is the content?
|
||||
- Constraints & Limitations: Identify potential barriers or restrictions
|
||||
in moving from $A$ to $B$. These can be external (e.g., environmental
|
||||
factors) or internal (properties of $A$ or $B$).
|
||||
- Resource and Information Analysis: What resources and information are
|
||||
required for the transition? This could be time, entity, factor, code
|
||||
language, software platform, unknowns, etc.
|
||||
- External Influences: Consider socio-economic, political, or
|
||||
environmental factors (if applicable) that could influence the transition
|
||||
conditions.
|
||||
- Creative/Heuristic Reasoning: Open your mind to multiple possible $C$'s,
|
||||
no matter how unconventional they might seem. Utilize analogies,
|
||||
metaphors, or brainstorming techniques to envision possible conditions or
|
||||
paths from $A$ to $B$.
|
||||
- The conditions $C$ should be multiple but in one sentence. And each
|
||||
condition should be concerned with one aspect/entity.
|
||||
4. Entity/Label Recognition of Conditions ($C$):
|
||||
- Identify and categorize entities of Conditions ($C$) such as the names,
|
||||
locations, dates, specific technical terms or contextual parameters that
|
||||
might be associated with events, innovations post-2022.
|
||||
- The output of the entities/labels will be used as tags or labels for
|
||||
semantic similarity searches. The entities/labels may be the words, or
|
||||
phrases, each of them should contain valuable, high information entropy
|
||||
information, and should be independent.
|
||||
- Ensure that the identified entities are formatted in a manner suitable
|
||||
for database indexing and retrieval. Organize the entities into
|
||||
categories, and combine the category with its instance into a continuous
|
||||
phrase, without using colons or other separators.
|
||||
- Format these entities for database indexing: output the category rather
|
||||
than its instance/content into a continuous phrase. For example, instead
|
||||
of "Jan. 02", identify it as "Event time".
|
||||
5. Quality Assessment ($Q$):
|
||||
- Efficiency: How efficient is the transition from $A$ to $B$, which
|
||||
measures the resources used versus the desired outcome?
|
||||
- Effectiveness: Did the transition achieve the desired outcome or was the
|
||||
target state achieved as intended?
|
||||
- Safety & Risks: Assess any risks associated with the transition and the
|
||||
measures to mitigate them.
|
||||
- Feedback Mechanisms: Incorporate feedback loops to continuously monitor
|
||||
and adjust the quality of transition, making it more adaptive.
|
||||
6. Iterative Evaluation:
|
||||
- Test & Refine: Based on the initially deduced conditions and assessed
|
||||
quality, iterate the process to refine and optimize the transition. This
|
||||
might involve tweaking conditions, employing different paths, or changing
|
||||
resources.
|
||||
- Feedback Integration: Use feedback to make improvements and increase the
|
||||
quality of the transition.
|
||||
7. Real-world scenarios often present challenges that may not be captured by
|
||||
models and frameworks. While using the model, maintain an adaptive mindset:
|
||||
- Scenario Exploration: Continuously imagine various possible scenarios,
|
||||
both positive and negative, to prepare for unexpected events.
|
||||
- Flexibility: Be prepared to modify conditions ($C$) or alter the path/
|
||||
process ($L$) if unforeseen challenges arise.
|
||||
- Feedback Integration: Rapidly integrate feedback from actual
|
||||
implementations to adjust the model's application, ensuring relevancy and
|
||||
effectiveness.
|
||||
|
||||
===== TASK =====
|
||||
Given the starting state $A$ and the target state $B$, assuming that a path
|
||||
$L$ always exists between $A$ and $B$, how can one deduce or identify the
|
||||
necessary conditions $C$ and the quality $Q$ of the transition?
|
||||
|
||||
===== STARTING STATE $A$ =====
|
||||
{starting_state}
|
||||
|
||||
===== TARGET STATE $B$ =====
|
||||
{target_state}
|
||||
|
||||
{role_with_description_prompt}
|
||||
===== ANSWER TEMPLATE =====
|
||||
- Characterization and comparison of $A$ and $B$:\n<BLANK>
|
||||
- Historical & Empirical Analysis:\n<BLANK>/None
|
||||
- Logical Deduction of Conditions ($C$) (multiple conditions can be deduced):
|
||||
condition <NUM>:
|
||||
<BLANK>.
|
||||
- Entity/Label Recognition of Conditions:\n[<BLANK>, <BLANK>, ...] (include
|
||||
square brackets)
|
||||
- Quality Assessment ($Q$) (do not use symbols):
|
||||
<BLANK>.
|
||||
- Iterative Evaluation:\n<BLANK>/None"""
|
||||
|
||||
if role_descriptions_dict is not None:
|
||||
role_names = role_descriptions_dict.keys()
|
||||
role_with_description_prompt = (
|
||||
"===== ROLES WITH DESCRIPTIONS =====\n"
|
||||
+ "\n".join(
|
||||
f"{role_name}:\n{role_descriptions_dict[role_name]}\n"
|
||||
for role_name in role_names
|
||||
)
|
||||
+ "\n\n"
|
||||
)
|
||||
else:
|
||||
role_with_description_prompt = ""
|
||||
deduce_prompt = TextPrompt(deduce_prompt)
|
||||
|
||||
deduce = deduce_prompt.format(
|
||||
starting_state=starting_state,
|
||||
target_state=target_state,
|
||||
role_with_description_prompt=role_with_description_prompt,
|
||||
)
|
||||
|
||||
conditions_and_quality_generation_msg = BaseMessage.make_user_message(
|
||||
role_name="Deductive Reasoner", content=deduce
|
||||
)
|
||||
|
||||
response = self.step(
|
||||
input_message=conditions_and_quality_generation_msg
|
||||
)
|
||||
|
||||
if response.terminated:
|
||||
raise RuntimeError(
|
||||
"Deduction failed. Error:\n" + f"{response.info}"
|
||||
)
|
||||
msg: BaseMessage = response.msg
|
||||
logger.info(f"Message content:\n{msg.content}")
|
||||
|
||||
# Extract the conditions from the message
|
||||
conditions_dict = {
|
||||
f"condition {i}": cdt.replace("<", "")
|
||||
.replace(">", "")
|
||||
.strip()
|
||||
.strip('\n')
|
||||
for i, cdt in re.findall(
|
||||
r"condition (\d+):\s*(.+?)(?=condition \d+|- Entity)",
|
||||
msg.content,
|
||||
re.DOTALL,
|
||||
)
|
||||
}
|
||||
|
||||
# Extract the labels from the message
|
||||
labels = [
|
||||
label.strip().strip('\n').strip("\"'")
|
||||
for label in re.findall(
|
||||
r"Entity/Label Recognition of Conditions:\n\[(.+?)\]",
|
||||
msg.content,
|
||||
re.DOTALL,
|
||||
)[0].split(",")
|
||||
]
|
||||
|
||||
# Extract the quality from the message
|
||||
quality = next(
|
||||
q.strip().strip('\n')
|
||||
for q in re.findall(
|
||||
r"Quality Assessment \(\$Q\$\) \(do not use symbols\):"
|
||||
r"\n(.+?)- Iterative",
|
||||
msg.content,
|
||||
re.DOTALL,
|
||||
)
|
||||
)
|
||||
|
||||
# Convert them into JSON format
|
||||
conditions_and_quality_json: Dict[
|
||||
str, Union[List[str], Dict[str, str]]
|
||||
] = {}
|
||||
conditions_and_quality_json["conditions"] = conditions_dict
|
||||
conditions_and_quality_json["labels"] = labels
|
||||
conditions_and_quality_json["evaluate_quality"] = quality
|
||||
|
||||
return conditions_and_quality_json
|
||||
201
deep-swarm/camel/agents/embodied_agent.py
Normal file
201
deep-swarm/camel/agents/embodied_agent.py
Normal file
@ -0,0 +1,201 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from colorama import Fore
|
||||
|
||||
from camel.agents.chat_agent import ChatAgent
|
||||
from camel.agents.tool_agents.base import BaseToolAgent
|
||||
from camel.interpreters import (
|
||||
BaseInterpreter,
|
||||
InternalPythonInterpreter,
|
||||
SubprocessInterpreter,
|
||||
)
|
||||
from camel.messages import BaseMessage
|
||||
from camel.models import BaseModelBackend
|
||||
from camel.responses import ChatAgentResponse
|
||||
from camel.utils import print_text_animated
|
||||
|
||||
# AgentOps decorator setting
|
||||
try:
|
||||
import os
|
||||
|
||||
if os.getenv("AGENTOPS_API_KEY") is not None:
|
||||
from agentops import track_agent
|
||||
else:
|
||||
raise ImportError
|
||||
except (ImportError, AttributeError):
|
||||
from camel.utils import track_agent
|
||||
|
||||
|
||||
@track_agent(name="EmbodiedAgent")
|
||||
class EmbodiedAgent(ChatAgent):
|
||||
r"""Class for managing conversations of CAMEL Embodied Agents.
|
||||
|
||||
Args:
|
||||
system_message (BaseMessage): The system message for the chat agent.
|
||||
model (BaseModelBackend, optional): The model backend to use for
|
||||
generating responses. (default: :obj:`OpenAIModel` with
|
||||
`GPT_4O_MINI`)
|
||||
message_window_size (int, optional): The maximum number of previous
|
||||
messages to include in the context window. If `None`, no windowing
|
||||
is performed. (default: :obj:`None`)
|
||||
tool_agents (List[BaseToolAgent], optional): The tools agents to use in
|
||||
the embodied agent. (default: :obj:`None`)
|
||||
code_interpreter (BaseInterpreter, optional): The code interpreter to
|
||||
execute codes. If `code_interpreter` and `tool_agent` are both
|
||||
`None`, default to `SubProcessInterpreter`. If `code_interpreter`
|
||||
is `None` and `tool_agents` is not `None`, default to
|
||||
`InternalPythonInterpreter`. (default: :obj:`None`)
|
||||
verbose (bool, optional): Whether to print the critic's messages.
|
||||
logger_color (Any): The color of the logger displayed to the user.
|
||||
(default: :obj:`Fore.MAGENTA`)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
system_message: BaseMessage,
|
||||
model: Optional[BaseModelBackend] = None,
|
||||
message_window_size: Optional[int] = None,
|
||||
tool_agents: Optional[List[BaseToolAgent]] = None,
|
||||
code_interpreter: Optional[BaseInterpreter] = None,
|
||||
verbose: bool = False,
|
||||
logger_color: Any = Fore.MAGENTA,
|
||||
) -> None:
|
||||
self.tool_agents = tool_agents
|
||||
self.code_interpreter: BaseInterpreter
|
||||
if code_interpreter is not None:
|
||||
self.code_interpreter = code_interpreter
|
||||
elif self.tool_agents:
|
||||
self.code_interpreter = InternalPythonInterpreter()
|
||||
else:
|
||||
self.code_interpreter = SubprocessInterpreter()
|
||||
|
||||
if self.tool_agents:
|
||||
system_message = self._set_tool_agents(system_message)
|
||||
self.verbose = verbose
|
||||
self.logger_color = logger_color
|
||||
super().__init__(
|
||||
system_message=system_message,
|
||||
model=model,
|
||||
message_window_size=message_window_size,
|
||||
)
|
||||
|
||||
def _set_tool_agents(self, system_message: BaseMessage) -> BaseMessage:
|
||||
action_space_prompt = self._get_tool_agents_prompt()
|
||||
result_message = system_message.create_new_instance(
|
||||
content=system_message.content.format(
|
||||
action_space=action_space_prompt
|
||||
)
|
||||
)
|
||||
if self.tool_agents is not None:
|
||||
self.code_interpreter.update_action_space(
|
||||
{tool.name: tool for tool in self.tool_agents}
|
||||
)
|
||||
return result_message
|
||||
|
||||
def _get_tool_agents_prompt(self) -> str:
|
||||
r"""Returns the action space prompt.
|
||||
|
||||
Returns:
|
||||
str: The action space prompt.
|
||||
"""
|
||||
if self.tool_agents is not None:
|
||||
return "\n".join(
|
||||
[
|
||||
f"*** {tool.name} ***:\n {tool.description}"
|
||||
for tool in self.tool_agents
|
||||
]
|
||||
)
|
||||
else:
|
||||
return ""
|
||||
|
||||
def get_tool_agent_names(self) -> List[str]:
|
||||
r"""Returns the names of tool agents.
|
||||
|
||||
Returns:
|
||||
List[str]: The names of tool agents.
|
||||
"""
|
||||
if self.tool_agents is not None:
|
||||
return [tool.name for tool in self.tool_agents]
|
||||
else:
|
||||
return []
|
||||
|
||||
# ruff: noqa: E501
|
||||
def step(self, input_message: BaseMessage) -> ChatAgentResponse: # type: ignore[override]
|
||||
r"""Performs a step in the conversation.
|
||||
|
||||
Args:
|
||||
input_message (BaseMessage): The input message.
|
||||
|
||||
Returns:
|
||||
ChatAgentResponse: A struct containing the output messages,
|
||||
a boolean indicating whether the chat session has terminated,
|
||||
and information about the chat session.
|
||||
"""
|
||||
response = super().step(input_message)
|
||||
|
||||
if response.msgs is None or len(response.msgs) == 0:
|
||||
raise RuntimeError("Got None output messages.")
|
||||
if response.terminated:
|
||||
raise RuntimeError(f"{self.__class__.__name__} step failed.")
|
||||
|
||||
# NOTE: Only single output messages are supported
|
||||
explanations, codes = response.msg.extract_text_and_code_prompts()
|
||||
|
||||
if self.verbose:
|
||||
for explanation, code in zip(explanations, codes):
|
||||
print_text_animated(
|
||||
self.logger_color + f"> Explanation:\n{explanation}"
|
||||
)
|
||||
print_text_animated(self.logger_color + f"> Code:\n{code}")
|
||||
|
||||
if len(explanations) > len(codes):
|
||||
print_text_animated(
|
||||
self.logger_color + f"> Explanation:\n{explanations[-1]}"
|
||||
)
|
||||
|
||||
content = response.msg.content
|
||||
|
||||
if codes is not None:
|
||||
try:
|
||||
content = "\n> Executed Results:\n"
|
||||
for block_idx, code in enumerate(codes):
|
||||
executed_output = self.code_interpreter.run(
|
||||
code, code.code_type
|
||||
)
|
||||
content += (
|
||||
f"Executing code block {block_idx}: {{\n"
|
||||
+ executed_output
|
||||
+ "}\n"
|
||||
)
|
||||
except InterruptedError as e:
|
||||
content = (
|
||||
f"\n> Running code fail: {e}\n"
|
||||
"Please regenerate the code."
|
||||
)
|
||||
|
||||
# TODO: Handle errors
|
||||
content = input_message.content + f"\n> Embodied Actions:\n{content}"
|
||||
message = BaseMessage(
|
||||
input_message.role_name,
|
||||
input_message.role_type,
|
||||
input_message.meta_dict,
|
||||
content,
|
||||
)
|
||||
return ChatAgentResponse(
|
||||
msgs=[message],
|
||||
terminated=response.terminated,
|
||||
info=response.info,
|
||||
)
|
||||
259
deep-swarm/camel/agents/knowledge_graph_agent.py
Normal file
259
deep-swarm/camel/agents/knowledge_graph_agent.py
Normal file
@ -0,0 +1,259 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from unstructured.documents.elements import Element
|
||||
|
||||
from camel.agents import ChatAgent
|
||||
from camel.messages import BaseMessage
|
||||
from camel.models import BaseModelBackend
|
||||
from camel.prompts import TextPrompt
|
||||
from camel.storages.graph_storages.graph_element import (
|
||||
GraphElement,
|
||||
Node,
|
||||
Relationship,
|
||||
)
|
||||
from camel.types import RoleType
|
||||
|
||||
# AgentOps decorator setting
|
||||
try:
|
||||
import os
|
||||
|
||||
if os.getenv("AGENTOPS_API_KEY") is not None:
|
||||
from agentops import track_agent
|
||||
else:
|
||||
raise ImportError
|
||||
except (ImportError, AttributeError):
|
||||
from camel.utils import track_agent
|
||||
|
||||
|
||||
text_prompt = """
|
||||
You are tasked with extracting nodes and relationships from given content and
|
||||
structures them into Node and Relationship objects. Here's the outline of what
|
||||
you needs to do:
|
||||
|
||||
Content Extraction:
|
||||
You should be able to process input content and identify entities mentioned
|
||||
within it.
|
||||
Entities can be any noun phrases or concepts that represent distinct entities
|
||||
in the context of the given content.
|
||||
|
||||
Node Extraction:
|
||||
For each identified entity, you should create a Node object.
|
||||
Each Node object should have a unique identifier (id) and a type (type).
|
||||
Additional properties associated with the node can also be extracted and
|
||||
stored.
|
||||
|
||||
Relationship Extraction:
|
||||
You should identify relationships between entities mentioned in the content.
|
||||
For each relationship, create a Relationship object.
|
||||
A Relationship object should have a subject (subj) and an object (obj) which
|
||||
are Node objects representing the entities involved in the relationship.
|
||||
Each relationship should also have a type (type), and additional properties if
|
||||
applicable.
|
||||
|
||||
Output Formatting:
|
||||
The extracted nodes and relationships should be formatted as instances of the
|
||||
provided Node and Relationship classes.
|
||||
Ensure that the extracted data adheres to the structure defined by the classes.
|
||||
Output the structured data in a format that can be easily validated against
|
||||
the provided code.
|
||||
|
||||
Instructions for you:
|
||||
Read the provided content thoroughly.
|
||||
Identify distinct entities mentioned in the content and categorize them as
|
||||
nodes.
|
||||
Determine relationships between these entities and represent them as directed
|
||||
relationships.
|
||||
Provide the extracted nodes and relationships in the specified format below.
|
||||
Example for you:
|
||||
|
||||
Example Content:
|
||||
"John works at XYZ Corporation. He is a software engineer. The company is
|
||||
located in New York City."
|
||||
|
||||
Expected Output:
|
||||
|
||||
Nodes:
|
||||
|
||||
Node(id='John', type='Person')
|
||||
Node(id='XYZ Corporation', type='Organization')
|
||||
Node(id='New York City', type='Location')
|
||||
|
||||
Relationships:
|
||||
|
||||
Relationship(subj=Node(id='John', type='Person'), obj=Node(id='XYZ
|
||||
Corporation', type='Organization'), type='WorksAt')
|
||||
Relationship(subj=Node(id='John', type='Person'), obj=Node(id='New York City',
|
||||
type='Location'), type='ResidesIn')
|
||||
|
||||
===== TASK =====
|
||||
Please extracts nodes and relationships from given content and structures them
|
||||
into Node and Relationship objects.
|
||||
|
||||
{task}
|
||||
"""
|
||||
|
||||
|
||||
@track_agent(name="KnowledgeGraphAgent")
|
||||
class KnowledgeGraphAgent(ChatAgent):
|
||||
r"""An agent that can extract node and relationship information for
|
||||
different entities from given `Element` content.
|
||||
|
||||
Attributes:
|
||||
task_prompt (TextPrompt): A prompt for the agent to extract node and
|
||||
relationship information for different entities.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[BaseModelBackend] = None,
|
||||
) -> None:
|
||||
r"""Initialize the `KnowledgeGraphAgent`.
|
||||
|
||||
Args:
|
||||
model (BaseModelBackend, optional): The model backend to use for
|
||||
generating responses. (default: :obj:`OpenAIModel` with
|
||||
`GPT_4O_MINI`)
|
||||
"""
|
||||
system_message = BaseMessage(
|
||||
role_name="Graphify",
|
||||
role_type=RoleType.ASSISTANT,
|
||||
meta_dict=None,
|
||||
content="Your mission is to transform unstructured content "
|
||||
"into structured graph data. Extract nodes and relationships with "
|
||||
"precision, and let the connections unfold. Your graphs will "
|
||||
"illuminate the hidden connections within the chaos of "
|
||||
"information.",
|
||||
)
|
||||
super().__init__(system_message, model=model)
|
||||
|
||||
def run(
|
||||
self,
|
||||
element: "Element",
|
||||
parse_graph_elements: bool = False,
|
||||
) -> Union[str, GraphElement]:
|
||||
r"""Run the agent to extract node and relationship information.
|
||||
|
||||
Args:
|
||||
element (Element): The input element.
|
||||
parse_graph_elements (bool, optional): Whether to parse into
|
||||
`GraphElement`. Defaults to `False`.
|
||||
|
||||
Returns:
|
||||
Union[str, GraphElement]: The extracted node and relationship
|
||||
information. If `parse_graph_elements` is `True` then return
|
||||
`GraphElement`, else return `str`.
|
||||
"""
|
||||
self.reset()
|
||||
self.element = element
|
||||
|
||||
knowledge_graph_prompt = TextPrompt(text_prompt)
|
||||
knowledge_graph_generation = knowledge_graph_prompt.format(
|
||||
task=str(element)
|
||||
)
|
||||
|
||||
knowledge_graph_generation_msg = BaseMessage.make_user_message(
|
||||
role_name="Graphify", content=knowledge_graph_generation
|
||||
)
|
||||
|
||||
response = self.step(input_message=knowledge_graph_generation_msg)
|
||||
|
||||
content = response.msg.content
|
||||
|
||||
if parse_graph_elements:
|
||||
content = self._parse_graph_elements(content)
|
||||
|
||||
return content
|
||||
|
||||
def _validate_node(self, node: Node) -> bool:
|
||||
r"""Validate if the object is a valid Node.
|
||||
|
||||
Args:
|
||||
node (Node): Object to be validated.
|
||||
|
||||
Returns:
|
||||
bool: True if the object is a valid Node, False otherwise.
|
||||
"""
|
||||
return (
|
||||
isinstance(node, Node)
|
||||
and isinstance(node.id, (str, int))
|
||||
and isinstance(node.type, str)
|
||||
)
|
||||
|
||||
def _validate_relationship(self, relationship: Relationship) -> bool:
|
||||
r"""Validate if the object is a valid Relationship.
|
||||
|
||||
Args:
|
||||
relationship (Relationship): Object to be validated.
|
||||
|
||||
Returns:
|
||||
bool: True if the object is a valid Relationship, False otherwise.
|
||||
"""
|
||||
return (
|
||||
isinstance(relationship, Relationship)
|
||||
and self._validate_node(relationship.subj)
|
||||
and self._validate_node(relationship.obj)
|
||||
and isinstance(relationship.type, str)
|
||||
)
|
||||
|
||||
def _parse_graph_elements(self, input_string: str) -> GraphElement:
|
||||
r"""Parses graph elements from given content.
|
||||
|
||||
Args:
|
||||
input_string (str): The input content.
|
||||
|
||||
Returns:
|
||||
GraphElement: The parsed graph elements.
|
||||
"""
|
||||
import re
|
||||
|
||||
# Regular expressions to extract nodes and relationships
|
||||
node_pattern = r"Node\(id='(.*?)', type='(.*?)'\)"
|
||||
rel_pattern = (
|
||||
r"Relationship\(subj=Node\(id='(.*?)', type='(.*?)'\), "
|
||||
r"obj=Node\(id='(.*?)', type='(.*?)'\), type='(.*?)'\)"
|
||||
)
|
||||
|
||||
nodes = {}
|
||||
relationships = []
|
||||
|
||||
# Extract nodes
|
||||
for match in re.finditer(node_pattern, input_string):
|
||||
id, type = match.groups()
|
||||
properties = {'source': 'agent_created'}
|
||||
if id not in nodes:
|
||||
node = Node(id=id, type=type, properties=properties)
|
||||
if self._validate_node(node):
|
||||
nodes[id] = node
|
||||
|
||||
# Extract relationships
|
||||
for match in re.finditer(rel_pattern, input_string):
|
||||
subj_id, subj_type, obj_id, obj_type, rel_type = match.groups()
|
||||
properties = {'source': 'agent_created'}
|
||||
if subj_id in nodes and obj_id in nodes:
|
||||
subj = nodes[subj_id]
|
||||
obj = nodes[obj_id]
|
||||
relationship = Relationship(
|
||||
subj=subj, obj=obj, type=rel_type, properties=properties
|
||||
)
|
||||
if self._validate_relationship(relationship):
|
||||
relationships.append(relationship)
|
||||
|
||||
return GraphElement(
|
||||
nodes=list(nodes.values()),
|
||||
relationships=relationships,
|
||||
source=self.element,
|
||||
)
|
||||
141
deep-swarm/camel/agents/role_assignment_agent.py
Normal file
141
deep-swarm/camel/agents/role_assignment_agent.py
Normal file
@ -0,0 +1,141 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
import re
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from camel.agents.chat_agent import ChatAgent
|
||||
from camel.messages import BaseMessage
|
||||
from camel.models import BaseModelBackend
|
||||
from camel.prompts import TextPrompt
|
||||
from camel.types import RoleType
|
||||
|
||||
# AgentOps decorator setting
|
||||
try:
|
||||
import os
|
||||
|
||||
if os.getenv("AGENTOPS_API_KEY") is not None:
|
||||
from agentops import track_agent
|
||||
else:
|
||||
raise ImportError
|
||||
except (ImportError, AttributeError):
|
||||
from camel.utils import track_agent
|
||||
|
||||
|
||||
@track_agent(name="RoleAssignmentAgent")
|
||||
class RoleAssignmentAgent(ChatAgent):
|
||||
r"""An agent that generates role names based on the task prompt.
|
||||
|
||||
Args:
|
||||
model (BaseModelBackend, optional): The model backend to use for
|
||||
generating responses. (default: :obj:`OpenAIModel` with
|
||||
`GPT_4O_MINI`)
|
||||
|
||||
Attributes:
|
||||
role_assignment_prompt (TextPrompt): A prompt for the agent to generate
|
||||
role names.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[BaseModelBackend] = None,
|
||||
) -> None:
|
||||
system_message = BaseMessage(
|
||||
role_name="Role Assigner",
|
||||
role_type=RoleType.ASSISTANT,
|
||||
meta_dict=None,
|
||||
content="You assign roles based on tasks.",
|
||||
)
|
||||
super().__init__(system_message, model=model)
|
||||
|
||||
def run(
|
||||
self,
|
||||
task_prompt: Union[str, TextPrompt],
|
||||
num_roles: int = 2,
|
||||
) -> Dict[str, str]:
|
||||
r"""Generate role names based on the input task prompt.
|
||||
|
||||
Args:
|
||||
task_prompt (Union[str, TextPrompt]): The prompt
|
||||
for the task based on which the roles are to be generated.
|
||||
num_roles (int, optional): The number of roles to generate.
|
||||
(default: :obj:`2`)
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: A dictionary mapping role names to their
|
||||
descriptions.
|
||||
"""
|
||||
self.reset()
|
||||
|
||||
expert_prompt = "===== ANSWER PROMPT =====\n" + "\n".join(
|
||||
f"Domain expert {i + 1}: <BLANK>\n"
|
||||
f"Associated competencies, characteristics, duties "
|
||||
f"and workflows: <BLANK>. End."
|
||||
for i in range(num_roles or 0)
|
||||
)
|
||||
role_assignment_generation_prompt = TextPrompt(
|
||||
"You are a role assignment agent, and you're in charge of "
|
||||
+ "recruiting {num_roles} experts for the following task."
|
||||
+ "\n==== TASK =====\n {task}\n\n"
|
||||
+ "Identify the domain experts you'd recruit and detail their "
|
||||
+ "associated competencies, characteristics, duties and workflows "
|
||||
+ "to complete the task.\n "
|
||||
+ "Your answer MUST adhere to the format of ANSWER PROMPT, and "
|
||||
+ "ONLY answer the BLANKs.\n"
|
||||
+ expert_prompt
|
||||
)
|
||||
role_assignment_generation = role_assignment_generation_prompt.format(
|
||||
num_roles=num_roles, task=task_prompt
|
||||
)
|
||||
|
||||
role_assignment_generation_msg = BaseMessage.make_user_message(
|
||||
role_name="Role Assigner", content=role_assignment_generation
|
||||
)
|
||||
|
||||
response = self.step(input_message=role_assignment_generation_msg)
|
||||
|
||||
msg = response.msg # type: BaseMessage
|
||||
terminated = response.terminated
|
||||
|
||||
# Distribute the output completions into role names and descriptions
|
||||
role_names = [
|
||||
desc.replace("<|", "").replace("|>", "")
|
||||
for desc in re.findall(
|
||||
r"Domain expert \d: (.+?)\nAssociated competencies,",
|
||||
msg.content,
|
||||
re.DOTALL,
|
||||
)
|
||||
]
|
||||
role_descriptions = [
|
||||
desc.replace("<|", "").replace("|>", "")
|
||||
for desc in re.findall(
|
||||
r"Associated competencies, characteristics, "
|
||||
r"duties and workflows: (.+?) End.",
|
||||
msg.content,
|
||||
re.DOTALL,
|
||||
)
|
||||
]
|
||||
|
||||
if len(role_names) != num_roles or len(role_descriptions) != num_roles:
|
||||
raise RuntimeError(
|
||||
"Got None or insufficient information of roles."
|
||||
)
|
||||
if terminated:
|
||||
raise RuntimeError("Role assignment failed.")
|
||||
|
||||
role_descriptions_dict = {
|
||||
role_name: description
|
||||
for role_name, description in zip(role_names, role_descriptions)
|
||||
}
|
||||
|
||||
return role_descriptions_dict
|
||||
133
deep-swarm/camel/agents/search_agent.py
Normal file
133
deep-swarm/camel/agents/search_agent.py
Normal file
@ -0,0 +1,133 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from typing import Optional
|
||||
|
||||
from camel.agents.chat_agent import ChatAgent
|
||||
from camel.messages import BaseMessage
|
||||
from camel.models import BaseModelBackend
|
||||
from camel.prompts import TextPrompt
|
||||
from camel.types import RoleType
|
||||
from camel.utils import create_chunks
|
||||
|
||||
# AgentOps decorator setting
|
||||
try:
|
||||
import os
|
||||
|
||||
if os.getenv("AGENTOPS_API_KEY") is not None:
|
||||
from agentops import track_agent
|
||||
else:
|
||||
raise ImportError
|
||||
except (ImportError, AttributeError):
|
||||
from camel.utils import track_agent
|
||||
|
||||
|
||||
@track_agent(name="SearchAgent")
|
||||
class SearchAgent(ChatAgent):
|
||||
r"""An agent that summarizes text based on a query and evaluates the
|
||||
relevance of an answer.
|
||||
|
||||
Args:
|
||||
model (BaseModelBackend, optional): The model backend to use for
|
||||
generating responses. (default: :obj:`OpenAIModel` with
|
||||
`GPT_4O_MINI`)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[BaseModelBackend] = None,
|
||||
) -> None:
|
||||
system_message = BaseMessage(
|
||||
role_name="Assistant",
|
||||
role_type=RoleType.ASSISTANT,
|
||||
meta_dict=None,
|
||||
content="You are a helpful assistant.",
|
||||
)
|
||||
super().__init__(system_message, model=model)
|
||||
|
||||
def summarize_text(self, text: str, query: str) -> str:
|
||||
r"""Summarize the information from the text, base on the query.
|
||||
|
||||
Args:
|
||||
text (str): Text to summarize.
|
||||
query (str): What information you want.
|
||||
|
||||
Returns:
|
||||
str: Strings with information.
|
||||
"""
|
||||
self.reset()
|
||||
|
||||
summary_prompt = TextPrompt(
|
||||
'''Gather information from this text that relative to the
|
||||
question, but do not directly answer the question.\nquestion:
|
||||
{query}\ntext '''
|
||||
)
|
||||
summary_prompt = summary_prompt.format(query=query)
|
||||
# Max length of each chunk
|
||||
max_len = 3000
|
||||
results = ""
|
||||
chunks = create_chunks(text, max_len)
|
||||
# Summarize
|
||||
for i, chunk in enumerate(chunks, start=1):
|
||||
prompt = summary_prompt + str(i) + ": " + chunk
|
||||
user_msg = BaseMessage.make_user_message(
|
||||
role_name="User",
|
||||
content=prompt,
|
||||
)
|
||||
result = self.step(user_msg).msg.content
|
||||
results += result + "\n"
|
||||
|
||||
# Final summarization
|
||||
final_prompt = TextPrompt(
|
||||
'''Here are some summarized texts which split from one text. Using
|
||||
the information to answer the question. If can't find the answer,
|
||||
you must answer "I can not find the answer to the query" and
|
||||
explain why.\n Query:\n{query}.\n\nText:\n'''
|
||||
)
|
||||
final_prompt = final_prompt.format(query=query)
|
||||
prompt = final_prompt + results
|
||||
|
||||
user_msg = BaseMessage.make_user_message(
|
||||
role_name="User",
|
||||
content=prompt,
|
||||
)
|
||||
response = self.step(user_msg).msg.content
|
||||
|
||||
return response
|
||||
|
||||
def continue_search(self, query: str, answer: str) -> bool:
|
||||
r"""Ask whether to continue search or not based on the provided answer.
|
||||
|
||||
Args:
|
||||
query (str): The question.
|
||||
answer (str): The answer to the question.
|
||||
|
||||
Returns:
|
||||
bool: `True` if the user want to continue search, `False`
|
||||
otherwise.
|
||||
"""
|
||||
prompt = TextPrompt(
|
||||
"Do you think the ANSWER can answer the QUERY? "
|
||||
"Use only 'yes' or 'no' to answer.\n"
|
||||
"===== QUERY =====\n{query}\n\n"
|
||||
"===== ANSWER =====\n{answer}"
|
||||
)
|
||||
prompt = prompt.format(query=query, answer=answer)
|
||||
user_msg = BaseMessage.make_user_message(
|
||||
role_name="User",
|
||||
content=prompt,
|
||||
)
|
||||
response = self.step(user_msg).msg.content
|
||||
if "yes" in str(response).lower():
|
||||
return False
|
||||
return True
|
||||
410
deep-swarm/camel/agents/task_agent.py
Normal file
410
deep-swarm/camel/agents/task_agent.py
Normal file
@ -0,0 +1,410 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from camel.agents.chat_agent import ChatAgent
|
||||
from camel.messages import BaseMessage
|
||||
from camel.models import BaseModelBackend
|
||||
from camel.prompts import PromptTemplateGenerator, TextPrompt
|
||||
from camel.types import RoleType, TaskType
|
||||
from camel.utils import get_task_list
|
||||
|
||||
# AgentOps decorator setting
|
||||
try:
|
||||
import os
|
||||
|
||||
if os.getenv("AGENTOPS_API_KEY") is not None:
|
||||
from agentops import track_agent
|
||||
else:
|
||||
raise ImportError
|
||||
except (ImportError, AttributeError):
|
||||
from camel.utils import track_agent
|
||||
|
||||
|
||||
@track_agent(name="TaskSpecifyAgent")
|
||||
class TaskSpecifyAgent(ChatAgent):
|
||||
r"""An agent that specifies a given task prompt by prompting the user to
|
||||
provide more details.
|
||||
|
||||
Attributes:
|
||||
DEFAULT_WORD_LIMIT (int): The default word limit for the task prompt.
|
||||
task_specify_prompt (TextPrompt): The prompt for specifying the task.
|
||||
|
||||
Args:
|
||||
model (BaseModelBackend, optional): The model backend to use for
|
||||
generating responses. (default: :obj:`OpenAIModel` with
|
||||
`GPT_4O_MINI`)
|
||||
task_type (TaskType, optional): The type of task for which to generate
|
||||
a prompt. (default: :obj:`TaskType.AI_SOCIETY`)
|
||||
task_specify_prompt (Union[str, TextPrompt], optional): The prompt for
|
||||
specifying the task. (default: :obj:`None`)
|
||||
word_limit (int, optional): The word limit for the task prompt.
|
||||
(default: :obj:`50`)
|
||||
output_language (str, optional): The language to be output by the
|
||||
agent. (default: :obj:`None`)
|
||||
"""
|
||||
|
||||
DEFAULT_WORD_LIMIT = 50
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[BaseModelBackend] = None,
|
||||
task_type: TaskType = TaskType.AI_SOCIETY,
|
||||
task_specify_prompt: Optional[Union[str, TextPrompt]] = None,
|
||||
word_limit: int = DEFAULT_WORD_LIMIT,
|
||||
output_language: Optional[str] = None,
|
||||
) -> None:
|
||||
self.task_specify_prompt: Union[str, TextPrompt]
|
||||
if task_specify_prompt is None:
|
||||
task_specify_prompt_template = (
|
||||
PromptTemplateGenerator().get_task_specify_prompt(task_type)
|
||||
)
|
||||
|
||||
self.task_specify_prompt = task_specify_prompt_template.format(
|
||||
word_limit=word_limit
|
||||
)
|
||||
else:
|
||||
self.task_specify_prompt = TextPrompt(task_specify_prompt)
|
||||
|
||||
system_message = BaseMessage(
|
||||
role_name="Task Specifier",
|
||||
role_type=RoleType.ASSISTANT,
|
||||
meta_dict=None,
|
||||
content="You can make a task more specific.",
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
system_message,
|
||||
model=model,
|
||||
output_language=output_language,
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
task_prompt: Union[str, TextPrompt],
|
||||
meta_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> TextPrompt:
|
||||
r"""Specify the given task prompt by providing more details.
|
||||
|
||||
Args:
|
||||
task_prompt (Union[str, TextPrompt]): The original task
|
||||
prompt.
|
||||
meta_dict (Dict[str, Any], optional): A dictionary containing
|
||||
additional information to include in the prompt.
|
||||
(default: :obj:`None`)
|
||||
|
||||
Returns:
|
||||
TextPrompt: The specified task prompt.
|
||||
"""
|
||||
self.reset()
|
||||
task_specify_prompt = self.task_specify_prompt.format(task=task_prompt)
|
||||
|
||||
if meta_dict is not None:
|
||||
task_specify_prompt = task_specify_prompt.format(**meta_dict)
|
||||
task_msg = BaseMessage.make_user_message(
|
||||
role_name="Task Specifier", content=task_specify_prompt
|
||||
)
|
||||
specifier_response = self.step(task_msg)
|
||||
|
||||
if specifier_response.terminated:
|
||||
raise RuntimeError("Task specification failed.")
|
||||
if len(specifier_response.msgs) == 0:
|
||||
raise RuntimeError("Got no specification message.")
|
||||
|
||||
specified_task_msg = specifier_response.msgs[0]
|
||||
|
||||
return TextPrompt(specified_task_msg.content)
|
||||
|
||||
|
||||
@track_agent(name="TaskPlannerAgent")
|
||||
class TaskPlannerAgent(ChatAgent):
|
||||
r"""An agent that helps divide a task into subtasks based on the input
|
||||
task prompt.
|
||||
|
||||
Attributes:
|
||||
task_planner_prompt (TextPrompt): A prompt for the agent to divide
|
||||
the task into subtasks.
|
||||
|
||||
Args:
|
||||
model (BaseModelBackend, optional): The model backend to use for
|
||||
generating responses. (default: :obj:`OpenAIModel` with
|
||||
`GPT_4O_MINI`)
|
||||
output_language (str, optional): The language to be output by the
|
||||
agent. (default: :obj:`None`)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[BaseModelBackend] = None,
|
||||
output_language: Optional[str] = None,
|
||||
) -> None:
|
||||
self.task_planner_prompt = TextPrompt(
|
||||
"Divide this task into subtasks: {task}. Be concise."
|
||||
)
|
||||
system_message = BaseMessage(
|
||||
role_name="Task Planner",
|
||||
role_type=RoleType.ASSISTANT,
|
||||
meta_dict=None,
|
||||
content="You are a helpful task planner.",
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
system_message,
|
||||
model=model,
|
||||
output_language=output_language,
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
task_prompt: Union[str, TextPrompt],
|
||||
) -> TextPrompt:
|
||||
r"""Generate subtasks based on the input task prompt.
|
||||
|
||||
Args:
|
||||
task_prompt (Union[str, TextPrompt]): The prompt for the task to
|
||||
be divided into subtasks.
|
||||
|
||||
Returns:
|
||||
TextPrompt: A prompt for the subtasks generated by the agent.
|
||||
"""
|
||||
# TODO: Maybe include roles information.
|
||||
self.reset()
|
||||
task_planner_prompt = self.task_planner_prompt.format(task=task_prompt)
|
||||
|
||||
task_msg = BaseMessage.make_user_message(
|
||||
role_name="Task Planner", content=task_planner_prompt
|
||||
)
|
||||
|
||||
task_response = self.step(task_msg)
|
||||
|
||||
if task_response.terminated:
|
||||
raise RuntimeError("Task planning failed.")
|
||||
if len(task_response.msgs) == 0:
|
||||
raise RuntimeError("Got no task planning message.")
|
||||
|
||||
sub_tasks_msg = task_response.msgs[0]
|
||||
return TextPrompt(sub_tasks_msg.content)
|
||||
|
||||
|
||||
@track_agent(name="TaskCreationAgent")
|
||||
class TaskCreationAgent(ChatAgent):
|
||||
r"""An agent that helps create new tasks based on the objective
|
||||
and last completed task. Compared to :obj:`TaskPlannerAgent`,
|
||||
it's still a task planner, but it has more context information
|
||||
like last task and incomplete task list. Modified from
|
||||
`BabyAGI <https://github.com/yoheinakajima/babyagi>`_.
|
||||
|
||||
Attributes:
|
||||
task_creation_prompt (TextPrompt): A prompt for the agent to
|
||||
create new tasks.
|
||||
|
||||
Args:
|
||||
role_name (str): The role name of the Agent to create the task.
|
||||
objective (Union[str, TextPrompt]): The objective of the Agent to
|
||||
perform the task.
|
||||
model (BaseModelBackend, optional): The LLM backend to use for
|
||||
generating responses. (default: :obj:`OpenAIModel` with
|
||||
`GPT_4O_MINI`)
|
||||
output_language (str, optional): The language to be output by the
|
||||
agent. (default: :obj:`None`)
|
||||
message_window_size (int, optional): The maximum number of previous
|
||||
messages to include in the context window. If `None`, no windowing
|
||||
is performed. (default: :obj:`None`)
|
||||
max_task_num (int, optional): The maximum number of planned
|
||||
tasks in one round. (default: :obj:3)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
role_name: str,
|
||||
objective: Union[str, TextPrompt],
|
||||
model: Optional[BaseModelBackend] = None,
|
||||
output_language: Optional[str] = None,
|
||||
message_window_size: Optional[int] = None,
|
||||
max_task_num: Optional[int] = 3,
|
||||
) -> None:
|
||||
task_creation_prompt = TextPrompt(
|
||||
"""Create new a task with the following objective: {objective}.
|
||||
Never forget you are a Task Creator of {role_name}.
|
||||
You must instruct me based on my expertise and your needs to solve the task.
|
||||
You should consider past solved tasks and in-progress tasks: {task_list}.
|
||||
The new created tasks must not overlap with these past tasks.
|
||||
The result must be a numbered list in the format:
|
||||
|
||||
#. First Task
|
||||
#. Second Task
|
||||
#. Third Task
|
||||
|
||||
You can only give me up to {max_task_num} tasks at a time. \
|
||||
Each task should be concise, concrete and doable for a {role_name}.
|
||||
You should make task plan and not ask me questions.
|
||||
If you think no new tasks are needed right now, write "No tasks to add."
|
||||
Now start to give me new tasks one by one. No more than three tasks.
|
||||
Be concrete.
|
||||
"""
|
||||
)
|
||||
|
||||
self.task_creation_prompt = task_creation_prompt.format(
|
||||
objective=objective, role_name=role_name, max_task_num=max_task_num
|
||||
)
|
||||
self.objective = objective
|
||||
|
||||
system_message = BaseMessage(
|
||||
role_name="Task Creator",
|
||||
role_type=RoleType.ASSISTANT,
|
||||
meta_dict=None,
|
||||
content="You are a helpful task creator.",
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
system_message,
|
||||
model=model,
|
||||
output_language=output_language,
|
||||
message_window_size=message_window_size,
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
task_list: List[str],
|
||||
) -> List[str]:
|
||||
r"""Generate subtasks based on the previous task results and
|
||||
incomplete task list.
|
||||
|
||||
Args:
|
||||
task_list (List[str]): The completed or in-progress
|
||||
tasks which should not overlap with new created tasks.
|
||||
|
||||
Returns:
|
||||
List[str]: The new task list generated by the Agent.
|
||||
"""
|
||||
|
||||
if len(task_list) > 0:
|
||||
task_creation_prompt = self.task_creation_prompt.format(
|
||||
task_list=task_list
|
||||
)
|
||||
else:
|
||||
task_creation_prompt = self.task_creation_prompt.format(
|
||||
task_list=""
|
||||
)
|
||||
|
||||
task_msg = BaseMessage.make_user_message(
|
||||
role_name="Task Creator", content=task_creation_prompt
|
||||
)
|
||||
task_response = self.step(task_msg)
|
||||
|
||||
if task_response.terminated:
|
||||
raise RuntimeError("Task creation failed.")
|
||||
if len(task_response.msgs) == 0:
|
||||
raise RuntimeError("Got no task creation message.")
|
||||
|
||||
sub_tasks_msg = task_response.msgs[0]
|
||||
return get_task_list(sub_tasks_msg.content)
|
||||
|
||||
|
||||
@track_agent(name="TaskPrioritizationAgent")
|
||||
class TaskPrioritizationAgent(ChatAgent):
|
||||
r"""An agent that helps re-prioritize the task list and
|
||||
returns numbered prioritized list. Modified from
|
||||
`BabyAGI <https://github.com/yoheinakajima/babyagi>`_.
|
||||
|
||||
Attributes:
|
||||
task_prioritization_prompt (TextPrompt): A prompt for the agent to
|
||||
prioritize tasks.
|
||||
|
||||
Args:
|
||||
objective (Union[str, TextPrompt]): The objective of the Agent to
|
||||
perform the task.
|
||||
model (BaseModelBackend, optional): The LLM backend to use for
|
||||
generating responses. (default: :obj:`OpenAIModel` with
|
||||
`GPT_4O_MINI`)
|
||||
output_language (str, optional): The language to be output by the
|
||||
agent. (default: :obj:`None`)
|
||||
message_window_size (int, optional): The maximum number of previous
|
||||
messages to include in the context window. If `None`, no windowing
|
||||
is performed. (default: :obj:`None`)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
objective: Union[str, TextPrompt],
|
||||
model: Optional[BaseModelBackend] = None,
|
||||
output_language: Optional[str] = None,
|
||||
message_window_size: Optional[int] = None,
|
||||
) -> None:
|
||||
task_prioritization_prompt = TextPrompt(
|
||||
"""Prioritize the following tasks : {task_list}.
|
||||
Consider the ultimate objective of you: {objective}.
|
||||
Tasks should be sorted from highest to lowest priority, where higher-priority \
|
||||
tasks are those that act as pre-requisites or are more essential for meeting \
|
||||
the objective. Return one task per line in your response.
|
||||
Do not remove or modify any tasks.
|
||||
The result must be a numbered list in the format:
|
||||
|
||||
#. First task
|
||||
#. Second task
|
||||
|
||||
The entries must be consecutively numbered, starting with 1.
|
||||
The number of each entry must be followed by a period.
|
||||
Do not include any headers before your ranked list or follow your list \
|
||||
with any other output."""
|
||||
)
|
||||
|
||||
self.task_prioritization_prompt = task_prioritization_prompt.format(
|
||||
objective=objective
|
||||
)
|
||||
self.objective = objective
|
||||
|
||||
system_message = BaseMessage(
|
||||
role_name="Task Prioritizer",
|
||||
role_type=RoleType.ASSISTANT,
|
||||
meta_dict=None,
|
||||
content="You are a helpful task prioritizer.",
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
system_message,
|
||||
model=model,
|
||||
output_language=output_language,
|
||||
message_window_size=message_window_size,
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
task_list: List[str],
|
||||
) -> List[str]:
|
||||
r"""Prioritize the task list given the agent objective.
|
||||
|
||||
Args:
|
||||
task_list (List[str]): The unprioritized tasks of agent.
|
||||
|
||||
Returns:
|
||||
List[str]: The new prioritized task list generated by the Agent.
|
||||
"""
|
||||
task_prioritization_prompt = self.task_prioritization_prompt.format(
|
||||
task_list=task_list
|
||||
)
|
||||
|
||||
task_msg = BaseMessage.make_user_message(
|
||||
role_name="Task Prioritizer", content=task_prioritization_prompt
|
||||
)
|
||||
|
||||
task_response = self.step(task_msg)
|
||||
|
||||
if task_response.terminated:
|
||||
raise RuntimeError("Task prioritization failed.")
|
||||
if len(task_response.msgs) == 0:
|
||||
raise RuntimeError("Got no task prioritization message.")
|
||||
|
||||
sub_tasks_msg = task_response.msgs[0]
|
||||
return get_task_list(sub_tasks_msg.content)
|
||||
20
deep-swarm/camel/agents/tool_agents/__init__.py
Normal file
20
deep-swarm/camel/agents/tool_agents/__init__.py
Normal file
@ -0,0 +1,20 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from .base import BaseToolAgent
|
||||
from .hugging_face_tool_agent import HuggingFaceToolAgent
|
||||
|
||||
__all__ = [
|
||||
'BaseToolAgent',
|
||||
'HuggingFaceToolAgent',
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
39
deep-swarm/camel/agents/tool_agents/base.py
Normal file
39
deep-swarm/camel/agents/tool_agents/base.py
Normal file
@ -0,0 +1,39 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from camel.agents import BaseAgent
|
||||
|
||||
|
||||
class BaseToolAgent(BaseAgent):
|
||||
r"""Creates a :obj:`BaseToolAgent` object with the specified name and
|
||||
description.
|
||||
|
||||
Args:
|
||||
name (str): The name of the tool agent.
|
||||
description (str): The description of the tool agent.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, description: str) -> None:
|
||||
self.name = name
|
||||
self.description = description
|
||||
|
||||
def reset(self) -> None:
|
||||
r"""Resets the agent to its initial state."""
|
||||
pass
|
||||
|
||||
def step(self) -> None:
|
||||
r"""Performs a single step of the agent."""
|
||||
pass
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.name}: {self.description}"
|
||||
206
deep-swarm/camel/agents/tool_agents/hugging_face_tool_agent.py
Normal file
206
deep-swarm/camel/agents/tool_agents/hugging_face_tool_agent.py
Normal file
@ -0,0 +1,206 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from typing import Any, Optional
|
||||
|
||||
from camel.agents.tool_agents.base import BaseToolAgent
|
||||
|
||||
|
||||
# flake8: noqa :E501
|
||||
class HuggingFaceToolAgent(BaseToolAgent):
|
||||
r"""Tool agent for calling HuggingFace models. This agent is a wrapper
|
||||
around agents from the `transformers` library. For more information
|
||||
about the available models, please see the `transformers` documentation
|
||||
at https://huggingface.co/docs/transformers/transformers_agents.
|
||||
|
||||
Args:
|
||||
name (str): The name of the agent.
|
||||
*args (Any): Additional positional arguments to pass to the underlying
|
||||
Agent class.
|
||||
remote (bool, optional): Flag indicating whether to run the agent
|
||||
remotely. (default: :obj:`True`)
|
||||
**kwargs (Any): Additional keyword arguments to pass to the underlying
|
||||
Agent class.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
*args: Any,
|
||||
remote: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
try:
|
||||
# TODO: Support other tool agents
|
||||
import transformers
|
||||
from packaging import version
|
||||
|
||||
if version.parse(transformers.__version__) < version.parse(
|
||||
"4.31.0"
|
||||
):
|
||||
raise ValueError(
|
||||
"The version of \"transformers\" package should >= 4.31.0"
|
||||
)
|
||||
|
||||
from transformers.tools import OpenAiAgent
|
||||
from transformers.tools.agent_types import AgentImage
|
||||
except (ImportError, ValueError):
|
||||
raise ValueError(
|
||||
"Could not import transformers tool agents. "
|
||||
"Please setup the environment with "
|
||||
"pip install huggingface_hub==0.14.1 transformers==4.31.0 diffusers accelerate==0.20.3 datasets torch soundfile sentencepiece opencv-python"
|
||||
)
|
||||
self.agent_image_type = AgentImage
|
||||
self.agent = OpenAiAgent(*args, **kwargs)
|
||||
description = f"""The `{name}` is a tool agent that can perform a variety of tasks including:
|
||||
- Document question answering: given a document (such as a PDF) in image format, answer a question on this document
|
||||
- Text question answering: given a long text and a question, answer the question in the text
|
||||
- Unconditional image captioning: Caption the image!
|
||||
- Image question answering: given an image, answer a question on this image
|
||||
- Image segmentation: given an image and a prompt, output the segmentation mask of that prompt
|
||||
- Speech to text: given an audio recording of a person talking, transcribe the speech into text
|
||||
- Text to speech: convert text to speech
|
||||
- Zero-shot text classification: given a text and a list of labels, identify to which label the text corresponds the most
|
||||
- Text summarization: summarize a long text in one or a few sentences
|
||||
- Translation: translate the text into a given language
|
||||
- Text downloading: to download a text from a web URL
|
||||
- Text to image: generate an image according to a prompt, leveraging stable diffusion
|
||||
- Image transformation: modify an image given an initial image and a prompt, leveraging instruct pix2pix stable diffusion
|
||||
- Text to video: generate a small video according to a prompt
|
||||
|
||||
Here are some python code examples of what you can do with this agent:
|
||||
|
||||
Single execution (step) mode, the single execution method is when using the step() method of the agent:
|
||||
```
|
||||
# Text to image
|
||||
rivers_and_lakes_image = {name}.step("Draw me a picture of rivers and lakes.")
|
||||
rivers_and_lakes_image.save("./rivers_and_lakes_image.png")
|
||||
|
||||
# Text to image -> Image transformation
|
||||
sea_add_island_image = {name}.step("Draw me a picture of the sea then transform the picture to add an island")
|
||||
sea_add_island_image.save("./sea_add_island_image.png")
|
||||
|
||||
# If you'd like to keep a state across executions or to pass non-text objects to the agent,
|
||||
# you can do so by specifying variables that you would like the agent to use. For example,
|
||||
# you could generate the first image of rivers and lakes, and ask the model to update that picture to add an island by doing the following:
|
||||
picture = {name}.step("Generate a picture of rivers and lakes.")
|
||||
picture.save("./picture.png")
|
||||
updated_picture = {name}.step("Transform the image in `picture` to add an island to it.", picture=picture)
|
||||
updated_picture.save("./updated_picture.png")
|
||||
|
||||
capybara_sea_image = {name}.step("Draw me a picture of the `prompt`", prompt="a capybara swimming in the sea")
|
||||
capybara_sea_image.save("./capybara_sea_image.png")
|
||||
|
||||
# Document question answering
|
||||
answer = {name}.step(
|
||||
"In the following `document`, where will the TRRF Scientific Advisory Council Meeting take place?",
|
||||
document=document,
|
||||
)
|
||||
print(answer)
|
||||
|
||||
|
||||
# Text to image
|
||||
boat_image = {name}.step("Generate an image of a boat in the water")
|
||||
boat_image.save("./boat_image.png")
|
||||
|
||||
# Unconditional image captioning
|
||||
boat_image_caption = {name}.step("Can you caption the `boat_image`?", boat_image=boat_image)
|
||||
print(boat_image_caption)
|
||||
|
||||
# Text to image -> Unconditional image captioning -> Text to speech
|
||||
boat_audio = {name}.step("Can you generate an image of a boat? Please read out loud the contents of the image afterwards")
|
||||
|
||||
# Text downloading
|
||||
document = {name}.step("Download the text from http://hf.co")
|
||||
print(document)
|
||||
|
||||
# Text summarization
|
||||
summary = {name}.step("Summarize the following text: `document`", document=document)
|
||||
print(summary)
|
||||
|
||||
# Text downloading -> Text summarization -> Text to speech
|
||||
audio = {name}.step("Read out loud the summary of http://hf.co")
|
||||
```
|
||||
|
||||
Chat-based execution (chat), the agent also has a chat-based approach, using the chat() method:
|
||||
```
|
||||
# Clean the chat history
|
||||
{name}.reset()
|
||||
|
||||
# Text to image
|
||||
capybara_image = {name}.chat("Show me an an image of a capybara")
|
||||
capybara_image.save("./capybara_image.png")
|
||||
|
||||
# Image transformation
|
||||
transformed_capybara_image = {name}.chat("Transform the image so that it snows")
|
||||
transformed_capybara_image.save("./transformed_capybara_image.png")
|
||||
|
||||
# Image segmentation
|
||||
segmented_transformed_capybara_image = {name}.chat("Show me a mask of the snowy capybaras")
|
||||
segmented_transformed_capybara_image.save("./segmented_transformed_capybara_image.png")
|
||||
```
|
||||
"""
|
||||
super(HuggingFaceToolAgent, self).__init__(name, description)
|
||||
self.remote = remote
|
||||
|
||||
def reset(self) -> None:
|
||||
r"""Resets the chat history of the agent."""
|
||||
self.agent.prepare_for_new_chat()
|
||||
|
||||
def step(
|
||||
self,
|
||||
*args: Any,
|
||||
remote: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
r"""Runs the agent in single execution mode.
|
||||
|
||||
Args:
|
||||
*args (Any): Positional arguments to pass to the agent.
|
||||
remote (bool, optional): Flag indicating whether to run the agent
|
||||
remotely. Overrides the default setting. (default: :obj:`None`)
|
||||
**kwargs (Any): Keyword arguments to pass to the agent.
|
||||
|
||||
Returns:
|
||||
str: The response from the agent.
|
||||
"""
|
||||
if remote is None:
|
||||
remote = self.remote
|
||||
agent_output = self.agent.run(*args, remote=remote, **kwargs)
|
||||
if isinstance(agent_output, self.agent_image_type):
|
||||
agent_output = agent_output.to_raw()
|
||||
return agent_output
|
||||
|
||||
def chat(
|
||||
self,
|
||||
*args: Any,
|
||||
remote: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
r"""Runs the agent in a chat conversation mode.
|
||||
|
||||
Args:
|
||||
*args (Any): Positional arguments to pass to the agent.
|
||||
remote (bool, optional): Flag indicating whether to run the agent
|
||||
remotely. Overrides the default setting. (default: :obj:`None`)
|
||||
**kwargs (Any): Keyword arguments to pass to the agent.
|
||||
|
||||
Returns:
|
||||
str: The response from the agent.
|
||||
"""
|
||||
if remote is None:
|
||||
remote = self.remote
|
||||
agent_output = self.agent.chat(*args, remote=remote, **kwargs)
|
||||
if isinstance(agent_output, self.agent_image_type):
|
||||
agent_output = agent_output.to_raw()
|
||||
return agent_output
|
||||
17
deep-swarm/camel/benchmarks/__init__.py
Normal file
17
deep-swarm/camel/benchmarks/__init__.py
Normal file
@ -0,0 +1,17 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
from .base import BaseBenchmark
|
||||
|
||||
__all__ = ["BaseBenchmark"]
|
||||
BIN
deep-swarm/camel/benchmarks/__pycache__/__init__.cpython-311.pyc
Normal file
BIN
deep-swarm/camel/benchmarks/__pycache__/__init__.cpython-311.pyc
Normal file
Binary file not shown.
BIN
deep-swarm/camel/benchmarks/__pycache__/base.cpython-311.pyc
Normal file
BIN
deep-swarm/camel/benchmarks/__pycache__/base.cpython-311.pyc
Normal file
Binary file not shown.
152
deep-swarm/camel/benchmarks/base.py
Normal file
152
deep-swarm/camel/benchmarks/base.py
Normal file
@ -0,0 +1,152 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from camel.agents import ChatAgent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseBenchmark(ABC):
|
||||
r"""Base class for benchmarks.
|
||||
|
||||
Attributes:
|
||||
name (str): Name of the benchmark.
|
||||
data_dir (str): Path to the data directory.
|
||||
save_to (str): Path to save the results.
|
||||
processes (int): Number of processes to use for parallel
|
||||
processing. :(default: :obj:`1`)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, name: str, data_dir: str, save_to: str, processes: int = 1
|
||||
):
|
||||
r"""Initialize the benchmark.
|
||||
|
||||
Args:
|
||||
name (str): Name of the benchmark.
|
||||
data_dir (str): Path to the data directory.
|
||||
save_to (str): Path to save the results.
|
||||
processes (int): Number of processes to use for parallel
|
||||
processing. :(default: :obj:`1`)
|
||||
|
||||
"""
|
||||
self.name = name
|
||||
self.data_dir = Path(data_dir)
|
||||
self.processes = processes
|
||||
self.save_to = save_to
|
||||
if not self.data_dir.exists():
|
||||
logger.info(
|
||||
f"Data directory {data_dir} does not exist. Creating it."
|
||||
)
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
if not self.data_dir.is_dir():
|
||||
raise NotADirectoryError(
|
||||
f"Data directory {data_dir} is not a directory"
|
||||
)
|
||||
self._data: Dict[str, List[Dict[str, Any]]] = dict()
|
||||
self._results: List[Dict[str, Any]] = []
|
||||
|
||||
@abstractmethod
|
||||
def download(self) -> "BaseBenchmark":
|
||||
r"""Download the benchmark data.
|
||||
|
||||
Returns:
|
||||
BaseBenchmark: The benchmark instance.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load(self, force_download: bool = False) -> "BaseBenchmark":
|
||||
r"""Load the benchmark data.
|
||||
|
||||
Args:
|
||||
force_download (bool): Whether to force download the data.
|
||||
|
||||
Returns:
|
||||
BaseBenchmark: The benchmark instance.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def train(self) -> List[Dict[str, Any]]:
|
||||
r"""Get the training data.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: The training data.
|
||||
"""
|
||||
if not self._data:
|
||||
logger.info("Data not loaded. Loading data.")
|
||||
self.load()
|
||||
return self._data["train"]
|
||||
|
||||
@property
|
||||
def valid(self) -> List[Dict[str, Any]]:
|
||||
r"""Get the validation data.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: The validation data.
|
||||
"""
|
||||
if not self._data:
|
||||
logger.info("Data not loaded. Loading data.")
|
||||
self.load()
|
||||
return self._data["valid"]
|
||||
|
||||
@property
|
||||
def test(self) -> List[Dict[str, Any]]:
|
||||
r"""Get the test data.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: The test data.
|
||||
"""
|
||||
if not self._data:
|
||||
logger.info("Data not loaded. Loading data.")
|
||||
self.load()
|
||||
return self._data["test"]
|
||||
|
||||
@abstractmethod
|
||||
def run(
|
||||
self,
|
||||
agent: ChatAgent,
|
||||
on: Literal["train", "valid", "test"],
|
||||
randomize: bool = False,
|
||||
subset: Optional[int] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> "BaseBenchmark":
|
||||
r"""Run the benchmark.
|
||||
|
||||
Args:
|
||||
agent (ChatAgent): The chat agent.
|
||||
on (str): The data split to run the benchmark on.
|
||||
randomize (bool): Whether to randomize the data.
|
||||
subset (int): The subset of the data to run the benchmark on.
|
||||
|
||||
Returns:
|
||||
BaseBenchmark: The benchmark instance.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def results(self) -> List[Dict[str, Any]]:
|
||||
r"""Get the results.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: The results.
|
||||
"""
|
||||
return self._results
|
||||
34
deep-swarm/camel/bots/__init__.py
Normal file
34
deep-swarm/camel/bots/__init__.py
Normal file
@ -0,0 +1,34 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from .discord_app import DiscordApp
|
||||
from .slack.models import (
|
||||
SlackAppMentionEventBody,
|
||||
SlackAppMentionEventProfile,
|
||||
SlackAuthProfile,
|
||||
SlackEventBody,
|
||||
SlackEventProfile,
|
||||
)
|
||||
from .slack.slack_app import SlackApp
|
||||
from .telegram_bot import TelegramBot
|
||||
|
||||
__all__ = [
|
||||
'DiscordApp',
|
||||
'SlackApp',
|
||||
'SlackAppMentionEventBody',
|
||||
'SlackAppMentionEventProfile',
|
||||
'SlackAuthProfile',
|
||||
'SlackEventBody',
|
||||
'SlackEventProfile',
|
||||
'TelegramBot',
|
||||
]
|
||||
138
deep-swarm/camel/bots/discord_app.py
Normal file
138
deep-swarm/camel/bots/discord_app.py
Normal file
@ -0,0 +1,138 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from camel.utils import dependencies_required
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from discord import Message
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DiscordApp:
|
||||
r"""A class representing a Discord app that uses the `discord.py` library
|
||||
to interact with Discord servers.
|
||||
|
||||
This bot can respond to messages in specific channels and only reacts to
|
||||
messages that mention the bot.
|
||||
|
||||
Attributes:
|
||||
channel_ids (Optional[List[int]]): A list of allowed channel IDs. If
|
||||
provided, the bot will only respond to messages in these channels.
|
||||
token (Optional[str]): The Discord bot token used for authentication.
|
||||
"""
|
||||
|
||||
@dependencies_required('discord')
|
||||
def __init__(
|
||||
self,
|
||||
channel_ids: Optional[List[int]] = None,
|
||||
token: Optional[str] = None,
|
||||
) -> None:
|
||||
r"""Initialize the DiscordApp instance by setting up the Discord client
|
||||
and event handlers.
|
||||
|
||||
Args:
|
||||
channel_ids (Optional[List[int]]): A list of allowed channel IDs.
|
||||
The bot will only respond to messages in these channels if
|
||||
provided.
|
||||
token (Optional[str]): The Discord bot token for authentication.
|
||||
If not provided, the token will be retrieved from the
|
||||
environment variable `DISCORD_TOKEN`.
|
||||
|
||||
Raises:
|
||||
ValueError: If the `DISCORD_TOKEN` is not found in environment
|
||||
variables.
|
||||
"""
|
||||
self.token = token or os.getenv('DISCORD_TOKEN')
|
||||
self.channel_ids = channel_ids
|
||||
|
||||
if not self.token:
|
||||
raise ValueError(
|
||||
"`DISCORD_TOKEN` not found in environment variables. Get it"
|
||||
" here: `https://discord.com/developers/applications`."
|
||||
)
|
||||
|
||||
import discord
|
||||
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
self._client = discord.Client(intents=intents)
|
||||
|
||||
# Register event handlers
|
||||
self._client.event(self.on_ready)
|
||||
self._client.event(self.on_message)
|
||||
|
||||
async def start(self):
|
||||
r"""Asynchronously start the Discord bot using its token.
|
||||
|
||||
This method starts the bot and logs into Discord asynchronously using
|
||||
the provided token. It should be awaited when used in an async
|
||||
environment.
|
||||
"""
|
||||
await self._client.start(self.token)
|
||||
|
||||
def run(self) -> None:
|
||||
r"""Start the Discord bot using its token.
|
||||
|
||||
This method starts the bot and logs into Discord synchronously using
|
||||
the provided token. It blocks execution and keeps the bot running.
|
||||
"""
|
||||
self._client.run(self.token) # type: ignore[arg-type]
|
||||
|
||||
async def on_ready(self) -> None:
|
||||
r"""Event handler that is called when the bot has successfully
|
||||
connected to the Discord server.
|
||||
|
||||
When the bot is ready and logged into Discord, it prints a message
|
||||
displaying the bot's username.
|
||||
"""
|
||||
logger.info(f'We have logged in as {self._client.user}')
|
||||
|
||||
async def on_message(self, message: 'Message') -> None:
|
||||
r"""Event handler for processing incoming messages.
|
||||
|
||||
This method is called whenever a new message is received by the bot. It
|
||||
will ignore messages sent by the bot itself, only respond to messages
|
||||
in allowed channels (if specified), and only to messages that mention
|
||||
the bot.
|
||||
|
||||
Args:
|
||||
message (discord.Message): The message object received from
|
||||
Discord.
|
||||
"""
|
||||
# If the message author is the bot itself,
|
||||
# do not respond to this message
|
||||
if message.author == self._client.user:
|
||||
return
|
||||
|
||||
# If allowed channel IDs are provided,
|
||||
# only respond to messages in those channels
|
||||
if self.channel_ids and message.channel.id not in self.channel_ids:
|
||||
return
|
||||
|
||||
# Only respond to messages that mention the bot
|
||||
if not self._client.user or not self._client.user.mentioned_in(
|
||||
message
|
||||
):
|
||||
return
|
||||
|
||||
logger.info(f"Received message: {message.content}")
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
return self._client
|
||||
30
deep-swarm/camel/bots/slack/__init__.py
Normal file
30
deep-swarm/camel/bots/slack/__init__.py
Normal file
@ -0,0 +1,30 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from .models import (
|
||||
SlackAppMentionEventBody,
|
||||
SlackAppMentionEventProfile,
|
||||
SlackAuthProfile,
|
||||
SlackEventBody,
|
||||
SlackEventProfile,
|
||||
)
|
||||
from .slack_app import SlackApp
|
||||
|
||||
__all__ = [
|
||||
'SlackApp',
|
||||
'SlackAppMentionEventBody',
|
||||
'SlackAppMentionEventProfile',
|
||||
'SlackAuthProfile',
|
||||
'SlackEventBody',
|
||||
'SlackEventProfile',
|
||||
]
|
||||
158
deep-swarm/camel/bots/slack/models.py
Normal file
158
deep-swarm/camel/bots/slack/models.py
Normal file
@ -0,0 +1,158 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SlackAuthProfile(BaseModel):
|
||||
r"""Represents the authorization profile within a Slack event.
|
||||
|
||||
Events will contain a single, compact authorizations field that shows one
|
||||
installation of your app that the event is visible to.
|
||||
In other words, lists of authorizations will be truncated to one element.
|
||||
|
||||
If there's more than one installing party that your app is keeping track
|
||||
of, it's best not to rely on the single party listed in authorizations to
|
||||
be any particular one.
|
||||
|
||||
To get a full list of who can see events, call the apps.event.
|
||||
authorizations.list method after obtaining an app-level token. Read more on
|
||||
the changes here; they have taken effect for existing apps as of
|
||||
February 24, 2021.
|
||||
|
||||
References:
|
||||
|
||||
- https://api.slack.com/apis/events-api#authorizations
|
||||
- https://api.slack.com/changelog/2020-09-15-events-api-truncate-authed-users#no_context
|
||||
"""
|
||||
|
||||
enterprise_id: Optional[str] = None
|
||||
"""The ID of the enterprise associated with the authorization."""
|
||||
|
||||
team_id: str
|
||||
"""The ID of the team associated with the authorization."""
|
||||
|
||||
user_id: str
|
||||
"""The ID of the user associated with the authorization."""
|
||||
|
||||
is_bot: bool
|
||||
"""Whether the authorized user is a bot."""
|
||||
|
||||
is_enterprise_install: bool
|
||||
"""Whether the authorization is for an enterprise installation."""
|
||||
|
||||
|
||||
class SlackEventProfile(BaseModel):
|
||||
r"""Represents the detailed profile of a Slack event, including user,
|
||||
message, and context data.
|
||||
"""
|
||||
|
||||
user: str
|
||||
"""The ID of the user associated with the event."""
|
||||
|
||||
type: str
|
||||
"""The type of the event (e.g., 'message')."""
|
||||
|
||||
ts: str
|
||||
"""A timestamp representing when the event was triggered."""
|
||||
|
||||
thread_ts: Optional[str] = None
|
||||
"""The timestamp of the parent message in a thread."""
|
||||
|
||||
client_msg_id: str
|
||||
"""A unique ID generated by the client for the message (if available)."""
|
||||
|
||||
text: str
|
||||
"""The message content text."""
|
||||
|
||||
team: str
|
||||
"""The ID of the team that the event is associated with."""
|
||||
|
||||
blocks: list
|
||||
"""The list of message blocks, providing structured information."""
|
||||
|
||||
channel: str
|
||||
"""The ID of the Slack channel where the event happened."""
|
||||
|
||||
event_ts: str
|
||||
"""The event-specific timestamp when it occurred."""
|
||||
|
||||
channel_type: Optional[str]
|
||||
"""The type of Slack channel (e.g., 'channel', 'im')."""
|
||||
|
||||
|
||||
class SlackEventBody(BaseModel):
|
||||
r"""Represents the entire body of a Slack event, including the event
|
||||
profile, authorization, and context.
|
||||
"""
|
||||
|
||||
token: str
|
||||
"""The token to verify the source of the event."""
|
||||
|
||||
team_id: str
|
||||
"""The ID of the team where the event is happening."""
|
||||
|
||||
context_team_id: Optional[str]
|
||||
"""The team ID for the shared channel context, if applicable."""
|
||||
|
||||
context_enterprise_id: Optional[str] = None
|
||||
"""The enterprise ID for the shared channel context, if applicable."""
|
||||
|
||||
api_app_id: str
|
||||
"""The unique identifier for the Slack app that received the event."""
|
||||
|
||||
event: SlackEventProfile
|
||||
"""A detailed profile of the event"""
|
||||
|
||||
type: str
|
||||
"""The overall type of event received (e.g., 'event_callback')."""
|
||||
|
||||
event_id: str
|
||||
"""A unique identifier assigned to this event by Slack."""
|
||||
|
||||
event_time: int
|
||||
"""The timestamp (in seconds) representing when the event was triggered."""
|
||||
|
||||
authorizations: Optional[list[SlackAuthProfile]] = None
|
||||
"""An optional list of authorizations that describe which installation can
|
||||
see the event."""
|
||||
|
||||
is_ext_shared_channel: bool
|
||||
"""Indicates if the event is part of a shared channel between different
|
||||
organizations."""
|
||||
|
||||
event_context: str
|
||||
"""A unique string representing the context of the event."""
|
||||
|
||||
|
||||
class SlackAppMentionEventProfile(SlackEventProfile):
|
||||
r"""Represents the detailed profile of a Slack event where the app was
|
||||
mentioned in a message.
|
||||
"""
|
||||
|
||||
channel_type: Optional[str] = None
|
||||
"""The type of Slack channel. it's None for app mentions."""
|
||||
|
||||
|
||||
class SlackAppMentionEventBody(SlackEventBody):
|
||||
r"""Represents the entire body of a Slack event where the app was mentioned
|
||||
in a message.
|
||||
"""
|
||||
|
||||
context_team_id: Optional[str] = None
|
||||
"""A detailed profile of the event. it's None for app mentions."""
|
||||
|
||||
event: SlackAppMentionEventProfile
|
||||
"""A detailed profile of the event"""
|
||||
255
deep-swarm/camel/bots/slack/slack_app.py
Normal file
255
deep-swarm/camel/bots/slack/slack_app.py
Normal file
@ -0,0 +1,255 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
from slack_sdk.oauth.installation_store.async_installation_store import (
|
||||
AsyncInstallationStore,
|
||||
)
|
||||
from starlette import requests, responses
|
||||
|
||||
from camel.bots.slack.models import (
|
||||
SlackAppMentionEventBody,
|
||||
SlackAppMentionEventProfile,
|
||||
SlackEventBody,
|
||||
SlackEventProfile,
|
||||
)
|
||||
from camel.utils import dependencies_required
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from slack_bolt.context.async_context import AsyncBoltContext
|
||||
from slack_bolt.context.say.async_say import AsyncSay
|
||||
from slack_sdk.web.async_client import AsyncWebClient
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SlackApp:
|
||||
r"""Represents a Slack app that is powered by a Slack Bolt `AsyncApp`.
|
||||
|
||||
This class is responsible for initializing and managing the Slack
|
||||
application by setting up event handlers, running the app server, and
|
||||
handling events such as messages and mentions from Slack.
|
||||
|
||||
Args:
|
||||
token (Optional[str]): Slack API token for authentication.
|
||||
scopes (Optional[str]): Slack app scopes for permissions.
|
||||
signing_secret (Optional[str]): Signing secret for verifying Slack
|
||||
requests.
|
||||
client_id (Optional[str]): Slack app client ID.
|
||||
client_secret (Optional[str]): Slack app client secret.
|
||||
redirect_uri_path (str): The URI path for OAuth redirect, defaults to
|
||||
"/slack/oauth_redirect".
|
||||
installation_store (Optional[AsyncInstallationStore]): The installation
|
||||
store for handling OAuth installations.
|
||||
"""
|
||||
|
||||
@dependencies_required('slack_bolt')
|
||||
def __init__(
|
||||
self,
|
||||
token: Optional[str] = None,
|
||||
scopes: Optional[str] = None,
|
||||
signing_secret: Optional[str] = None,
|
||||
client_id: Optional[str] = None,
|
||||
client_secret: Optional[str] = None,
|
||||
redirect_uri_path: str = "/slack/oauth_redirect",
|
||||
installation_store: Optional[AsyncInstallationStore] = None,
|
||||
) -> None:
|
||||
r"""Initializes the SlackApp instance by setting up the Slack Bolt app
|
||||
and configuring event handlers and OAuth settings.
|
||||
|
||||
Args:
|
||||
token (Optional[str]): The Slack API token.
|
||||
scopes (Optional[str]): The scopes for Slack app permissions.
|
||||
signing_secret (Optional[str]): The signing secret for verifying
|
||||
requests.
|
||||
client_id (Optional[str]): The Slack app client ID.
|
||||
client_secret (Optional[str]): The Slack app client secret.
|
||||
redirect_uri_path (str): The URI path for handling OAuth redirects
|
||||
(default is "/slack/oauth_redirect").
|
||||
installation_store (Optional[AsyncInstallationStore]): An optional
|
||||
installation store for OAuth installations.
|
||||
"""
|
||||
from slack_bolt.adapter.starlette.async_handler import (
|
||||
AsyncSlackRequestHandler,
|
||||
)
|
||||
from slack_bolt.app.async_app import AsyncApp
|
||||
from slack_bolt.oauth.async_oauth_settings import AsyncOAuthSettings
|
||||
|
||||
self.token: Optional[str] = token or os.getenv("SLACK_TOKEN")
|
||||
self.scopes: Optional[str] = scopes or os.getenv("SLACK_SCOPES")
|
||||
self.signing_secret: Optional[str] = signing_secret or os.getenv(
|
||||
"SLACK_SIGNING_SECRET"
|
||||
)
|
||||
self.client_id: Optional[str] = client_id or os.getenv(
|
||||
"SLACK_CLIENT_ID"
|
||||
)
|
||||
self.client_secret: Optional[str] = client_secret or os.getenv(
|
||||
"SLACK_CLIENT_SECRET"
|
||||
)
|
||||
|
||||
if not all([self.token, self.scopes, self.signing_secret]):
|
||||
raise ValueError(
|
||||
"`SLACK_TOKEN`, `SLACK_SCOPES`, and `SLACK_SIGNING_SECRET` "
|
||||
"environment variables must be set. Get it here: "
|
||||
"`https://api.slack.com/apps`."
|
||||
)
|
||||
|
||||
# Setup OAuth settings if client ID and secret are provided
|
||||
if self.client_id and self.client_secret:
|
||||
self._app = AsyncApp(
|
||||
oauth_settings=AsyncOAuthSettings(
|
||||
client_id=self.client_id,
|
||||
client_secret=self.client_secret,
|
||||
scopes=self.scopes,
|
||||
redirect_uri_path=redirect_uri_path,
|
||||
),
|
||||
logger=logger,
|
||||
signing_secret=self.signing_secret,
|
||||
installation_store=installation_store,
|
||||
token=self.token,
|
||||
)
|
||||
else:
|
||||
# Initialize Slack Bolt AsyncApp with settings
|
||||
self._app = AsyncApp(
|
||||
logger=logger,
|
||||
signing_secret=self.signing_secret,
|
||||
installation_store=installation_store,
|
||||
token=self.token,
|
||||
)
|
||||
|
||||
self._handler = AsyncSlackRequestHandler(self._app)
|
||||
self.setup_handlers()
|
||||
|
||||
def setup_handlers(self) -> None:
|
||||
r"""Sets up the event handlers for Slack events, such as `app_mention`
|
||||
and `message`.
|
||||
|
||||
This method registers the `app_mention` and `on_message` event handlers
|
||||
with the Slack Bolt app to respond to Slack events.
|
||||
"""
|
||||
self._app.event("app_mention")(self.app_mention)
|
||||
self._app.event("message")(self.on_message)
|
||||
|
||||
def run(
|
||||
self,
|
||||
port: int = 3000,
|
||||
path: str = "/slack/events",
|
||||
host: Optional[str] = None,
|
||||
) -> None:
|
||||
r"""Starts the Slack Bolt app server to listen for incoming Slack
|
||||
events.
|
||||
|
||||
Args:
|
||||
port (int): The port on which the server should run (default is
|
||||
3000).
|
||||
path (str): The endpoint path for receiving Slack events (default
|
||||
is "/slack/events").
|
||||
host (Optional[str]): The hostname to bind the server (default is
|
||||
None).
|
||||
"""
|
||||
self._app.start(port=port, path=path, host=host)
|
||||
|
||||
async def handle_request(
|
||||
self, request: requests.Request
|
||||
) -> responses.Response:
|
||||
r"""Handles incoming requests from Slack through the request handler.
|
||||
|
||||
Args:
|
||||
request (Request): A Starlette request object representing the
|
||||
incoming request.
|
||||
|
||||
Returns:
|
||||
The response generated by the Slack Bolt handler.
|
||||
"""
|
||||
return await self._handler.handle(request)
|
||||
|
||||
async def app_mention(
|
||||
self,
|
||||
context: "AsyncBoltContext",
|
||||
client: "AsyncWebClient",
|
||||
event: Dict[str, Any],
|
||||
body: Dict[str, Any],
|
||||
say: "AsyncSay",
|
||||
) -> None:
|
||||
r"""Event handler for `app_mention` events.
|
||||
|
||||
This method is triggered when someone mentions the app in Slack.
|
||||
|
||||
Args:
|
||||
context (AsyncBoltContext): The Slack Bolt context for the event.
|
||||
client (AsyncWebClient): The Slack Web API client.
|
||||
event (Dict[str, Any]): The event data for the app mention.
|
||||
body (Dict[str, Any]): The full request body from Slack.
|
||||
say (AsyncSay): A function to send a response back to the channel.
|
||||
"""
|
||||
event_profile = SlackAppMentionEventProfile(**event)
|
||||
event_body = SlackAppMentionEventBody(**body)
|
||||
|
||||
logger.info(f"app_mention, context: {context}")
|
||||
logger.info(f"app_mention, client: {client}")
|
||||
logger.info(f"app_mention, event_profile: {event_profile}")
|
||||
logger.info(f"app_mention, event_body: {event_body}")
|
||||
logger.info(f"app_mention, say: {say}")
|
||||
|
||||
async def on_message(
|
||||
self,
|
||||
context: "AsyncBoltContext",
|
||||
client: "AsyncWebClient",
|
||||
event: Dict[str, Any],
|
||||
body: Dict[str, Any],
|
||||
say: "AsyncSay",
|
||||
) -> None:
|
||||
r"""Event handler for `message` events.
|
||||
|
||||
This method is triggered when the app receives a message in Slack.
|
||||
|
||||
Args:
|
||||
context (AsyncBoltContext): The Slack Bolt context for the event.
|
||||
client (AsyncWebClient): The Slack Web API client.
|
||||
event (Dict[str, Any]): The event data for the message.
|
||||
body (Dict[str, Any]): The full request body from Slack.
|
||||
say (AsyncSay): A function to send a response back to the channel.
|
||||
"""
|
||||
await context.ack()
|
||||
|
||||
event_profile = SlackEventProfile(**event)
|
||||
event_body = SlackEventBody(**body)
|
||||
|
||||
logger.info(f"on_message, context: {context}")
|
||||
logger.info(f"on_message, client: {client}")
|
||||
logger.info(f"on_message, event_profile: {event_profile}")
|
||||
logger.info(f"on_message, event_body: {event_body}")
|
||||
logger.info(f"on_message, say: {say}")
|
||||
|
||||
logger.info(f"Received message: {event_profile.text}")
|
||||
|
||||
def mention_me(
|
||||
self, context: "AsyncBoltContext", body: SlackEventBody
|
||||
) -> bool:
|
||||
r"""Check if the bot is mentioned in the message.
|
||||
|
||||
Args:
|
||||
context (AsyncBoltContext): The Slack Bolt context for the event.
|
||||
body (SlackEventBody): The body of the Slack event.
|
||||
|
||||
Returns:
|
||||
bool: True if the bot is mentioned in the message, False otherwise.
|
||||
"""
|
||||
message = body.event.text
|
||||
bot_user_id = context.bot_user_id
|
||||
mention = f"<@{bot_user_id}>"
|
||||
return mention in message
|
||||
82
deep-swarm/camel/bots/telegram_bot.py
Normal file
82
deep-swarm/camel/bots/telegram_bot.py
Normal file
@ -0,0 +1,82 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from camel.agents import ChatAgent
|
||||
from camel.messages import BaseMessage
|
||||
from camel.utils import dependencies_required
|
||||
|
||||
# Conditionally import telebot types only for type checking
|
||||
if TYPE_CHECKING:
|
||||
from telebot.types import ( # type: ignore[import-untyped]
|
||||
Message,
|
||||
)
|
||||
|
||||
|
||||
class TelegramBot:
|
||||
r"""Represents a Telegram bot that is powered by an agent.
|
||||
|
||||
Attributes:
|
||||
chat_agent (ChatAgent): Chat agent that will power the bot.
|
||||
telegram_token (str, optional): The bot token.
|
||||
"""
|
||||
|
||||
@dependencies_required('telebot')
|
||||
def __init__(
|
||||
self,
|
||||
chat_agent: ChatAgent,
|
||||
telegram_token: Optional[str] = None,
|
||||
) -> None:
|
||||
self.chat_agent = chat_agent
|
||||
|
||||
if not telegram_token:
|
||||
self.token = os.getenv('TELEGRAM_TOKEN')
|
||||
if not self.token:
|
||||
raise ValueError(
|
||||
"`TELEGRAM_TOKEN` not found in environment variables. "
|
||||
"Get it from t.me/BotFather."
|
||||
)
|
||||
else:
|
||||
self.token = telegram_token
|
||||
|
||||
import telebot # type: ignore[import-untyped]
|
||||
|
||||
self.bot = telebot.TeleBot(token=self.token)
|
||||
|
||||
# Register the message handler within the constructor
|
||||
self.bot.message_handler(func=lambda message: True)(self.on_message)
|
||||
|
||||
def run(self) -> None:
|
||||
r"""Start the Telegram bot."""
|
||||
print("Telegram bot is running...")
|
||||
self.bot.infinity_polling()
|
||||
|
||||
def on_message(self, message: 'Message') -> None:
|
||||
r"""Handles incoming messages from the user.
|
||||
|
||||
Args:
|
||||
message (types.Message): The incoming message object.
|
||||
"""
|
||||
self.chat_agent.reset()
|
||||
|
||||
if not message.text:
|
||||
return
|
||||
|
||||
user_msg = BaseMessage.make_user_message(
|
||||
role_name="User", content=message.text
|
||||
)
|
||||
assistant_response = self.chat_agent.step(user_msg)
|
||||
|
||||
self.bot.reply_to(message, assistant_response.msg.content)
|
||||
76
deep-swarm/camel/configs/__init__.py
Normal file
76
deep-swarm/camel/configs/__init__.py
Normal file
@ -0,0 +1,76 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from .anthropic_config import ANTHROPIC_API_PARAMS, AnthropicConfig
|
||||
from .base_config import BaseConfig
|
||||
from .cohere_config import COHERE_API_PARAMS, CohereConfig
|
||||
from .deepseek_config import DEEPSEEK_API_PARAMS, DeepSeekConfig
|
||||
from .gemini_config import Gemini_API_PARAMS, GeminiConfig
|
||||
from .groq_config import GROQ_API_PARAMS, GroqConfig
|
||||
from .litellm_config import LITELLM_API_PARAMS, LiteLLMConfig
|
||||
from .mistral_config import MISTRAL_API_PARAMS, MistralConfig
|
||||
from .nvidia_config import NVIDIA_API_PARAMS, NvidiaConfig
|
||||
from .ollama_config import OLLAMA_API_PARAMS, OllamaConfig
|
||||
from .openai_config import OPENAI_API_PARAMS, ChatGPTConfig
|
||||
from .qwen_config import QWEN_API_PARAMS, QwenConfig
|
||||
from .reka_config import REKA_API_PARAMS, RekaConfig
|
||||
from .samba_config import (
|
||||
SAMBA_CLOUD_API_PARAMS,
|
||||
SAMBA_VERSE_API_PARAMS,
|
||||
SambaCloudAPIConfig,
|
||||
SambaVerseAPIConfig,
|
||||
)
|
||||
from .togetherai_config import TOGETHERAI_API_PARAMS, TogetherAIConfig
|
||||
from .vllm_config import VLLM_API_PARAMS, VLLMConfig
|
||||
from .yi_config import YI_API_PARAMS, YiConfig
|
||||
from .zhipuai_config import ZHIPUAI_API_PARAMS, ZhipuAIConfig
|
||||
|
||||
__all__ = [
|
||||
'BaseConfig',
|
||||
'ChatGPTConfig',
|
||||
'OPENAI_API_PARAMS',
|
||||
'AnthropicConfig',
|
||||
'ANTHROPIC_API_PARAMS',
|
||||
'GROQ_API_PARAMS',
|
||||
'GroqConfig',
|
||||
'LiteLLMConfig',
|
||||
'LITELLM_API_PARAMS',
|
||||
'NvidiaConfig',
|
||||
'NVIDIA_API_PARAMS',
|
||||
'OllamaConfig',
|
||||
'OLLAMA_API_PARAMS',
|
||||
'ZhipuAIConfig',
|
||||
'ZHIPUAI_API_PARAMS',
|
||||
'GeminiConfig',
|
||||
'Gemini_API_PARAMS',
|
||||
'VLLMConfig',
|
||||
'VLLM_API_PARAMS',
|
||||
'MistralConfig',
|
||||
'MISTRAL_API_PARAMS',
|
||||
'RekaConfig',
|
||||
'REKA_API_PARAMS',
|
||||
'SambaVerseAPIConfig',
|
||||
'SAMBA_VERSE_API_PARAMS',
|
||||
'SambaCloudAPIConfig',
|
||||
'SAMBA_CLOUD_API_PARAMS',
|
||||
'TogetherAIConfig',
|
||||
'TOGETHERAI_API_PARAMS',
|
||||
'CohereConfig',
|
||||
'COHERE_API_PARAMS',
|
||||
'YiConfig',
|
||||
'YI_API_PARAMS',
|
||||
'QwenConfig',
|
||||
'QWEN_API_PARAMS',
|
||||
'DeepSeekConfig',
|
||||
'DEEPSEEK_API_PARAMS',
|
||||
]
|
||||
BIN
deep-swarm/camel/configs/__pycache__/__init__.cpython-311.pyc
Normal file
BIN
deep-swarm/camel/configs/__pycache__/__init__.cpython-311.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
deep-swarm/camel/configs/__pycache__/base_config.cpython-311.pyc
Normal file
BIN
deep-swarm/camel/configs/__pycache__/base_config.cpython-311.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
deep-swarm/camel/configs/__pycache__/groq_config.cpython-311.pyc
Normal file
BIN
deep-swarm/camel/configs/__pycache__/groq_config.cpython-311.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
deep-swarm/camel/configs/__pycache__/qwen_config.cpython-311.pyc
Normal file
BIN
deep-swarm/camel/configs/__pycache__/qwen_config.cpython-311.pyc
Normal file
Binary file not shown.
BIN
deep-swarm/camel/configs/__pycache__/reka_config.cpython-311.pyc
Normal file
BIN
deep-swarm/camel/configs/__pycache__/reka_config.cpython-311.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
deep-swarm/camel/configs/__pycache__/vllm_config.cpython-311.pyc
Normal file
BIN
deep-swarm/camel/configs/__pycache__/vllm_config.cpython-311.pyc
Normal file
Binary file not shown.
BIN
deep-swarm/camel/configs/__pycache__/yi_config.cpython-311.pyc
Normal file
BIN
deep-swarm/camel/configs/__pycache__/yi_config.cpython-311.pyc
Normal file
Binary file not shown.
Binary file not shown.
69
deep-swarm/camel/configs/anthropic_config.py
Normal file
69
deep-swarm/camel/configs/anthropic_config.py
Normal file
@ -0,0 +1,69 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Union
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
from camel.types import NOT_GIVEN, NotGiven
|
||||
|
||||
|
||||
class AnthropicConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
Anthropic API.
|
||||
|
||||
See: https://docs.anthropic.com/claude/reference/complete_post
|
||||
Args:
|
||||
max_tokens (int, optional): The maximum number of tokens to
|
||||
generate before stopping. Note that Anthropic models may stop
|
||||
before reaching this maximum. This parameter only specifies the
|
||||
absolute maximum number of tokens to generate.
|
||||
(default: :obj:`256`)
|
||||
stop_sequences (List[str], optional): Sequences that will cause the
|
||||
model to stop generating completion text. Anthropic models stop
|
||||
on "\n\nHuman:", and may include additional built-in stop sequences
|
||||
in the future. By providing the stop_sequences parameter, you may
|
||||
include additional strings that will cause the model to stop
|
||||
generating.
|
||||
temperature (float, optional): Amount of randomness injected into the
|
||||
response. Defaults to 1. Ranges from 0 to 1. Use temp closer to 0
|
||||
for analytical / multiple choice, and closer to 1 for creative
|
||||
and generative tasks.
|
||||
(default: :obj:`1`)
|
||||
top_p (float, optional): Use nucleus sampling. In nucleus sampling, we
|
||||
compute the cumulative distribution over all the options for each
|
||||
subsequent token in decreasing probability order and cut it off
|
||||
once it reaches a particular probability specified by `top_p`.
|
||||
You should either alter `temperature` or `top_p`,
|
||||
but not both.
|
||||
(default: :obj:`0.7`)
|
||||
top_k (int, optional): Only sample from the top K options for each
|
||||
subsequent token. Used to remove "long tail" low probability
|
||||
responses.
|
||||
(default: :obj:`5`)
|
||||
metadata: An object describing metadata about the request.
|
||||
stream (bool, optional): Whether to incrementally stream the response
|
||||
using server-sent events. (default: :obj:`False`)
|
||||
"""
|
||||
|
||||
max_tokens: int = 256
|
||||
stop_sequences: Union[List[str], NotGiven] = NOT_GIVEN
|
||||
temperature: float = 1
|
||||
top_p: Union[float, NotGiven] = NOT_GIVEN
|
||||
top_k: Union[int, NotGiven] = NOT_GIVEN
|
||||
metadata: NotGiven = NOT_GIVEN
|
||||
stream: bool = False
|
||||
|
||||
|
||||
ANTHROPIC_API_PARAMS = {param for param in AnthropicConfig.model_fields.keys()}
|
||||
89
deep-swarm/camel/configs/base_config.py
Normal file
89
deep-swarm/camel/configs/base_config.py
Normal file
@ -0,0 +1,89 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
|
||||
|
||||
class BaseConfig(ABC, BaseModel):
|
||||
r"""Base configuration class for all models.
|
||||
|
||||
This class provides a common interface for all models, ensuring that all
|
||||
models have a consistent set of attributes and methods.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
extra="forbid",
|
||||
frozen=True,
|
||||
# UserWarning: conflict with protected namespace "model_"
|
||||
protected_namespaces=(),
|
||||
)
|
||||
|
||||
tools: Optional[List[Any]] = None
|
||||
"""A list of tools the model may
|
||||
call. Currently, only functions are supported as a tool. Use this
|
||||
to provide a list of functions the model may generate JSON inputs
|
||||
for. A max of 128 functions are supported.
|
||||
"""
|
||||
|
||||
@field_validator("tools", mode="before")
|
||||
@classmethod
|
||||
def fields_type_checking(cls, tools):
|
||||
r"""Validate the type of tools in the configuration.
|
||||
|
||||
This method ensures that the tools provided in the configuration are
|
||||
instances of `FunctionTool`. If any tool is not an instance of
|
||||
`FunctionTool`, it raises a ValueError.
|
||||
"""
|
||||
if tools is not None:
|
||||
from camel.toolkits import FunctionTool
|
||||
|
||||
for tool in tools:
|
||||
if not isinstance(tool, FunctionTool):
|
||||
raise ValueError(
|
||||
f"The tool {tool} should "
|
||||
"be an instance of `FunctionTool`."
|
||||
)
|
||||
return tools
|
||||
|
||||
def as_dict(self) -> dict[str, Any]:
|
||||
r"""Convert the current configuration to a dictionary.
|
||||
|
||||
This method converts the current configuration object to a dictionary
|
||||
representation, which can be used for serialization or other purposes.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: A dictionary representation of the current
|
||||
configuration.
|
||||
"""
|
||||
config_dict = self.model_dump()
|
||||
|
||||
tools_schema = None
|
||||
if self.tools:
|
||||
from camel.toolkits import FunctionTool
|
||||
|
||||
tools_schema = []
|
||||
for tool in self.tools:
|
||||
if not isinstance(tool, FunctionTool):
|
||||
raise ValueError(
|
||||
f"The tool {tool} should "
|
||||
"be an instance of `FunctionTool`."
|
||||
)
|
||||
tools_schema.append(tool.get_openai_tool_schema())
|
||||
config_dict["tools"] = tools_schema
|
||||
return config_dict
|
||||
76
deep-swarm/camel/configs/cohere_config.py
Normal file
76
deep-swarm/camel/configs/cohere_config.py
Normal file
@ -0,0 +1,76 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
|
||||
|
||||
class CohereConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
Cohere API.
|
||||
|
||||
Args:
|
||||
temperature (float, optional): Sampling temperature to use, between
|
||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
||||
while lower values make it more focused and deterministic.
|
||||
(default: :obj:`0.3`)
|
||||
documents (list, optional): A list of relevant documents that the
|
||||
model can cite to generate a more accurate reply. Each document is
|
||||
either a string or document object with content and metadata.
|
||||
(default: :obj:`None`)
|
||||
max_tokens (int, optional): The maximum number of tokens the model
|
||||
will generate as part of the response. (default: :obj:`None`)
|
||||
stop_sequences (List(str), optional): A list of up to 5 strings that
|
||||
the model will use to stop generation. If the model generates a
|
||||
string that matches any of the strings in the list, it will stop
|
||||
generating tokens and return the generated text up to that point
|
||||
not including the stop sequence. (default: :obj:`None`)
|
||||
seed (int, optional): If specified, the backend will make a best
|
||||
effort to sample tokens deterministically, such that repeated
|
||||
requests with the same seed and parameters should return the same
|
||||
result. However, determinism cannot be totally guaranteed.
|
||||
(default: :obj:`None`)
|
||||
frequency_penalty (float, optional): Min value of `0.0`, max value of
|
||||
`1.0`. Used to reduce repetitiveness of generated tokens. The
|
||||
higher the value, the stronger a penalty is applied to previously
|
||||
present tokens, proportional to how many times they have already
|
||||
appeared in the prompt or prior generation. (default: :obj:`0.0`)
|
||||
presence_penalty (float, optional): Min value of `0.0`, max value of
|
||||
`1.0`. Used to reduce repetitiveness of generated tokens. Similar
|
||||
to `frequency_penalty`, except that this penalty is applied
|
||||
equally to all tokens that have already appeared, regardless of
|
||||
their exact frequencies. (default: :obj:`0.0`)
|
||||
k (int, optional): Ensures only the top k most likely tokens are
|
||||
considered for generation at each step. Min value of `0`, max
|
||||
value of `500`. (default: :obj:`0`)
|
||||
p (float, optional): Ensures that only the most likely tokens, with
|
||||
total probability mass of `p`, are considered for generation at
|
||||
each step. If both k and p are enabled, `p` acts after `k`. Min
|
||||
value of `0.01`, max value of `0.99`. (default: :obj:`0.75`)
|
||||
"""
|
||||
|
||||
temperature: Optional[float] = 0.2
|
||||
documents: Optional[list] = None
|
||||
max_tokens: Optional[int] = None
|
||||
stop_sequences: Optional[List[str]] = None
|
||||
seed: Optional[int] = None
|
||||
frequency_penalty: Optional[float] = 0.0
|
||||
presence_penalty: Optional[float] = 0.0
|
||||
k: Optional[int] = 0
|
||||
p: Optional[float] = 0.75
|
||||
|
||||
|
||||
COHERE_API_PARAMS = {param for param in CohereConfig().model_fields.keys()}
|
||||
134
deep-swarm/camel/configs/deepseek_config.py
Normal file
134
deep-swarm/camel/configs/deepseek_config.py
Normal file
@ -0,0 +1,134 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional, Sequence, Type, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
from camel.types import NOT_GIVEN, NotGiven
|
||||
|
||||
|
||||
class DeepSeekConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
DeepSeek API.
|
||||
|
||||
Args:
|
||||
temperature (float, optional): Sampling temperature to use, between
|
||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
||||
while lower values make it more focused and deterministic.
|
||||
(default: :obj:`0.2`)
|
||||
top_p (float, optional): Controls the diversity and focus of the
|
||||
generated results. Higher values make the output more diverse,
|
||||
while lower values make it more focused. (default: :obj:`1.0`)
|
||||
response_format (object, optional): Specifies the format of the
|
||||
returned content. The available values are `{"type": "text"}` or
|
||||
`{"type": "json_object"}`. Setting it to `{"type": "json_object"}`
|
||||
will output a standard JSON string.
|
||||
(default: :obj:`{"type": "text"}`)
|
||||
stream (bool, optional): If set, partial message deltas will be sent.
|
||||
Tokens will be sent as data-only server-sent events (SSE) as
|
||||
they become available, with the stream terminated by a
|
||||
data: [DONE] message. (default: :obj:`False`)
|
||||
stop (Union[str, list[str]], optional): Up to 16 sequences where
|
||||
the API will stop generating further tokens. (default: :obj:`None`)
|
||||
max_tokens (int, optional): The maximum number of tokens that can
|
||||
be generated in the chat completion. The total length of input
|
||||
tokens and generated tokens is limited by the model's context
|
||||
length. (default: :obj:`None`)
|
||||
presence_penalty (float, optional): Number between -2.0 and 2.0.
|
||||
Positive values penalize new tokens based on whether they
|
||||
appear in the text so far, increasing the model's likelihood
|
||||
to talk about new topics. (default: :obj:`0.0`)
|
||||
frequency_penalty (float, optional): Number between -2.0 and 2.0.
|
||||
Positive values penalize new tokens based on their existing
|
||||
frequency in the text so far, decreasing the model's likelihood
|
||||
to repeat the same line verbatim. (default: :obj:`0`)
|
||||
tools (list[FunctionTool], optional): A list of tools the model may
|
||||
call. Currently, only functions are supported as a tool. Use
|
||||
this to provide a list of functions the model may generate JSON
|
||||
inputs for. A max of 128 functions are supported.
|
||||
(default: :obj:`None`)
|
||||
tool_choice (Union[dict[str, str], str], optional): Controls which
|
||||
(if any) tool is called by the model. "none" means the model
|
||||
will not call any tool and instead generates a message. "auto"
|
||||
means the model can pick between generating a message or calling
|
||||
one or more tools. "required" means the model must call one or
|
||||
more tools. Specifying a particular tool via
|
||||
{"type": "function", "function": {"name": "my_function"}} forces
|
||||
the model to call that tool. "none" is the default when no tools
|
||||
are present. "auto" is the default if tools are present.
|
||||
(default: :obj:`"auto"`)
|
||||
logprobs (bool, optional): Whether to return log probabilities of
|
||||
the output tokens or not. If true, returns the log probabilities
|
||||
of each output token returned in the content of message.
|
||||
(default: :obj:`False`)
|
||||
top_logprobs (int, optional): An integer between 0 and 20 specifying
|
||||
the number of most likely tokens to return at each token
|
||||
position, each with an associated log probability. logprobs
|
||||
must be set to true if this parameter is used.
|
||||
(default: :obj:`None`)
|
||||
include_usage (bool, optional): When streaming, specifies whether to
|
||||
include usage information in `stream_options`. (default:
|
||||
:obj:`True`)
|
||||
"""
|
||||
|
||||
temperature: float = 0.2 # deepseek default: 1.0
|
||||
top_p: float = 1.0
|
||||
stream: bool = False
|
||||
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
|
||||
max_tokens: Union[int, NotGiven] = NOT_GIVEN
|
||||
presence_penalty: float = 0.0
|
||||
response_format: Union[Type[BaseModel], dict, NotGiven] = NOT_GIVEN
|
||||
frequency_penalty: float = 0.0
|
||||
tool_choice: Optional[Union[dict[str, str], str]] = None
|
||||
logprobs: bool = False
|
||||
top_logprobs: Optional[int] = None
|
||||
|
||||
def __init__(self, include_usage: bool = True, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
# Only set stream_options when stream is True
|
||||
# Otherwise, it will raise error when calling the API
|
||||
if self.stream:
|
||||
self.stream_options = {"include_usage": include_usage}
|
||||
|
||||
def as_dict(self) -> dict[str, Any]:
|
||||
r"""Convert the current configuration to a dictionary.
|
||||
|
||||
This method converts the current configuration object to a dictionary
|
||||
representation, which can be used for serialization or other purposes.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: A dictionary representation of the current
|
||||
configuration.
|
||||
"""
|
||||
config_dict = self.model_dump()
|
||||
if self.tools:
|
||||
from camel.toolkits import FunctionTool
|
||||
|
||||
tools_schema = []
|
||||
for tool in self.tools:
|
||||
if not isinstance(tool, FunctionTool):
|
||||
raise ValueError(
|
||||
f"The tool {tool} should "
|
||||
"be an instance of `FunctionTool`."
|
||||
)
|
||||
tools_schema.append(tool.get_openai_tool_schema())
|
||||
config_dict["tools"] = NOT_GIVEN
|
||||
return config_dict
|
||||
|
||||
|
||||
DEEPSEEK_API_PARAMS = {param for param in DeepSeekConfig.model_fields.keys()}
|
||||
114
deep-swarm/camel/configs/gemini_config.py
Normal file
114
deep-swarm/camel/configs/gemini_config.py
Normal file
@ -0,0 +1,114 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional, Sequence, Type, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
from camel.types import NOT_GIVEN, NotGiven
|
||||
|
||||
|
||||
class GeminiConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
Gemini API.
|
||||
|
||||
Args:
|
||||
temperature (float, optional): Sampling temperature to use, between
|
||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
||||
while lower values make it more focused and deterministic.
|
||||
(default: :obj:`0.2`)
|
||||
top_p (float, optional): An alternative to sampling with temperature,
|
||||
called nucleus sampling, where the model considers the results of
|
||||
the tokens with top_p probability mass. So :obj:`0.1` means only
|
||||
the tokens comprising the top 10% probability mass are considered.
|
||||
(default: :obj:`1.0`)
|
||||
n (int, optional): How many chat completion choices to generate for
|
||||
each input message. (default: :obj:`1`)
|
||||
response_format (object, optional): An object specifying the format
|
||||
that the model must output. Compatible with GPT-4 Turbo and all
|
||||
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
|
||||
{"type": "json_object"} enables JSON mode, which guarantees the
|
||||
message the model generates is valid JSON. Important: when using
|
||||
JSON mode, you must also instruct the model to produce JSON
|
||||
yourself via a system or user message. Without this, the model
|
||||
may generate an unending stream of whitespace until the generation
|
||||
reaches the token limit, resulting in a long-running and seemingly
|
||||
"stuck" request. Also note that the message content may be
|
||||
partially cut off if finish_reason="length", which indicates the
|
||||
generation exceeded max_tokens or the conversation exceeded the
|
||||
max context length.
|
||||
stream (bool, optional): If True, partial message deltas will be sent
|
||||
as data-only server-sent events as they become available.
|
||||
(default: :obj:`False`)
|
||||
stop (str or list, optional): Up to :obj:`4` sequences where the API
|
||||
will stop generating further tokens. (default: :obj:`None`)
|
||||
max_tokens (int, optional): The maximum number of tokens to generate
|
||||
in the chat completion. The total length of input tokens and
|
||||
generated tokens is limited by the model's context length.
|
||||
(default: :obj:`None`)
|
||||
tools (list[FunctionTool], optional): A list of tools the model may
|
||||
call. Currently, only functions are supported as a tool. Use this
|
||||
to provide a list of functions the model may generate JSON inputs
|
||||
for. A max of 128 functions are supported.
|
||||
tool_choice (Union[dict[str, str], str], optional): Controls which (if
|
||||
any) tool is called by the model. :obj:`"none"` means the model
|
||||
will not call any tool and instead generates a message.
|
||||
:obj:`"auto"` means the model can pick between generating a
|
||||
message or calling one or more tools. :obj:`"required"` means the
|
||||
model must call one or more tools. Specifying a particular tool
|
||||
via {"type": "function", "function": {"name": "my_function"}}
|
||||
forces the model to call that tool. :obj:`"none"` is the default
|
||||
when no tools are present. :obj:`"auto"` is the default if tools
|
||||
are present.
|
||||
"""
|
||||
|
||||
temperature: float = 0.2 # openai default: 1.0
|
||||
top_p: float = 1.0
|
||||
n: int = 1
|
||||
stream: bool = False
|
||||
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
|
||||
max_tokens: Union[int, NotGiven] = NOT_GIVEN
|
||||
response_format: Union[Type[BaseModel], dict, NotGiven] = NOT_GIVEN
|
||||
tool_choice: Optional[Union[dict[str, str], str]] = None
|
||||
|
||||
def as_dict(self) -> dict[str, Any]:
|
||||
r"""Convert the current configuration to a dictionary.
|
||||
|
||||
This method converts the current configuration object to a dictionary
|
||||
representation, which can be used for serialization or other purposes.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: A dictionary representation of the current
|
||||
configuration.
|
||||
"""
|
||||
config_dict = self.model_dump()
|
||||
if self.tools:
|
||||
from camel.toolkits import FunctionTool
|
||||
|
||||
tools_schema = []
|
||||
for tool in self.tools:
|
||||
if not isinstance(tool, FunctionTool):
|
||||
raise ValueError(
|
||||
f"The tool {tool} should "
|
||||
"be an instance of `FunctionTool`."
|
||||
)
|
||||
tools_schema.append(tool.get_openai_tool_schema())
|
||||
config_dict["tools"] = NOT_GIVEN
|
||||
return config_dict
|
||||
|
||||
|
||||
Gemini_API_PARAMS = {param for param in GeminiConfig.model_fields.keys()}
|
||||
104
deep-swarm/camel/configs/groq_config.py
Normal file
104
deep-swarm/camel/configs/groq_config.py
Normal file
@ -0,0 +1,104 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
from camel.types import NOT_GIVEN, NotGiven
|
||||
|
||||
|
||||
class GroqConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using OpenAI
|
||||
compatibility.
|
||||
|
||||
Reference: https://console.groq.com/docs/openai
|
||||
|
||||
Args:
|
||||
temperature (float, optional): Sampling temperature to use, between
|
||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
||||
while lower values make it more focused and deterministic.
|
||||
(default: :obj:`0.2`)
|
||||
top_p (float, optional): An alternative to sampling with temperature,
|
||||
called nucleus sampling, where the model considers the results of
|
||||
the tokens with top_p probability mass. So :obj:`0.1` means only
|
||||
the tokens comprising the top 10% probability mass are considered.
|
||||
(default: :obj:`1.0`)
|
||||
n (int, optional): How many chat completion choices to generate for
|
||||
each input message. (default: :obj:`1`)
|
||||
response_format (object, optional): An object specifying the format
|
||||
that the model must output. Compatible with GPT-4 Turbo and all
|
||||
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
|
||||
{"type": "json_object"} enables JSON mode, which guarantees the
|
||||
message the model generates is valid JSON. Important: when using
|
||||
JSON mode, you must also instruct the model to produce JSON
|
||||
yourself via a system or user message. Without this, the model
|
||||
may generate an unending stream of whitespace until the generation
|
||||
reaches the token limit, resulting in a long-running and seemingly
|
||||
"stuck" request. Also note that the message content may be
|
||||
partially cut off if finish_reason="length", which indicates the
|
||||
generation exceeded max_tokens or the conversation exceeded the
|
||||
max context length.
|
||||
stream (bool, optional): If True, partial message deltas will be sent
|
||||
as data-only server-sent events as they become available.
|
||||
(default: :obj:`False`)
|
||||
stop (str or list, optional): Up to :obj:`4` sequences where the API
|
||||
will stop generating further tokens. (default: :obj:`None`)
|
||||
max_tokens (int, optional): The maximum number of tokens to generate
|
||||
in the chat completion. The total length of input tokens and
|
||||
generated tokens is limited by the model's context length.
|
||||
(default: :obj:`None`)
|
||||
presence_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on whether
|
||||
they appear in the text so far, increasing the model's likelihood
|
||||
to talk about new topics. See more information about frequency and
|
||||
presence penalties. (default: :obj:`0.0`)
|
||||
frequency_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on their
|
||||
existing frequency in the text so far, decreasing the model's
|
||||
likelihood to repeat the same line verbatim. See more information
|
||||
about frequency and presence penalties. (default: :obj:`0.0`)
|
||||
user (str, optional): A unique identifier representing your end-user,
|
||||
which can help OpenAI to monitor and detect abuse.
|
||||
(default: :obj:`""`)
|
||||
tools (list[FunctionTool], optional): A list of tools the model may
|
||||
call. Currently, only functions are supported as a tool. Use this
|
||||
to provide a list of functions the model may generate JSON inputs
|
||||
for. A max of 128 functions are supported.
|
||||
tool_choice (Union[dict[str, str], str], optional): Controls which (if
|
||||
any) tool is called by the model. :obj:`"none"` means the model
|
||||
will not call any tool and instead generates a message.
|
||||
:obj:`"auto"` means the model can pick between generating a
|
||||
message or calling one or more tools. :obj:`"required"` means the
|
||||
model must call one or more tools. Specifying a particular tool
|
||||
via {"type": "function", "function": {"name": "my_function"}}
|
||||
forces the model to call that tool. :obj:`"none"` is the default
|
||||
when no tools are present. :obj:`"auto"` is the default if tools
|
||||
are present.
|
||||
"""
|
||||
|
||||
temperature: float = 0.2 # openai default: 1.0
|
||||
top_p: float = 1.0
|
||||
n: int = 1
|
||||
stream: bool = False
|
||||
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
|
||||
max_tokens: Union[int, NotGiven] = NOT_GIVEN
|
||||
presence_penalty: float = 0.0
|
||||
response_format: Union[dict, NotGiven] = NOT_GIVEN
|
||||
frequency_penalty: float = 0.0
|
||||
user: str = ""
|
||||
tool_choice: Optional[Union[dict[str, str], str]] = "auto"
|
||||
|
||||
|
||||
GROQ_API_PARAMS = {param for param in GroqConfig.model_fields.keys()}
|
||||
97
deep-swarm/camel/configs/litellm_config.py
Normal file
97
deep-swarm/camel/configs/litellm_config.py
Normal file
@ -0,0 +1,97 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
|
||||
|
||||
class LiteLLMConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
LiteLLM API.
|
||||
|
||||
Args:
|
||||
timeout (Optional[Union[float, str]], optional): Request timeout.
|
||||
(default: None)
|
||||
temperature (Optional[float], optional): Temperature parameter for
|
||||
controlling randomness. (default: None)
|
||||
top_p (Optional[float], optional): Top-p parameter for nucleus
|
||||
sampling. (default: None)
|
||||
n (Optional[int], optional): Number of completions to generate.
|
||||
(default: None)
|
||||
stream (Optional[bool], optional): Whether to return a streaming
|
||||
response. (default: None)
|
||||
stream_options (Optional[dict], optional): Options for the streaming
|
||||
response. (default: None)
|
||||
stop (Optional[Union[str, List[str]]], optional): Sequences where the
|
||||
API will stop generating further tokens. (default: None)
|
||||
max_tokens (Optional[int], optional): Maximum number of tokens to
|
||||
generate. (default: None)
|
||||
presence_penalty (Optional[float], optional): Penalize new tokens
|
||||
based on their existence in the text so far. (default: None)
|
||||
frequency_penalty (Optional[float], optional): Penalize new tokens
|
||||
based on their frequency in the text so far. (default: None)
|
||||
logit_bias (Optional[dict], optional): Modify the probability of
|
||||
specific tokens appearing in the completion. (default: None)
|
||||
user (Optional[str], optional): A unique identifier representing the
|
||||
end-user. (default: None)
|
||||
response_format (Optional[dict], optional): Response format
|
||||
parameters. (default: None)
|
||||
seed (Optional[int], optional): Random seed. (default: None)
|
||||
tools (Optional[List], optional): List of tools. (default: None)
|
||||
tool_choice (Optional[Union[str, dict]], optional): Tool choice
|
||||
parameters. (default: None)
|
||||
logprobs (Optional[bool], optional): Whether to return log
|
||||
probabilities of the output tokens. (default: None)
|
||||
top_logprobs (Optional[int], optional): Number of most likely tokens
|
||||
to return at each token position. (default: None)
|
||||
deployment_id (Optional[str], optional): Deployment ID. (default: None)
|
||||
extra_headers (Optional[dict], optional): Additional headers for the
|
||||
request. (default: None)
|
||||
api_version (Optional[str], optional): API version. (default: None)
|
||||
mock_response (Optional[str], optional): Mock completion response for
|
||||
testing or debugging. (default: None)
|
||||
custom_llm_provider (Optional[str], optional): Non-OpenAI LLM
|
||||
provider. (default: None)
|
||||
max_retries (Optional[int], optional): Maximum number of retries.
|
||||
(default: None)
|
||||
"""
|
||||
|
||||
timeout: Optional[Union[float, str]] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
n: Optional[int] = None
|
||||
stream: Optional[bool] = None
|
||||
stream_options: Optional[dict] = None
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
max_tokens: Optional[int] = None
|
||||
presence_penalty: Optional[float] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
logit_bias: Optional[dict] = None
|
||||
user: Optional[str] = None
|
||||
response_format: Optional[dict] = None
|
||||
seed: Optional[int] = None
|
||||
tool_choice: Optional[Union[str, dict]] = None
|
||||
logprobs: Optional[bool] = None
|
||||
top_logprobs: Optional[int] = None
|
||||
deployment_id: Optional[str] = None
|
||||
extra_headers: Optional[dict] = None
|
||||
api_version: Optional[str] = None
|
||||
mock_response: Optional[str] = None
|
||||
custom_llm_provider: Optional[str] = None
|
||||
max_retries: Optional[int] = None
|
||||
|
||||
|
||||
LITELLM_API_PARAMS = {param for param in LiteLLMConfig.model_fields.keys()}
|
||||
79
deep-swarm/camel/configs/mistral_config.py
Normal file
79
deep-swarm/camel/configs/mistral_config.py
Normal file
@ -0,0 +1,79 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from pydantic import field_validator
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
|
||||
|
||||
class MistralConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
Mistral API.
|
||||
|
||||
reference: https://github.com/mistralai/client-python/blob/9d238f88c41689821d7b08570f13b43426f97fd6/src/mistralai/client.py#L195
|
||||
|
||||
#TODO: Support stream mode
|
||||
|
||||
Args:
|
||||
temperature (Optional[float], optional): temperature the temperature
|
||||
to use for sampling, e.g. 0.5.
|
||||
top_p (Optional[float], optional): the cumulative probability of
|
||||
tokens to generate, e.g. 0.9. Defaults to None.
|
||||
max_tokens (Optional[int], optional): the maximum number of tokens to
|
||||
generate, e.g. 100. Defaults to None.
|
||||
stop (Optional[Union[str,list[str]]]): Stop generation if this token
|
||||
is detected. Or if one of these tokens is detected when providing
|
||||
a string list.
|
||||
random_seed (Optional[int], optional): the random seed to use for
|
||||
sampling, e.g. 42. Defaults to None.
|
||||
safe_prompt (bool, optional): whether to use safe prompt, e.g. true.
|
||||
Defaults to False.
|
||||
response_format (Union[Dict[str, str], ResponseFormat): format of the
|
||||
response.
|
||||
tool_choice (str, optional): Controls which (if
|
||||
any) tool is called by the model. :obj:`"none"` means the model
|
||||
will not call any tool and instead generates a message.
|
||||
:obj:`"auto"` means the model can pick between generating a
|
||||
message or calling one or more tools. :obj:`"any"` means the
|
||||
model must call one or more tools. :obj:`"auto"` is the default
|
||||
value.
|
||||
"""
|
||||
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
max_tokens: Optional[int] = None
|
||||
stop: Optional[Union[str, list[str]]] = None
|
||||
random_seed: Optional[int] = None
|
||||
safe_prompt: bool = False
|
||||
response_format: Optional[Union[Dict[str, str], Any]] = None
|
||||
tool_choice: Optional[str] = "auto"
|
||||
|
||||
@field_validator("response_format", mode="before")
|
||||
@classmethod
|
||||
def fields_type_checking(cls, response_format):
|
||||
if response_format and not isinstance(response_format, dict):
|
||||
from mistralai.models import ResponseFormat
|
||||
|
||||
if not isinstance(response_format, ResponseFormat):
|
||||
raise ValueError(
|
||||
f"The tool {response_format} should be an instance "
|
||||
"of `mistralai.models.ResponseFormat`."
|
||||
)
|
||||
return response_format
|
||||
|
||||
|
||||
MISTRAL_API_PARAMS = {param for param in MistralConfig().model_fields.keys()}
|
||||
70
deep-swarm/camel/configs/nvidia_config.py
Normal file
70
deep-swarm/camel/configs/nvidia_config.py
Normal file
@ -0,0 +1,70 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
from camel.types import NOT_GIVEN, NotGiven
|
||||
|
||||
|
||||
class NvidiaConfig(BaseConfig):
|
||||
r"""Configuration class for NVIDIA API models.
|
||||
|
||||
This class defines the configuration parameters for NVIDIA's language
|
||||
models, including temperature, sampling parameters, and response format
|
||||
settings.
|
||||
|
||||
Args:
|
||||
stream (bool, optional): Whether to stream the response.
|
||||
(default: :obj:`False`)
|
||||
temperature (float, optional): Controls randomness in the response.
|
||||
Higher values make output more random, lower values make it more
|
||||
deterministic. Range: [0.0, 2.0]. (default: :obj:`0.7`)
|
||||
top_p (float, optional): Controls diversity via nucleus sampling.
|
||||
Range: [0.0, 1.0]. (default: :obj:`0.95`)
|
||||
presence_penalty (float, optional): Penalizes new tokens based on
|
||||
whether they appear in the text so far. Range: [-2.0, 2.0].
|
||||
(default: :obj:`0.0`)
|
||||
frequency_penalty (float, optional): Penalizes new tokens based on
|
||||
their frequency in the text so far. Range: [-2.0, 2.0].
|
||||
(default: :obj:`0.0`)
|
||||
max_tokens (Union[int, NotGiven], optional): Maximum number of tokens
|
||||
to generate. If not provided, model will use its default maximum.
|
||||
(default: :obj:`NOT_GIVEN`)
|
||||
seed (Optional[int], optional): Random seed for deterministic sampling.
|
||||
(default: :obj:`None`)
|
||||
tools (Optional[List[Dict]], optional): List of tools available to the
|
||||
model. This includes tools such as a text editor, a calculator, or
|
||||
a search engine. (default: :obj:`None`)
|
||||
tool_choice (Optional[str], optional): Tool choice configuration.
|
||||
(default: :obj:`None`)
|
||||
stop (Optional[List[str]], optional): List of stop sequences.
|
||||
(default: :obj:`None`)
|
||||
"""
|
||||
|
||||
stream: bool = Field(default=False)
|
||||
temperature: float = Field(default=0.7)
|
||||
top_p: float = Field(default=0.95)
|
||||
presence_penalty: float = Field(default=0.0)
|
||||
frequency_penalty: float = Field(default=0.0)
|
||||
max_tokens: Union[int, NotGiven] = Field(default=NOT_GIVEN)
|
||||
seed: Optional[int] = Field(default=None)
|
||||
tool_choice: Optional[str] = Field(default=None)
|
||||
stop: Optional[List[str]] = Field(default=None)
|
||||
|
||||
|
||||
NVIDIA_API_PARAMS = {param for param in NvidiaConfig.model_fields.keys()}
|
||||
82
deep-swarm/camel/configs/ollama_config.py
Normal file
82
deep-swarm/camel/configs/ollama_config.py
Normal file
@ -0,0 +1,82 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
from camel.types import NOT_GIVEN, NotGiven
|
||||
|
||||
|
||||
class OllamaConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using OpenAI
|
||||
compatibility
|
||||
|
||||
Reference: https://github.com/ollama/ollama/blob/main/docs/openai.md
|
||||
|
||||
Args:
|
||||
temperature (float, optional): Sampling temperature to use, between
|
||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
||||
while lower values make it more focused and deterministic.
|
||||
(default: :obj:`0.2`)
|
||||
top_p (float, optional): An alternative to sampling with temperature,
|
||||
called nucleus sampling, where the model considers the results of
|
||||
the tokens with top_p probability mass. So :obj:`0.1` means only
|
||||
the tokens comprising the top 10% probability mass are considered.
|
||||
(default: :obj:`1.0`)
|
||||
response_format (object, optional): An object specifying the format
|
||||
that the model must output. Compatible with GPT-4 Turbo and all
|
||||
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
|
||||
{"type": "json_object"} enables JSON mode, which guarantees the
|
||||
message the model generates is valid JSON. Important: when using
|
||||
JSON mode, you must also instruct the model to produce JSON
|
||||
yourself via a system or user message. Without this, the model
|
||||
may generate an unending stream of whitespace until the generation
|
||||
reaches the token limit, resulting in a long-running and seemingly
|
||||
"stuck" request. Also note that the message content may be
|
||||
partially cut off if finish_reason="length", which indicates the
|
||||
generation exceeded max_tokens or the conversation exceeded the
|
||||
max context length.
|
||||
stream (bool, optional): If True, partial message deltas will be sent
|
||||
as data-only server-sent events as they become available.
|
||||
(default: :obj:`False`)
|
||||
stop (str or list, optional): Up to :obj:`4` sequences where the API
|
||||
will stop generating further tokens. (default: :obj:`None`)
|
||||
max_tokens (int, optional): The maximum number of tokens to generate
|
||||
in the chat completion. The total length of input tokens and
|
||||
generated tokens is limited by the model's context length.
|
||||
(default: :obj:`None`)
|
||||
presence_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on whether
|
||||
they appear in the text so far, increasing the model's likelihood
|
||||
to talk about new topics. See more information about frequency and
|
||||
presence penalties. (default: :obj:`0.0`)
|
||||
frequency_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on their
|
||||
existing frequency in the text so far, decreasing the model's
|
||||
likelihood to repeat the same line verbatim. See more information
|
||||
about frequency and presence penalties. (default: :obj:`0.0`)
|
||||
"""
|
||||
|
||||
temperature: float = 0.2
|
||||
top_p: float = 1.0
|
||||
stream: bool = False
|
||||
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
|
||||
max_tokens: Union[int, NotGiven] = NOT_GIVEN
|
||||
presence_penalty: float = 0.0
|
||||
response_format: Union[dict, NotGiven] = NOT_GIVEN
|
||||
frequency_penalty: float = 0.0
|
||||
|
||||
|
||||
OLLAMA_API_PARAMS = {param for param in OllamaConfig.model_fields.keys()}
|
||||
139
deep-swarm/camel/configs/openai_config.py
Normal file
139
deep-swarm/camel/configs/openai_config.py
Normal file
@ -0,0 +1,139 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional, Sequence, Type, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
from camel.types import NOT_GIVEN, NotGiven
|
||||
|
||||
|
||||
class ChatGPTConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
OpenAI API.
|
||||
|
||||
Args:
|
||||
temperature (float, optional): Sampling temperature to use, between
|
||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
||||
while lower values make it more focused and deterministic.
|
||||
(default: :obj:`0.2`)
|
||||
top_p (float, optional): An alternative to sampling with temperature,
|
||||
called nucleus sampling, where the model considers the results of
|
||||
the tokens with top_p probability mass. So :obj:`0.1` means only
|
||||
the tokens comprising the top 10% probability mass are considered.
|
||||
(default: :obj:`1.0`)
|
||||
n (int, optional): How many chat completion choices to generate for
|
||||
each input message. (default: :obj:`1`)
|
||||
response_format (object, optional): An object specifying the format
|
||||
that the model must output. Compatible with GPT-4 Turbo and all
|
||||
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
|
||||
{"type": "json_object"} enables JSON mode, which guarantees the
|
||||
message the model generates is valid JSON. Important: when using
|
||||
JSON mode, you must also instruct the model to produce JSON
|
||||
yourself via a system or user message. Without this, the model
|
||||
may generate an unending stream of whitespace until the generation
|
||||
reaches the token limit, resulting in a long-running and seemingly
|
||||
"stuck" request. Also note that the message content may be
|
||||
partially cut off if finish_reason="length", which indicates the
|
||||
generation exceeded max_tokens or the conversation exceeded the
|
||||
max context length.
|
||||
stream (bool, optional): If True, partial message deltas will be sent
|
||||
as data-only server-sent events as they become available.
|
||||
(default: :obj:`False`)
|
||||
stop (str or list, optional): Up to :obj:`4` sequences where the API
|
||||
will stop generating further tokens. (default: :obj:`None`)
|
||||
max_tokens (int, optional): The maximum number of tokens to generate
|
||||
in the chat completion. The total length of input tokens and
|
||||
generated tokens is limited by the model's context length.
|
||||
(default: :obj:`None`)
|
||||
presence_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on whether
|
||||
they appear in the text so far, increasing the model's likelihood
|
||||
to talk about new topics. See more information about frequency and
|
||||
presence penalties. (default: :obj:`0.0`)
|
||||
frequency_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on their
|
||||
existing frequency in the text so far, decreasing the model's
|
||||
likelihood to repeat the same line verbatim. See more information
|
||||
about frequency and presence penalties. (default: :obj:`0.0`)
|
||||
logit_bias (dict, optional): Modify the likelihood of specified tokens
|
||||
appearing in the completion. Accepts a json object that maps tokens
|
||||
(specified by their token ID in the tokenizer) to an associated
|
||||
bias value from :obj:`-100` to :obj:`100`. Mathematically, the bias
|
||||
is added to the logits generated by the model prior to sampling.
|
||||
The exact effect will vary per model, but values between:obj:` -1`
|
||||
and :obj:`1` should decrease or increase likelihood of selection;
|
||||
values like :obj:`-100` or :obj:`100` should result in a ban or
|
||||
exclusive selection of the relevant token. (default: :obj:`{}`)
|
||||
user (str, optional): A unique identifier representing your end-user,
|
||||
which can help OpenAI to monitor and detect abuse.
|
||||
(default: :obj:`""`)
|
||||
tools (list[FunctionTool], optional): A list of tools the model may
|
||||
call. Currently, only functions are supported as a tool. Use this
|
||||
to provide a list of functions the model may generate JSON inputs
|
||||
for. A max of 128 functions are supported.
|
||||
tool_choice (Union[dict[str, str], str], optional): Controls which (if
|
||||
any) tool is called by the model. :obj:`"none"` means the model
|
||||
will not call any tool and instead generates a message.
|
||||
:obj:`"auto"` means the model can pick between generating a
|
||||
message or calling one or more tools. :obj:`"required"` means the
|
||||
model must call one or more tools. Specifying a particular tool
|
||||
via {"type": "function", "function": {"name": "my_function"}}
|
||||
forces the model to call that tool. :obj:`"none"` is the default
|
||||
when no tools are present. :obj:`"auto"` is the default if tools
|
||||
are present.
|
||||
"""
|
||||
|
||||
temperature: float = 0.2 # openai default: 1.0
|
||||
top_p: float = 1.0
|
||||
n: int = 1
|
||||
stream: bool = False
|
||||
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
|
||||
max_tokens: Union[int, NotGiven] = NOT_GIVEN
|
||||
presence_penalty: float = 0.0
|
||||
response_format: Union[Type[BaseModel], dict, NotGiven] = NOT_GIVEN
|
||||
frequency_penalty: float = 0.0
|
||||
logit_bias: dict = Field(default_factory=dict)
|
||||
user: str = ""
|
||||
tool_choice: Optional[Union[dict[str, str], str]] = None
|
||||
|
||||
def as_dict(self) -> dict[str, Any]:
|
||||
r"""Convert the current configuration to a dictionary.
|
||||
|
||||
This method converts the current configuration object to a dictionary
|
||||
representation, which can be used for serialization or other purposes.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: A dictionary representation of the current
|
||||
configuration.
|
||||
"""
|
||||
config_dict = self.model_dump()
|
||||
if self.tools:
|
||||
from camel.toolkits import FunctionTool
|
||||
|
||||
tools_schema = []
|
||||
for tool in self.tools:
|
||||
if not isinstance(tool, FunctionTool):
|
||||
raise ValueError(
|
||||
f"The tool {tool} should "
|
||||
"be an instance of `FunctionTool`."
|
||||
)
|
||||
tools_schema.append(tool.get_openai_tool_schema())
|
||||
config_dict["tools"] = NOT_GIVEN
|
||||
return config_dict
|
||||
|
||||
|
||||
OPENAI_API_PARAMS = {param for param in ChatGPTConfig.model_fields.keys()}
|
||||
91
deep-swarm/camel/configs/qwen_config.py
Normal file
91
deep-swarm/camel/configs/qwen_config.py
Normal file
@ -0,0 +1,91 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import ClassVar, Optional, Union
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
from camel.types import NOT_GIVEN, NotGiven
|
||||
|
||||
|
||||
class QwenConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
Qwen API. You can refer to the following link for more details:
|
||||
https://help.aliyun.com/zh/model-studio/developer-reference/use-qwen-by-calling-api
|
||||
|
||||
Args:
|
||||
stream (bool, optional): Whether to stream the response.
|
||||
(default: :obj:`False`)
|
||||
temperature (float, optional): Controls the diversity and focus of
|
||||
the generated results. Lower values make the output more focused,
|
||||
while higher values make it more diverse. (default: :obj:`0.3`)
|
||||
top_p (float, optional): Controls the diversity and focus of the
|
||||
generated results. Higher values make the output more diverse,
|
||||
while lower values make it more focused. (default: :obj:`0.9`)
|
||||
presence_penalty (float, optional): Controls the repetition of
|
||||
content in the generated results. Positive values reduce the
|
||||
repetition of content, while negative values increase it.
|
||||
(default: :obj:`0.0`)
|
||||
response_format (object, optional): Specifies the format of the
|
||||
returned content. The available values are `{"type": "text"}` or
|
||||
`{"type": "json_object"}`. Setting it to `{"type": "json_object"}`
|
||||
will output a standard JSON string.
|
||||
(default: :obj:`{"type": "text"}`)
|
||||
max_tokens (Union[int, NotGiven], optional): Allows the model to
|
||||
generate the maximum number of tokens.
|
||||
(default: :obj:`NOT_GIVEN`)
|
||||
seed (int, optional): Sets the seed parameter to make the text
|
||||
generation process more deterministic, typically used to ensure
|
||||
that the results are consistent across model runs. By passing the
|
||||
same seed value (specified by you) in each model call while
|
||||
keeping other parameters unchanged, the model is likely to return
|
||||
the same result.
|
||||
(default: :obj:`None`)
|
||||
stop (str or list, optional): Using the stop parameter, the model will
|
||||
automatically stop generating text when it is about to include the
|
||||
specified string or token_id. You can use the stop parameter to
|
||||
control the output of the model by passing sensitive words.
|
||||
(default: :obj:`None`)
|
||||
tools (list, optional): Specifies an array of tools that the model can
|
||||
call. It can contain one or more tool objects. During a function
|
||||
call process, the model will select one tool from the array.
|
||||
(default: :obj:`None`)
|
||||
extra_body (dict, optional): Additional parameters to be sent to the
|
||||
Qwen API. If you want to enable internet search, you can set this
|
||||
parameter to `{"enable_search": True}`.
|
||||
(default: :obj:`{"enable_search": False}`)
|
||||
include_usage (bool, optional): When streaming, specifies whether to
|
||||
include usage information in `stream_options`. (default:
|
||||
:obj:`True`)
|
||||
"""
|
||||
|
||||
stream: bool = False
|
||||
temperature: float = 0.3
|
||||
top_p: float = 0.9
|
||||
presence_penalty: float = 0.0
|
||||
response_format: ClassVar[dict] = {"type": "text"}
|
||||
max_tokens: Union[int, NotGiven] = NOT_GIVEN
|
||||
seed: Optional[int] = None
|
||||
stop: Optional[Union[str, list]] = None
|
||||
extra_body: ClassVar[dict] = {"enable_search": False}
|
||||
|
||||
def __init__(self, include_usage: bool = True, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
# Only set stream_options when stream is True
|
||||
# Otherwise, it will raise error when calling the API
|
||||
if self.stream:
|
||||
self.stream_options = {"include_usage": include_usage}
|
||||
|
||||
|
||||
QWEN_API_PARAMS = {param for param in QwenConfig.model_fields.keys()}
|
||||
74
deep-swarm/camel/configs/reka_config.py
Normal file
74
deep-swarm/camel/configs/reka_config.py
Normal file
@ -0,0 +1,74 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
|
||||
|
||||
class RekaConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
Reka API.
|
||||
|
||||
Reference: https://docs.reka.ai/api-reference/chat/create
|
||||
|
||||
Args:
|
||||
temperature (Optional[float], optional): temperature the temperature
|
||||
to use for sampling, e.g. 0.5.
|
||||
top_p (Optional[float], optional): the cumulative probability of
|
||||
tokens to generate, e.g. 0.9. Defaults to None.
|
||||
top_k (Optional[int], optional): Parameter which forces the model to
|
||||
only consider the tokens with the `top_k` highest probabilities at
|
||||
the next step. Defaults to 1024.
|
||||
max_tokens (Optional[int], optional): the maximum number of tokens to
|
||||
generate, e.g. 100. Defaults to None.
|
||||
stop (Optional[Union[str,list[str]]]): Stop generation if this token
|
||||
is detected. Or if one of these tokens is detected when providing
|
||||
a string list.
|
||||
seed (Optional[int], optional): the random seed to use for sampling, e.
|
||||
g. 42. Defaults to None.
|
||||
presence_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on whether
|
||||
they appear in the text so far, increasing the model's likelihood
|
||||
to talk about new topics. See more information about frequency and
|
||||
presence penalties. (default: :obj:`0.0`)
|
||||
frequency_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on their
|
||||
existing frequency in the text so far, decreasing the model's
|
||||
likelihood to repeat the same line verbatim. See more information
|
||||
about frequency and presence penalties. (default: :obj:`0.0`)
|
||||
use_search_engine (Optional[bool]): Whether to consider using search
|
||||
engine to complete the request. Note that even if this is set to
|
||||
`True`, the model might decide to not use search.
|
||||
"""
|
||||
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
max_tokens: Optional[int] = None
|
||||
stop: Optional[Union[str, list[str]]] = None
|
||||
seed: Optional[int] = None
|
||||
frequency_penalty: float = 0.0
|
||||
presence_penalty: float = 0.0
|
||||
use_search_engine: Optional[bool] = False
|
||||
|
||||
def as_dict(self) -> dict[str, Any]:
|
||||
config_dict = super().as_dict()
|
||||
if "tools" in config_dict:
|
||||
del config_dict["tools"] # Reka does not support tool calling
|
||||
return config_dict
|
||||
|
||||
|
||||
REKA_API_PARAMS = {param for param in RekaConfig().model_fields.keys()}
|
||||
170
deep-swarm/camel/configs/samba_config.py
Normal file
170
deep-swarm/camel/configs/samba_config.py
Normal file
@ -0,0 +1,170 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional, Sequence, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
from camel.types import NOT_GIVEN, NotGiven
|
||||
|
||||
|
||||
class SambaVerseAPIConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
SambaVerse API.
|
||||
|
||||
Args:
|
||||
temperature (float, optional): Sampling temperature to use, between
|
||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
||||
while lower values make it more focused and deterministic.
|
||||
(default: :obj:`0.7`)
|
||||
top_p (float, optional): An alternative to sampling with temperature,
|
||||
called nucleus sampling, where the model considers the results of
|
||||
the tokens with top_p probability mass. So :obj:`0.1` means only
|
||||
the tokens comprising the top 10% probability mass are considered.
|
||||
(default: :obj:`0.95`)
|
||||
top_k (int, optional): Only sample from the top K options for each
|
||||
subsequent token. Used to remove "long tail" low probability
|
||||
responses.
|
||||
(default: :obj:`50`)
|
||||
max_tokens (Optional[int], optional): The maximum number of tokens to
|
||||
generate, e.g. 100.
|
||||
(default: :obj:`2048`)
|
||||
repetition_penalty (Optional[float], optional): The parameter for
|
||||
repetition penalty. 1.0 means no penalty.
|
||||
(default: :obj:`1.0`)
|
||||
stop (Optional[Union[str,list[str]]]): Stop generation if this token
|
||||
is detected. Or if one of these tokens is detected when providing
|
||||
a string list.
|
||||
(default: :obj:`""`)
|
||||
stream (Optional[bool]): If True, partial message deltas will be sent
|
||||
as data-only server-sent events as they become available.
|
||||
Currently SambaVerse API doesn't support stream mode.
|
||||
(default: :obj:`False`)
|
||||
"""
|
||||
|
||||
temperature: Optional[float] = 0.7
|
||||
top_p: Optional[float] = 0.95
|
||||
top_k: Optional[int] = 50
|
||||
max_tokens: Optional[int] = 2048
|
||||
repetition_penalty: Optional[float] = 1.0
|
||||
stop: Optional[Union[str, list[str]]] = ""
|
||||
stream: Optional[bool] = False
|
||||
|
||||
def as_dict(self) -> dict[str, Any]:
|
||||
config_dict = super().as_dict()
|
||||
if "tools" in config_dict:
|
||||
del config_dict["tools"] # SambaNova does not support tool calling
|
||||
return config_dict
|
||||
|
||||
|
||||
SAMBA_VERSE_API_PARAMS = {
|
||||
param for param in SambaVerseAPIConfig().model_fields.keys()
|
||||
}
|
||||
|
||||
|
||||
class SambaCloudAPIConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
OpenAI API.
|
||||
|
||||
Args:
|
||||
temperature (float, optional): Sampling temperature to use, between
|
||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
||||
while lower values make it more focused and deterministic.
|
||||
(default: :obj:`0.2`)
|
||||
top_p (float, optional): An alternative to sampling with temperature,
|
||||
called nucleus sampling, where the model considers the results of
|
||||
the tokens with top_p probability mass. So :obj:`0.1` means only
|
||||
the tokens comprising the top 10% probability mass are considered.
|
||||
(default: :obj:`1.0`)
|
||||
n (int, optional): How many chat completion choices to generate for
|
||||
each input message. (default: :obj:`1`)
|
||||
response_format (object, optional): An object specifying the format
|
||||
that the model must output. Compatible with GPT-4 Turbo and all
|
||||
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
|
||||
{"type": "json_object"} enables JSON mode, which guarantees the
|
||||
message the model generates is valid JSON. Important: when using
|
||||
JSON mode, you must also instruct the model to produce JSON
|
||||
yourself via a system or user message. Without this, the model
|
||||
may generate an unending stream of whitespace until the generation
|
||||
reaches the token limit, resulting in a long-running and seemingly
|
||||
"stuck" request. Also note that the message content may be
|
||||
partially cut off if finish_reason="length", which indicates the
|
||||
generation exceeded max_tokens or the conversation exceeded the
|
||||
max context length.
|
||||
stream (bool, optional): If True, partial message deltas will be sent
|
||||
as data-only server-sent events as they become available.
|
||||
(default: :obj:`False`)
|
||||
stop (str or list, optional): Up to :obj:`4` sequences where the API
|
||||
will stop generating further tokens. (default: :obj:`None`)
|
||||
max_tokens (int, optional): The maximum number of tokens to generate
|
||||
in the chat completion. The total length of input tokens and
|
||||
generated tokens is limited by the model's context length.
|
||||
(default: :obj:`None`)
|
||||
presence_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on whether
|
||||
they appear in the text so far, increasing the model's likelihood
|
||||
to talk about new topics. See more information about frequency and
|
||||
presence penalties. (default: :obj:`0.0`)
|
||||
frequency_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on their
|
||||
existing frequency in the text so far, decreasing the model's
|
||||
likelihood to repeat the same line verbatim. See more information
|
||||
about frequency and presence penalties. (default: :obj:`0.0`)
|
||||
logit_bias (dict, optional): Modify the likelihood of specified tokens
|
||||
appearing in the completion. Accepts a json object that maps tokens
|
||||
(specified by their token ID in the tokenizer) to an associated
|
||||
bias value from :obj:`-100` to :obj:`100`. Mathematically, the bias
|
||||
is added to the logits generated by the model prior to sampling.
|
||||
The exact effect will vary per model, but values between:obj:` -1`
|
||||
and :obj:`1` should decrease or increase likelihood of selection;
|
||||
values like :obj:`-100` or :obj:`100` should result in a ban or
|
||||
exclusive selection of the relevant token. (default: :obj:`{}`)
|
||||
user (str, optional): A unique identifier representing your end-user,
|
||||
which can help OpenAI to monitor and detect abuse.
|
||||
(default: :obj:`""`)
|
||||
tools (list[FunctionTool], optional): A list of tools the model may
|
||||
call. Currently, only functions are supported as a tool. Use this
|
||||
to provide a list of functions the model may generate JSON inputs
|
||||
for. A max of 128 functions are supported.
|
||||
tool_choice (Union[dict[str, str], str], optional): Controls which (if
|
||||
any) tool is called by the model. :obj:`"none"` means the model
|
||||
will not call any tool and instead generates a message.
|
||||
:obj:`"auto"` means the model can pick between generating a
|
||||
message or calling one or more tools. :obj:`"required"` means the
|
||||
model must call one or more tools. Specifying a particular tool
|
||||
via {"type": "function", "function": {"name": "my_function"}}
|
||||
forces the model to call that tool. :obj:`"none"` is the default
|
||||
when no tools are present. :obj:`"auto"` is the default if tools
|
||||
are present.
|
||||
"""
|
||||
|
||||
temperature: float = 0.2 # openai default: 1.0
|
||||
top_p: float = 1.0
|
||||
n: int = 1
|
||||
stream: bool = False
|
||||
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
|
||||
max_tokens: Union[int, NotGiven] = NOT_GIVEN
|
||||
presence_penalty: float = 0.0
|
||||
response_format: Union[dict, NotGiven] = NOT_GIVEN
|
||||
frequency_penalty: float = 0.0
|
||||
logit_bias: dict = Field(default_factory=dict)
|
||||
user: str = ""
|
||||
tool_choice: Optional[Union[dict[str, str], str]] = None
|
||||
|
||||
|
||||
SAMBA_CLOUD_API_PARAMS = {
|
||||
param for param in SambaCloudAPIConfig().model_fields.keys()
|
||||
}
|
||||
107
deep-swarm/camel/configs/togetherai_config.py
Normal file
107
deep-swarm/camel/configs/togetherai_config.py
Normal file
@ -0,0 +1,107 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Sequence, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
from camel.types import NOT_GIVEN, NotGiven
|
||||
|
||||
|
||||
class TogetherAIConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
OpenAI API.
|
||||
|
||||
Args:
|
||||
temperature (float, optional): Sampling temperature to use, between
|
||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
||||
while lower values make it more focused and deterministic.
|
||||
(default: :obj:`0.2`)
|
||||
top_p (float, optional): An alternative to sampling with temperature,
|
||||
called nucleus sampling, where the model considers the results of
|
||||
the tokens with top_p probability mass. So :obj:`0.1` means only
|
||||
the tokens comprising the top 10% probability mass are considered.
|
||||
(default: :obj:`1.0`)
|
||||
n (int, optional): How many chat completion choices to generate for
|
||||
each input message. (default: :obj:`1`)
|
||||
response_format (object, optional): An object specifying the format
|
||||
that the model must output. Compatible with GPT-4 Turbo and all
|
||||
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
|
||||
{"type": "json_object"} enables JSON mode, which guarantees the
|
||||
message the model generates is valid JSON. Important: when using
|
||||
JSON mode, you must also instruct the model to produce JSON
|
||||
yourself via a system or user message. Without this, the model
|
||||
may generate an unending stream of whitespace until the generation
|
||||
reaches the token limit, resulting in a long-running and seemingly
|
||||
"stuck" request. Also note that the message content may be
|
||||
partially cut off if finish_reason="length", which indicates the
|
||||
generation exceeded max_tokens or the conversation exceeded the
|
||||
max context length.
|
||||
stream (bool, optional): If True, partial message deltas will be sent
|
||||
as data-only server-sent events as they become available.
|
||||
(default: :obj:`False`)
|
||||
stop (str or list, optional): Up to :obj:`4` sequences where the API
|
||||
will stop generating further tokens. (default: :obj:`None`)
|
||||
max_tokens (int, optional): The maximum number of tokens to generate
|
||||
in the chat completion. The total length of input tokens and
|
||||
generated tokens is limited by the model's context length.
|
||||
(default: :obj:`None`)
|
||||
presence_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on whether
|
||||
they appear in the text so far, increasing the model's likelihood
|
||||
to talk about new topics. See more information about frequency and
|
||||
presence penalties. (default: :obj:`0.0`)
|
||||
frequency_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on their
|
||||
existing frequency in the text so far, decreasing the model's
|
||||
likelihood to repeat the same line verbatim. See more information
|
||||
about frequency and presence penalties. (default: :obj:`0.0`)
|
||||
logit_bias (dict, optional): Modify the likelihood of specified tokens
|
||||
appearing in the completion. Accepts a json object that maps tokens
|
||||
(specified by their token ID in the tokenizer) to an associated
|
||||
bias value from :obj:`-100` to :obj:`100`. Mathematically, the bias
|
||||
is added to the logits generated by the model prior to sampling.
|
||||
The exact effect will vary per model, but values between:obj:` -1`
|
||||
and :obj:`1` should decrease or increase likelihood of selection;
|
||||
values like :obj:`-100` or :obj:`100` should result in a ban or
|
||||
exclusive selection of the relevant token. (default: :obj:`{}`)
|
||||
user (str, optional): A unique identifier representing your end-user,
|
||||
which can help OpenAI to monitor and detect abuse.
|
||||
(default: :obj:`""`)
|
||||
"""
|
||||
|
||||
temperature: float = 0.2 # openai default: 1.0
|
||||
top_p: float = 1.0
|
||||
n: int = 1
|
||||
stream: bool = False
|
||||
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
|
||||
max_tokens: Union[int, NotGiven] = NOT_GIVEN
|
||||
presence_penalty: float = 0.0
|
||||
response_format: Union[dict, NotGiven] = NOT_GIVEN
|
||||
frequency_penalty: float = 0.0
|
||||
logit_bias: dict = Field(default_factory=dict)
|
||||
user: str = ""
|
||||
|
||||
def as_dict(self) -> dict[str, Any]:
|
||||
config_dict = super().as_dict()
|
||||
if "tools" in config_dict:
|
||||
del config_dict["tools"] # Currently does not support tool calling
|
||||
return config_dict
|
||||
|
||||
|
||||
TOGETHERAI_API_PARAMS = {
|
||||
param for param in TogetherAIConfig.model_fields.keys()
|
||||
}
|
||||
111
deep-swarm/camel/configs/vllm_config.py
Normal file
111
deep-swarm/camel/configs/vllm_config.py
Normal file
@ -0,0 +1,111 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
from camel.types import NOT_GIVEN, NotGiven
|
||||
|
||||
|
||||
# flake8: noqa: E501
|
||||
class VLLMConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
OpenAI API.
|
||||
|
||||
Reference: https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html
|
||||
|
||||
Args:
|
||||
temperature (float, optional): Sampling temperature to use, between
|
||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
||||
while lower values make it more focused and deterministic.
|
||||
(default: :obj:`0.2`)
|
||||
top_p (float, optional): An alternative to sampling with temperature,
|
||||
called nucleus sampling, where the model considers the results of
|
||||
the tokens with top_p probability mass. So :obj:`0.1` means only
|
||||
the tokens comprising the top 10% probability mass are considered.
|
||||
(default: :obj:`1.0`)
|
||||
n (int, optional): How many chat completion choices to generate for
|
||||
each input message. (default: :obj:`1`)
|
||||
response_format (object, optional): An object specifying the format
|
||||
that the model must output. Compatible with GPT-4 Turbo and all
|
||||
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
|
||||
{"type": "json_object"} enables JSON mode, which guarantees the
|
||||
message the model generates is valid JSON. Important: when using
|
||||
JSON mode, you must also instruct the model to produce JSON
|
||||
yourself via a system or user message. Without this, the model
|
||||
may generate an unending stream of whitespace until the generation
|
||||
reaches the token limit, resulting in a long-running and seemingly
|
||||
"stuck" request. Also note that the message content may be
|
||||
partially cut off if finish_reason="length", which indicates the
|
||||
generation exceeded max_tokens or the conversation exceeded the
|
||||
max context length.
|
||||
stream (bool, optional): If True, partial message deltas will be sent
|
||||
as data-only server-sent events as they become available.
|
||||
(default: :obj:`False`)
|
||||
stop (str or list, optional): Up to :obj:`4` sequences where the API
|
||||
will stop generating further tokens. (default: :obj:`None`)
|
||||
max_tokens (int, optional): The maximum number of tokens to generate
|
||||
in the chat completion. The total length of input tokens and
|
||||
generated tokens is limited by the model's context length.
|
||||
(default: :obj:`None`)
|
||||
presence_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on whether
|
||||
they appear in the text so far, increasing the model's likelihood
|
||||
to talk about new topics. See more information about frequency and
|
||||
presence penalties. (default: :obj:`0.0`)
|
||||
frequency_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on their
|
||||
existing frequency in the text so far, decreasing the model's
|
||||
likelihood to repeat the same line verbatim. See more information
|
||||
about frequency and presence penalties. (default: :obj:`0.0`)
|
||||
logit_bias (dict, optional): Modify the likelihood of specified tokens
|
||||
appearing in the completion. Accepts a json object that maps tokens
|
||||
(specified by their token ID in the tokenizer) to an associated
|
||||
bias value from :obj:`-100` to :obj:`100`. Mathematically, the bias
|
||||
is added to the logits generated by the model prior to sampling.
|
||||
The exact effect will vary per model, but values between:obj:` -1`
|
||||
and :obj:`1` should decrease or increase likelihood of selection;
|
||||
values like :obj:`-100` or :obj:`100` should result in a ban or
|
||||
exclusive selection of the relevant token. (default: :obj:`{}`)
|
||||
user (str, optional): A unique identifier representing your end-user,
|
||||
which can help OpenAI to monitor and detect abuse.
|
||||
(default: :obj:`""`)
|
||||
logprobs: Whether to return log probabilities of the output tokens or
|
||||
not. If true, returns the log probabilities of each output token
|
||||
returned in the `logits` of `message`. (default: :obj:`None`)
|
||||
top_logprobs: An integer between 0 and 20 specifying the number of
|
||||
most likely tokens to return at each token position, each with an
|
||||
associated log probability. `logprobs` must be set to `true` if
|
||||
this parameter is used. (default: :obj:`None`)
|
||||
"""
|
||||
|
||||
temperature: float = 0.2 # openai default: 1.0
|
||||
top_p: float = 1.0
|
||||
n: int = 1
|
||||
stream: bool = False
|
||||
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
|
||||
max_tokens: Union[int, NotGiven] = NOT_GIVEN
|
||||
presence_penalty: float = 0.0
|
||||
response_format: Union[dict, NotGiven] = NOT_GIVEN
|
||||
frequency_penalty: float = 0.0
|
||||
logit_bias: dict = Field(default_factory=dict)
|
||||
user: str = ""
|
||||
logprobs: Optional[bool] = None
|
||||
top_logprobs: Optional[int] = None
|
||||
|
||||
|
||||
VLLM_API_PARAMS = {param for param in VLLMConfig.model_fields.keys()}
|
||||
58
deep-swarm/camel/configs/yi_config.py
Normal file
58
deep-swarm/camel/configs/yi_config.py
Normal file
@ -0,0 +1,58 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
from camel.types import NOT_GIVEN, NotGiven
|
||||
|
||||
|
||||
class YiConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
Yi API. You can refer to the following link for more details:
|
||||
https://platform.lingyiwanwu.com/docs/api-reference
|
||||
|
||||
Args:
|
||||
tool_choice (Union[dict[str, str], str], optional): Controls which (if
|
||||
any) tool is called by the model. :obj:`"none"` means the model
|
||||
will not call any tool and instead generates a message.
|
||||
:obj:`"auto"` means the model can pick between generating a
|
||||
message or calling one or more tools. :obj:`"required"` or
|
||||
specifying a particular tool via
|
||||
{"type": "function", "function": {"name": "some_function"}}
|
||||
can be used to guide the model to use tools more strongly.
|
||||
(default: :obj:`None`)
|
||||
max_tokens (int, optional): Specifies the maximum number of tokens
|
||||
the model can generate. This sets an upper limit, but does not
|
||||
guarantee that this number will always be reached.
|
||||
(default: :obj:`5000`)
|
||||
top_p (float, optional): Controls the randomness of the generated
|
||||
results. Lower values lead to less randomness, while higher
|
||||
values increase randomness. (default: :obj:`0.9`)
|
||||
temperature (float, optional): Controls the diversity and focus of
|
||||
the generated results. Lower values make the output more focused,
|
||||
while higher values make it more diverse. (default: :obj:`0.3`)
|
||||
stream (bool, optional): If True, enables streaming output.
|
||||
(default: :obj:`False`)
|
||||
"""
|
||||
|
||||
tool_choice: Optional[Union[dict[str, str], str]] = None
|
||||
max_tokens: Union[int, NotGiven] = NOT_GIVEN
|
||||
top_p: float = 0.9
|
||||
temperature: float = 0.3
|
||||
stream: bool = False
|
||||
|
||||
|
||||
YI_API_PARAMS = {param for param in YiConfig.model_fields.keys()}
|
||||
71
deep-swarm/camel/configs/zhipuai_config.py
Normal file
71
deep-swarm/camel/configs/zhipuai_config.py
Normal file
@ -0,0 +1,71 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
from camel.types import NOT_GIVEN, NotGiven
|
||||
|
||||
|
||||
class ZhipuAIConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using OpenAI
|
||||
compatibility
|
||||
|
||||
Reference: https://open.bigmodel.cn/dev/api#glm-4v
|
||||
|
||||
Args:
|
||||
temperature (float, optional): Sampling temperature to use, between
|
||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
||||
while lower values make it more focused and deterministic.
|
||||
(default: :obj:`0.2`)
|
||||
top_p (float, optional): An alternative to sampling with temperature,
|
||||
called nucleus sampling, where the model considers the results of
|
||||
the tokens with top_p probability mass. So :obj:`0.1` means only
|
||||
the tokens comprising the top 10% probability mass are considered.
|
||||
(default: :obj:`0.6`)
|
||||
stream (bool, optional): If True, partial message deltas will be sent
|
||||
as data-only server-sent events as they become available.
|
||||
(default: :obj:`False`)
|
||||
stop (str or list, optional): Up to :obj:`4` sequences where the API
|
||||
will stop generating further tokens. (default: :obj:`None`)
|
||||
max_tokens (int, optional): The maximum number of tokens to generate
|
||||
in the chat completion. The total length of input tokens and
|
||||
generated tokens is limited by the model's context length.
|
||||
(default: :obj:`None`)
|
||||
tools (list[FunctionTool], optional): A list of tools the model may
|
||||
call. Currently, only functions are supported as a tool. Use this
|
||||
to provide a list of functions the model may generate JSON inputs
|
||||
for. A max of 128 functions are supported.
|
||||
tool_choice (Union[dict[str, str], str], optional): Controls which (if
|
||||
any) tool is called by the model. :obj:`"none"` means the model
|
||||
will not call any tool and instead generates a message.
|
||||
:obj:`"auto"` means the model can pick between generating a
|
||||
message or calling one or more tools. :obj:`"required"` means the
|
||||
model must call one or more tools. Specifying a particular tool
|
||||
via {"type": "function", "function": {"name": "my_function"}}
|
||||
forces the model to call that tool. :obj:`"none"` is the default
|
||||
when no tools are present. :obj:`"auto"` is the default if tools
|
||||
are present.
|
||||
"""
|
||||
|
||||
temperature: float = 0.2
|
||||
top_p: float = 0.6
|
||||
stream: bool = False
|
||||
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
|
||||
max_tokens: Union[int, NotGiven] = NOT_GIVEN
|
||||
tool_choice: Optional[Union[dict[str, str], str]] = None
|
||||
|
||||
|
||||
ZHIPUAI_API_PARAMS = {param for param in ZhipuAIConfig.model_fields.keys()}
|
||||
23
deep-swarm/camel/datahubs/__init__.py
Normal file
23
deep-swarm/camel/datahubs/__init__.py
Normal file
@ -0,0 +1,23 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
from .base import BaseDatasetManager
|
||||
from .huggingface import HuggingFaceDatasetManager
|
||||
from .models import Record
|
||||
|
||||
__all__ = [
|
||||
"BaseDatasetManager",
|
||||
"Record",
|
||||
"HuggingFaceDatasetManager",
|
||||
]
|
||||
136
deep-swarm/camel/datahubs/base.py
Normal file
136
deep-swarm/camel/datahubs/base.py
Normal file
@ -0,0 +1,136 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List
|
||||
|
||||
from camel.datahubs.models import Record
|
||||
|
||||
|
||||
class BaseDatasetManager(ABC):
|
||||
r"""Abstract base class for dataset managers."""
|
||||
|
||||
@abstractmethod
|
||||
def create_dataset(self, name: str, **kwargs: Any) -> str:
|
||||
r"""Creates a new dataset.
|
||||
|
||||
Args:
|
||||
name (str): The name of the dataset.
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
str: The URL of the created dataset.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_datasets(
|
||||
self, username: str, limit: int = 100, **kwargs: Any
|
||||
) -> List[str]:
|
||||
r"""Lists all datasets for the current user.
|
||||
|
||||
Args:
|
||||
username (str): The username of the user whose datasets to list.
|
||||
limit (int): The maximum number of datasets to list.
|
||||
(default::obj:`100`)
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of dataset ids.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_dataset(self, dataset_name: str, **kwargs: Any) -> None:
|
||||
r"""Deletes a dataset.
|
||||
|
||||
Args:
|
||||
dataset_name (str): The name of the dataset to delete.
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_records(
|
||||
self,
|
||||
dataset_name: str,
|
||||
records: List[Record],
|
||||
filepath: str = "records/records.json",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
r"""Adds records to a dataset.
|
||||
|
||||
Args:
|
||||
dataset_name (str): The name of the dataset.
|
||||
records (List[Record]): A list of records to add to the dataset.
|
||||
filepath (str): The path to the file containing the records.
|
||||
(default::obj:`"records/records.json"`)
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_records(
|
||||
self,
|
||||
dataset_name: str,
|
||||
records: List[Record],
|
||||
filepath: str = "records/records.json",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
r"""Updates records in a dataset.
|
||||
|
||||
Args:
|
||||
dataset_name (str): The name of the dataset.
|
||||
records (List[Record]): A list of records to update in the dataset.
|
||||
filepath (str): The path to the file containing the records.
|
||||
(default::obj:`"records/records.json"`)
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_records(
|
||||
self,
|
||||
dataset_name: str,
|
||||
filepath: str = "records/records.json",
|
||||
**kwargs: Any,
|
||||
) -> List[Record]:
|
||||
r"""Lists records in a dataset.
|
||||
|
||||
Args:
|
||||
dataset_name (str): The name of the dataset.
|
||||
filepath (str): The path to the file containing the records.
|
||||
(default::obj:`"records/records.json"`)
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
"""
|
||||
pass
|
||||
|
||||
# New method for record deletion
|
||||
@abstractmethod
|
||||
def delete_record(
|
||||
self,
|
||||
dataset_name: str,
|
||||
record_id: str,
|
||||
filepath: str = "records/records.json",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
r"""Deletes a record from the dataset.
|
||||
|
||||
Args:
|
||||
dataset_name (str): The name of the dataset.
|
||||
record_id (str): The ID of the record to delete.
|
||||
filepath (str): The path to the file containing the records.
|
||||
(default::obj:`"records/records.json"`)
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
"""
|
||||
pass
|
||||
433
deep-swarm/camel/datahubs/huggingface.py
Normal file
433
deep-swarm/camel/datahubs/huggingface.py
Normal file
@ -0,0 +1,433 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from camel.datahubs.base import BaseDatasetManager
|
||||
from camel.datahubs.models import Record
|
||||
from camel.logger import get_logger
|
||||
from camel.types import HuggingFaceRepoType
|
||||
from camel.utils import api_keys_required, dependencies_required
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class HuggingFaceDatasetManager(BaseDatasetManager):
|
||||
r"""A dataset manager for Hugging Face datasets. This class provides
|
||||
methods to create, add, update, delete, and list records in a dataset on
|
||||
the Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
token (str): The Hugging Face API token. If not provided, the token
|
||||
will be read from the environment variable `HUGGING_FACE_TOKEN`.
|
||||
"""
|
||||
|
||||
@api_keys_required("HUGGING_FACE_TOKEN")
|
||||
@dependencies_required('huggingface_hub')
|
||||
def __init__(self, token: Optional[str] = None):
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
self._api_key = token or os.getenv("HUGGING_FACE_TOKEN")
|
||||
self.api = HfApi(token=self._api_key)
|
||||
|
||||
def create_dataset_card(
|
||||
self,
|
||||
dataset_name: str,
|
||||
description: str,
|
||||
license: Optional[str] = None,
|
||||
version: Optional[str] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
authors: Optional[List[str]] = None,
|
||||
size_category: Optional[List[str]] = None,
|
||||
language: Optional[List[str]] = None,
|
||||
task_categories: Optional[List[str]] = None,
|
||||
content: Optional[str] = None,
|
||||
) -> None:
|
||||
r"""Creates and uploads a dataset card to the Hugging Face Hub in YAML
|
||||
format.
|
||||
|
||||
Args:
|
||||
dataset_name (str): The name of the dataset.
|
||||
description (str): A description of the dataset.
|
||||
license (str): The license of the dataset. (default: :obj:`None`)
|
||||
version (str): The version of the dataset. (default: :obj:`None`)
|
||||
tags (list): A list of tags for the dataset.(default: :obj:`None`)
|
||||
authors (list): A list of authors of the dataset. (default:
|
||||
:obj:`None`)
|
||||
size_category (list): A size category for the dataset. (default:
|
||||
:obj:`None`)
|
||||
language (list): A list of languages the dataset is in. (default:
|
||||
:obj:`None`)
|
||||
task_categories (list): A list of task categories. (default:
|
||||
:obj:`None`)
|
||||
content (str): Custom markdown content that the user wants to add
|
||||
to the dataset card. (default: :obj:`None`)
|
||||
"""
|
||||
import yaml
|
||||
|
||||
metadata = {
|
||||
"license": license,
|
||||
"authors": authors,
|
||||
"task_categories": task_categories,
|
||||
"language": language,
|
||||
"tags": tags,
|
||||
"pretty_name": dataset_name,
|
||||
"size_categories": size_category,
|
||||
"version": version,
|
||||
"description": description,
|
||||
}
|
||||
|
||||
# Remove keys with None values
|
||||
metadata = {k: v for k, v in metadata.items() if v}
|
||||
|
||||
card_content = (
|
||||
"---\n"
|
||||
+ yaml.dump(metadata, default_flow_style=False, allow_unicode=True)
|
||||
+ "\n---"
|
||||
)
|
||||
|
||||
if content:
|
||||
card_content += f"\n\n# Additional Information\n{content}\n"
|
||||
|
||||
self._upload_file(
|
||||
file_content=card_content,
|
||||
dataset_name=dataset_name,
|
||||
filepath="README.md",
|
||||
file_type="md",
|
||||
)
|
||||
|
||||
def create_dataset(
|
||||
self, name: str, private: bool = False, **kwargs: Any
|
||||
) -> str:
|
||||
r"""Creates a new dataset on the Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
name (str): The name of the dataset.
|
||||
private (bool): Whether the dataset should be private. defaults to
|
||||
False.
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
str: The URL of the created dataset.
|
||||
"""
|
||||
from huggingface_hub.errors import RepositoryNotFoundError
|
||||
|
||||
try:
|
||||
self.api.repo_info(
|
||||
repo_id=name,
|
||||
repo_type=HuggingFaceRepoType.DATASET.value,
|
||||
**kwargs,
|
||||
)
|
||||
except RepositoryNotFoundError:
|
||||
self.api.create_repo(
|
||||
repo_id=name,
|
||||
repo_type=HuggingFaceRepoType.DATASET.value,
|
||||
private=private,
|
||||
)
|
||||
|
||||
return f"https://huggingface.co/datasets/{name}"
|
||||
|
||||
def list_datasets(
|
||||
self, username: str, limit: int = 100, **kwargs: Any
|
||||
) -> List[str]:
|
||||
r"""Lists all datasets for the current user.
|
||||
|
||||
Args:
|
||||
username (str): The username of the user whose datasets to list.
|
||||
limit (int): The maximum number of datasets to list.
|
||||
(default: :obj:`100`)
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of dataset ids.
|
||||
"""
|
||||
try:
|
||||
return [
|
||||
dataset.id
|
||||
for dataset in self.api.list_datasets(
|
||||
author=username, limit=limit, **kwargs
|
||||
)
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing datasets: {e}")
|
||||
return []
|
||||
|
||||
def delete_dataset(self, dataset_name: str, **kwargs: Any) -> None:
|
||||
r"""Deletes a dataset from the Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
dataset_name (str): The name of the dataset to delete.
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
"""
|
||||
try:
|
||||
self.api.delete_repo(
|
||||
repo_id=dataset_name,
|
||||
repo_type=HuggingFaceRepoType.DATASET.value,
|
||||
**kwargs,
|
||||
)
|
||||
logger.info(f"Dataset '{dataset_name}' deleted successfully.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting dataset '{dataset_name}': {e}")
|
||||
raise
|
||||
|
||||
def add_records(
|
||||
self,
|
||||
dataset_name: str,
|
||||
records: List[Record],
|
||||
filepath: str = "records/records.json",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
r"""Adds records to a dataset on the Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
dataset_name (str): The name of the dataset.
|
||||
records (List[Record]): A list of records to add to the dataset.
|
||||
filepath (str): The path to the file containing the records.
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
|
||||
Raises:
|
||||
ValueError: If the dataset already has a records file.
|
||||
"""
|
||||
existing_records = self._download_records(
|
||||
dataset_name=dataset_name, filepath=filepath, **kwargs
|
||||
)
|
||||
|
||||
if existing_records:
|
||||
raise ValueError(
|
||||
f"Dataset '{filepath}' already exists. "
|
||||
f"Use `update_records` to modify."
|
||||
)
|
||||
|
||||
self._upload_records(
|
||||
records=records,
|
||||
dataset_name=dataset_name,
|
||||
filepath=filepath,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def update_records(
|
||||
self,
|
||||
dataset_name: str,
|
||||
records: List[Record],
|
||||
filepath: str = "records/records.json",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
r"""Updates records in a dataset on the Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
dataset_name (str): The name of the dataset.
|
||||
records (List[Record]): A list of records to update in the dataset.
|
||||
filepath (str): The path to the file containing the records.
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
|
||||
Raises:
|
||||
ValueError: If the dataset does not have an existing file to update
|
||||
records in.
|
||||
"""
|
||||
existing_records = self._download_records(
|
||||
dataset_name=dataset_name, filepath=filepath, **kwargs
|
||||
)
|
||||
|
||||
if not existing_records:
|
||||
logger.warning(
|
||||
f"Dataset '{dataset_name}' does not have existing "
|
||||
"records. Adding new records."
|
||||
)
|
||||
self._upload_records(
|
||||
records=records,
|
||||
dataset_name=dataset_name,
|
||||
filepath=filepath,
|
||||
**kwargs,
|
||||
)
|
||||
return
|
||||
|
||||
old_dict = {record.id: record for record in existing_records}
|
||||
new_dict = {record.id: record for record in records}
|
||||
merged_dict = old_dict.copy()
|
||||
merged_dict.update(new_dict)
|
||||
|
||||
self._upload_records(
|
||||
records=list(merged_dict.values()),
|
||||
dataset_name=dataset_name,
|
||||
filepath=filepath,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def delete_record(
|
||||
self,
|
||||
dataset_name: str,
|
||||
record_id: str,
|
||||
filepath: str = "records/records.json",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
r"""Deletes a record from the dataset.
|
||||
|
||||
Args:
|
||||
dataset_name (str): The name of the dataset.
|
||||
record_id (str): The ID of the record to delete.
|
||||
filepath (str): The path to the file containing the records.
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
|
||||
Raises:
|
||||
ValueError: If the dataset does not have an existing file to delete
|
||||
records from.
|
||||
"""
|
||||
existing_records = self._download_records(
|
||||
dataset_name=dataset_name, filepath=filepath, **kwargs
|
||||
)
|
||||
|
||||
if not existing_records:
|
||||
raise ValueError(
|
||||
f"Dataset '{dataset_name}' does not have an existing file to "
|
||||
f"delete records from."
|
||||
)
|
||||
|
||||
filtered_records = [
|
||||
record for record in existing_records if record.id != record_id
|
||||
]
|
||||
|
||||
self._upload_records(
|
||||
records=filtered_records,
|
||||
dataset_name=dataset_name,
|
||||
filepath=filepath,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def list_records(
|
||||
self,
|
||||
dataset_name: str,
|
||||
filepath: str = "records/records.json",
|
||||
**kwargs: Any,
|
||||
) -> List[Record]:
|
||||
r"""Lists all records in a dataset.
|
||||
|
||||
Args:
|
||||
dataset_name (str): The name of the dataset.
|
||||
filepath (str): The path to the file containing the records.
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
List[Record]: A list of records in the dataset.
|
||||
"""
|
||||
return self._download_records(
|
||||
dataset_name=dataset_name, filepath=filepath, **kwargs
|
||||
)
|
||||
|
||||
def _download_records(
|
||||
self, dataset_name: str, filepath: str, **kwargs: Any
|
||||
) -> List[Record]:
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.errors import EntryNotFoundError
|
||||
|
||||
try:
|
||||
downloaded_file_path = hf_hub_download(
|
||||
repo_id=dataset_name,
|
||||
filename=filepath,
|
||||
repo_type=HuggingFaceRepoType.DATASET.value,
|
||||
token=self._api_key,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
with open(downloaded_file_path, "r") as f:
|
||||
records_data = json.load(f)
|
||||
|
||||
return [Record(**record) for record in records_data]
|
||||
except EntryNotFoundError:
|
||||
logger.info(f"No records found for dataset '{dataset_name}'.")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error downloading or processing records: {e}")
|
||||
raise e
|
||||
|
||||
def _upload_records(
|
||||
self,
|
||||
records: List[Record],
|
||||
dataset_name: str,
|
||||
filepath: str,
|
||||
**kwargs: Any,
|
||||
):
|
||||
with tempfile.NamedTemporaryFile(
|
||||
delete=False, mode="w", newline="", encoding="utf-8"
|
||||
) as f:
|
||||
json.dump([record.model_dump() for record in records], f)
|
||||
temp_file_path = f.name
|
||||
|
||||
try:
|
||||
self.api.upload_file(
|
||||
path_or_fileobj=temp_file_path,
|
||||
path_in_repo=filepath,
|
||||
repo_id=dataset_name,
|
||||
repo_type=HuggingFaceRepoType.DATASET.value,
|
||||
**kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error uploading records file: {e}")
|
||||
raise
|
||||
finally:
|
||||
if os.path.exists(temp_file_path):
|
||||
os.remove(temp_file_path)
|
||||
|
||||
def _upload_file(
|
||||
self,
|
||||
file_content: str,
|
||||
dataset_name: str,
|
||||
filepath: str,
|
||||
file_type: str = "json",
|
||||
**kwargs: Any,
|
||||
):
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", delete=False, suffix=f".{file_type}"
|
||||
) as f:
|
||||
if file_type == "json":
|
||||
if isinstance(file_content, str):
|
||||
try:
|
||||
json_content = json.loads(file_content)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(
|
||||
"Invalid JSON string provided for file_content."
|
||||
)
|
||||
else:
|
||||
try:
|
||||
json.dumps(file_content)
|
||||
json_content = file_content
|
||||
except (TypeError, ValueError):
|
||||
raise ValueError(
|
||||
"file_content is not JSON serializable."
|
||||
)
|
||||
|
||||
json.dump(json_content, f)
|
||||
elif file_type == "md" or file_type == "txt":
|
||||
f.write(file_content)
|
||||
else:
|
||||
raise ValueError(f"Unsupported file type: {file_type}")
|
||||
|
||||
temp_file_path = f.name
|
||||
|
||||
try:
|
||||
self.api.upload_file(
|
||||
path_or_fileobj=temp_file_path,
|
||||
path_in_repo=filepath,
|
||||
repo_id=dataset_name,
|
||||
repo_type=HuggingFaceRepoType.DATASET.value,
|
||||
**kwargs,
|
||||
)
|
||||
logger.info(f"File uploaded successfully: {filepath}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error uploading file: {e}")
|
||||
raise
|
||||
|
||||
if os.path.exists(temp_file_path):
|
||||
os.remove(temp_file_path)
|
||||
22
deep-swarm/camel/datahubs/models.py
Normal file
22
deep-swarm/camel/datahubs/models.py
Normal file
@ -0,0 +1,22 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Record(BaseModel):
|
||||
id: Optional[str] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
content: Dict[str, Any]
|
||||
28
deep-swarm/camel/embeddings/__init__.py
Normal file
28
deep-swarm/camel/embeddings/__init__.py
Normal file
@ -0,0 +1,28 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from .base import BaseEmbedding
|
||||
from .mistral_embedding import MistralEmbedding
|
||||
from .openai_compatible_embedding import OpenAICompatibleEmbedding
|
||||
from .openai_embedding import OpenAIEmbedding
|
||||
from .sentence_transformers_embeddings import SentenceTransformerEncoder
|
||||
from .vlm_embedding import VisionLanguageEmbedding
|
||||
|
||||
__all__ = [
|
||||
"BaseEmbedding",
|
||||
"OpenAIEmbedding",
|
||||
"SentenceTransformerEncoder",
|
||||
"VisionLanguageEmbedding",
|
||||
"MistralEmbedding",
|
||||
"OpenAICompatibleEmbedding",
|
||||
]
|
||||
BIN
deep-swarm/camel/embeddings/__pycache__/__init__.cpython-311.pyc
Normal file
BIN
deep-swarm/camel/embeddings/__pycache__/__init__.cpython-311.pyc
Normal file
Binary file not shown.
BIN
deep-swarm/camel/embeddings/__pycache__/base.cpython-311.pyc
Normal file
BIN
deep-swarm/camel/embeddings/__pycache__/base.cpython-311.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
67
deep-swarm/camel/embeddings/base.py
Normal file
67
deep-swarm/camel/embeddings/base.py
Normal file
@ -0,0 +1,67 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class BaseEmbedding(ABC, Generic[T]):
|
||||
r"""Abstract base class for text embedding functionalities."""
|
||||
|
||||
@abstractmethod
|
||||
def embed_list(
|
||||
self,
|
||||
objs: list[T],
|
||||
**kwargs: Any,
|
||||
) -> list[list[float]]:
|
||||
r"""Generates embeddings for the given texts.
|
||||
|
||||
Args:
|
||||
objs (list[T]): The objects for which to generate the embeddings.
|
||||
**kwargs (Any): Extra kwargs passed to the embedding API.
|
||||
|
||||
Returns:
|
||||
list[list[float]]: A list that represents the
|
||||
generated embedding as a list of floating-point numbers.
|
||||
"""
|
||||
pass
|
||||
|
||||
def embed(
|
||||
self,
|
||||
obj: T,
|
||||
**kwargs: Any,
|
||||
) -> list[float]:
|
||||
r"""Generates an embedding for the given text.
|
||||
|
||||
Args:
|
||||
obj (T): The object for which to generate the embedding.
|
||||
**kwargs (Any): Extra kwargs passed to the embedding API.
|
||||
|
||||
Returns:
|
||||
list[float]: A list of floating-point numbers representing the
|
||||
generated embedding.
|
||||
"""
|
||||
return self.embed_list([obj], **kwargs)[0]
|
||||
|
||||
@abstractmethod
|
||||
def get_output_dim(self) -> int:
|
||||
r"""Returns the output dimension of the embeddings.
|
||||
|
||||
Returns:
|
||||
int: The dimensionality of the embedding for the current model.
|
||||
"""
|
||||
pass
|
||||
89
deep-swarm/camel/embeddings/mistral_embedding.py
Normal file
89
deep-swarm/camel/embeddings/mistral_embedding.py
Normal file
@ -0,0 +1,89 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from camel.embeddings.base import BaseEmbedding
|
||||
from camel.types import EmbeddingModelType
|
||||
from camel.utils import api_keys_required
|
||||
|
||||
|
||||
class MistralEmbedding(BaseEmbedding[str]):
|
||||
r"""Provides text embedding functionalities using Mistral's models.
|
||||
|
||||
Args:
|
||||
model_type (EmbeddingModelType, optional): The model type to be
|
||||
used for text embeddings.
|
||||
(default: :obj:`MISTRAL_EMBED`)
|
||||
api_key (str, optional): The API key for authenticating with the
|
||||
Mistral service. (default: :obj:`None`)
|
||||
dimensions (int, optional): The text embedding output dimensions.
|
||||
(default: :obj:`None`)
|
||||
|
||||
Raises:
|
||||
RuntimeError: If an unsupported model type is specified.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_type: EmbeddingModelType = (EmbeddingModelType.MISTRAL_EMBED),
|
||||
api_key: str | None = None,
|
||||
dimensions: int | None = None,
|
||||
) -> None:
|
||||
from mistralai import Mistral
|
||||
|
||||
if not model_type.is_mistral:
|
||||
raise ValueError("Invalid Mistral embedding model type.")
|
||||
self.model_type = model_type
|
||||
if dimensions is None:
|
||||
self.output_dim = model_type.output_dim
|
||||
else:
|
||||
assert isinstance(dimensions, int)
|
||||
self.output_dim = dimensions
|
||||
self._api_key = api_key or os.environ.get("MISTRAL_API_KEY")
|
||||
self._client = Mistral(api_key=self._api_key)
|
||||
|
||||
@api_keys_required("MISTRAL_API_KEY")
|
||||
def embed_list(
|
||||
self,
|
||||
objs: list[str],
|
||||
**kwargs: Any,
|
||||
) -> list[list[float]]:
|
||||
r"""Generates embeddings for the given texts.
|
||||
|
||||
Args:
|
||||
objs (list[str]): The texts for which to generate the embeddings.
|
||||
**kwargs (Any): Extra kwargs passed to the embedding API.
|
||||
|
||||
Returns:
|
||||
list[list[float]]: A list that represents the generated embedding
|
||||
as a list of floating-point numbers.
|
||||
"""
|
||||
# TODO: count tokens
|
||||
response = self._client.embeddings.create(
|
||||
inputs=objs,
|
||||
model=self.model_type.value,
|
||||
**kwargs,
|
||||
)
|
||||
return [data.embedding for data in response.data] # type: ignore[misc,union-attr]
|
||||
|
||||
def get_output_dim(self) -> int:
|
||||
r"""Returns the output dimension of the embeddings.
|
||||
|
||||
Returns:
|
||||
int: The dimensionality of the embedding for the current model.
|
||||
"""
|
||||
return self.output_dim
|
||||
91
deep-swarm/camel/embeddings/openai_compatible_embedding.py
Normal file
91
deep-swarm/camel/embeddings/openai_compatible_embedding.py
Normal file
@ -0,0 +1,91 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from camel.embeddings.base import BaseEmbedding
|
||||
from camel.utils import api_keys_required
|
||||
|
||||
|
||||
class OpenAICompatibleEmbedding(BaseEmbedding[str]):
|
||||
r"""Provides text embedding functionalities supporting OpenAI
|
||||
compatibility.
|
||||
|
||||
Args:
|
||||
model_type (str): The model type to be used for text embeddings.
|
||||
api_key (str): The API key for authenticating with the model service.
|
||||
url (str): The url to the model service.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_type: str,
|
||||
api_key: Optional[str] = None,
|
||||
url: Optional[str] = None,
|
||||
) -> None:
|
||||
self.model_type = model_type
|
||||
self.output_dim: Optional[int] = None
|
||||
|
||||
self._api_key = api_key or os.environ.get(
|
||||
"OPENAI_COMPATIBILIY_API_KEY"
|
||||
)
|
||||
self._url = url or os.environ.get("OPENAI_COMPATIBILIY_API_BASE_URL")
|
||||
self._client = OpenAI(
|
||||
timeout=60,
|
||||
max_retries=3,
|
||||
api_key=self._api_key,
|
||||
base_url=self._url,
|
||||
)
|
||||
|
||||
@api_keys_required("OPENAI_COMPATIBILIY_API_KEY")
|
||||
def embed_list(
|
||||
self,
|
||||
objs: list[str],
|
||||
**kwargs: Any,
|
||||
) -> list[list[float]]:
|
||||
r"""Generates embeddings for the given texts.
|
||||
|
||||
Args:
|
||||
objs (list[str]): The texts for which to generate the embeddings.
|
||||
**kwargs (Any): Extra kwargs passed to the embedding API.
|
||||
|
||||
Returns:
|
||||
list[list[float]]: A list that represents the generated embedding
|
||||
as a list of floating-point numbers.
|
||||
"""
|
||||
|
||||
response = self._client.embeddings.create(
|
||||
input=objs,
|
||||
model=self.model_type,
|
||||
**kwargs,
|
||||
)
|
||||
self.output_dim = len(response.data[0].embedding)
|
||||
return [data.embedding for data in response.data]
|
||||
|
||||
def get_output_dim(self) -> int:
|
||||
r"""Returns the output dimension of the embeddings.
|
||||
|
||||
Returns:
|
||||
int: The dimensionality of the embedding for the current model.
|
||||
"""
|
||||
if self.output_dim is None:
|
||||
raise ValueError(
|
||||
"Output dimension is not yet determined. Call "
|
||||
"'embed_list' first."
|
||||
)
|
||||
return self.output_dim
|
||||
99
deep-swarm/camel/embeddings/openai_embedding.py
Normal file
99
deep-swarm/camel/embeddings/openai_embedding.py
Normal file
@ -0,0 +1,99 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from camel.embeddings.base import BaseEmbedding
|
||||
from camel.types import NOT_GIVEN, EmbeddingModelType, NotGiven
|
||||
from camel.utils import api_keys_required
|
||||
|
||||
|
||||
class OpenAIEmbedding(BaseEmbedding[str]):
|
||||
r"""Provides text embedding functionalities using OpenAI's models.
|
||||
|
||||
Args:
|
||||
model_type (EmbeddingModelType, optional): The model type to be
|
||||
used for text embeddings.
|
||||
(default: :obj:`TEXT_EMBEDDING_3_SMALL`)
|
||||
api_key (str, optional): The API key for authenticating with the
|
||||
OpenAI service. (default: :obj:`None`)
|
||||
dimensions (int, optional): The text embedding output dimensions.
|
||||
(default: :obj:`NOT_GIVEN`)
|
||||
|
||||
Raises:
|
||||
RuntimeError: If an unsupported model type is specified.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_type: EmbeddingModelType = (
|
||||
EmbeddingModelType.TEXT_EMBEDDING_3_SMALL
|
||||
),
|
||||
api_key: str | None = None,
|
||||
dimensions: int | NotGiven = NOT_GIVEN,
|
||||
) -> None:
|
||||
if not model_type.is_openai:
|
||||
raise ValueError("Invalid OpenAI embedding model type.")
|
||||
self.model_type = model_type
|
||||
if dimensions == NOT_GIVEN:
|
||||
self.output_dim = model_type.output_dim
|
||||
else:
|
||||
assert isinstance(dimensions, int)
|
||||
self.output_dim = dimensions
|
||||
self._api_key = api_key or os.environ.get("OPENAI_API_KEY")
|
||||
self.client = OpenAI(timeout=60, max_retries=3, api_key=self._api_key)
|
||||
|
||||
@api_keys_required("OPENAI_API_KEY")
|
||||
def embed_list(
|
||||
self,
|
||||
objs: list[str],
|
||||
**kwargs: Any,
|
||||
) -> list[list[float]]:
|
||||
r"""Generates embeddings for the given texts.
|
||||
|
||||
Args:
|
||||
objs (list[str]): The texts for which to generate the embeddings.
|
||||
**kwargs (Any): Extra kwargs passed to the embedding API.
|
||||
|
||||
Returns:
|
||||
list[list[float]]: A list that represents the generated embedding
|
||||
as a list of floating-point numbers.
|
||||
"""
|
||||
# TODO: count tokens
|
||||
if self.model_type == EmbeddingModelType.TEXT_EMBEDDING_ADA_2:
|
||||
response = self.client.embeddings.create(
|
||||
input=objs,
|
||||
model=self.model_type.value,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
response = self.client.embeddings.create(
|
||||
input=objs,
|
||||
model=self.model_type.value,
|
||||
dimensions=self.output_dim,
|
||||
**kwargs,
|
||||
)
|
||||
return [data.embedding for data in response.data]
|
||||
|
||||
def get_output_dim(self) -> int:
|
||||
r"""Returns the output dimension of the embeddings.
|
||||
|
||||
Returns:
|
||||
int: The dimensionality of the embedding for the current model.
|
||||
"""
|
||||
return self.output_dim
|
||||
@ -0,0 +1,80 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from numpy import ndarray
|
||||
|
||||
from camel.embeddings.base import BaseEmbedding
|
||||
|
||||
|
||||
class SentenceTransformerEncoder(BaseEmbedding[str]):
|
||||
r"""This class provides functionalities to generate text
|
||||
embeddings using `Sentence Transformers`.
|
||||
|
||||
References:
|
||||
https://www.sbert.net/
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "intfloat/e5-large-v2",
|
||||
**kwargs,
|
||||
):
|
||||
r"""Initializes the: obj: `SentenceTransformerEmbedding` class
|
||||
with the specified transformer model.
|
||||
|
||||
Args:
|
||||
model_name (str, optional): The name of the model to use.
|
||||
(default: :obj:`intfloat/e5-large-v2`)
|
||||
**kwargs (optional): Additional arguments of
|
||||
:class:`SentenceTransformer`, such as :obj:`prompts` etc.
|
||||
"""
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
self.model = SentenceTransformer(model_name, **kwargs)
|
||||
|
||||
def embed_list(
|
||||
self,
|
||||
objs: list[str],
|
||||
**kwargs: Any,
|
||||
) -> list[list[float]]:
|
||||
r"""Generates embeddings for the given texts using the model.
|
||||
|
||||
Args:
|
||||
objs (list[str]): The texts for which to generate the
|
||||
embeddings.
|
||||
|
||||
Returns:
|
||||
list[list[float]]: A list that represents the generated embedding
|
||||
as a list of floating-point numbers.
|
||||
"""
|
||||
if not objs:
|
||||
raise ValueError("Input text list is empty")
|
||||
embeddings = self.model.encode(
|
||||
objs, normalize_embeddings=True, **kwargs
|
||||
)
|
||||
assert isinstance(embeddings, ndarray)
|
||||
return embeddings.tolist()
|
||||
|
||||
def get_output_dim(self) -> int:
|
||||
r"""Returns the output dimension of the embeddings.
|
||||
|
||||
Returns:
|
||||
int: The dimensionality of the embeddings.
|
||||
"""
|
||||
output_dim = self.model.get_sentence_embedding_dimension()
|
||||
assert isinstance(output_dim, int)
|
||||
return output_dim
|
||||
149
deep-swarm/camel/embeddings/vlm_embedding.py
Normal file
149
deep-swarm/camel/embeddings/vlm_embedding.py
Normal file
@ -0,0 +1,149 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from camel.embeddings import BaseEmbedding
|
||||
from camel.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class VisionLanguageEmbedding(BaseEmbedding[Union[str, Image.Image]]):
|
||||
r"""Provides image embedding functionalities using multimodal model.
|
||||
|
||||
Args:
|
||||
model_name : The model type to be used for generating embeddings.
|
||||
And the default value is: obj:`openai/clip-vit-base-patch32`.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If an unsupported model type is specified.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, model_name: str = "openai/clip-vit-base-patch32"
|
||||
) -> None:
|
||||
r"""Initializes the: obj: `VisionLanguageEmbedding` class with a
|
||||
specified model and return the dimension of embeddings.
|
||||
|
||||
Args:
|
||||
model_name (str, optional): The version name of the model to use.
|
||||
(default: :obj:`openai/clip-vit-base-patch32`)
|
||||
"""
|
||||
from transformers import AutoModel, AutoProcessor
|
||||
|
||||
try:
|
||||
self.model = AutoModel.from_pretrained(model_name)
|
||||
self.processor = AutoProcessor.from_pretrained(model_name)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load model '{model_name}': {e}")
|
||||
|
||||
self.valid_processor_kwargs = []
|
||||
self.valid_model_kwargs = []
|
||||
|
||||
try:
|
||||
self.valid_processor_kwargs = (
|
||||
self.processor.image_processor._valid_processor_keys
|
||||
)
|
||||
self.valid_model_kwargs = [
|
||||
"pixel_values",
|
||||
"return_dict",
|
||||
"interpolate_pos_encoding",
|
||||
]
|
||||
except Exception:
|
||||
logger.warning("not typically processor and model structure")
|
||||
pass
|
||||
self.dim: Optional[int] = None
|
||||
|
||||
def embed_list(
|
||||
self, objs: List[Union[Image.Image, str]], **kwargs: Any
|
||||
) -> List[List[float]]:
|
||||
"""Generates embeddings for the given images or texts.
|
||||
|
||||
Args:
|
||||
objs (List[Image.Image|str]): The list of images or texts for
|
||||
which to generate the embeddings.
|
||||
image_processor_kwargs: Extra kwargs passed to the image processor.
|
||||
tokenizer_kwargs: Extra kwargs passed to the text tokenizer
|
||||
(processor).
|
||||
model_kwargs: Extra kwargs passed to the main model.
|
||||
|
||||
Returns:
|
||||
List[List[float]]: A list that represents the generated embedding
|
||||
as a list of floating-point numbers.
|
||||
|
||||
Raises:
|
||||
ValueError: If the input type is not `Image.Image` or `str`.
|
||||
"""
|
||||
if not objs:
|
||||
raise ValueError("Input objs list is empty.")
|
||||
|
||||
image_processor_kwargs: Optional[dict] = kwargs.get(
|
||||
'image_processor_kwargs', {}
|
||||
)
|
||||
tokenizer_kwargs: Optional[dict] = kwargs.get('tokenizer_kwargs', {})
|
||||
model_kwargs: Optional[dict] = kwargs.get('model_kwargs', {})
|
||||
|
||||
result_list = []
|
||||
for obj in objs:
|
||||
if isinstance(obj, Image.Image):
|
||||
image_input = self.processor(
|
||||
images=obj,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
**image_processor_kwargs,
|
||||
)
|
||||
image_feature = (
|
||||
self.model.get_image_features(
|
||||
**image_input, **model_kwargs
|
||||
)
|
||||
.squeeze(dim=0)
|
||||
.tolist()
|
||||
)
|
||||
result_list.append(image_feature)
|
||||
elif isinstance(obj, str):
|
||||
text_input = self.processor(
|
||||
text=obj,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
**tokenizer_kwargs,
|
||||
)
|
||||
text_feature = (
|
||||
self.model.get_text_features(**text_input, **model_kwargs)
|
||||
.squeeze(dim=0)
|
||||
.tolist()
|
||||
)
|
||||
result_list.append(text_feature)
|
||||
else:
|
||||
raise ValueError("Input type is not image nor text.")
|
||||
|
||||
self.dim = len(result_list[0])
|
||||
|
||||
if any(len(result) != self.dim for result in result_list):
|
||||
raise ValueError("Dimensionality is not consistent.")
|
||||
|
||||
return result_list
|
||||
|
||||
def get_output_dim(self) -> int:
|
||||
r"""Returns the output dimension of the embeddings.
|
||||
|
||||
Returns:
|
||||
int: The dimensionality of the embedding for the current model.
|
||||
"""
|
||||
if self.dim is None:
|
||||
text = 'dimension'
|
||||
inputs = self.processor(text=[text], return_tensors="pt")
|
||||
self.dim = self.model.get_text_features(**inputs).shape[1]
|
||||
return self.dim
|
||||
375
deep-swarm/camel/generators.py
Normal file
375
deep-swarm/camel/generators.py
Normal file
@ -0,0 +1,375 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from typing import Dict, Generator, List, Optional, Set, Tuple
|
||||
|
||||
from camel.messages import BaseMessage
|
||||
from camel.prompts import PromptTemplateGenerator, TextPrompt
|
||||
from camel.types import RoleType, TaskType
|
||||
|
||||
|
||||
class SystemMessageGenerator:
|
||||
r"""System message generator for agents.
|
||||
|
||||
Args:
|
||||
task_type (TaskType, optional): The task type.
|
||||
(default: :obj:`TaskType.AI_SOCIETY`)
|
||||
sys_prompts (Optional[Dict[RoleType, str]], optional): The prompts of
|
||||
the system messages for each role type. (default: :obj:`None`)
|
||||
sys_msg_meta_dict_keys (Optional[Set[str]], optional): The set of keys
|
||||
of the meta dictionary used to fill the prompts.
|
||||
(default: :obj:`None`)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task_type: TaskType = TaskType.AI_SOCIETY,
|
||||
sys_prompts: Optional[Dict[RoleType, str]] = None,
|
||||
sys_msg_meta_dict_keys: Optional[Set[str]] = None,
|
||||
) -> None:
|
||||
self.sys_prompts: Dict[RoleType, str]
|
||||
|
||||
if sys_prompts is not None:
|
||||
self.sys_prompts = sys_prompts
|
||||
self.sys_msg_meta_dict_keys = sys_msg_meta_dict_keys or set()
|
||||
else:
|
||||
assistant_prompt_template = (
|
||||
PromptTemplateGenerator().get_system_prompt(
|
||||
task_type,
|
||||
RoleType.ASSISTANT,
|
||||
)
|
||||
)
|
||||
user_prompt_template = PromptTemplateGenerator().get_system_prompt(
|
||||
task_type,
|
||||
RoleType.USER,
|
||||
)
|
||||
critic_prompt_template = (
|
||||
PromptTemplateGenerator().get_system_prompt(
|
||||
task_type,
|
||||
RoleType.CRITIC,
|
||||
)
|
||||
)
|
||||
embodiment_prompt_template = (
|
||||
PromptTemplateGenerator().get_system_prompt(
|
||||
task_type,
|
||||
RoleType.EMBODIMENT,
|
||||
)
|
||||
)
|
||||
|
||||
self.sys_prompts = dict()
|
||||
self.sys_prompts[RoleType.ASSISTANT] = assistant_prompt_template
|
||||
self.sys_prompts[RoleType.USER] = user_prompt_template
|
||||
self.sys_prompts[RoleType.CRITIC] = critic_prompt_template
|
||||
self.sys_prompts[RoleType.EMBODIMENT] = embodiment_prompt_template
|
||||
|
||||
self.sys_msg_meta_dict_keys = (
|
||||
assistant_prompt_template.key_words
|
||||
| user_prompt_template.key_words
|
||||
| critic_prompt_template.key_words
|
||||
| embodiment_prompt_template.key_words
|
||||
)
|
||||
|
||||
if RoleType.DEFAULT not in self.sys_prompts:
|
||||
self.sys_prompts[RoleType.DEFAULT] = "You are a helpful assistant."
|
||||
|
||||
def validate_meta_dict_keys(self, meta_dict: Dict[str, str]) -> None:
|
||||
r"""Validates the keys of the meta_dict.
|
||||
|
||||
Args:
|
||||
meta_dict (Dict[str, str]): The dictionary to validate.
|
||||
"""
|
||||
if not set(meta_dict.keys()).issubset(self.sys_msg_meta_dict_keys):
|
||||
raise ValueError(
|
||||
"The keys of the meta_dict should be in "
|
||||
f"{self.sys_msg_meta_dict_keys}. "
|
||||
f"Got {set(meta_dict.keys())} instead."
|
||||
)
|
||||
|
||||
def from_dict(
|
||||
self,
|
||||
meta_dict: Dict[str, str],
|
||||
role_tuple: Tuple[str, RoleType] = ("", RoleType.DEFAULT),
|
||||
) -> BaseMessage:
|
||||
r"""Generates a system message from a dictionary.
|
||||
|
||||
Args:
|
||||
meta_dict (Dict[str, str]): The dictionary containing the
|
||||
information to generate the system message.
|
||||
role_tuple (Tuple[str, RoleType], optional): The tuple containing
|
||||
the role name and role type. (default: ("", RoleType.DEFAULT))
|
||||
|
||||
Returns:
|
||||
BaseMessage: The generated system message.
|
||||
"""
|
||||
self.validate_meta_dict_keys(meta_dict)
|
||||
role_name, role_type = role_tuple
|
||||
sys_prompt = self.sys_prompts[role_type]
|
||||
sys_prompt = sys_prompt.format(**meta_dict)
|
||||
return BaseMessage(
|
||||
role_name=role_name,
|
||||
role_type=role_type,
|
||||
meta_dict=meta_dict,
|
||||
content=sys_prompt,
|
||||
)
|
||||
|
||||
def from_dicts(
|
||||
self,
|
||||
meta_dicts: List[Dict[str, str]],
|
||||
role_tuples: List[Tuple[str, RoleType]],
|
||||
) -> List[BaseMessage]:
|
||||
r"""Generates a list of system messages from a list of dictionaries.
|
||||
|
||||
Args:
|
||||
meta_dicts (List[Dict[str, str]]): A list of dictionaries
|
||||
containing the information to generate the system messages.
|
||||
role_tuples (List[Tuple[str, RoleType]]): A list of tuples
|
||||
containing the role name and role type for each system message.
|
||||
|
||||
Returns:
|
||||
List[BaseMessage]: A list of generated system messages.
|
||||
|
||||
Raises:
|
||||
ValueError: If the number of meta_dicts and role_tuples are
|
||||
different.
|
||||
"""
|
||||
if len(meta_dicts) != len(role_tuples):
|
||||
raise ValueError(
|
||||
"The number of meta_dicts and role_types should be the same."
|
||||
)
|
||||
|
||||
return [
|
||||
self.from_dict(meta_dict, role_tuple)
|
||||
for meta_dict, role_tuple in zip(meta_dicts, role_tuples)
|
||||
]
|
||||
|
||||
|
||||
class RoleNameGenerator:
|
||||
r"""Role name generator for role-playing workers.
|
||||
|
||||
Args:
|
||||
assistant_role_names_path (str, optional): The path to the file
|
||||
containing the assistant role names.
|
||||
(default: :obj:`"data/ai_society/assistant_roles.txt"`)
|
||||
user_role_names_path (str, optional): The path to the file
|
||||
containing the user role names.
|
||||
(default: :obj:`"data/ai_society/user_roles.txt"`)
|
||||
assistant_role_names (Optional[List[str]], optional): The list of
|
||||
assistant role names. (default: :obj:`None`)
|
||||
user_role_names (Optional[List[str]], optional): The list of user role
|
||||
names. (default: :obj:`None`)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
assistant_role_names_path: str = "data/ai_society/assistant_roles.txt",
|
||||
user_role_names_path: str = "data/ai_society/user_roles.txt",
|
||||
assistant_role_names: Optional[List[str]] = None,
|
||||
user_role_names: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
if assistant_role_names is None:
|
||||
with open(assistant_role_names_path, "r") as f:
|
||||
assistant_role_names_: List[str] = f.read().splitlines()
|
||||
self.assistant_role_names = [
|
||||
" ".join(name.split(" ")[1:])
|
||||
for name in assistant_role_names_
|
||||
]
|
||||
else:
|
||||
self.assistant_role_names = assistant_role_names
|
||||
|
||||
if user_role_names is None:
|
||||
with open(user_role_names_path, "r") as f:
|
||||
user_role_names_: List[str] = f.read().splitlines()
|
||||
self.user_role_names = [
|
||||
" ".join(name.split(" ")[1:]) for name in user_role_names_
|
||||
]
|
||||
else:
|
||||
self.user_role_names = user_role_names
|
||||
|
||||
def from_role_files(self) -> Generator[Tuple, None, None]:
|
||||
r"""Generate role names from the file.
|
||||
|
||||
Returns:
|
||||
Generator[Tuple, None, None]: A generator that yields tuples of
|
||||
assistant role names and user role names.
|
||||
"""
|
||||
for assistant_role_name in self.assistant_role_names:
|
||||
for user_role_name in self.user_role_names:
|
||||
yield (assistant_role_name, user_role_name)
|
||||
|
||||
|
||||
class AISocietyTaskPromptGenerator:
|
||||
r"""Task prompt generator for AI society tasks.
|
||||
|
||||
Args:
|
||||
num_tasks (int, optional): The number of tasks to generate.
|
||||
(default: :obj:`10`)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_tasks: int = 10,
|
||||
) -> None:
|
||||
self.generate_tasks_prompt = (
|
||||
PromptTemplateGenerator().get_generate_tasks_prompt(
|
||||
TaskType.AI_SOCIETY
|
||||
)
|
||||
)
|
||||
|
||||
self.num_tasks = num_tasks
|
||||
|
||||
# TODO: Return role names for user and assistant with the generator.
|
||||
def from_role_files(
|
||||
self,
|
||||
assistant_role_names_path: str = "data/ai_society/assistant_roles.txt",
|
||||
user_role_names_path: str = "data/ai_society/user_roles.txt",
|
||||
) -> Generator[Tuple[str, Tuple[str, str]], None, None]:
|
||||
r"""Generate tasks from role files.
|
||||
|
||||
Args:
|
||||
assistant_role_names_path (str, optional): The path to the file
|
||||
containing the assistant role names.
|
||||
(default: :obj:`"data/ai_society/assistant_roles.txt"`)
|
||||
user_role_names_path (str, optional): The path to the file
|
||||
containing the user role names.
|
||||
(default: :obj:`"data/ai_society/user_roles.txt"`)
|
||||
|
||||
Returns:
|
||||
Generator[Tuple[str, Tuple[str, str]], None, None]: A generator
|
||||
that yields tuples of task prompts and role names.
|
||||
"""
|
||||
roles_generator = RoleNameGenerator(
|
||||
assistant_role_names_path, user_role_names_path
|
||||
).from_role_files()
|
||||
for role_1, role_2 in roles_generator:
|
||||
generate_tasks_prompt = self.generate_tasks_prompt.format(
|
||||
assistant_role=role_1,
|
||||
user_role=role_2,
|
||||
num_tasks=self.num_tasks,
|
||||
)
|
||||
|
||||
yield (generate_tasks_prompt, (role_1, role_2))
|
||||
|
||||
def from_role_generator(
|
||||
self, role_generator: Generator[Tuple, None, None]
|
||||
) -> Generator[Tuple[str, Tuple[str, str]], None, None]:
|
||||
r"""Generate tasks from a role generator.
|
||||
|
||||
Args:
|
||||
role_generator (Generator[Tuple, None, None]): A generator that
|
||||
yields tuples of role names.
|
||||
|
||||
Returns:
|
||||
Generator[Tuple[str, Tuple[str, str]], None, None]: A generator
|
||||
that yields tuples of task prompts and role names.
|
||||
"""
|
||||
for role_1, role_2 in role_generator:
|
||||
generate_tasks_prompt = self.generate_tasks_prompt.format(
|
||||
assistant_role=role_1,
|
||||
user_role=role_2,
|
||||
num_tasks=self.num_tasks,
|
||||
)
|
||||
|
||||
yield (generate_tasks_prompt, (role_1, role_2))
|
||||
|
||||
|
||||
class SingleTxtGenerator:
|
||||
r"""Single text generator for role-playing workers.
|
||||
|
||||
Args:
|
||||
text_file_path (str): The path to the file containing the text data.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_file_path: str,
|
||||
) -> None:
|
||||
with open(text_file_path, "r") as f:
|
||||
data_list: List[str] = f.read().splitlines()
|
||||
self.data_list = [
|
||||
" ".join(name.split(" ")[1:]) for name in data_list
|
||||
]
|
||||
|
||||
def from_role_files(self) -> Generator[str, None, None]:
|
||||
r"""Generate text from the file.
|
||||
|
||||
Returns:
|
||||
Generator[str, None, None]: A generator that yields the text data.
|
||||
"""
|
||||
for data in self.data_list:
|
||||
yield data
|
||||
|
||||
|
||||
class CodeTaskPromptGenerator:
|
||||
r"""Code task prompt generator for code tasks.
|
||||
|
||||
Args:
|
||||
num_tasks (int, optional): The number of tasks to generate.
|
||||
(default: :obj:`50`)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_tasks: int = 50,
|
||||
) -> None:
|
||||
self.generate_tasks_prompt = (
|
||||
PromptTemplateGenerator().get_generate_tasks_prompt(TaskType.CODE)
|
||||
)
|
||||
|
||||
self.num_tasks = num_tasks
|
||||
|
||||
def from_role_files(
|
||||
self,
|
||||
languages_path: str = "data/code/languages.txt",
|
||||
domains_path: str = "data/code/domains.txt",
|
||||
) -> Generator[Tuple[TextPrompt, str, str], None, None]:
|
||||
r"""Generate tasks from role files.
|
||||
|
||||
Args:
|
||||
languages_path (str, optional): The path to the file containing
|
||||
the language names. (default: :obj:`"data/code/languages.txt"`)
|
||||
domains_path (str, optional): The path to the file containing
|
||||
the domain names. (default: :obj:`"data/code/domains.txt"`)
|
||||
|
||||
Returns:
|
||||
Generator[Tuple[TextPrompt, str, str], None, None]: A generator
|
||||
that yields tuples of task prompts, language names, and domain
|
||||
names.
|
||||
"""
|
||||
language_generator = SingleTxtGenerator(
|
||||
languages_path
|
||||
).from_role_files()
|
||||
|
||||
for language in language_generator:
|
||||
domains_generator = SingleTxtGenerator(
|
||||
domains_path
|
||||
).from_role_files()
|
||||
for domain in domains_generator:
|
||||
generated_tasks_prompt = self.generate_tasks_prompt.format(
|
||||
language=language, domain=domain, num_tasks=self.num_tasks
|
||||
)
|
||||
yield generated_tasks_prompt, language, domain
|
||||
|
||||
def from_role_generator(
|
||||
self, role_generator: Generator[Tuple, None, None]
|
||||
) -> Generator[str, None, None]:
|
||||
r"""Generate tasks from a role generator.
|
||||
|
||||
Args:
|
||||
role_generator (Generator[Tuple, None, None]): A generator that
|
||||
yields tuples of role names.
|
||||
|
||||
Returns:
|
||||
Generator[str, None, None]: A generator that yields the task
|
||||
prompts.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user