mirror of
https://github.com/camel-ai/owl.git
synced 2025-12-26 10:07:51 +08:00
fix: minor
This commit is contained in:
parent
562e97bb79
commit
46facf5950
6
.gitignore
vendored
6
.gitignore
vendored
@ -47,6 +47,12 @@ owl/results
|
||||
logs/
|
||||
log/
|
||||
|
||||
# Local files
|
||||
**/local_*
|
||||
**/local-*
|
||||
**/*_local*
|
||||
**/*-local*
|
||||
|
||||
# Coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
from datasets import Dataset
|
||||
from datasets import Dataset, DatasetDict
|
||||
from tqdm import tqdm
|
||||
|
||||
import json
|
||||
import random
|
||||
import argparse
|
||||
import pandas as pd
|
||||
import datasets
|
||||
|
||||
|
||||
def construct_system_message(task_prompt):
|
||||
@ -43,7 +43,6 @@ Please note that our overall task may be very complicated. Here are some tips th
|
||||
"""
|
||||
|
||||
|
||||
|
||||
def process_and_push(input_file, repo_id):
|
||||
r"""Process dataset and push to Hugging Face Hub."""
|
||||
|
||||
@ -90,14 +89,20 @@ def process_and_push(input_file, repo_id):
|
||||
df["messages"] = df.apply(
|
||||
lambda row: build_messages(row["history"], row["question"]), axis=1)
|
||||
df["messages_camel"] = df["messages"]
|
||||
|
||||
test_df = df.sample(n=min(5, len(df)), random_state=42)
|
||||
train_df = df.drop(test_df.index)
|
||||
|
||||
hf_dataset = DatasetDict({
|
||||
"train": Dataset.from_pandas(train_df),
|
||||
"test": Dataset.from_pandas(test_df),
|
||||
})
|
||||
|
||||
hf_dataset = Dataset.from_pandas(df)
|
||||
hf_dataset.push_to_hub(repo_id)
|
||||
print(f"Dataset pushed to: https://huggingface.co/datasets/{repo_id}")
|
||||
|
||||
|
||||
def main():
|
||||
"""Parse arguments and execute the processing."""
|
||||
parser = argparse.ArgumentParser(description="Process and push dataset.")
|
||||
parser.add_argument("--input_file", "-i", type=str, required=True,
|
||||
help="Path to the input JSON file")
|
||||
@ -105,6 +110,7 @@ def main():
|
||||
help="Hugging Face Hub repository ID")
|
||||
args = parser.parse_args()
|
||||
|
||||
random.seed(42)
|
||||
process_and_push(args.input_file, args.repo_id)
|
||||
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1
|
||||
|
||||
model=${1:-Qwen/Qwen2.5-VL-7B-Instruct}
|
||||
# model=${1:-Qwen/Qwen2.5-7B-Instruct}
|
||||
max_model_len=${2:-32768}
|
||||
gpu_memory_utilization=${3:-0.9}
|
||||
devices=${4:-"0,1,2,3"}
|
||||
|
||||
96
owl/scripts/sft.py
Normal file
96
owl/scripts/sft.py
Normal file
@ -0,0 +1,96 @@
|
||||
"""
|
||||
Supervised Fine-Tuning (SFT) with owl agents' trajectories.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
|
||||
|
||||
from trl import (
|
||||
ModelConfig,
|
||||
ScriptArguments,
|
||||
SFTConfig,
|
||||
SFTTrainer,
|
||||
TrlParser,
|
||||
get_kbit_device_map,
|
||||
get_peft_config,
|
||||
get_quantization_config,
|
||||
)
|
||||
|
||||
|
||||
def main(script_args, training_args, model_args):
|
||||
################
|
||||
# Model init kwargs & Tokenizer
|
||||
################
|
||||
quantization_config = get_quantization_config(model_args)
|
||||
model_kwargs = dict(
|
||||
revision=model_args.model_revision,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
torch_dtype=model_args.torch_dtype,
|
||||
use_cache=False if training_args.gradient_checkpointing else True,
|
||||
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
|
||||
# Create model
|
||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
**model_kwargs
|
||||
)
|
||||
|
||||
# Create tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True
|
||||
)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
################
|
||||
# Dataset
|
||||
################
|
||||
dataset = load_dataset(
|
||||
script_args.dataset_name,
|
||||
name=script_args.dataset_config
|
||||
)
|
||||
|
||||
################
|
||||
# Training
|
||||
################
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset[script_args.dataset_train_split],
|
||||
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
|
||||
processing_class=tokenizer,
|
||||
peft_config=get_peft_config(model_args),
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
# Save and push to hub
|
||||
trainer.save_model(training_args.output_dir)
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub(dataset_name=script_args.dataset_name)
|
||||
|
||||
|
||||
def make_parser(subparsers: argparse._SubParsersAction = None):
|
||||
dataclass_types = (ScriptArguments, SFTConfig, ModelConfig)
|
||||
if subparsers is not None:
|
||||
parser = subparsers.add_parser(
|
||||
"sft",
|
||||
help="Run the SFT training script",
|
||||
dataclass_types=dataclass_types
|
||||
)
|
||||
else:
|
||||
parser = TrlParser(dataclass_types)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = make_parser()
|
||||
script_args, training_args, model_args = parser.parse_args_and_config()
|
||||
main(script_args, training_args, model_args)
|
||||
Loading…
x
Reference in New Issue
Block a user