from transformers.integrations import TensorBoardCallback from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM from transformers import TrainingArguments, Trainer, DataCollatorForSeq2Seq from transformers import TrainerCallback, TrainerState, TrainerControl from transformers.trainer import TRAINING_ARGS_NAME from torch.utils.tensorboard import SummaryWriter import datasets import torch import os import re import sys import wandb import argparse from datetime import datetime from functools import partial from tqdm import tqdm from utils import * # LoRA from peft import ( TaskType, LoraConfig, get_peft_model, get_peft_model_state_dict, prepare_model_for_int8_training, set_peft_model_state_dict, ) # Replace with your own api_key and project name os.environ['WANDB_API_KEY'] = 'ecf1e5e4f47441d46822d38a3249d62e8fc94db4' os.environ['WANDB_PROJECT'] = 'fingpt-forecaster' class GenerationEvalCallback(TrainerCallback): def __init__(self, eval_dataset, ignore_until_epoch=0): self.eval_dataset = eval_dataset self.ignore_until_epoch = ignore_until_epoch def on_evaluate(self, args, state: TrainerState, control: TrainerControl, **kwargs): if state.epoch is None or state.epoch + 1 < self.ignore_until_epoch: return if state.is_local_process_zero: model = kwargs['model'] tokenizer = kwargs['tokenizer'] generated_texts, reference_texts = [], [] for feature in tqdm(self.eval_dataset): prompt = feature['prompt'] gt = feature['answer'] inputs = tokenizer( prompt, return_tensors='pt', padding=False, max_length=4096 ) inputs = {key: value.to(model.device) for key, value in inputs.items()} res = model.generate( **inputs, use_cache=True ) output = tokenizer.decode(res[0], skip_special_tokens=True) answer = re.sub(r'.*\[/INST\]\s*', '', output, flags=re.DOTALL) generated_texts.append(answer) reference_texts.append(gt) # print("GENERATED: ", answer) # print("REFERENCE: ", gt) metrics = calc_metrics(reference_texts, generated_texts) # Ensure wandb is initialized if wandb.run is None: wandb.init() wandb.log(metrics, step=state.global_step) torch.cuda.empty_cache() def main(args): model_name = parse_model_name(args.base_model, args.from_remote) # load model model = AutoModelForCausalLM.from_pretrained( model_name, # load_in_8bit=True, trust_remote_code=True ) if args.local_rank == 0: print(model) tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right" # load data dataset_list = load_dataset(args.dataset, args.from_remote) dataset_train = datasets.concatenate_datasets([d['train'] for d in dataset_list]).shuffle(seed=42) if args.test_dataset: dataset_list = load_dataset(args.test_dataset, args.from_remote) dataset_test = datasets.concatenate_datasets([d['test'] for d in dataset_list]) original_dataset = datasets.DatasetDict({'train': dataset_train, 'test': dataset_test}) eval_dataset = original_dataset['test'].shuffle(seed=42).select(range(50)) dataset = original_dataset.map(partial(tokenize, args, tokenizer)) print('original dataset length: ', len(dataset['train'])) dataset = dataset.filter(lambda x: not x['exceed_max_length']) print('filtered dataset length: ', len(dataset['train'])) dataset = dataset.remove_columns( ['prompt', 'answer', 'label', 'symbol', 'period', 'exceed_max_length'] ) current_time = datetime.now() formatted_time = current_time.strftime('%Y%m%d%H%M') training_args = TrainingArguments( output_dir=f'finetuned_models/{args.run_name}_{formatted_time}', # 保存位置 logging_steps=args.log_interval, num_train_epochs=args.num_epochs, per_device_train_batch_size=args.batch_size, per_device_eval_batch_size=args.batch_size, gradient_accumulation_steps=args.gradient_accumulation_steps, dataloader_num_workers=args.num_workers, learning_rate=args.learning_rate, weight_decay=args.weight_decay, warmup_ratio=args.warmup_ratio, lr_scheduler_type=args.scheduler, save_steps=args.eval_steps, eval_steps=args.eval_steps, fp16=True, deepspeed=args.ds_config, evaluation_strategy=args.evaluation_strategy, remove_unused_columns=False, report_to='wandb', run_name=args.run_name ) model.gradient_checkpointing_enable() model.enable_input_require_grads() model.is_parallelizable = True model.model_parallel = True model.model.config.use_cache = False # model = prepare_model_for_int8_training(model) # setup peft peft_config = LoraConfig( task_type=TaskType.CAUSAL_LM, inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.1, target_modules=lora_module_dict[args.base_model], bias='none', ) model = get_peft_model(model, peft_config) # Train trainer = Trainer( model=model, args=training_args, train_dataset=dataset['train'], eval_dataset=dataset['test'], tokenizer=tokenizer, data_collator=DataCollatorForSeq2Seq( tokenizer, padding=True, return_tensors="pt" ), callbacks=[ GenerationEvalCallback( eval_dataset=eval_dataset, ignore_until_epoch=round(0.3 * args.num_epochs) ) ] ) if torch.__version__ >= "2" and sys.platform != "win32": model = torch.compile(model) torch.cuda.empty_cache() trainer.train() # save model model.save_pretrained(training_args.output_dir) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--local_rank", default=0, type=int) parser.add_argument("--run_name", default='local-test', type=str) parser.add_argument("--dataset", required=True, type=str) parser.add_argument("--test_dataset", type=str) parser.add_argument("--base_model", required=True, type=str, choices=['chatglm2', 'llama2']) parser.add_argument("--max_length", default=512, type=int) parser.add_argument("--batch_size", default=4, type=int, help="The train batch size per device") parser.add_argument("--learning_rate", default=1e-4, type=float, help="The learning rate") parser.add_argument("--weight_decay", default=0.01, type=float, help="weight decay") parser.add_argument("--num_epochs", default=8, type=float, help="The training epochs") parser.add_argument("--num_workers", default=8, type=int, help="dataloader workers") parser.add_argument("--log_interval", default=20, type=int) parser.add_argument("--gradient_accumulation_steps", default=8, type=int) parser.add_argument("--warmup_ratio", default=0.05, type=float) parser.add_argument("--ds_config", default='./config_new.json', type=str) parser.add_argument("--scheduler", default='linear', type=str) parser.add_argument("--instruct_template", default='default') parser.add_argument("--evaluation_strategy", default='steps', type=str) parser.add_argument("--eval_steps", default=0.1, type=float) parser.add_argument("--from_remote", default=False, type=bool) args = parser.parse_args() wandb.login() main(args)