QLoRA-Flan-T5-Small / README.md
emonty777's picture
Update README.md
c079f78
|
raw
history blame
2.26 kB
metadata
license: apache-2.0
tags:
  - generated_from_trainer
datasets:
  - cnn_dailymail
model-index:
  - name: QLoRA-Flan-T5-Small
    results: []

QLoRA-Flan-T5-Small

This model is a fine-tuned version of google/flan-t5-small on the cnn_dailymail dataset.

Model description

More information needed

Intended uses & limitations

More information needed

Training and evaluation data

More information needed

How to use model

  1. Loading the model

'''python import torch from peft import PeftModel, PeftConfig from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

Load peft config for pre-trained checkpoint etc.

peft_model_id = "emonty777/QLoRA-Flan-T5-Small"

config = PeftConfig.from_pretrained(peft_model_id)

load base LLM model and tokenizer / runs on CPU

model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path) tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)

load base LLM model and tokenizer for GPU

model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path, load_in_8bit=True, device_map={"":0}) tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)

Load the Lora model

model = PeftModel.from_pretrained(model, peft_model_id, device_map={"":0}) model.eval() '''

Training procedure

Training hyperparameters

The following hyperparameters were used during training:

  • learning_rate: 3e-05
  • train_batch_size: 8
  • eval_batch_size: 8
  • seed: 42
  • optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
  • lr_scheduler_type: linear
  • num_epochs: 4

Training results

Evaluated on full CNN Dailymail test set

'rouge-1': {'r': 0.3484396421841008, 'p': 0.37845620239152916, 'f': 0.3484265780526604},

'rouge-2': {'r': 0.1472418310455188, 'p': 0.15418276080118026, 'f': 0.14343059577230782},

'rouge-l': {'r': 0.3280567401095563, 'p': 0.3565504002457199, 'f': 0.32809541498574013}

Framework versions

  • Transformers 4.27.1
  • Pytorch 2.0.1+cu118
  • Datasets 2.9.0
  • Tokenizers 0.13.3