mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Add GPTSwarm (Graph-based Workflow) (#2460)
* update * revert the main branch lock file * regenerate the poetry.lock * move poetry into another dependency group * fix infer.sh and gpt code --------- Co-authored-by: yufansong <yufan@risingwave-labs.com>
This commit is contained in:
parent
9b1f59a56e
commit
0277311fd6
@ -14,6 +14,7 @@ from . import ( # noqa: E402
|
||||
codeact_swe_agent,
|
||||
delegator_agent,
|
||||
dummy_agent,
|
||||
gptswarm_agent,
|
||||
monologue_agent,
|
||||
planner_agent,
|
||||
)
|
||||
@ -21,6 +22,7 @@ from . import ( # noqa: E402
|
||||
__all__ = [
|
||||
'monologue_agent',
|
||||
'codeact_agent',
|
||||
'gptswarm_agent',
|
||||
'codeact_swe_agent',
|
||||
'planner_agent',
|
||||
'delegator_agent',
|
||||
|
||||
16
agenthub/gptswarm_agent/README.md
Normal file
16
agenthub/gptswarm_agent/README.md
Normal file
@ -0,0 +1,16 @@
|
||||
# GPTSwarm Framework
|
||||
|
||||
## Introduction
|
||||
|
||||
This folder implements the GPTSwarm ([paper](https://arxiv.org/abs/2402.01030), [Original Repo](https://github.com/metauto-ai/GPTSwarm)). For more details, please see paper.
|
||||
|
||||
|
||||
## Reference
|
||||
```
|
||||
@article{zhuge2024language,
|
||||
title={Language Agents as Optimizable Graphs},
|
||||
author={Zhuge, Mingchen and Wang, Wenyi and Kirsch, Louis and Faccio, Francesco and Khizbullin, Dmitrii and Schmidhuber, Jurgen},
|
||||
journal={arXiv preprint arXiv:2402.16823},
|
||||
year={2024}
|
||||
}
|
||||
```
|
||||
5
agenthub/gptswarm_agent/__init__.py
Normal file
5
agenthub/gptswarm_agent/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
from opendevin.controller.agent import Agent
|
||||
|
||||
from .gptswarm_agent import GPTSwarm
|
||||
|
||||
Agent.register('GPTSwarmAgent', GPTSwarm)
|
||||
196
agenthub/gptswarm_agent/gptswarm_agent.py
Normal file
196
agenthub/gptswarm_agent/gptswarm_agent.py
Normal file
@ -0,0 +1,196 @@
|
||||
import asyncio
|
||||
import dataclasses
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Literal
|
||||
|
||||
from agenthub.gptswarm_agent.gptswarm_graph import AssistantGraph
|
||||
from agenthub.gptswarm_agent.prompt import GPTSwarmPromptSet
|
||||
from opendevin.controller.agent import Agent
|
||||
from opendevin.controller.state.state import State
|
||||
from opendevin.core.logger import opendevin_logger as logger
|
||||
from opendevin.events.action import Action
|
||||
from opendevin.llm.llm import LLM
|
||||
|
||||
ENABLE_GITHUB = True
|
||||
OPENAI_API_KEY = 'sk-proj-****' # TODO: get from environment or config
|
||||
|
||||
|
||||
MessageRole = Literal['system', 'user', 'assistant']
|
||||
|
||||
|
||||
@dataclasses.dataclass()
|
||||
class Message:
|
||||
role: MessageRole
|
||||
content: str
|
||||
|
||||
|
||||
class GPTSwarm(Agent):
|
||||
VERSION = '1.0'
|
||||
"""
|
||||
This is simple revision of GPTSwarm which serve as an assistant agent.
|
||||
|
||||
GPTSwarm Paper: https://arxiv.org/abs/2402.16823 (ICML 2024, Oral Presentation)
|
||||
GPTSwarm Code: https://github.com/metauto-ai/GPTSwarm
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: LLM,
|
||||
model_name: str,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes a new instance of the GPTSwarm class.
|
||||
|
||||
Parameters:
|
||||
- llm (LLM): The llm to be used by this agent
|
||||
"""
|
||||
super().__init__(llm)
|
||||
self.api_key = OPENAI_API_KEY
|
||||
self.llm = LLM(model=model_name, api_key=self.api_key)
|
||||
self.graph = AssistantGraph(domain='gaia', model_name=model_name)
|
||||
self.prompt_set = GPTSwarmPromptSet()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""
|
||||
Resets the GPTSwarm Agent.
|
||||
"""
|
||||
super().reset()
|
||||
|
||||
def step(self, state: State) -> Action:
|
||||
"""
|
||||
# TODO: It is stateless now. Find a way to make it stateful.
|
||||
# NOTE: For the AI assistant, state-based design may introduce more uncertainties.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def swarm_run(self, inputs: List[Dict[str, Any]], num_agents=3) -> List[str]:
|
||||
"""
|
||||
Run the `run` method of this agent concurrently for `num_agents` times.
|
||||
# NOTE: This is just a simple self-consistency.
|
||||
# TODO: should follow original GPTSwarm's graph design to revise.
|
||||
"""
|
||||
|
||||
async def run_single_agent(index):
|
||||
try:
|
||||
result = await asyncio.wait_for(self.run(inputs=inputs), timeout=200)
|
||||
print('-----------------------------------')
|
||||
print(f'No. {index} Agent complete task..')
|
||||
logger.info(result[0])
|
||||
print('-----------------------------------')
|
||||
return result[0]
|
||||
except asyncio.TimeoutError:
|
||||
print(f'No. {index} Agent timed out.')
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f'No. {index} Agent resulted in an error: {e}')
|
||||
return None
|
||||
|
||||
# Create a list of tasks to run concurrently
|
||||
tasks = [run_single_agent(i) for i in range(num_agents)]
|
||||
|
||||
# Run all tasks concurrently and gather the results
|
||||
agent_answers = await asyncio.gather(*tasks)
|
||||
|
||||
# Filter out None results (from timeouts or errors)
|
||||
agent_answers = [answer for answer in agent_answers if answer is not None]
|
||||
|
||||
task = inputs[0]['task']
|
||||
prompt = self.prompt_set.get_self_consistency(
|
||||
question=task,
|
||||
answers=agent_answers,
|
||||
constraint=self.prompt_set.get_constraint(),
|
||||
)
|
||||
messages = [
|
||||
Message(role='system', content=f'You are a {self.prompt_set.get_role()}.'),
|
||||
Message(role='user', content=prompt),
|
||||
]
|
||||
|
||||
swarm_ans = self.llm.completion(
|
||||
messages=[{'role': msg.role, 'content': msg.content} for msg in messages]
|
||||
)
|
||||
swarm_ans = swarm_ans.choices[0].message.content
|
||||
return [swarm_ans]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
inputs: List[Dict[str, Any]],
|
||||
max_tries: int = 3,
|
||||
max_time: int = 600,
|
||||
return_all_outputs: bool = False,
|
||||
) -> List[Any]:
|
||||
def is_node_useful(node):
|
||||
if node in self.graph.output_nodes:
|
||||
return True
|
||||
|
||||
for successor in node.successors:
|
||||
if is_node_useful(successor):
|
||||
return True
|
||||
return False
|
||||
|
||||
useful_node_ids = [
|
||||
node_id
|
||||
for node_id, node in self.graph.nodes.items()
|
||||
if is_node_useful(node)
|
||||
]
|
||||
in_degree = {
|
||||
node_id: len(self.graph.nodes[node_id].predecessors)
|
||||
for node_id in useful_node_ids
|
||||
}
|
||||
zero_in_degree_queue = [
|
||||
node_id
|
||||
for node_id, deg in in_degree.items()
|
||||
if deg == 0 and node_id in useful_node_ids
|
||||
]
|
||||
|
||||
for i, input_node in enumerate(self.graph.input_nodes):
|
||||
node_input = deepcopy(inputs)
|
||||
input_node.inputs = [node_input]
|
||||
|
||||
while zero_in_degree_queue:
|
||||
current_node_id = zero_in_degree_queue.pop(0)
|
||||
current_node = self.graph.nodes[current_node_id]
|
||||
tries = 0
|
||||
while tries < max_tries:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self.graph.nodes[current_node_id].execute(), timeout=max_time
|
||||
)
|
||||
# TODO: make GPTSwarm stateful in OpenDevin.
|
||||
# State.inputs = self.graph.nodes[current_node_id].inputs
|
||||
# State.outputs = self.graph.nodes[current_node_id].outputs
|
||||
# self.step(State)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
print(
|
||||
f'Node {current_node_id} execution timed out, retrying {tries + 1} out of {max_tries}...'
|
||||
)
|
||||
except Exception as e:
|
||||
print(f'Error during execution of node {current_node_id}: {e}')
|
||||
break
|
||||
tries += 1
|
||||
|
||||
for successor in current_node.successors:
|
||||
if successor.id in useful_node_ids:
|
||||
in_degree[successor.id] -= 1
|
||||
if in_degree[successor.id] == 0:
|
||||
zero_in_degree_queue.append(successor.id)
|
||||
|
||||
final_answers = []
|
||||
|
||||
for output_node in self.graph.output_nodes:
|
||||
output_messages = output_node.outputs
|
||||
|
||||
if len(output_messages) > 0 and not return_all_outputs:
|
||||
final_answer = output_messages[-1].get('output', output_messages[-1])
|
||||
final_answers.append(final_answer)
|
||||
else:
|
||||
for output_message in output_messages:
|
||||
final_answer = output_message.get('output', output_message)
|
||||
final_answers.append(final_answer)
|
||||
|
||||
if len(final_answers) == 0:
|
||||
final_answers.append('No answer since there are no inputs provided')
|
||||
return final_answers
|
||||
|
||||
def search_memory(self, query: str) -> list[str]:
|
||||
raise NotImplementedError('Implement this abstract method')
|
||||
520
agenthub/gptswarm_agent/gptswarm_graph.py
Normal file
520
agenthub/gptswarm_agent/gptswarm_graph.py
Normal file
@ -0,0 +1,520 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import os
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Literal, Optional
|
||||
|
||||
import requests
|
||||
from pytube import YouTube
|
||||
from swarm.graph import Graph, Node
|
||||
|
||||
from agenthub.gptswarm_agent.prompt import GPTSwarmPromptSet
|
||||
from opendevin.core.logger import opendevin_logger as logger
|
||||
from opendevin.llm.llm import LLM
|
||||
from opendevin.runtime.plugins.agent_skills.agentskills import (
|
||||
parse_audio,
|
||||
parse_docx,
|
||||
parse_image,
|
||||
parse_latex,
|
||||
parse_pdf,
|
||||
parse_pptx,
|
||||
parse_txt,
|
||||
parse_video,
|
||||
)
|
||||
|
||||
OPENAI_API_KEY = 'sk-proj-****' # TODO: get from environment or config
|
||||
SEARCHAPI_API_KEY = '****' # TODO: get from environment or config
|
||||
|
||||
MessageRole = Literal['system', 'user', 'assistant']
|
||||
|
||||
|
||||
@dataclasses.dataclass()
|
||||
class Message:
|
||||
role: MessageRole
|
||||
content: str
|
||||
|
||||
|
||||
READER_MAP = {
|
||||
'.png': parse_image,
|
||||
'.jpg': parse_image,
|
||||
'.jpeg': parse_image,
|
||||
'.gif': parse_image,
|
||||
'.bmp': parse_image,
|
||||
'.tiff': parse_image,
|
||||
'.tif': parse_image,
|
||||
'.webp': parse_image,
|
||||
'.mp3': parse_audio,
|
||||
'.m4a': parse_audio,
|
||||
'.wav': parse_audio,
|
||||
'.MOV': parse_video,
|
||||
'.mp4': parse_video,
|
||||
'.mov': parse_video,
|
||||
'.avi': parse_video,
|
||||
'.mpg': parse_video,
|
||||
'.mpeg': parse_video,
|
||||
'.wmv': parse_video,
|
||||
'.flv': parse_video,
|
||||
'.webm': parse_video,
|
||||
'.pptx': parse_pptx,
|
||||
'.pdf': parse_pdf,
|
||||
'.docx': parse_docx,
|
||||
'.tex': parse_latex,
|
||||
'.txt': parse_txt,
|
||||
}
|
||||
|
||||
|
||||
class FileReader:
|
||||
def __init__(self):
|
||||
self.reader = None # Initial type is None
|
||||
|
||||
def set_reader(self, suffix: str):
|
||||
reader = READER_MAP.get(suffix)
|
||||
if reader is not None:
|
||||
self.reader = reader
|
||||
logger.info(f'Setting Reader to {self.reader.__name__}')
|
||||
else:
|
||||
logger.error(f'No reader found for suffix {suffix}')
|
||||
self.reader = None
|
||||
|
||||
def read_file(self, file_path: Path, task: str = 'describe the file') -> str:
|
||||
suffix = file_path.suffix
|
||||
self.set_reader(suffix)
|
||||
if not self.reader:
|
||||
raise ValueError(f'No reader set for suffix {suffix}')
|
||||
if self.reader in [parse_image, parse_video]:
|
||||
file_content = self.reader(file_path, task)
|
||||
else:
|
||||
file_content = self.reader(file_path)
|
||||
logger.info(f'Reading file {file_path} using {self.reader.__name__}')
|
||||
return file_content
|
||||
|
||||
|
||||
class GenerateQuery(Node):
|
||||
def __init__(
|
||||
self,
|
||||
domain: str = 'gaia',
|
||||
model_name: Optional[str] = 'gpt-4o-2024-05-13',
|
||||
operation_description: str = 'Given a question, return what information is needed to answer the question.',
|
||||
id=None,
|
||||
):
|
||||
super().__init__(operation_description, id, True)
|
||||
self.domain = domain
|
||||
self.api_key = OPENAI_API_KEY
|
||||
self.llm = LLM(model=model_name, api_key=self.api_key)
|
||||
self.prompt_set = GPTSwarmPromptSet()
|
||||
|
||||
@property
|
||||
def node_name(self) -> str:
|
||||
return self.__class__.__name__
|
||||
|
||||
def extract_urls(self, text: str) -> List[str]:
|
||||
url_pattern = r'https?://[^\s]+'
|
||||
urls = re.findall(url_pattern, text)
|
||||
return urls
|
||||
|
||||
def is_youtube_url(self, url: str) -> bool:
|
||||
youtube_regex = (
|
||||
r'(https?://)?(www\.)?'
|
||||
r'(youtube|youtu|youtube-nocookie)\.(com|be)/'
|
||||
r'(watch\?v=|embed/|v/|.+\?v=)?([^&=%\?]{11})'
|
||||
)
|
||||
return bool(re.match(youtube_regex, url))
|
||||
|
||||
def _youtube_download(self, url: str) -> str:
|
||||
try:
|
||||
video_id = url.split('v=')[-1].split('&')[0]
|
||||
video_id = video_id.strip()
|
||||
youtube = YouTube(url)
|
||||
video_stream = (
|
||||
youtube.streams.filter(progressive=True, file_extension='mp4')
|
||||
.order_by('resolution')
|
||||
.desc()
|
||||
.first()
|
||||
)
|
||||
if not video_stream:
|
||||
raise ValueError('No suitable video stream found.')
|
||||
|
||||
output_dir = 'workspace/tmp'
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
output_path = f'{output_dir}/{video_id}.mp4'
|
||||
video_stream.download(output_path=output_dir, filename=f'{video_id}.mp4')
|
||||
return output_path
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'Error downloading video from {url}: {e}'
|
||||
) # Use logger for error messages
|
||||
return ''
|
||||
|
||||
async def _execute(
|
||||
self, inputs: Optional[List[dict]] = None, **kwargs
|
||||
) -> List[dict]:
|
||||
if inputs is None:
|
||||
inputs = []
|
||||
node_inputs = inputs
|
||||
outputs = []
|
||||
|
||||
for input in node_inputs:
|
||||
urls = self.extract_urls(input['task'])
|
||||
|
||||
download_paths = []
|
||||
|
||||
for url in urls:
|
||||
if self.is_youtube_url(url):
|
||||
download_path = self._youtube_download(url)
|
||||
if download_path:
|
||||
download_paths.append(download_path)
|
||||
|
||||
if urls:
|
||||
logger.info(urls)
|
||||
if download_paths:
|
||||
logger.info(download_paths)
|
||||
|
||||
files = input.get('files', [])
|
||||
if not isinstance(files, list):
|
||||
files = []
|
||||
files.extend(download_paths)
|
||||
|
||||
role = self.prompt_set.get_role()
|
||||
# constraint = self.prompt_set.get_constraint()
|
||||
prompt = self.prompt_set.get_query_prompt(question=input['task'])
|
||||
|
||||
messages = [
|
||||
Message(role='system', content=f'You are a {role}.'),
|
||||
Message(role='user', content=prompt),
|
||||
]
|
||||
|
||||
response = self.llm.completion(
|
||||
messages=[
|
||||
{'role': msg.role, 'content': msg.content} for msg in messages
|
||||
]
|
||||
)
|
||||
response = response.choices[0].message.content
|
||||
|
||||
executions = {
|
||||
'operation': self.node_name,
|
||||
'task': input['task'],
|
||||
'files': files,
|
||||
'input': input.get('task', None),
|
||||
'subtask': prompt,
|
||||
'output': response,
|
||||
'format': 'natural language',
|
||||
}
|
||||
outputs.append(executions)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class FileAnalyse(Node):
|
||||
def __init__(
|
||||
self,
|
||||
domain: str = 'gaia',
|
||||
model_name: Optional[str] = 'gpt-4o-2024-05-13',
|
||||
operation_description: str = 'Given a question, extract information from a file.',
|
||||
id=None,
|
||||
):
|
||||
super().__init__(operation_description, id, True)
|
||||
self.domain = domain
|
||||
self.api_key = OPENAI_API_KEY
|
||||
self.llm = LLM(model=model_name, api_key=self.api_key)
|
||||
self.prompt_set = GPTSwarmPromptSet()
|
||||
self.reader = FileReader()
|
||||
|
||||
@property
|
||||
def node_name(self) -> str:
|
||||
return self.__class__.__name__
|
||||
|
||||
async def _execute(
|
||||
self, inputs: Optional[List[dict]] = None, **kwargs
|
||||
) -> List[dict]:
|
||||
if inputs is None:
|
||||
inputs = []
|
||||
node_inputs = inputs
|
||||
outputs = []
|
||||
for input in node_inputs:
|
||||
query = input.get('output', 'Please organize the information of this file.')
|
||||
files = input.get('files', [])
|
||||
response = await self.file_analyse(query, files, self.llm)
|
||||
|
||||
executions = {
|
||||
'operation': self.node_name,
|
||||
'task': input['task'],
|
||||
'files': files,
|
||||
'input': query,
|
||||
'subtask': f'Read the content of ###{files}, use query ###{query}',
|
||||
'output': response,
|
||||
'format': 'natural language',
|
||||
}
|
||||
|
||||
outputs.append(executions)
|
||||
|
||||
return outputs
|
||||
|
||||
async def file_analyse(self, query: str, files: List[str], llm: LLM) -> str:
|
||||
answer = ''
|
||||
for file in files:
|
||||
file_path = Path(file)
|
||||
if self.reader not in [parse_image, parse_video]:
|
||||
file_content = self.reader.read_file(file_path)
|
||||
prompt = self.prompt_set.get_file_analysis_prompt(
|
||||
query=query, file=file_content
|
||||
)
|
||||
messages = [
|
||||
Message(
|
||||
role='system',
|
||||
content=f'You are a {self.prompt_set.get_role()}.',
|
||||
),
|
||||
Message(role='user', content=prompt),
|
||||
]
|
||||
response = llm.completion(
|
||||
messages=[
|
||||
{'role': msg.role, 'content': msg.content} for msg in messages
|
||||
]
|
||||
)
|
||||
answer += response.choices[0].message.content + '\n'
|
||||
return answer
|
||||
|
||||
|
||||
class WebSearch(Node):
|
||||
def __init__(
|
||||
self,
|
||||
domain: str = 'gaia',
|
||||
model_name: Optional[str] = 'gpt-4o-2024-05-13',
|
||||
operation_description: str = 'Given a question, search the web for infomation.',
|
||||
id=None,
|
||||
):
|
||||
super().__init__(operation_description, id, True)
|
||||
self.domain = domain
|
||||
self.api_key = OPENAI_API_KEY
|
||||
self.llm = LLM(model=model_name, api_key=self.api_key)
|
||||
self.prompt_set = GPTSwarmPromptSet()
|
||||
|
||||
@property
|
||||
def node_name(self) -> str:
|
||||
return self.__class__.__name__
|
||||
|
||||
async def _execute(
|
||||
self, inputs: Optional[List[dict]] = None, max_keywords: int = 4, **kwargs
|
||||
) -> List[dict]:
|
||||
if inputs is None:
|
||||
inputs = []
|
||||
node_inputs = inputs
|
||||
outputs = []
|
||||
for input in node_inputs:
|
||||
task = input['task']
|
||||
query = input['output']
|
||||
prompt = self.prompt_set.get_websearch_prompt(question=task, query=query)
|
||||
messages = [
|
||||
Message(
|
||||
role='system', content=f'You are a {self.prompt_set.get_role()}.'
|
||||
),
|
||||
Message(role='user', content=prompt),
|
||||
]
|
||||
generated_quires = self.llm.completion(
|
||||
messages=[
|
||||
{'role': msg.role, 'content': msg.content} for msg in messages
|
||||
]
|
||||
)
|
||||
|
||||
generated_quires = generated_quires.choices[0].message.content
|
||||
generated_quires = generated_quires.split(',')[:max_keywords]
|
||||
logger.info(f'The search keywords include: {generated_quires}')
|
||||
search_results = [self.web_search(query) for query in generated_quires]
|
||||
logger.info(f'The search results: {str(search_results)[:100]}...')
|
||||
|
||||
distill_prompt = self.prompt_set.get_distill_websearch_prompt(
|
||||
question=input['task'], query=query, results='.\n'.join(search_results)
|
||||
)
|
||||
|
||||
messages = [
|
||||
Message(
|
||||
role='system', content=f'You are a {self.prompt_set.get_role()}.'
|
||||
),
|
||||
Message(role='user', content=distill_prompt),
|
||||
]
|
||||
response = self.llm.completion(
|
||||
messages=[
|
||||
{'role': msg.role, 'content': msg.content} for msg in messages
|
||||
]
|
||||
)
|
||||
response = response.choices[0].message.content
|
||||
|
||||
executions = {
|
||||
'operation': self.node_name,
|
||||
'task': task,
|
||||
'files': input.get('files', []),
|
||||
'input': query,
|
||||
'subtask': distill_prompt,
|
||||
'output': response,
|
||||
'format': 'natural language',
|
||||
}
|
||||
outputs.append(executions)
|
||||
|
||||
return outputs
|
||||
|
||||
def web_search(self, query: str, item_num: int = 3) -> str:
|
||||
url = 'https://www.searchapi.io/api/v1/search'
|
||||
params = {
|
||||
'engine': 'google',
|
||||
'q': query,
|
||||
'api_key': SEARCHAPI_API_KEY, # os.getenv("SEARCHAPI_API_KEY")
|
||||
}
|
||||
|
||||
response = ast.literal_eval(requests.get(url, params=params).text)
|
||||
|
||||
if (
|
||||
'knowledge_graph' in response.keys()
|
||||
and 'description' in response['knowledge_graph'].keys()
|
||||
):
|
||||
return response['knowledge_graph']['description']
|
||||
|
||||
if (
|
||||
'organic_results' in response.keys()
|
||||
and len(response['organic_results']) > 0
|
||||
):
|
||||
snippets = []
|
||||
for res in response['organic_results'][:item_num]:
|
||||
if 'snippet' in res:
|
||||
snippets.append(res['snippet'])
|
||||
return '\n'.join(snippets)
|
||||
|
||||
return ' '
|
||||
|
||||
|
||||
class CombineAnswer(Node):
|
||||
def __init__(
|
||||
self,
|
||||
domain: str = 'gaia',
|
||||
model_name: Optional[str] = 'gpt-4o-2024-05-13',
|
||||
operation_description: str = 'Combine multiple inputs into one.',
|
||||
max_token: int = 500,
|
||||
id=None,
|
||||
):
|
||||
super().__init__(operation_description, id, True)
|
||||
self.domain = domain
|
||||
self.max_token = max_token
|
||||
self.api_key = OPENAI_API_KEY
|
||||
self.llm = LLM(model=model_name, api_key=self.api_key)
|
||||
self.prompt_set = GPTSwarmPromptSet()
|
||||
self.materials: defaultdict[str, str] = defaultdict(str)
|
||||
|
||||
@property
|
||||
def node_name(self) -> str:
|
||||
return self.__class__.__name__
|
||||
|
||||
async def _execute(
|
||||
self, inputs: Optional[List[Any]] = None, **kwargs
|
||||
) -> List[dict]:
|
||||
if inputs is None:
|
||||
inputs = []
|
||||
node_inputs = inputs
|
||||
|
||||
role = self.prompt_set.get_role()
|
||||
constraint = self.prompt_set.get_constraint()
|
||||
|
||||
self.materials = defaultdict(str)
|
||||
for input in node_inputs:
|
||||
operation = input.get('operation')
|
||||
if operation:
|
||||
self.materials[operation] += f'{input.get("output", "")}\n'
|
||||
self.materials['task'] = input.get('task')
|
||||
|
||||
question = self.prompt_set.get_combine_materials(self.materials)
|
||||
prompt = self.prompt_set.get_answer_prompt(question=question)
|
||||
|
||||
messages = [
|
||||
Message(role='system', content=f'You are a {role}. {constraint}'),
|
||||
Message(role='user', content=prompt),
|
||||
]
|
||||
|
||||
response = self.llm.completion(
|
||||
messages=[{'role': msg.role, 'content': msg.content} for msg in messages]
|
||||
)
|
||||
|
||||
response = response.choices[0].message.content
|
||||
|
||||
executions = {
|
||||
'operation': self.node_name,
|
||||
'task': self.materials['task'],
|
||||
'files': self.materials['files']
|
||||
if isinstance(self.materials['files'], str)
|
||||
else ', '.join(self.materials['files']),
|
||||
'input': node_inputs,
|
||||
'subtask': prompt,
|
||||
'output': response,
|
||||
'format': 'natural language',
|
||||
}
|
||||
|
||||
return [executions]
|
||||
|
||||
|
||||
class AssistantGraph(Graph):
|
||||
def build_graph(self):
|
||||
query = GenerateQuery(self.domain, self.model_name)
|
||||
|
||||
file_analysis = FileAnalyse(self.domain, self.model_name)
|
||||
web_search = WebSearch(self.domain, self.model_name)
|
||||
|
||||
query.add_successor(file_analysis)
|
||||
query.add_successor(web_search)
|
||||
|
||||
combine = CombineAnswer(self.domain, self.model_name)
|
||||
file_analysis.add_successor(combine)
|
||||
web_search.add_successor(combine)
|
||||
|
||||
self.input_nodes = [query]
|
||||
self.output_nodes = [combine]
|
||||
|
||||
self.add_node(query)
|
||||
self.add_node(file_analysis)
|
||||
self.add_node(web_search)
|
||||
self.add_node(combine)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# # test node
|
||||
# task = 'What is the text representation of the last digit of twelve squared?'
|
||||
# inputs = [{'task': task}]
|
||||
# query_instance = GenerateQuery()
|
||||
# query = asyncio.run(query_instance._execute(inputs))
|
||||
# print(query)
|
||||
|
||||
# task = 'What is the text representation of the last digit of twelve squared?'
|
||||
# inputs = [
|
||||
# {
|
||||
# 'task': 'How can researchers ensure AGI development is both safe and ethical while avoiding societal biases and inequalities?',
|
||||
# 'files': ['agi.txt'],
|
||||
# }
|
||||
# ]
|
||||
# file_instance = FileAnalyse()
|
||||
# file_info = asyncio.run(file_instance._execute(inputs))
|
||||
# print(file_info)
|
||||
|
||||
# task = 'What is the text representation of the last digit of twelve squared?'
|
||||
# inputs = [
|
||||
# {
|
||||
# 'task': 'How can researchers ensure AGI development is both safe and ethical while avoiding societal biases and inequalities?'
|
||||
# }
|
||||
# ]
|
||||
# search_instance = WebSearch()
|
||||
# search_info = asyncio.run(search_instance._execute(inputs))
|
||||
# print(search_info)
|
||||
|
||||
assistant_graph = AssistantGraph(domain='gaia', model_name='gpt-4o-2024-05-13')
|
||||
|
||||
# test graph
|
||||
assistant_graph.build_graph()
|
||||
inputs = [
|
||||
{
|
||||
'task': 'How can researchers ensure AGI development is both safe and ethical while avoiding societal biases and inequalities?',
|
||||
'files': ['agi.txt'],
|
||||
}
|
||||
]
|
||||
outputs = asyncio.run(assistant_graph.run(inputs))
|
||||
print(outputs)
|
||||
129
agenthub/gptswarm_agent/prompt.py
Normal file
129
agenthub/gptswarm_agent/prompt.py
Normal file
@ -0,0 +1,129 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
class GPTSwarmPromptSet:
|
||||
"""
|
||||
GPTSwarmPromptSet provides a collection of static methods to generate prompts
|
||||
for a general AI assistant. These prompts cover various tasks like answering questions,
|
||||
performing web searches, analyzing files, and reflecting on tasks.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_role():
|
||||
return 'a general AI assistant'
|
||||
|
||||
@staticmethod
|
||||
def get_constraint():
|
||||
return (
|
||||
'I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. '
|
||||
'YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. '
|
||||
"If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. "
|
||||
"If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. "
|
||||
'If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. '
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_format():
|
||||
return 'natural language'
|
||||
|
||||
@staticmethod
|
||||
def get_answer_prompt(question):
|
||||
return f'{question}'
|
||||
|
||||
@staticmethod
|
||||
def get_query_prompt(question):
|
||||
return (
|
||||
'# Information Gathering for Question Resolution\n\n'
|
||||
'Evaluate if additional information is needed to answer the question. '
|
||||
'If a web search or file analysis is necessary, outline specific clues or details to be searched for.\n\n'
|
||||
f'## ❓ Target Question:\n{question}\n\n'
|
||||
'## 🔍 Clues for Investigation:\n'
|
||||
'Identify critical clues and concepts within the question that are essential for finding the answer.\n'
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_file_analysis_prompt(query, file):
|
||||
return (
|
||||
'# File Analysis Task\n\n'
|
||||
f'## 🔍 Information Extraction Objective:\n---\n{query}\n---\n\n'
|
||||
f'## 📄 File Under Analysis:\n---\n{file}\n---\n\n'
|
||||
'## 📝 Instructions:\n'
|
||||
'1. Identify the key sections in the file relevant to the query.\n'
|
||||
'2. Extract and summarize the necessary information from these sections.\n'
|
||||
'3. Ensure the response is focused and directly addresses the query.\n'
|
||||
"Example: 'Identify the main theme in the text.'"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_websearch_prompt(question, query):
|
||||
return (
|
||||
'# Web Search Task\n\n'
|
||||
f'## Original Question: \n---\n{question}\n---\n\n'
|
||||
f'## 🔍 Targeted Search Objective:\n---\n{query}\n---\n\n'
|
||||
'## 🌐 Simplified Search Instructions:\n'
|
||||
'Generate three specific search queries directly related to the original question. Each query should focus on key terms from the question. Format the output as a comma-separated list.\n'
|
||||
"For example, if the question is 'Who will be the next US president?', your queries could be: 'US presidential candidates, current US president, next US president'.\n"
|
||||
"Remember to format the queries as 'query1, query2, query3'."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_distill_websearch_prompt(question, query, results):
|
||||
return (
|
||||
'# Summarization of Search Results\n\n'
|
||||
f'## Original question: \n---\n{question}\n---\n\n'
|
||||
f'## 🔍 Required Information for Summary:\n---\n{query}\n---\n\n'
|
||||
f'## 🌐 Analyzed Search Results:\n---\n{results}\n---\n\n'
|
||||
'## 📝 Instructions for Summarization:\n'
|
||||
'1. Review the provided search results and identify the most relevant information related to the question and query.\n'
|
||||
'2. Extract and highlight the key findings, facts, or data points from these results.\n'
|
||||
'3. Organize the summarized information in a coherent and logical manner.\n'
|
||||
'4. Ensure the summary is concise and directly addresses the query, avoiding extraneous details.\n'
|
||||
'5. If the information from web search is useless, directly answer: "No useful information from WebSearch".\n'
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_combine_materials(materials: Dict[str, Any], avoid_vague=True) -> str:
|
||||
question = materials.get('task', 'No problem provided')
|
||||
|
||||
for key, value in materials.items():
|
||||
if 'No useful information from WebSearch' in value:
|
||||
continue
|
||||
value = value.strip('\n').strip()
|
||||
if key != 'task' and value:
|
||||
question += (
|
||||
f'\n\nReference information for {key}:'
|
||||
+ '\n----------------------------------------------\n'
|
||||
+ f'{value}'
|
||||
+ '\n----------------------------------------------\n\n'
|
||||
)
|
||||
|
||||
if avoid_vague:
|
||||
question += (
|
||||
'\nProvide a specific answer. For questions with known answers, ensure to provide accurate and factual responses. '
|
||||
+ "Avoid vague responses or statements like 'unable to...' that don't contribute to a definitive answer. "
|
||||
+ "For example: if a question asks 'who will be the president of America', and the answer is currently unknown, you could suggest possibilities like 'Donald Trump', or 'Biden'. However, if the answer is known, provide the correct information."
|
||||
)
|
||||
|
||||
return question
|
||||
|
||||
@staticmethod
|
||||
def get_self_consistency(question: str, answers: list, constraint: str) -> str:
|
||||
formatted_answers = '\n'.join(
|
||||
[f'Answer {index + 1}: {answer}' for index, answer in enumerate(answers)]
|
||||
)
|
||||
return (
|
||||
'# Self-Consistency Evaluation Task\n\n'
|
||||
f'## 🤔 Question for Review:\n---\n{question}\n---\n\n'
|
||||
f'## 💡 Reviewable Answers:\n---\n{formatted_answers}\n---\n\n'
|
||||
'## 📋 Instructions for Selection:\n'
|
||||
'1. Read each answer and assess how it addresses the question.\n'
|
||||
"2. Compare the answers for their adherence to the given question's criteria and logical coherence.\n"
|
||||
"3. Identify the answer that best aligns with the question's requirements and is the most logically consistent.\n"
|
||||
"4. Ignore the candidate answers if they do not give a direct answer, for example, using 'unable to ...', 'as an AI ...'.\n"
|
||||
'5. Copy the most suitable answer as it is, without modification, to maintain its original form.\n'
|
||||
f'6. Adhere to the constraints: {constraint}.\n'
|
||||
'Note: If no answer fully meets the criteria, choose and copy the one that is closest to the requirements.'
|
||||
)
|
||||
@ -10,6 +10,7 @@ import huggingface_hub
|
||||
import pandas as pd
|
||||
from datasets import load_dataset
|
||||
|
||||
from agenthub.gptswarm_agent.gptswarm_agent import GPTSwarm
|
||||
from evaluation.gaia.scorer import question_scorer
|
||||
from evaluation.utils.shared import (
|
||||
EvalMetadata,
|
||||
@ -31,10 +32,23 @@ from opendevin.llm.llm import LLM
|
||||
DATASET_CACHE_DIR = '~/.cache/open-devin/evals/gaia'
|
||||
DATASET_CACHE_DIR = os.path.expanduser(DATASET_CACHE_DIR)
|
||||
|
||||
HUGGINGFACE_TOKEN = os.getenv('HUGGINGFACE_TOKEN') or (
|
||||
open(os.path.expanduser('~/.huggingface/token')).read().strip()
|
||||
if os.path.exists(os.path.expanduser('~/.huggingface/token'))
|
||||
else input('Please enter your Hugging Face token: ').strip()
|
||||
)
|
||||
|
||||
|
||||
def gptswarm_user_response(state: State) -> str:
|
||||
# NOTE: For the AI assistant, state-based design may introduce more uncertainties.
|
||||
# TODO: It is stateless now. Find a way to make it stateful.
|
||||
print('Not implemented.')
|
||||
|
||||
|
||||
AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
|
||||
'CodeActAgent': partial(codeact_user_response, encapsulate_solution=True),
|
||||
'MonologueAgent': monologue_user_response,
|
||||
'GPTSwarmAgent': gptswarm_user_response,
|
||||
}
|
||||
|
||||
AGENT_CLS_TO_INST_SUFFIX = {
|
||||
@ -46,6 +60,7 @@ def process_instance(
|
||||
instance: pd.Series,
|
||||
metadata: EvalMetadata,
|
||||
reset_logger: bool = True,
|
||||
single_agent: bool = False,
|
||||
):
|
||||
# Create the agent
|
||||
agent = Agent.get_cls(metadata.agent_class)(llm=LLM(llm_config=metadata.llm_config))
|
||||
@ -170,17 +185,102 @@ def process_instance(
|
||||
# remove when it becomes unnecessary
|
||||
histories = state.history.compatibility_for_eval_history_pairs()
|
||||
|
||||
# Prepare instruction
|
||||
instruction = f"{instance['Question']}\n"
|
||||
logger.info(f'Instruction: {instruction}')
|
||||
if dest_file:
|
||||
instruction += f"\n\nThe mentioned file is provided in the workspace at: {dest_file.split('/')[-1]}"
|
||||
|
||||
# TODO: Need further improve for new V1.1 version and drop if-else.
|
||||
if agent.__class__.__name__ == 'GPTSwarmAgent':
|
||||
if dest_file:
|
||||
inputs = [{'task': instruction, 'files': [dest_file]}]
|
||||
else:
|
||||
inputs = [{'task': instruction}]
|
||||
|
||||
model_name = metadata['model_name']
|
||||
gptswarm_agent = GPTSwarm(llm=LLM(), model_name=model_name)
|
||||
if single_agent:
|
||||
model_answer_raw = asyncio.run(gptswarm_agent.run(inputs))
|
||||
else:
|
||||
model_answer_raw = asyncio.run(gptswarm_agent.swarm_run(inputs))
|
||||
|
||||
model_answer = model_answer_raw[-1].split('FINAL ANSWER: ')[-1]
|
||||
|
||||
else:
|
||||
instruction += 'IMPORTANT: You should ONLY interact with the environment provided to you AND NEVER ASK FOR HUMAN HELP.\n'
|
||||
instruction += 'Please encapsulate your final answer (answer ONLY) within <solution> and </solution>.\n'
|
||||
instruction += 'For example: The answer to the question is <solution> 42 </solution>.\n'
|
||||
# NOTE: You can actually set slightly different instruction for different agents
|
||||
instruction += AGENT_CLS_TO_INST_SUFFIX.get(agent.__class__.__name__, '')
|
||||
logger.info(
|
||||
f'Instruction:\n{instruction}', extra={'msg_type': 'OBSERVATION'}
|
||||
)
|
||||
|
||||
# Here's how you can run the agent (similar to the `main` function) and get the final task state
|
||||
state: State | None = asyncio.run(
|
||||
run_agent_controller(
|
||||
agent,
|
||||
instruction,
|
||||
max_iterations=metadata.max_iterations,
|
||||
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
|
||||
agent.__class__.__name__
|
||||
],
|
||||
sid=instance['text'].strip(),
|
||||
)
|
||||
)
|
||||
# ======= Attempt to evaluate the agent's edits =======
|
||||
# If you are working on simplier benchmark that only evaluates the final model output (e.g., in a MessageAction)
|
||||
# You can simply get the LAST `MessageAction` from the returned `state.history` and parse it for evaluation.
|
||||
|
||||
if state is None:
|
||||
raise ValueError('State should not be None.')
|
||||
|
||||
model_answer_raw = ''
|
||||
for act, _ in reversed(state.history):
|
||||
if isinstance(act, CmdRunAction) and act.source == 'agent':
|
||||
model_answer_raw = act.thought
|
||||
break
|
||||
elif isinstance(act, MessageAction) and act.source == 'agent':
|
||||
model_answer_raw = act.content
|
||||
break
|
||||
|
||||
# attempt to parse model_answer
|
||||
model_answer = re.findall(r'<solution>(.*?)</solution>', model_answer_raw)
|
||||
if len(model_answer) == 0:
|
||||
logger.warning(f'Failed to parse model answer: {model_answer_raw}')
|
||||
model_answer = model_answer_raw
|
||||
else:
|
||||
model_answer = model_answer[0]
|
||||
|
||||
logger.info(
|
||||
f'Final message: {model_answer} | Ground truth: {instance["Final answer"]}'
|
||||
)
|
||||
score = question_scorer(
|
||||
model_answer=model_answer, ground_truth=instance['Final answer']
|
||||
)
|
||||
test_result = {
|
||||
'score': score,
|
||||
'model_answer_raw': model_answer_raw,
|
||||
'model_answer': model_answer,
|
||||
'ground_truth': instance['Final answer'],
|
||||
}
|
||||
|
||||
# Save the output
|
||||
output = {
|
||||
'instance_id': instance['task_id'],
|
||||
'instance': instance,
|
||||
'instruction': instance['Question'],
|
||||
'metadata': metadata.model_dump(),
|
||||
'metadata': metadata,
|
||||
'history': histories,
|
||||
# [
|
||||
# (event_to_dict(action), event_to_dict(obs)) for action, obs in state.history
|
||||
# ],
|
||||
'error': state.error if state and state.error else None,
|
||||
'metrics': metrics,
|
||||
'error': state.last_error if state and state.last_error else None,
|
||||
'test_result': test_result,
|
||||
}
|
||||
|
||||
except Exception:
|
||||
logger.error('Process instance failed')
|
||||
raise
|
||||
|
||||
@ -17,8 +17,8 @@ fi
|
||||
checkout_eval_branch
|
||||
|
||||
if [ -z "$AGENT" ]; then
|
||||
echo "Agent not specified, use default CodeActAgent"
|
||||
AGENT="CodeActAgent"
|
||||
echo "Agent not specified, use default GPTSwarmAgent"
|
||||
AGENT="GPTSwarmAgent"
|
||||
fi
|
||||
|
||||
get_agent_version
|
||||
@ -38,12 +38,8 @@ echo "LEVELS: $LEVELS"
|
||||
COMMAND="poetry run python ./evaluation/gaia/run_infer.py \
|
||||
--agent-cls $AGENT \
|
||||
--llm-config $MODEL_CONFIG \
|
||||
--max-iterations 30 \
|
||||
--level $LEVELS \
|
||||
--data-split validation \
|
||||
--max-chars 10000000 \
|
||||
--eval-num-workers $NUM_WORKERS \
|
||||
--eval-note ${AGENT_VERSION}_${LEVELS}"
|
||||
--data-split validation"
|
||||
|
||||
if [ -n "$EVAL_LIMIT" ]; then
|
||||
echo "EVAL_LIMIT: $EVAL_LIMIT"
|
||||
|
||||
50
evaluation/gaia/scripts/run_infer_multiple.sh
Executable file
50
evaluation/gaia/scripts/run_infer_multiple.sh
Executable file
@ -0,0 +1,50 @@
|
||||
#!/bin/bash
|
||||
|
||||
MODEL_CONFIG=$1
|
||||
AGENT=$2
|
||||
EVAL_LIMIT=$3
|
||||
LEVELS=$4
|
||||
NUM_RUNS=5
|
||||
OUTPUT_BASE_DIR="/Users/zhugem/Desktop/OpenDevin/evaluation/evaluation_outputs/outputs/gaia"
|
||||
|
||||
if [ -z "$AGENT" ]; then
|
||||
echo "Agent not specified, use default GPTSwarmAgent"
|
||||
AGENT="GPTSwarmAgent"
|
||||
fi
|
||||
|
||||
if [ -z "$LEVELS" ]; then
|
||||
LEVELS="2023_level1"
|
||||
echo "Levels not specified, use default $LEVELS"
|
||||
fi
|
||||
|
||||
# IMPORTANT: Because Agent's prompt changes fairly often in the rapidly evolving codebase of OpenDevin
|
||||
# We need to track the version of Agent in the evaluation to make sure results are comparable
|
||||
AGENT_VERSION=v$(poetry run python -c "import agenthub; from opendevin.controller.agent import Agent; print(Agent.get_cls('$AGENT').VERSION)")
|
||||
|
||||
echo "AGENT: $AGENT"
|
||||
echo "AGENT_VERSION: $AGENT_VERSION"
|
||||
echo "MODEL_CONFIG: $MODEL_CONFIG"
|
||||
echo "LEVELS: $LEVELS"
|
||||
|
||||
for i in $(seq 1 $NUM_RUNS)
|
||||
do
|
||||
RANDOM_SUFFIX=$(date +%s%N)
|
||||
OUTPUT_DIR="${OUTPUT_BASE_DIR}/${AGENT}/${MODEL_CONFIG}-${RANDOM_SUFFIX}"
|
||||
#OUTPUT_DIR="${OUTPUT_BASE_DIR}/${AGENT}/${MODEL_CONFIG}"
|
||||
echo "Running iteration $i, output will be stored in $OUTPUT_DIR"
|
||||
|
||||
COMMAND="poetry run python ./evaluation/gaia/run_infer.py \
|
||||
--agent-cls $AGENT \
|
||||
--llm-config $MODEL_CONFIG \
|
||||
--level $LEVELS \
|
||||
--data-split validation \
|
||||
--eval-output-dir $OUTPUT_DIR"
|
||||
|
||||
if [ -n "$EVAL_LIMIT" ]; then
|
||||
echo "EVAL_LIMIT: $EVAL_LIMIT"
|
||||
COMMAND="$COMMAND --eval-n-limit $EVAL_LIMIT"
|
||||
fi
|
||||
|
||||
# Run the command
|
||||
eval $COMMAND
|
||||
done
|
||||
@ -27,6 +27,7 @@ import tempfile
|
||||
from inspect import signature
|
||||
from typing import Optional
|
||||
|
||||
import charset_normalizer
|
||||
import docx
|
||||
import PyPDF2
|
||||
from openai import OpenAI
|
||||
@ -43,6 +44,12 @@ ENABLE_AUTO_LINT = os.getenv('ENABLE_AUTO_LINT', 'false').lower() == 'true'
|
||||
MSG_FILE_UPDATED = '[File updated (edited at line {line_number}). Please review the changes and make sure they are correct (correct indentation, no duplicate lines, etc). Edit the file again if necessary.]'
|
||||
|
||||
# OPENAI
|
||||
"""
|
||||
https://github.com/OpenDevin/OpenDevin/pull/2052
|
||||
NOTE: GPTSwarm need manually export...
|
||||
TODO: Fix it
|
||||
export SANDBOX_ENV_OPENAI_API_KEY="sk-***"
|
||||
"""
|
||||
OPENAI_API_KEY = os.getenv(
|
||||
'OPENAI_API_KEY', os.getenv('SANDBOX_ENV_OPENAI_API_KEY', '')
|
||||
)
|
||||
@ -1011,18 +1018,18 @@ def parse_image(
|
||||
"""
|
||||
print(f'[Reading image file from {file_path}]')
|
||||
# TODO: record the COST of the API call
|
||||
try:
|
||||
base64_image = _base64_img(file_path)
|
||||
response = client.chat.completions.create(
|
||||
model=OPENAI_MODEL,
|
||||
messages=_prepare_image_messages(task, base64_image),
|
||||
max_tokens=MAX_TOKEN,
|
||||
)
|
||||
content = response.choices[0].message.content
|
||||
print(content)
|
||||
# try:
|
||||
base64_image = _base64_img(file_path)
|
||||
response = client.chat.completions.create(
|
||||
model=OPENAI_MODEL,
|
||||
messages=_prepare_image_messages(task, base64_image),
|
||||
max_tokens=MAX_TOKEN,
|
||||
)
|
||||
content = response.choices[0].message.content
|
||||
print(content)
|
||||
|
||||
except Exception as error:
|
||||
print(f'Error with the request: {error}')
|
||||
# except Exception as error:
|
||||
# print(f'Error with the request: {error}')
|
||||
|
||||
|
||||
@update_pwd_decorator
|
||||
@ -1097,6 +1104,23 @@ def parse_pptx(file_path: str) -> None:
|
||||
print(f'Error reading PowerPoint file: {e}')
|
||||
|
||||
|
||||
def parse_txt(file_path: str) -> Optional[str]:
|
||||
"""
|
||||
Parses the content of a txt file and prints it.
|
||||
|
||||
Args:
|
||||
file_path: str: The path to the file to open.
|
||||
"""
|
||||
print(f'[Reading TXT file from {file_path}]')
|
||||
try:
|
||||
content = charset_normalizer.from_path(str(file_path)).best()
|
||||
return str(content)
|
||||
|
||||
except Exception as e:
|
||||
print(f'Error reading TXT file: {e}')
|
||||
return ''
|
||||
|
||||
|
||||
__all__ = [
|
||||
# file operation
|
||||
'open_file',
|
||||
@ -1115,6 +1139,7 @@ __all__ = [
|
||||
'parse_docx',
|
||||
'parse_latex',
|
||||
'parse_pptx',
|
||||
'parse_txt',
|
||||
]
|
||||
|
||||
if OPENAI_API_KEY and OPENAI_BASE_URL:
|
||||
|
||||
1596
poetry.lock
generated
1596
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -38,6 +38,10 @@ zope-interface = "6.4.post2"
|
||||
pathspec = "^0.12.1"
|
||||
google-cloud-aiplatform = "*"
|
||||
|
||||
[tool.poetry.group.gptswarm.dependencies]
|
||||
serpapi = "0.1.5"
|
||||
gptswarm = "0.1.0"
|
||||
|
||||
[tool.poetry.group.llama-index.dependencies]
|
||||
llama-index = "*"
|
||||
llama-index-vector-stores-chroma = "*"
|
||||
@ -95,4 +99,4 @@ ignore = [ "E501" ]
|
||||
|
||||
[tool.black]
|
||||
# prevent black (if installed) from changing single quotes to double quotes
|
||||
skip-string-normalization = true
|
||||
skip-string-normalization = true
|
||||
Loading…
x
Reference in New Issue
Block a user