license: apache-2.0
base_model: google/mt5-small
tags:
- generated_from_trainer
metrics:
- rouge
- bleu
- meteor
datasets:
- natural_questions
model-index:
- name: mt5-small
results:
- task:
type: Question answering from context
name: Question answering
dataset:
type: natural-questions
name: Adapted Natural Questions
metrics:
- type: bleu
value: 34.1596
name: BLEU
verified: true
- type: rouge
value: 44.4366
name: ROUGE1
verified: true
- type: rouge
value: 38.8202
name: ROUGE2
verified: true
- type: rouge
value: 43.113
name: ROUGEl
verified: true
- type: rouge
value: 43.1423
name: ROUGElsum
verified: true
- type: meteor
value: 0.4049
name: METEOR
verified: true
mt5-small
This model is a fine-tuned version of google/mt5-small on an enhanced version of the Natural Questions dataset. It achieves the following results on the evaluation set:
- Loss: 0.7291
- Rouge1: 44.4366
- Rouge2: 38.8202
- Rougel: 43.113
- Rougelsum: 43.1423
- Bleu: 34.1596
- Gen Len: 12.6724
- Meteor: 0.4049
- True negatives: 69.7281
- False negatives: 10.4037
- Cosine Sim: 0.763
Model description
This model is fine-tuned for long-form, closed-domain question answering - question-answering from context. It uses a heavily refined version of Google's Natural Questions dataset.
Answers to the questions were rewritten using OpenAI's GPT-3.5 Turbo model.
Please see the following repo for all code and adaptations.
Intended uses & limitations
The model requires questions to be submitted using the following format using the input message: [CONTEXT] <\s> [QUESTION]
It is trained to respond appropriately when a question cannot be answered using the provided context.
It can give false negatives and false positives on occasion (see Training Results), and all answers must be checked appropriately.
Training and evaluation data
The model is trained using the Natural Questions dataset, with answers that have been refined using GPT-3.5 Turbo. It is evaluated using a number of metrics including BLEU, ROUGE, METEOR, and cosine similarity.
Usage
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load model and tokenizer
model_name = "psxjp5/mt5-small"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Generate text
context = "Once upon a time"
question = "What is time"
input_ids = tokenizer(context, question, return_tensors="pt").input_ids
outputs = model.generate(input_ids, max_new_tokens=150)
print(tokenizer.decode(output[0], skip_special_tokens=True))
Training procedure
Training hyperparameters
The following hyperparameters were used during training:
- learning_rate: 0.001
- train_batch_size: 16
- eval_batch_size: 16
- seed: 9
- gradient_accumulation_steps: 8
- total_train_batch_size: 128
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
- lr_scheduler_type: linear
- num_epochs: 20
- weight_decay: 0.007
- dropout: 0.4
Training results
Training Loss | Epoch | Step | Validation Loss | Rouge1 | Rouge2 | Rougel | Rougelsum | Bleu | Gen Len | Meteor | True negatives | False negatives | Cosine Sim |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
2.5724 | 1.0 | 175 | 0.9876 | 18.7781 | 15.6002 | 18.22 | 18.2686 | 7.6676 | 7.7661 | 0.1628 | 72.8701 | 56.677 | 0.4003 |
1.1469 | 1.99 | 350 | 0.8580 | 36.8209 | 31.2514 | 35.5008 | 35.5462 | 25.7137 | 12.0014 | 0.3311 | 62.8399 | 20.3934 | 0.66 |
0.9468 | 2.99 | 525 | 0.7997 | 40.4128 | 34.716 | 39.0867 | 39.0972 | 29.3028 | 12.4287 | 0.3656 | 63.4441 | 15.295 | 0.7114 |
0.8129 | 3.98 | 700 | 0.7733 | 42.6764 | 36.7266 | 41.2465 | 41.2833 | 32.0644 | 12.9002 | 0.3871 | 62.1752 | 11.413 | 0.7425 |
0.7228 | 4.98 | 875 | 0.7483 | 42.9082 | 36.957 | 41.482 | 41.5233 | 32.4942 | 12.8866 | 0.3906 | 63.3233 | 11.5166 | 0.747 |
0.6493 | 5.97 | 1050 | 0.7293 | 40.3205 | 34.9632 | 39.1111 | 39.1168 | 28.8249 | 11.6867 | 0.3674 | 73.8973 | 17.9865 | 0.7068 |
0.5883 | 6.97 | 1225 | 0.7172 | 42.7342 | 37.0855 | 41.4069 | 41.424 | 32.1296 | 12.48 | 0.3887 | 70.0302 | 12.7847 | 0.7392 |
0.5409 | 7.96 | 1400 | 0.7387 | 44.6657 | 38.8426 | 43.3276 | 43.3496 | 34.4773 | 12.9395 | 0.4084 | 66.3444 | 9.5238 | 0.7658 |
0.5035 | 8.96 | 1575 | 0.7330 | 43.4925 | 38.0013 | 42.2697 | 42.2372 | 32.6131 | 12.2789 | 0.3979 | 72.6284 | 12.8364 | 0.7```1 |
0.4652 | 9.95 | 1750 | 0.7291 | 44.4366 | 38.8202 | 43.113 | 43.1423 | 34.1596 | 12.6724 | 0.4049 | 69.7281 | 10.4037 | 0.763 |
Framework versions
- Transformers 4.31.0
- Pytorch 2.0.1+cu118
- Datasets 2.13.1
- Tokenizers 0.13.3