jessicayjm's picture
update sample code
a74b362 verified
|
raw
history blame
3.23 kB
metadata
license: cc-by-nc-sa-4.0
datasets:
  - Blablablab/ALOE

Model Description

The model classifies an appraisal given a sentence and is trained on ALOE dataset.

Input: a sentence

Labels: No Label, Pleasantness, Anticipated Effort, Certainty, Objective Experience, Self-Other Agency, Situational Control, Advice, Trope

Output: logits (in order of labels)

Model architecture: OpenPrompt_+RoBERTa

Developed by: Jiamin Yang

Model Performance

Overall performance
Macro-F1 Recall Precision
0.56 0.57 0.58
Per-label performance
Label Recall Precision
No Label 0.34 0.64
Pleasantness 0.69 0.54
Anticipated Effort 0.46 0.46
Certainty 0.58 0.47
Objective Experience 0.58 0.69
Self-Other Agency 0.62 0.55
Situational Control 0.31 0.55
Advice 0.72 0.66
Trope 0.80 0.67

Getting Started

import torch
from openprompt.plms import load_plm
from openprompt.prompts import ManualTemplate
from openprompt.prompts import ManualVerbalizer
from openprompt import PromptForClassification
from openprompt.data_utils import InputExample
from openprompt import PromptDataLoader

checkpoint_file = 'upload_version/empathy-appraisal-span.pt'

plm, tokenizer, model_config, WrapperClass = load_plm('roberta', 'roberta-large')
template_text = 'The sentence {"placeholder":"text_a"} has the label {"mask"}.'
template = ManualTemplate(tokenizer=tokenizer, text=template_text)

num_classes = 9
label_words = [['No Label'], ['Pleasantness'], ['Anticipated Effort'], ['Certainty'], ['Objective Experience'], ['Self-Other Agency'], ['Situational Control'], ['Advice'], ['Trope']]
verbalizer = ManualVerbalizer(tokenizer, num_classes=num_classes, label_words=label_words)
prompt_model = PromptForClassification(plm=plm,template=template, verbalizer=verbalizer, freeze_plm=False).to('cuda')

checkpoint = torch.load(checkpoint_file)
state_dict = checkpoint['model_state_dict']

# depend on the version of torch
del state_dict['prompt_model.plm.roberta.embeddings.position_ids']

prompt_model.load_state_dict(state_dict)

# use the model
dataset = [
    InputExample(
        guid = 0,
        text_a = "I am sorry for your loss",
    ),
    InputExample(
        guid = 1,
        text_a = "It's not your fault",
    ),
]

data_loader = PromptDataLoader(dataset=dataset, 
                template=template, 
                tokenizer=tokenizer,
                tokenizer_wrapper_class=WrapperClass,
                max_seq_length=512,
                batch_size=2,
                shuffle=False,
                teacher_forcing=False,
                predict_eos_token=False,
                truncate_method='head')
prompt_model.eval()
with torch.no_grad():
    for batch in data_loader:
        logits = prompt_model(batch.to('cuda'))
        preds = torch.argmax(logits, dim = -1)
        print(preds) #[8, 5]