librarian-bot's picture
Librarian Bot: Add base_model information to model
be8a83d
|
raw
history blame
3.47 kB
metadata
language:
  - en
license: mit
library_name: peft
datasets:
  - wikitext
metrics:
  - f1
base_model: roberta-large

RoBERTa large fine-tuned using LoRa for predicting comma placement in text. It expects input with commas removed and classifies each token for whether it should have a comma inserted after it or not.

As a PEFT model, it does not seem to work well with huggingface pipelines, at least not at the time of writing.

Examples of usage and a wrapper class for text-to-text comma fixing can be seen in the demo.

Loading the raw model in code:

from peft import PeftModel, PeftConfig
from transformers import AutoModelForTokenClassification
import torch

id2label = {
    0: "O",
    1: "B-COMMA"
}
label2id = {
    "O": 0,
    "B-COMMA": 1
}

peft_model_id = 'klasocki/roberta-large-lora-ner-comma-fixer'
config = PeftConfig.from_pretrained(peft_model_id)
inference_model = AutoModelForTokenClassification.from_pretrained(
    config.base_model_name_or_path, num_labels=len(id2label), id2label=id2label, label2id=label2id
)
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
model = PeftModel.from_pretrained(inference_model, peft_model_id)

text = "This text should have commas here here and there however it does not."
inputs = tokenizer(text, return_tensors="pt")

with torch.no_grad():
    logits = model(**inputs).logits

tokens = inputs.tokens()
predictions = torch.argmax(logits, dim=2)

for token, prediction in zip(tokens, predictions[0].numpy()):
    print((token, model.config.id2label[prediction]))

### OUTPUT:
('<s>', 'O')
('This', 'O')
('Ġtext', 'O')
('Ġshould', 'O')
('Ġhave', 'O')
('Ġcomm', 'O')
('as', 'O')
('Ġhere', 'B-COMMA')
('Ġhere', 'O')
('Ġand', 'O')
('Ġthere', 'B-COMMA')
('Ġhowever', 'O')
('Ġit', 'O')
('Ġdoes', 'O')
('Ġnot', 'O')
('.', 'O')
('</s>', 'O')

Evaluation results

Results for commas on the wikitext validation set:

Model precision recall F1 support
baseline* 0.79 0.72 0.75 10079
ours 0.84 0.84 0.84 10079

*baseline is the oliverguhr/fullstop-punctuation-multilang-large model evaluated on commas, out of domain on wikitext. In-domain, authors report F1 of 0.819 for English political speeches, however, it seems that wikipedia text could be more challenging for comma restoration.

Training procedure

To compare with the baseline, we fine-tune the same model, RoBERTa large, on the wikitext English dataset. We use a similar approach, where we treat comma-fixing as a NER problem, and for each token predict whether a comma should be inserted after it.

The biggest advantage of this approach is that it preserves the input structure and only focuses on commas, ensuring that nothing else will be changed and that the model will not have to learn repeating the input back in case no commas should be inserted.

We use LoRa to reduce training time and costs, and synthesize a training dataset from wikitext. In the end the model seems to converge after only about 15000 training examples, so a small subset of wikitext is more than enough. Adding more languages and domains can be explored in the future.

Framework versions

  • PEFT 0.5.0
  • Transformers 4.31.0
  • Torch 2.0.1