Spaces:
Runtime error
Runtime error
from transformers.integrations import TensorBoardCallback | |
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM | |
from transformers import TrainingArguments, Trainer, DataCollatorForSeq2Seq | |
from transformers import TrainerCallback, TrainerState, TrainerControl | |
from transformers.trainer import TRAINING_ARGS_NAME | |
from torch.utils.tensorboard import SummaryWriter | |
import datasets | |
import torch | |
import os | |
import re | |
import sys | |
import wandb | |
import argparse | |
from datetime import datetime | |
from functools import partial | |
from tqdm import tqdm | |
from utils import * | |
# LoRA | |
from peft import ( | |
TaskType, | |
LoraConfig, | |
get_peft_model, | |
get_peft_model_state_dict, | |
prepare_model_for_int8_training, | |
set_peft_model_state_dict, | |
) | |
# Replace with your own api_key and project name | |
os.environ['WANDB_API_KEY'] = 'ecf1e5e4f47441d46822d38a3249d62e8fc94db4' | |
os.environ['WANDB_PROJECT'] = 'fingpt-forecaster' | |
class GenerationEvalCallback(TrainerCallback): | |
def __init__(self, eval_dataset, ignore_until_epoch=0): | |
self.eval_dataset = eval_dataset | |
self.ignore_until_epoch = ignore_until_epoch | |
def on_evaluate(self, args, state: TrainerState, control: TrainerControl, **kwargs): | |
if state.epoch is None or state.epoch + 1 < self.ignore_until_epoch: | |
return | |
if state.is_local_process_zero: | |
model = kwargs['model'] | |
tokenizer = kwargs['tokenizer'] | |
generated_texts, reference_texts = [], [] | |
for feature in tqdm(self.eval_dataset): | |
prompt = feature['prompt'] | |
gt = feature['answer'] | |
inputs = tokenizer( | |
prompt, return_tensors='pt', | |
padding=False, max_length=4096 | |
) | |
inputs = {key: value.to(model.device) for key, value in inputs.items()} | |
res = model.generate( | |
**inputs, | |
use_cache=True | |
) | |
output = tokenizer.decode(res[0], skip_special_tokens=True) | |
answer = re.sub(r'.*\[/INST\]\s*', '', output, flags=re.DOTALL) | |
generated_texts.append(answer) | |
reference_texts.append(gt) | |
# print("GENERATED: ", answer) | |
# print("REFERENCE: ", gt) | |
metrics = calc_metrics(reference_texts, generated_texts) | |
# Ensure wandb is initialized | |
if wandb.run is None: | |
wandb.init() | |
wandb.log(metrics, step=state.global_step) | |
torch.cuda.empty_cache() | |
def main(args): | |
model_name = parse_model_name(args.base_model, args.from_remote) | |
# load model | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
# load_in_8bit=True, | |
trust_remote_code=True | |
) | |
if args.local_rank == 0: | |
print(model) | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
tokenizer.pad_token = tokenizer.eos_token | |
tokenizer.padding_side = "right" | |
# load data | |
dataset_list = load_dataset(args.dataset, args.from_remote) | |
dataset_train = datasets.concatenate_datasets([d['train'] for d in dataset_list]).shuffle(seed=42) | |
if args.test_dataset: | |
dataset_list = load_dataset(args.test_dataset, args.from_remote) | |
dataset_test = datasets.concatenate_datasets([d['test'] for d in dataset_list]) | |
original_dataset = datasets.DatasetDict({'train': dataset_train, 'test': dataset_test}) | |
eval_dataset = original_dataset['test'].shuffle(seed=42).select(range(50)) | |
dataset = original_dataset.map(partial(tokenize, args, tokenizer)) | |
print('original dataset length: ', len(dataset['train'])) | |
dataset = dataset.filter(lambda x: not x['exceed_max_length']) | |
print('filtered dataset length: ', len(dataset['train'])) | |
dataset = dataset.remove_columns( | |
['prompt', 'answer', 'label', 'symbol', 'period', 'exceed_max_length'] | |
) | |
current_time = datetime.now() | |
formatted_time = current_time.strftime('%Y%m%d%H%M') | |
training_args = TrainingArguments( | |
output_dir=f'finetuned_models/{args.run_name}_{formatted_time}', # 保存位置 | |
logging_steps=args.log_interval, | |
num_train_epochs=args.num_epochs, | |
per_device_train_batch_size=args.batch_size, | |
per_device_eval_batch_size=args.batch_size, | |
gradient_accumulation_steps=args.gradient_accumulation_steps, | |
dataloader_num_workers=args.num_workers, | |
learning_rate=args.learning_rate, | |
weight_decay=args.weight_decay, | |
warmup_ratio=args.warmup_ratio, | |
lr_scheduler_type=args.scheduler, | |
save_steps=args.eval_steps, | |
eval_steps=args.eval_steps, | |
fp16=True, | |
deepspeed=args.ds_config, | |
evaluation_strategy=args.evaluation_strategy, | |
remove_unused_columns=False, | |
report_to='wandb', | |
run_name=args.run_name | |
) | |
model.gradient_checkpointing_enable() | |
model.enable_input_require_grads() | |
model.is_parallelizable = True | |
model.model_parallel = True | |
model.model.config.use_cache = False | |
# model = prepare_model_for_int8_training(model) | |
# setup peft | |
peft_config = LoraConfig( | |
task_type=TaskType.CAUSAL_LM, | |
inference_mode=False, | |
r=8, | |
lora_alpha=16, | |
lora_dropout=0.1, | |
target_modules=lora_module_dict[args.base_model], | |
bias='none', | |
) | |
model = get_peft_model(model, peft_config) | |
# Train | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=dataset['train'], | |
eval_dataset=dataset['test'], | |
tokenizer=tokenizer, | |
data_collator=DataCollatorForSeq2Seq( | |
tokenizer, padding=True, | |
return_tensors="pt" | |
), | |
callbacks=[ | |
GenerationEvalCallback( | |
eval_dataset=eval_dataset, | |
ignore_until_epoch=round(0.3 * args.num_epochs) | |
) | |
] | |
) | |
if torch.__version__ >= "2" and sys.platform != "win32": | |
model = torch.compile(model) | |
torch.cuda.empty_cache() | |
trainer.train() | |
# save model | |
model.save_pretrained(training_args.output_dir) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--local_rank", default=0, type=int) | |
parser.add_argument("--run_name", default='local-test', type=str) | |
parser.add_argument("--dataset", required=True, type=str) | |
parser.add_argument("--test_dataset", type=str) | |
parser.add_argument("--base_model", required=True, type=str, choices=['chatglm2', 'llama2']) | |
parser.add_argument("--max_length", default=512, type=int) | |
parser.add_argument("--batch_size", default=4, type=int, help="The train batch size per device") | |
parser.add_argument("--learning_rate", default=1e-4, type=float, help="The learning rate") | |
parser.add_argument("--weight_decay", default=0.01, type=float, help="weight decay") | |
parser.add_argument("--num_epochs", default=8, type=float, help="The training epochs") | |
parser.add_argument("--num_workers", default=8, type=int, help="dataloader workers") | |
parser.add_argument("--log_interval", default=20, type=int) | |
parser.add_argument("--gradient_accumulation_steps", default=8, type=int) | |
parser.add_argument("--warmup_ratio", default=0.05, type=float) | |
parser.add_argument("--ds_config", default='./config_new.json', type=str) | |
parser.add_argument("--scheduler", default='linear', type=str) | |
parser.add_argument("--instruct_template", default='default') | |
parser.add_argument("--evaluation_strategy", default='steps', type=str) | |
parser.add_argument("--eval_steps", default=0.1, type=float) | |
parser.add_argument("--from_remote", default=False, type=bool) | |
args = parser.parse_args() | |
wandb.login() | |
main(args) |