mirror of
https://github.com/camel-ai/owl.git
synced 2026-03-22 14:07:17 +08:00
fix token count error in use case
This commit is contained in:
@@ -13,11 +13,9 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
import asyncio
|
||||
import sys
|
||||
import contextlib
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any
|
||||
import os
|
||||
|
||||
from colorama import Fore, init
|
||||
from dotenv import load_dotenv
|
||||
@@ -26,7 +24,6 @@ from camel.agents.chat_agent import ToolCallingRecord
|
||||
from camel.models import ModelFactory
|
||||
from camel.toolkits import FunctionTool, MCPToolkit
|
||||
from camel.types import ModelPlatformType, ModelType
|
||||
from camel.logger import set_log_level
|
||||
|
||||
from owl.utils.enhanced_role_playing import OwlRolePlaying
|
||||
|
||||
@@ -39,8 +36,6 @@ base_dir = pathlib.Path(__file__).parent.parent
|
||||
env_path = base_dir / "owl" / ".env"
|
||||
load_dotenv(dotenv_path=str(env_path))
|
||||
|
||||
set_log_level(level="INFO")
|
||||
|
||||
|
||||
async def construct_society(
|
||||
question: str,
|
||||
@@ -146,9 +141,6 @@ def write_to_md(filename: str, content: Dict[str, Any]) -> None:
|
||||
if "summary" in content:
|
||||
f.write(f"## Summary\n\n")
|
||||
f.write(f"{content['summary']}\n\n")
|
||||
|
||||
if "token_count" in content:
|
||||
f.write(f"**Total tokens used:** {content['token_count']}\n\n")
|
||||
|
||||
|
||||
async def run_society_with_formatted_output(society: OwlRolePlaying, md_filename: str, round_limit: int = 15):
|
||||
@@ -182,12 +174,25 @@ async def run_society_with_formatted_output(society: OwlRolePlaying, md_filename
|
||||
|
||||
input_msg = society.init_chat()
|
||||
chat_history = []
|
||||
token_count = {"total": 0}
|
||||
overall_completion_token_count = 0
|
||||
overall_prompt_token_count = 0
|
||||
n = 0
|
||||
|
||||
while n < round_limit:
|
||||
n += 1
|
||||
assistant_response, user_response = await society.astep(input_msg)
|
||||
|
||||
overall_completion_token_count += assistant_response.info["usage"].get(
|
||||
"completion_tokens", 0
|
||||
) + user_response.info["usage"].get("completion_tokens", 0)
|
||||
overall_prompt_token_count += assistant_response.info["usage"].get(
|
||||
"prompt_tokens", 0
|
||||
) + user_response.info["usage"].get("prompt_tokens", 0)
|
||||
|
||||
token_info = {
|
||||
"completion_token_count": overall_completion_token_count,
|
||||
"prompt_token_count": overall_prompt_token_count,
|
||||
}
|
||||
|
||||
md_content = {}
|
||||
|
||||
@@ -239,10 +244,6 @@ async def run_society_with_formatted_output(society: OwlRolePlaying, md_filename
|
||||
"user": user_response.msg.content,
|
||||
})
|
||||
|
||||
# Update token count
|
||||
if "token_count" in assistant_response.info:
|
||||
token_count["total"] += assistant_response.info["token_count"]
|
||||
|
||||
if "TASK_DONE" in user_response.msg.content:
|
||||
task_done_msg = "Task completed successfully!"
|
||||
print(Fore.YELLOW + task_done_msg + "\n")
|
||||
@@ -252,12 +253,12 @@ async def run_society_with_formatted_output(society: OwlRolePlaying, md_filename
|
||||
input_msg = assistant_response.msg
|
||||
|
||||
# Write token count information
|
||||
write_to_md(md_filename, {"token_count": token_count["total"]})
|
||||
write_to_md(md_filename, token_info)
|
||||
|
||||
# Extract final answer
|
||||
answer = assistant_response.msg.content if assistant_response and assistant_response.msg else ""
|
||||
|
||||
return answer, chat_history, token_count
|
||||
return answer, chat_history, token_info
|
||||
|
||||
|
||||
async def main():
|
||||
@@ -278,7 +279,7 @@ async def main():
|
||||
# Use command line argument if provided, otherwise use default task
|
||||
task = sys.argv[1] if len(sys.argv) > 1 else default_task
|
||||
|
||||
mcp_toolkit = MCPToolkit(config_path=str(config_path))
|
||||
mcp_toolkit = MCPToolkit(config_path=str(config_path), strict=True)
|
||||
|
||||
try:
|
||||
# Create markdown file for conversation export
|
||||
@@ -294,10 +295,10 @@ async def main():
|
||||
# Build and run society
|
||||
print(Fore.YELLOW + f"Starting task: {task}\n")
|
||||
society = await construct_society(task, tools)
|
||||
answer, chat_history, token_count = await run_society_with_formatted_output(society, md_filename)
|
||||
answer, chat_history, token_info = await run_society_with_formatted_output(society, md_filename)
|
||||
|
||||
print(Fore.GREEN + f"\nFinal Result: {answer}")
|
||||
print(Fore.CYAN + f"Total tokens used: {token_count['total']}")
|
||||
print(Fore.CYAN + f"Total tokens used: {token_info}")
|
||||
print(Fore.CYAN + f"Full conversation log saved to: {md_filename}")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
|
||||
Reference in New Issue
Block a user