fix: minor

This commit is contained in:
Ziyu 2025-03-25 00:15:12 -05:00
parent 562e97bb79
commit 46facf5950
4 changed files with 114 additions and 5 deletions

6
.gitignore vendored
View File

@ -47,6 +47,12 @@ owl/results
logs/
log/
# Local files
**/local_*
**/local-*
**/*_local*
**/*-local*
# Coverage reports
htmlcov/
.tox/

View File

@ -1,10 +1,10 @@
from datasets import Dataset
from datasets import Dataset, DatasetDict
from tqdm import tqdm
import json
import random
import argparse
import pandas as pd
import datasets
def construct_system_message(task_prompt):
@ -43,7 +43,6 @@ Please note that our overall task may be very complicated. Here are some tips th
"""
def process_and_push(input_file, repo_id):
r"""Process dataset and push to Hugging Face Hub."""
@ -90,14 +89,20 @@ def process_and_push(input_file, repo_id):
df["messages"] = df.apply(
lambda row: build_messages(row["history"], row["question"]), axis=1)
df["messages_camel"] = df["messages"]
test_df = df.sample(n=min(5, len(df)), random_state=42)
train_df = df.drop(test_df.index)
hf_dataset = DatasetDict({
"train": Dataset.from_pandas(train_df),
"test": Dataset.from_pandas(test_df),
})
hf_dataset = Dataset.from_pandas(df)
hf_dataset.push_to_hub(repo_id)
print(f"Dataset pushed to: https://huggingface.co/datasets/{repo_id}")
def main():
"""Parse arguments and execute the processing."""
parser = argparse.ArgumentParser(description="Process and push dataset.")
parser.add_argument("--input_file", "-i", type=str, required=True,
help="Path to the input JSON file")
@ -105,6 +110,7 @@ def main():
help="Hugging Face Hub repository ID")
args = parser.parse_args()
random.seed(42)
process_and_push(args.input_file, args.repo_id)

View File

@ -5,6 +5,7 @@
export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1
model=${1:-Qwen/Qwen2.5-VL-7B-Instruct}
# model=${1:-Qwen/Qwen2.5-7B-Instruct}
max_model_len=${2:-32768}
gpu_memory_utilization=${3:-0.9}
devices=${4:-"0,1,2,3"}

96
owl/scripts/sft.py Normal file
View File

@ -0,0 +1,96 @@
"""
Supervised Fine-Tuning (SFT) with owl agents' trajectories.
"""
import argparse
from datasets import load_dataset
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
from trl import (
ModelConfig,
ScriptArguments,
SFTConfig,
SFTTrainer,
TrlParser,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
def main(script_args, training_args, model_args):
################
# Model init kwargs & Tokenizer
################
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
attn_implementation=model_args.attn_implementation,
torch_dtype=model_args.torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
# Create model
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
**model_kwargs
)
# Create tokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
################
# Dataset
################
dataset = load_dataset(
script_args.dataset_name,
name=script_args.dataset_config
)
################
# Training
################
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
processing_class=tokenizer,
peft_config=get_peft_config(model_args),
)
trainer.train()
# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub(dataset_name=script_args.dataset_name)
def make_parser(subparsers: argparse._SubParsersAction = None):
dataclass_types = (ScriptArguments, SFTConfig, ModelConfig)
if subparsers is not None:
parser = subparsers.add_parser(
"sft",
help="Run the SFT training script",
dataclass_types=dataclass_types
)
else:
parser = TrlParser(dataclass_types)
return parser
if __name__ == "__main__":
parser = make_parser()
script_args, training_args, model_args = parser.parse_args_and_config()
main(script_args, training_args, model_args)