File size: 3,232 Bytes
c80f9e8
1b7fa3a
c80f9e8
 
 
 
 
 
 
 
 
 
 
 
 
 
6334b7a
c80f9e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6334b7a
 
c80f9e8
fdffc91
c80f9e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6334b7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c80f9e8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
---
license: cc-by-nc-sa-4.0
datasets:
- Blablablab/ALOE
---

### Model Description

The model classifies an *appraisal* given a sentence and is trained on [ALOE](https://huggingface.co./datasets/Blablablab/ALOE) dataset.

**Input:**  a sentence

**Labels:**   No Label, Pleasantness, Anticipated Effort, Certainty, Objective Experience, Self-Other Agency, Situational Control, Advice, Trope

**Output:** logits (in order of labels)

**Model architecture**: OpenPrompt_+RoBERTa

**Developed by:** Jiamin Yang

### Model Performance

##### Overall performance

| Macro-F1 | Recall | Precision |
| :------: | :----: | :-------: |
|   0.56   |  0.57  |   0.58    |

##### Per-label performance

| Label                | Recall | Precision |
| -------------------- | :----: | :-------: |
| No Label             |  0.34  |   0.64    |
| Pleasantness         |  0.69  |   0.54    |
| Anticipated Effort   |  0.46  |   0.46    |
| Certainty            |  0.58  |   0.47    |
| Objective Experience |  0.58  |   0.69    |
| Self-Other Agency    |  0.62  |   0.55    |
| Situational Control  |  0.31  |   0.55    |
| Advice               |  0.72  |   0.66    |
| Trope                |  0.80  |   0.67    |

### Getting Started

```python
import torch
from openprompt.plms import load_plm
from openprompt.prompts import ManualTemplate
from openprompt.prompts import ManualVerbalizer
from openprompt import PromptForClassification
from openprompt.data_utils import InputExample
from openprompt import PromptDataLoader

checkpoint_file = 'your_path_to/empathy-appraisal-span.pt'

plm, tokenizer, model_config, WrapperClass = load_plm('roberta', 'roberta-large')
template_text = 'The sentence {"placeholder":"text_a"} has the label {"mask"}.'
template = ManualTemplate(tokenizer=tokenizer, text=template_text)

num_classes = 9
label_words = [['No Label'], ['Pleasantness'], ['Anticipated Effort'], ['Certainty'], ['Objective Experience'], ['Self-Other Agency'], ['Situational Control'], ['Advice'], ['Trope']]
verbalizer = ManualVerbalizer(tokenizer, num_classes=num_classes, label_words=label_words)
prompt_model = PromptForClassification(plm=plm,template=template, verbalizer=verbalizer, freeze_plm=False).to('cuda')

checkpoint = torch.load(checkpoint_file)
state_dict = checkpoint['model_state_dict']

# depend on the version of torch
del state_dict['prompt_model.plm.roberta.embeddings.position_ids']

prompt_model.load_state_dict(state_dict)

# use the model
dataset = [
    InputExample(
        guid = 0,
        text_a = "I am sorry for your loss",
    ),
    InputExample(
        guid = 1,
        text_a = "It's not your fault",
    ),
]

data_loader = PromptDataLoader(dataset=dataset, 
                template=template, 
                tokenizer=tokenizer,
                tokenizer_wrapper_class=WrapperClass,
                max_seq_length=512,
                batch_size=2,
                shuffle=False,
                teacher_forcing=False,
                predict_eos_token=False,
                truncate_method='head')
prompt_model.eval()
with torch.no_grad():
    for batch in data_loader:
        logits = prompt_model(batch.to('cuda'))
        preds = torch.argmax(logits, dim = -1)
        print(preds) #[8, 5]
```