File size: 3,232 Bytes
c80f9e8 1b7fa3a c80f9e8 6334b7a c80f9e8 6334b7a c80f9e8 fdffc91 c80f9e8 6334b7a c80f9e8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
---
license: cc-by-nc-sa-4.0
datasets:
- Blablablab/ALOE
---
### Model Description
The model classifies an *appraisal* given a sentence and is trained on [ALOE](https://huggingface.co./datasets/Blablablab/ALOE) dataset.
**Input:** a sentence
**Labels:** No Label, Pleasantness, Anticipated Effort, Certainty, Objective Experience, Self-Other Agency, Situational Control, Advice, Trope
**Output:** logits (in order of labels)
**Model architecture**: OpenPrompt_+RoBERTa
**Developed by:** Jiamin Yang
### Model Performance
##### Overall performance
| Macro-F1 | Recall | Precision |
| :------: | :----: | :-------: |
| 0.56 | 0.57 | 0.58 |
##### Per-label performance
| Label | Recall | Precision |
| -------------------- | :----: | :-------: |
| No Label | 0.34 | 0.64 |
| Pleasantness | 0.69 | 0.54 |
| Anticipated Effort | 0.46 | 0.46 |
| Certainty | 0.58 | 0.47 |
| Objective Experience | 0.58 | 0.69 |
| Self-Other Agency | 0.62 | 0.55 |
| Situational Control | 0.31 | 0.55 |
| Advice | 0.72 | 0.66 |
| Trope | 0.80 | 0.67 |
### Getting Started
```python
import torch
from openprompt.plms import load_plm
from openprompt.prompts import ManualTemplate
from openprompt.prompts import ManualVerbalizer
from openprompt import PromptForClassification
from openprompt.data_utils import InputExample
from openprompt import PromptDataLoader
checkpoint_file = 'your_path_to/empathy-appraisal-span.pt'
plm, tokenizer, model_config, WrapperClass = load_plm('roberta', 'roberta-large')
template_text = 'The sentence {"placeholder":"text_a"} has the label {"mask"}.'
template = ManualTemplate(tokenizer=tokenizer, text=template_text)
num_classes = 9
label_words = [['No Label'], ['Pleasantness'], ['Anticipated Effort'], ['Certainty'], ['Objective Experience'], ['Self-Other Agency'], ['Situational Control'], ['Advice'], ['Trope']]
verbalizer = ManualVerbalizer(tokenizer, num_classes=num_classes, label_words=label_words)
prompt_model = PromptForClassification(plm=plm,template=template, verbalizer=verbalizer, freeze_plm=False).to('cuda')
checkpoint = torch.load(checkpoint_file)
state_dict = checkpoint['model_state_dict']
# depend on the version of torch
del state_dict['prompt_model.plm.roberta.embeddings.position_ids']
prompt_model.load_state_dict(state_dict)
# use the model
dataset = [
InputExample(
guid = 0,
text_a = "I am sorry for your loss",
),
InputExample(
guid = 1,
text_a = "It's not your fault",
),
]
data_loader = PromptDataLoader(dataset=dataset,
template=template,
tokenizer=tokenizer,
tokenizer_wrapper_class=WrapperClass,
max_seq_length=512,
batch_size=2,
shuffle=False,
teacher_forcing=False,
predict_eos_token=False,
truncate_method='head')
prompt_model.eval()
with torch.no_grad():
for batch in data_loader:
logits = prompt_model(batch.to('cuda'))
preds = torch.argmax(logits, dim = -1)
print(preds) #[8, 5]
```
|