Spaces:
Build error
Build error
LogiQA2.0 dataset
Browse files- .gitattributes +4 -1
- datasets/LogiQA2.0/README.md +132 -0
- datasets/LogiQA2.0/logiqa/DATA/LOGIQA/datasource.txt +3 -0
- datasets/LogiQA2.0/logiqa/DATA/LOGIQA/dev.txt +3 -0
- datasets/LogiQA2.0/logiqa/DATA/LOGIQA/dev_fol.jsonl +3 -0
- datasets/LogiQA2.0/logiqa/DATA/LOGIQA/dev_zh.txt +3 -0
- datasets/LogiQA2.0/logiqa/DATA/LOGIQA/ood_test.jsonl +3 -0
- datasets/LogiQA2.0/logiqa/DATA/LOGIQA/readme.md +7 -0
- datasets/LogiQA2.0/logiqa/DATA/LOGIQA/statistics.py +21 -0
- datasets/LogiQA2.0/logiqa/DATA/LOGIQA/test.txt +3 -0
- datasets/LogiQA2.0/logiqa/DATA/LOGIQA/test_fol.jsonl +3 -0
- datasets/LogiQA2.0/logiqa/DATA/LOGIQA/test_zh.txt +3 -0
- datasets/LogiQA2.0/logiqa/DATA/LOGIQA/train.txt +3 -0
- datasets/LogiQA2.0/logiqa/DATA/LOGIQA/train_fol.zip +3 -0
- datasets/LogiQA2.0/logiqa/DATA/LOGIQA/train_zh.txt +3 -0
- datasets/LogiQA2.0/logiqa/DATA/LOGIQA/word_matching.py +26 -0
- datasets/LogiQA2.0/logiqa/logiqa.sh +21 -0
- datasets/LogiQA2.0/logiqa/modeling_bart.py +1416 -0
- datasets/LogiQA2.0/logiqa/multi-choice-prompt.py +56 -0
- datasets/LogiQA2.0/logiqa/run_mrc.py +552 -0
- datasets/LogiQA2.0/logiqa/utils_mrc.py +280 -0
- datasets/LogiQA2.0/logiqa2nli/DATA/QA2NLI/ dev_new.txt +3 -0
- datasets/LogiQA2.0/logiqa2nli/DATA/QA2NLI/dev.txt +3 -0
- datasets/LogiQA2.0/logiqa2nli/DATA/QA2NLI/readme.md +1 -0
- datasets/LogiQA2.0/logiqa2nli/DATA/QA2NLI/stat.py +25 -0
- datasets/LogiQA2.0/logiqa2nli/DATA/QA2NLI/test.txt +3 -0
- datasets/LogiQA2.0/logiqa2nli/DATA/QA2NLI/test_new.txt +3 -0
- datasets/LogiQA2.0/logiqa2nli/DATA/QA2NLI/train.txt +3 -0
- datasets/LogiQA2.0/logiqa2nli/DATA/QA2NLI/train_new.txt +3 -0
- datasets/LogiQA2.0/logiqa2nli/nli-prompt.py +51 -0
- datasets/LogiQA2.0/logiqa2nli/qa2nli.sh +20 -0
- datasets/LogiQA2.0/logiqa2nli/run_nli.py +549 -0
- datasets/LogiQA2.0/logiqa2nli/scripts/anli.sh +21 -0
- datasets/LogiQA2.0/logiqa2nli/scripts/cood.sh +4 -0
- datasets/LogiQA2.0/logiqa2nli/scripts/mnli.sh +21 -0
- datasets/LogiQA2.0/logiqa2nli/scripts/multirun.sh +8 -0
- datasets/LogiQA2.0/logiqa2nli/scripts/pnli.sh +21 -0
- datasets/LogiQA2.0/logiqa2nli/scripts/qa2nli.sh +21 -0
- datasets/LogiQA2.0/logiqa2nli/scripts/qnli.sh +21 -0
- datasets/LogiQA2.0/logiqa2nli/scripts/qood.sh +20 -0
- datasets/LogiQA2.0/logiqa2nli/scripts/rte.sh +20 -0
- datasets/LogiQA2.0/logiqa2nli/scripts/scitail.sh +22 -0
- datasets/LogiQA2.0/logiqa2nli/scripts/snli.sh +21 -0
- datasets/LogiQA2.0/logiqa2nli/scripts/wnli.sh +21 -0
- datasets/LogiQA2.0/logiqa2nli/utils_nli.py +1002 -0
- datasets/LogiQA2.0/requirements.yml +17 -0
.gitattributes
CHANGED
@@ -33,6 +33,8 @@ unsloth/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.xz filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
35 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
36 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
37 |
datasets/mgtv/ filter=lfs diff=lfs merge=lfs -text
|
38 |
datasets/mgtv/dev.csv filter=lfs diff=lfs merge=lfs -text
|
@@ -106,5 +108,6 @@ results/test_b-results_r6.csv filter=lfs diff=lfs merge=lfs -text
|
|
106 |
mgtv_train_p1.json filter=lfs diff=lfs merge=lfs -text
|
107 |
mgtv_train_p2.json filter=lfs diff=lfs merge=lfs -text
|
108 |
datasets/mgtv/o1-mini.jsonl filter=lfs diff=lfs merge=lfs -text
|
109 |
-
datasets/mgtv/Icon
|
110 |
filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
111 |
datasets/mgtv/gpt-4o-mini.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
|
33 |
*.xz filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
35 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.jsonl filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.txt filter=lfs diff=lfs merge=lfs -text
|
38 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
39 |
datasets/mgtv/ filter=lfs diff=lfs merge=lfs -text
|
40 |
datasets/mgtv/dev.csv filter=lfs diff=lfs merge=lfs -text
|
|
|
108 |
mgtv_train_p1.json filter=lfs diff=lfs merge=lfs -text
|
109 |
mgtv_train_p2.json filter=lfs diff=lfs merge=lfs -text
|
110 |
datasets/mgtv/o1-mini.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
|
111 |
filter=lfs diff=lfs merge=lfs -text
|
112 |
+
datasets/mgtv/Icon
|
113 |
+
filter=lfs diff=lfs merge=lfs -text
|
114 |
datasets/mgtv/gpt-4o-mini.jsonl filter=lfs diff=lfs merge=lfs -text
|
datasets/LogiQA2.0/README.md
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# LogiQA2.0
|
2 |
+
Logiqa2.0 dataset - logical reasoning in MRC and NLI tasks
|
3 |
+
|
4 |
+
<a rel="license" href="http://creativecommons.org/licenses/by-nc-sa/4.0/"><img alt="Creative Commons License" style="border-width:0" src="https://i.creativecommons.org/l/by-nc-sa/4.0/88x31.png" /></a><br />This work is licensed under a <a rel="license" href="http://creativecommons.org/licenses/by-nc-sa/4.0/">Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License</a>.
|
5 |
+
|
6 |
+
> This repository contains the datasets and baseline codes for our paper [LogiQA2.0 - An Improved Dataset for Logic Reasoning in Question Answering and Textual Inference](https://ieeexplore.ieee.org/abstract/document/10174688)
|
7 |
+
|
8 |
+
## How to cite
|
9 |
+
```
|
10 |
+
@ARTICLE{10174688,
|
11 |
+
author={Liu, Hanmeng and Liu, Jian and Cui, Leyang and Teng, Zhiyang and Duan, Nan and Zhou, Ming and Zhang, Yue},
|
12 |
+
journal={IEEE/ACM Transactions on Audio, Speech, and Language Processing},
|
13 |
+
title={LogiQA 2.0—An Improved Dataset for Logical Reasoning in Natural Language Understanding},
|
14 |
+
year={2023},
|
15 |
+
volume={31},
|
16 |
+
number={},
|
17 |
+
pages={2947-2962},
|
18 |
+
doi={10.1109/TASLP.2023.3293046}}
|
19 |
+
|
20 |
+
```
|
21 |
+
## About
|
22 |
+
This is the version 2 of the LogiQA dataset, first released as a multi-choice reading comprehension dataset by our previous paper [LogiQA: A Challenge Dataset for Machine Reading Comprehension with Logical Reasoning](https://arxiv.org/abs/2007.08124).
|
23 |
+
|
24 |
+
The dataset is collected from the [Chinese Civil Service Entrance Examination](chinagwy.org). The dataset is both in Chinese and English (by translation). you can download the version 1 of the LogiQA dataset from [here](https://github.com/lgw863/logiqa-dataset).
|
25 |
+
|
26 |
+
To construct LogiQA2.0 dataset, we:
|
27 |
+
* collect more newly released exam questions and practice questions. There are about 20 provinces in China that hold the exam annually. The exam materials are publicly available on the Internet after the exams. Besides, practice questions are provided by various sources.
|
28 |
+
* hire professional translators to re-translate the dataset from Chinese to English; verify the labels and annotations with human experts. This program is conducted by [Speechocean](en.speechocean.com), a data annotation service provider. The project is accomplished with the help of Microsoft Research Asia.
|
29 |
+
* introduce a new NLI task to the dataset. The NLI version of the dataset is converted from the MRC version of the dataset, following previous work such as [Transforming Question Answering Datasets into Natural Language Inference Datasets](https://arxiv.org/abs/1809.02922).
|
30 |
+
|
31 |
+
## Datasets
|
32 |
+
### MRC
|
33 |
+
The MRC part of LogiQA2.0 dataset can be found in the `/logiqa/DATA/LOGIQA` folder.
|
34 |
+
|
35 |
+
`train.txt`: train split of the dataset in json lines.
|
36 |
+
|
37 |
+
`dev.txt`: dev split of the dataset in json lines.
|
38 |
+
|
39 |
+
`test.txt`: test split of the dataset in json lines.
|
40 |
+
|
41 |
+
`train_zh.txt`: train split of the Chinese version of dataset in json lines.
|
42 |
+
|
43 |
+
`dev_zh.txt`: dev split of the Chinese version of dataset in json lines.
|
44 |
+
|
45 |
+
`test_zh.txt`: test split of the Chinese version of dataset in json lines.
|
46 |
+
|
47 |
+
`train_fol.zip` is the training data with AMR and FOL annotations. The file is too big so we compressed it.
|
48 |
+
|
49 |
+
`dev_fol.jsonl` is the dev data with AMR and FOL annotations.
|
50 |
+
|
51 |
+
`test_fol.jsonl` is the test data with AMR and FOL annotations.
|
52 |
+
|
53 |
+
|
54 |
+
An example:
|
55 |
+
```
|
56 |
+
{"id": 10471, "answer": 0, "text": "The medieval Arabs had many manuscripts of the ancient Greek. When needed, they translate them into Arabic. Medieval Arab philosophers were very interested in Aristotle's Theory of Poetry, which was obviously not shared by Arab poets, because a poet interested in it must want to read Homer's poems. Aristotle himself often quotes Homer's poems. However, Homer's poems were not translated into Arabic until modern times.", "question": "Which of the following options, if true, strongly supports the above argument?", "options": ["Some medieval Arab translators have manuscripts of Homer poems in ancient Greek.", "Aristotle's Theory of Poetry is often quoted and commented by modern Arab poets.", "In Aristotle's Theory of Poetry, most of the content is related to drama, and medieval Arabs also wrote plays and performed them.", "A series of medieval Arab stories, such as Arab Night, are very similar to some parts of Homer's epic."], "type": {"Sufficient Conditional Reasoning": true, "Necessry Condtional Reasoning": true, "Conjunctive Reasoning": true}}
|
57 |
+
```
|
58 |
+
An example of the Chinese dataset:
|
59 |
+
```
|
60 |
+
{"id": 8018, "answer": 0, "text": "常春藤通常指美国东部的八所大学。常春藤一词一直以来是美国名校的代名词,这八所大学不仅历史悠久,治学严谨,而且教学质量极高。这些学校的毕业生大多成为社会精英,他们中的多数人年薪超过20万美元,有很多政界领袖来自常春藤,更有为数众多的科学家毕业于长春藤。", "question": "根据以上条件,下面那个选项一定为真:", "options": ["A.有些社会精英年薪超过20万美金", "B.有些政界领袖年薪不足20万美元", "C.有些科学家年薪超过20万美元", "D.有些政界领袖是社会精英"]}
|
61 |
+
```
|
62 |
+
|
63 |
+
### NLI
|
64 |
+
The NLI part of LogiQA2.0 dataset can be found in the `/logiqa2nli/DATA/QA2NLI` folder.
|
65 |
+
|
66 |
+
`train.txt`: train split of the dataset in json lines
|
67 |
+
|
68 |
+
`dev.txt`: dev split of the dataset in json lines
|
69 |
+
|
70 |
+
`test.txt`: test split of the dataset in json lines
|
71 |
+
|
72 |
+
An example:
|
73 |
+
```
|
74 |
+
{"label": "not entailed", "major_premise": ["Among the employees of a software company, there are three Cantonese, one Beijinger and three northerners"], "minor_premise": " Four are only responsible for software development and two are only responsible for product sales", "conclusion": "There may be at least 7 people and at most 12 people."}
|
75 |
+
```
|
76 |
+
## Annotations
|
77 |
+
The translation and annotation work is outsourced to [Speechocean](en.speechocean.com), the project fund is provided by Microsoft Research Asia
|
78 |
+
### Translation
|
79 |
+
|
80 |
+
| Final Report | |
|
81 |
+
| --- | --- |
|
82 |
+
| provider | Speechocean |
|
83 |
+
| Project Duration | 2021/10/20-2021/12/3 |
|
84 |
+
| Actual Working Hour | 667 hours |
|
85 |
+
| Cost | 45000 RMB |
|
86 |
+
|
87 |
+
Translation style/method:
|
88 |
+
|
89 |
+
1. Maintain a unified style, and the translated English questions need to inherit the logic of the original questions;
|
90 |
+
|
91 |
+
2. The pronoun in the question need to be unique, the translation needs to be unique and consistent without ambiguity;
|
92 |
+
|
93 |
+
3. The translated English conforms to the form of a proper question, that is, it is a clear question from the perspective of the respondent;
|
94 |
+
|
95 |
+
### Label consistency
|
96 |
+
The label credibility is mannually verified after the translation was done to maintain the truthfulness of the original text. 3 workers run a consistency test on each example, if 2 or more workers give different answer compared to the original answer, the translation would be redone to guareentee the label is correct.
|
97 |
+
|
98 |
+
### Additional annotations
|
99 |
+
Reasoning types of each question is assigned by a total of 5 workers, each of them corresponds to one reasoning type. We give the description of reasoning types (which can be found in our paper) to the workers. The reasoning types of each question is a collection of 5 workers' decision.
|
100 |
+
## Baseline Guidance
|
101 |
+
### Requirements
|
102 |
+
* python 3.6+
|
103 |
+
* pytorch 1.0+
|
104 |
+
* transformers 2.4.1
|
105 |
+
* sklearn
|
106 |
+
* tqdm
|
107 |
+
* tensorboardX
|
108 |
+
|
109 |
+
We recommend to use conda to manage virtual environments:
|
110 |
+
|
111 |
+
```
|
112 |
+
conda env update --name logiqa --file requirements.yml
|
113 |
+
```
|
114 |
+
### Logiqa
|
115 |
+
The folder `logiqa` contains both the code and data to run baseline experiments of LogiQA2.0 MRC.
|
116 |
+
|
117 |
+
To fine-tune the dataset, type following command from the terminal in your :computer:
|
118 |
+
```
|
119 |
+
bash logiqa.sh
|
120 |
+
```
|
121 |
+
### Logiqa2nli
|
122 |
+
The folder `logiqa2nli` contains both the code and data to run baseline experiments of LogiQA2.0 NLI.
|
123 |
+
|
124 |
+
To fine-tune the dataset, type following command from the terminal in your :computer:
|
125 |
+
```
|
126 |
+
bash qa2nli.sh
|
127 |
+
```
|
128 |
+
Note: `./scripts` contains the scripts for running other NLI benchmarks.
|
129 |
+
|
130 |
+
## How to Cite
|
131 |
+
## Acknowledgment
|
132 |
+
We appreciate the suggestions and critical questions from the reviewers.
|
datasets/LogiQA2.0/logiqa/DATA/LOGIQA/datasource.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c25927ef6b1229b4957b55b65e8bade028a26ba982f62ce0c7d0e9dcf447da29
|
3 |
+
size 315
|
datasets/LogiQA2.0/logiqa/DATA/LOGIQA/dev.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bbefb563b7ddc02640ccdc314c1315d5727dba48539d0ecdd126fa351e511b09
|
3 |
+
size 1770764
|
datasets/LogiQA2.0/logiqa/DATA/LOGIQA/dev_fol.jsonl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e9f31a32324e5147b4fe1e963476c012a2c997b46899214c7b5639c7c4ef3c2f
|
3 |
+
size 16840405
|
datasets/LogiQA2.0/logiqa/DATA/LOGIQA/dev_zh.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a72a23160c9e12e15ea8c13e57af5032a7c37157573ebdd7e7c8e0ad34aef780
|
3 |
+
size 1202597
|
datasets/LogiQA2.0/logiqa/DATA/LOGIQA/ood_test.jsonl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:85aeb366be14f4180af021356cc07758e8a13669a4253753cd86f22f1f46dfff
|
3 |
+
size 668323
|
datasets/LogiQA2.0/logiqa/DATA/LOGIQA/readme.md
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# LogiQA 2.0 dataset
|
2 |
+
|
3 |
+
`train_fol.zip` is the training data with AMR and FOL annotations. The file is too big so we compressed it.
|
4 |
+
|
5 |
+
`dev_fol.jsonl` is the dev data with AMR and FOL annotations.
|
6 |
+
|
7 |
+
`test_fol.jsonl` is the test data with AMR and FOL annotations.
|
datasets/LogiQA2.0/logiqa/DATA/LOGIQA/statistics.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
with open('test.txt', 'r') as f:
|
4 |
+
file = f.readlines()
|
5 |
+
n = 1
|
6 |
+
l = 0
|
7 |
+
max = 0
|
8 |
+
for line in file:
|
9 |
+
|
10 |
+
line = json.loads(line)
|
11 |
+
text = line['options']
|
12 |
+
for option in text:
|
13 |
+
s = 0
|
14 |
+
l = l + len(option.split(" "))
|
15 |
+
s = s + len(option.split(" "))
|
16 |
+
n += 1
|
17 |
+
if s >= max:
|
18 |
+
max = s
|
19 |
+
result = l/n
|
20 |
+
print(result)
|
21 |
+
print(max)
|
datasets/LogiQA2.0/logiqa/DATA/LOGIQA/test.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:71940b37ae0184b677c253a148d57ad4e75d6113447b1563c2ca82483e4e4f8d
|
3 |
+
size 1740565
|
datasets/LogiQA2.0/logiqa/DATA/LOGIQA/test_fol.jsonl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c0e736cf1bf24560ae188c507e661b6e041d3661b1975dbce561c8464f31f486
|
3 |
+
size 16807118
|
datasets/LogiQA2.0/logiqa/DATA/LOGIQA/test_zh.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7a8db83ccb3ebdc8d5b3886fd0ad9346c7e565722d2d592987b24dd57f251853
|
3 |
+
size 1182510
|
datasets/LogiQA2.0/logiqa/DATA/LOGIQA/train.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:98eb412e8ed53b3d65da5ef75b00b7a0bbdea7970c05ad699291a2a0510922de
|
3 |
+
size 14045351
|
datasets/LogiQA2.0/logiqa/DATA/LOGIQA/train_fol.zip
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:25ce1665338ee99d3aedc4835abf134ea03263861380bba6173c22fed13fcc24
|
3 |
+
size 23744378
|
datasets/LogiQA2.0/logiqa/DATA/LOGIQA/train_zh.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d87a15811cda64cb021d43cb9bc1d282424a8dfb8e35e6a7f6d6a0b36b38a54e
|
3 |
+
size 9581270
|
datasets/LogiQA2.0/logiqa/DATA/LOGIQA/word_matching.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
with open('test.txt') as f:
|
3 |
+
file = f.readlines()
|
4 |
+
n = 0
|
5 |
+
l = 0
|
6 |
+
for line in file:
|
7 |
+
d = json.loads(line)
|
8 |
+
label = d['answer']
|
9 |
+
text = d['text']
|
10 |
+
options = d['options']
|
11 |
+
text_vocab = set(text.split(' '))
|
12 |
+
ratio = []
|
13 |
+
for option in options:
|
14 |
+
option_vocab = set(option.split(' '))
|
15 |
+
intersection = text_vocab.intersection(option_vocab)
|
16 |
+
ratio.append(len(intersection)/len(text_vocab))
|
17 |
+
value_prev = 0
|
18 |
+
for value in ratio:
|
19 |
+
if value >= value_prev:
|
20 |
+
value_prev = value
|
21 |
+
index = ratio.index(value_prev)
|
22 |
+
if index == label:
|
23 |
+
l += 1
|
24 |
+
n += 1
|
25 |
+
result = l/n
|
26 |
+
print(result)
|
datasets/LogiQA2.0/logiqa/logiqa.sh
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export DATA_DIR=./DATA
|
2 |
+
export TASK_NAME=LOGIQA
|
3 |
+
|
4 |
+
CUDA_VISIBLE_DEVICES=0,1 python run_mrc.py \
|
5 |
+
--model_type bert \
|
6 |
+
--model_name_or_path bert-base-uncased \
|
7 |
+
--task_name $TASK_NAME \
|
8 |
+
--do_train \
|
9 |
+
--do_eval \
|
10 |
+
--do_lower_case \
|
11 |
+
--data_dir $DATA_DIR/$TASK_NAME \
|
12 |
+
--max_seq_length 256 \
|
13 |
+
--per_gpu_eval_batch_size=4 \
|
14 |
+
--per_gpu_train_batch_size=4 \
|
15 |
+
--gradient_accumulation_steps 2\
|
16 |
+
--learning_rate 1e-5 \
|
17 |
+
--num_train_epochs 10.0 \
|
18 |
+
--logging_steps 5000 \
|
19 |
+
--save_steps 5000 \
|
20 |
+
--output_dir ./tmp/$TASK_NAME/ \
|
21 |
+
--overwrite_output_dir \
|
datasets/LogiQA2.0/logiqa/modeling_bart.py
ADDED
@@ -0,0 +1,1416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""PyTorch BART model, ported from the fairseq repo."""
|
16 |
+
import math
|
17 |
+
import random
|
18 |
+
import warnings
|
19 |
+
from typing import Dict, List, Optional, Tuple
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
import torch
|
23 |
+
import torch.nn.functional as F
|
24 |
+
from torch import Tensor, nn
|
25 |
+
from torch.nn import CrossEntropyLoss
|
26 |
+
|
27 |
+
from transformers.activations import ACT2FN
|
28 |
+
from transformers.configuration_bart import BartConfig
|
29 |
+
from transformers.file_utils import (
|
30 |
+
add_code_sample_docstrings,
|
31 |
+
add_end_docstrings,
|
32 |
+
add_start_docstrings,
|
33 |
+
add_start_docstrings_to_callable,
|
34 |
+
replace_return_docstrings,
|
35 |
+
)
|
36 |
+
from transformers.modeling_outputs import (
|
37 |
+
BaseModelOutput,
|
38 |
+
BaseModelOutputWithPast,
|
39 |
+
Seq2SeqLMOutput,
|
40 |
+
Seq2SeqModelOutput,
|
41 |
+
Seq2SeqQuestionAnsweringModelOutput,
|
42 |
+
Seq2SeqSequenceClassifierOutput,
|
43 |
+
)
|
44 |
+
from transformers.modeling_utils import PreTrainedModel
|
45 |
+
from transformers.utils import logging
|
46 |
+
|
47 |
+
|
48 |
+
logger = logging.get_logger(__name__)
|
49 |
+
|
50 |
+
_CONFIG_FOR_DOC = "BartConfig"
|
51 |
+
_TOKENIZER_FOR_DOC = "BartTokenizer"
|
52 |
+
|
53 |
+
|
54 |
+
BART_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
55 |
+
"facebook/bart-base",
|
56 |
+
"facebook/bart-large",
|
57 |
+
"facebook/bart-large-mnli",
|
58 |
+
"facebook/bart-large-cnn",
|
59 |
+
"facebook/bart-large-xsum",
|
60 |
+
"facebook/mbart-large-en-ro",
|
61 |
+
]
|
62 |
+
# This list is incomplete. See all BART models at https://huggingface.co/models?filter=bart
|
63 |
+
|
64 |
+
|
65 |
+
BART_START_DOCSTRING = r"""
|
66 |
+
|
67 |
+
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class. Use it as a regular PyTorch Module and
|
68 |
+
refer to the PyTorch documentation for all matters related to general usage and behavior.
|
69 |
+
|
70 |
+
Parameters:
|
71 |
+
config (:class:`~transformers.BartConfig`): Model configuration class with all the parameters of the model.
|
72 |
+
Initializing with a config file does not load the weights associated with the model, only the configuration.
|
73 |
+
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
|
74 |
+
|
75 |
+
"""
|
76 |
+
BART_GENERATION_EXAMPLE = r"""
|
77 |
+
Summarization example::
|
78 |
+
|
79 |
+
from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig
|
80 |
+
|
81 |
+
# see ``examples/summarization/bart/run_eval.py`` for a longer example
|
82 |
+
model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn')
|
83 |
+
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')
|
84 |
+
|
85 |
+
ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
|
86 |
+
inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')
|
87 |
+
|
88 |
+
# Generate Summary
|
89 |
+
summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=5, early_stopping=True)
|
90 |
+
print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids])
|
91 |
+
|
92 |
+
"""
|
93 |
+
|
94 |
+
BART_INPUTS_DOCSTRING = r"""
|
95 |
+
Args:
|
96 |
+
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
97 |
+
Indices of input sequence tokens in the vocabulary. Use BartTokenizer.encode to produce them.
|
98 |
+
Padding will be ignored by default should you provide it.
|
99 |
+
Indices can be obtained using :class:`transformers.BartTokenizer.encode(text)`.
|
100 |
+
attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
101 |
+
Mask to avoid performing attention on padding token indices in input_ids.
|
102 |
+
Mask values selected in ``[0, 1]``:
|
103 |
+
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
104 |
+
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`, defaults to :obj:`None`):
|
105 |
+
Tuple consists of (`last_hidden_state`, `optional`: `hidden_states`, `optional`: `attentions`)
|
106 |
+
`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`) is a sequence of hidden-states at the output of the last layer of the encoder.
|
107 |
+
Used in the cross-attention of the decoder.
|
108 |
+
decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`):
|
109 |
+
Provide for translation and summarization training. By default, the model will create this tensor by shifting the input_ids right, following the paper.
|
110 |
+
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
|
111 |
+
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
|
112 |
+
If you want to change padding behavior, you should read :func:`~transformers.modeling_bart._prepare_decoder_inputs` and modify.
|
113 |
+
See diagram 1 in the paper for more info on the default strategy
|
114 |
+
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
115 |
+
Contains pre-computed key and value hidden-states of the attention blocks.
|
116 |
+
Can be used to speed up decoding.
|
117 |
+
If ``past_key_values`` are used, the user can optionally input only the last
|
118 |
+
``decoder_input_ids`` (those that don't have their past key value states given to this model) of shape
|
119 |
+
:obj:`(batch_size, 1)` instead of all ``decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`.
|
120 |
+
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
121 |
+
If `use_cache` is True, ``past_key_values`` are returned and can be used to speed up decoding (see
|
122 |
+
``past_key_values``).
|
123 |
+
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
|
124 |
+
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
|
125 |
+
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
|
126 |
+
If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
|
127 |
+
return_dict (:obj:`bool`, `optional`, defaults to :obj:`None`):
|
128 |
+
If set to ``True``, the model will return a :class:`~transformers.file_utils.ModelOutput` instead of a
|
129 |
+
plain tuple.
|
130 |
+
"""
|
131 |
+
|
132 |
+
|
133 |
+
def invert_mask(attention_mask):
|
134 |
+
"""Turns 1->0, 0->1, False->True, True-> False"""
|
135 |
+
assert attention_mask.dim() == 2
|
136 |
+
return attention_mask.eq(0)
|
137 |
+
|
138 |
+
|
139 |
+
def _prepare_bart_decoder_inputs(
|
140 |
+
config, input_ids, decoder_input_ids=None, decoder_padding_mask=None, causal_mask_dtype=torch.float32
|
141 |
+
):
|
142 |
+
"""Prepare masks that ignore padding tokens in the decoder and a causal mask for the decoder if
|
143 |
+
none are provided. This mimics the default behavior in fairseq. To override it pass in masks.
|
144 |
+
Note: this is not called during generation
|
145 |
+
"""
|
146 |
+
pad_token_id = config.pad_token_id
|
147 |
+
if decoder_input_ids is None:
|
148 |
+
decoder_input_ids = shift_tokens_right(input_ids, pad_token_id)
|
149 |
+
bsz, tgt_len = decoder_input_ids.size()
|
150 |
+
if decoder_padding_mask is None:
|
151 |
+
decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id)
|
152 |
+
else:
|
153 |
+
decoder_padding_mask = invert_mask(decoder_padding_mask)
|
154 |
+
if decoder_padding_mask is not None and decoder_padding_mask.shape[1] > 1:
|
155 |
+
# never mask leading token, even if it is pad
|
156 |
+
decoder_padding_mask[:, 0] = decoder_padding_mask[:, 1]
|
157 |
+
causal_mask = torch.triu(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)), 1).to(
|
158 |
+
dtype=causal_mask_dtype, device=decoder_input_ids.device
|
159 |
+
)
|
160 |
+
return decoder_input_ids, decoder_padding_mask, causal_mask
|
161 |
+
|
162 |
+
|
163 |
+
class PretrainedBartModel(PreTrainedModel):
|
164 |
+
config_class = BartConfig
|
165 |
+
base_model_prefix = "model"
|
166 |
+
|
167 |
+
def _init_weights(self, module):
|
168 |
+
std = self.config.init_std
|
169 |
+
if isinstance(module, nn.Linear):
|
170 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
171 |
+
if module.bias is not None:
|
172 |
+
module.bias.data.zero_()
|
173 |
+
elif isinstance(module, SinusoidalPositionalEmbedding):
|
174 |
+
pass
|
175 |
+
elif isinstance(module, nn.Embedding):
|
176 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
177 |
+
if module.padding_idx is not None:
|
178 |
+
module.weight.data[module.padding_idx].zero_()
|
179 |
+
|
180 |
+
@property
|
181 |
+
def dummy_inputs(self):
|
182 |
+
pad_token = self.config.pad_token_id
|
183 |
+
input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
|
184 |
+
dummy_inputs = {
|
185 |
+
"attention_mask": input_ids.ne(pad_token),
|
186 |
+
"input_ids": input_ids,
|
187 |
+
}
|
188 |
+
return dummy_inputs
|
189 |
+
|
190 |
+
|
191 |
+
def _make_linear_from_emb(emb):
|
192 |
+
vocab_size, emb_size = emb.weight.shape
|
193 |
+
lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
|
194 |
+
lin_layer.weight.data = emb.weight.data
|
195 |
+
return lin_layer
|
196 |
+
|
197 |
+
|
198 |
+
# Helper Functions, mostly for making masks
|
199 |
+
def _check_shapes(shape_1, shape2):
|
200 |
+
if shape_1 != shape2:
|
201 |
+
raise AssertionError("shape mismatch: {} != {}".format(shape_1, shape2))
|
202 |
+
|
203 |
+
|
204 |
+
def shift_tokens_right(input_ids, pad_token_id):
|
205 |
+
"""Shift input ids one token to the right, and wrap the last non pad token (usually <eos>)."""
|
206 |
+
prev_output_tokens = input_ids.clone()
|
207 |
+
index_of_eos = (input_ids.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
|
208 |
+
prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze()
|
209 |
+
prev_output_tokens[:, 1:] = input_ids[:, :-1]
|
210 |
+
return prev_output_tokens
|
211 |
+
|
212 |
+
|
213 |
+
def make_padding_mask(input_ids, padding_idx=1):
|
214 |
+
"""True for pad tokens"""
|
215 |
+
padding_mask = input_ids.eq(padding_idx)
|
216 |
+
if not padding_mask.any():
|
217 |
+
padding_mask = None
|
218 |
+
return padding_mask
|
219 |
+
|
220 |
+
|
221 |
+
# Adapter
|
222 |
+
class Adapter(nn.Module):
|
223 |
+
def __init__(self, config):
|
224 |
+
super(Adapter, self).__init__()
|
225 |
+
self.down_project = nn.Linear(config.hidden_size, config.adapter_size)
|
226 |
+
self.activation = ACT2FN[config.adapter_act] \
|
227 |
+
if isinstance(config.adapter_act, str) else config.adapter_act
|
228 |
+
self.up_project = nn.Linear(config.adapter_size, config.hidden_size)
|
229 |
+
self.init_weights(config)
|
230 |
+
|
231 |
+
def forward(self, hidden_states):
|
232 |
+
down_projected = self.down_project(hidden_states)
|
233 |
+
activated = self.activation(down_projected)
|
234 |
+
up_projected = self.up_project(activated)
|
235 |
+
return hidden_states + up_projected
|
236 |
+
|
237 |
+
def init_weights(self, config):
|
238 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
239 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
240 |
+
self.down_project.weight.data.normal_(mean=0.0, std=config.adapter_initializer_range)
|
241 |
+
self.down_project.bias.data.zero_()
|
242 |
+
self.up_project.weight.data.normal_(mean=0.0, std=config.adapter_initializer_range)
|
243 |
+
self.up_project.bias.data.zero_()
|
244 |
+
|
245 |
+
# Helper Modules
|
246 |
+
|
247 |
+
|
248 |
+
class EncoderLayer(nn.Module):
|
249 |
+
def __init__(self, config: BartConfig):
|
250 |
+
super().__init__()
|
251 |
+
self.embed_dim = config.d_model
|
252 |
+
self.self_attn = Attention(self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout)
|
253 |
+
self.normalize_before = config.normalize_before
|
254 |
+
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
|
255 |
+
self.dropout = config.dropout
|
256 |
+
self.activation_fn = ACT2FN[config.activation_function]
|
257 |
+
self.activation_dropout = config.activation_dropout
|
258 |
+
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
|
259 |
+
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
260 |
+
self.final_layer_norm = LayerNorm(self.embed_dim)
|
261 |
+
|
262 |
+
config.adapter_size = 256
|
263 |
+
config.adapter_act = "gelu"
|
264 |
+
config.adapter_initializer_range=0.0002
|
265 |
+
self.adapter = Adapter(config)
|
266 |
+
|
267 |
+
def forward(self, x, encoder_padding_mask, output_attentions=False):
|
268 |
+
"""
|
269 |
+
Args:
|
270 |
+
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
|
271 |
+
encoder_padding_mask (ByteTensor): binary ByteTensor of shape
|
272 |
+
`(batch, src_len)` where padding elements are indicated by ``1``.
|
273 |
+
for t_tgt, t_src is excluded (or masked out), =0 means it is
|
274 |
+
included in attention
|
275 |
+
|
276 |
+
Returns:
|
277 |
+
encoded output of shape `(seq_len, batch, embed_dim)`
|
278 |
+
"""
|
279 |
+
residual = x
|
280 |
+
if self.normalize_before:
|
281 |
+
x = self.self_attn_layer_norm(x)
|
282 |
+
x, attn_weights = self.self_attn(
|
283 |
+
query=x, key=x, key_padding_mask=encoder_padding_mask, output_attentions=output_attentions
|
284 |
+
)
|
285 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
286 |
+
|
287 |
+
# add adapter
|
288 |
+
x = self.adapter(x)
|
289 |
+
|
290 |
+
x = residual + x
|
291 |
+
if not self.normalize_before:
|
292 |
+
x = self.self_attn_layer_norm(x)
|
293 |
+
|
294 |
+
residual = x
|
295 |
+
if self.normalize_before:
|
296 |
+
x = self.final_layer_norm(x)
|
297 |
+
x = self.activation_fn(self.fc1(x))
|
298 |
+
x = F.dropout(x, p=self.activation_dropout, training=self.training)
|
299 |
+
x = self.fc2(x)
|
300 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
301 |
+
x = residual + x
|
302 |
+
if not self.normalize_before:
|
303 |
+
x = self.final_layer_norm(x)
|
304 |
+
return x, attn_weights
|
305 |
+
|
306 |
+
|
307 |
+
class BartEncoder(nn.Module):
|
308 |
+
"""
|
309 |
+
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer
|
310 |
+
is a :class:`EncoderLayer`.
|
311 |
+
|
312 |
+
Args:
|
313 |
+
config: BartConfig
|
314 |
+
"""
|
315 |
+
|
316 |
+
def __init__(self, config: BartConfig, embed_tokens):
|
317 |
+
super().__init__()
|
318 |
+
|
319 |
+
self.dropout = config.dropout
|
320 |
+
self.layerdrop = config.encoder_layerdrop
|
321 |
+
|
322 |
+
embed_dim = embed_tokens.embedding_dim
|
323 |
+
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
324 |
+
self.padding_idx = embed_tokens.padding_idx
|
325 |
+
self.max_source_positions = config.max_position_embeddings
|
326 |
+
|
327 |
+
self.embed_tokens = embed_tokens
|
328 |
+
if config.static_position_embeddings:
|
329 |
+
self.embed_positions = SinusoidalPositionalEmbedding(
|
330 |
+
config.max_position_embeddings, embed_dim, self.padding_idx
|
331 |
+
)
|
332 |
+
else:
|
333 |
+
self.embed_positions = LearnedPositionalEmbedding(
|
334 |
+
config.max_position_embeddings,
|
335 |
+
embed_dim,
|
336 |
+
self.padding_idx,
|
337 |
+
config.extra_pos_embeddings,
|
338 |
+
)
|
339 |
+
self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)])
|
340 |
+
self.layernorm_embedding = LayerNorm(embed_dim) if config.normalize_embedding else nn.Identity()
|
341 |
+
# mbart has one extra layer_norm
|
342 |
+
self.layer_norm = LayerNorm(config.d_model) if config.normalize_before else None
|
343 |
+
|
344 |
+
def forward(
|
345 |
+
self, input_ids, attention_mask=None, output_attentions=False, output_hidden_states=False, return_dict=False
|
346 |
+
):
|
347 |
+
"""
|
348 |
+
Args:
|
349 |
+
input_ids (LongTensor): tokens in the source language of shape
|
350 |
+
`(batch, src_len)`
|
351 |
+
attention_mask (torch.LongTensor): indicating which indices are padding tokens.
|
352 |
+
Returns:
|
353 |
+
BaseModelOutput or Tuple comprised of:
|
354 |
+
- **x** (Tensor): the last encoder layer's output of
|
355 |
+
shape `(src_len, batch, embed_dim)`
|
356 |
+
- **encoder_states** (tuple(torch.FloatTensor)): all intermediate
|
357 |
+
hidden states of shape `(src_len, batch, embed_dim)`.
|
358 |
+
Only populated if *output_hidden_states:* is True.
|
359 |
+
- **all_attentions** (tuple(torch.FloatTensor)): Attention weights for each layer.
|
360 |
+
During training might not be of length n_layers because of layer dropout.
|
361 |
+
"""
|
362 |
+
# check attention mask and invert
|
363 |
+
if attention_mask is not None:
|
364 |
+
attention_mask = invert_mask(attention_mask)
|
365 |
+
|
366 |
+
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
367 |
+
embed_pos = self.embed_positions(input_ids)
|
368 |
+
x = inputs_embeds + embed_pos
|
369 |
+
x = self.layernorm_embedding(x)
|
370 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
371 |
+
|
372 |
+
# B x T x C -> T x B x C
|
373 |
+
x = x.transpose(0, 1)
|
374 |
+
|
375 |
+
encoder_states = [] if output_hidden_states else None
|
376 |
+
all_attentions = () if output_attentions else None
|
377 |
+
for encoder_layer in self.layers:
|
378 |
+
if output_hidden_states:
|
379 |
+
encoder_states.append(x)
|
380 |
+
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
381 |
+
dropout_probability = random.uniform(0, 1)
|
382 |
+
if self.training and (dropout_probability < self.layerdrop): # skip the layer
|
383 |
+
attn = None
|
384 |
+
else:
|
385 |
+
x, attn = encoder_layer(x, attention_mask, output_attentions=output_attentions)
|
386 |
+
|
387 |
+
if output_attentions:
|
388 |
+
all_attentions = all_attentions + (attn,)
|
389 |
+
|
390 |
+
if self.layer_norm:
|
391 |
+
x = self.layer_norm(x)
|
392 |
+
if output_hidden_states:
|
393 |
+
encoder_states.append(x)
|
394 |
+
# T x B x C -> B x T x C
|
395 |
+
encoder_states = tuple(hidden_state.transpose(0, 1) for hidden_state in encoder_states)
|
396 |
+
|
397 |
+
# T x B x C -> B x T x C
|
398 |
+
x = x.transpose(0, 1)
|
399 |
+
|
400 |
+
if not return_dict:
|
401 |
+
return tuple(v for v in [x, encoder_states, all_attentions] if v is not None)
|
402 |
+
return BaseModelOutput(last_hidden_state=x, hidden_states=encoder_states, attentions=all_attentions)
|
403 |
+
|
404 |
+
|
405 |
+
class DecoderLayer(nn.Module):
|
406 |
+
def __init__(self, config: BartConfig):
|
407 |
+
super().__init__()
|
408 |
+
self.embed_dim = config.d_model
|
409 |
+
|
410 |
+
self.self_attn = Attention(
|
411 |
+
embed_dim=self.embed_dim,
|
412 |
+
num_heads=config.decoder_attention_heads,
|
413 |
+
dropout=config.attention_dropout,
|
414 |
+
)
|
415 |
+
self.dropout = config.dropout
|
416 |
+
self.activation_fn = ACT2FN[config.activation_function]
|
417 |
+
self.activation_dropout = config.activation_dropout
|
418 |
+
self.normalize_before = config.normalize_before
|
419 |
+
|
420 |
+
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
|
421 |
+
self.encoder_attn = Attention(
|
422 |
+
self.embed_dim,
|
423 |
+
config.decoder_attention_heads,
|
424 |
+
dropout=config.attention_dropout,
|
425 |
+
encoder_decoder_attention=True,
|
426 |
+
)
|
427 |
+
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)
|
428 |
+
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
|
429 |
+
self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
|
430 |
+
self.final_layer_norm = LayerNorm(self.embed_dim)
|
431 |
+
config.adapter_size = 256
|
432 |
+
config.adapter_act = "gelu"
|
433 |
+
config.adapter_initializer_range=0.0002
|
434 |
+
self.adapter = Adapter(config)
|
435 |
+
|
436 |
+
def forward(
|
437 |
+
self,
|
438 |
+
x,
|
439 |
+
encoder_hidden_states,
|
440 |
+
encoder_attn_mask=None,
|
441 |
+
layer_state=None,
|
442 |
+
causal_mask=None,
|
443 |
+
decoder_padding_mask=None,
|
444 |
+
output_attentions=False,
|
445 |
+
):
|
446 |
+
residual = x
|
447 |
+
|
448 |
+
if layer_state is None:
|
449 |
+
layer_state = {}
|
450 |
+
if self.normalize_before:
|
451 |
+
x = self.self_attn_layer_norm(x)
|
452 |
+
# Self Attention
|
453 |
+
|
454 |
+
x, self_attn_weights = self.self_attn(
|
455 |
+
query=x,
|
456 |
+
key=x,
|
457 |
+
layer_state=layer_state, # adds keys to layer state
|
458 |
+
key_padding_mask=decoder_padding_mask,
|
459 |
+
attn_mask=causal_mask,
|
460 |
+
output_attentions=output_attentions,
|
461 |
+
)
|
462 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
463 |
+
|
464 |
+
x = self.adapter(x)
|
465 |
+
|
466 |
+
x = residual + x
|
467 |
+
if not self.normalize_before:
|
468 |
+
x = self.self_attn_layer_norm(x)
|
469 |
+
|
470 |
+
# Cross attention
|
471 |
+
residual = x
|
472 |
+
assert self.encoder_attn.cache_key != self.self_attn.cache_key
|
473 |
+
if self.normalize_before:
|
474 |
+
x = self.encoder_attn_layer_norm(x)
|
475 |
+
x, _ = self.encoder_attn(
|
476 |
+
query=x,
|
477 |
+
key=encoder_hidden_states,
|
478 |
+
key_padding_mask=encoder_attn_mask,
|
479 |
+
layer_state=layer_state, # mutates layer state
|
480 |
+
)
|
481 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
482 |
+
x = residual + x
|
483 |
+
if not self.normalize_before:
|
484 |
+
x = self.encoder_attn_layer_norm(x)
|
485 |
+
|
486 |
+
# Fully Connected
|
487 |
+
residual = x
|
488 |
+
if self.normalize_before:
|
489 |
+
x = self.final_layer_norm(x)
|
490 |
+
x = self.activation_fn(self.fc1(x))
|
491 |
+
x = F.dropout(x, p=self.activation_dropout, training=self.training)
|
492 |
+
x = self.fc2(x)
|
493 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
494 |
+
x = residual + x
|
495 |
+
if not self.normalize_before:
|
496 |
+
x = self.final_layer_norm(x)
|
497 |
+
return (
|
498 |
+
x,
|
499 |
+
self_attn_weights,
|
500 |
+
layer_state,
|
501 |
+
) # just self_attn weights for now, following t5, layer_state = cache for decoding
|
502 |
+
|
503 |
+
|
504 |
+
class BartDecoder(nn.Module):
|
505 |
+
"""
|
506 |
+
Transformer decoder consisting of *config.decoder_layers* layers. Each layer
|
507 |
+
is a :class:`DecoderLayer`.
|
508 |
+
Args:
|
509 |
+
config: BartConfig
|
510 |
+
embed_tokens (torch.nn.Embedding): output embedding
|
511 |
+
"""
|
512 |
+
|
513 |
+
def __init__(self, config: BartConfig, embed_tokens: nn.Embedding):
|
514 |
+
super().__init__()
|
515 |
+
self.dropout = config.dropout
|
516 |
+
self.layerdrop = config.decoder_layerdrop
|
517 |
+
self.padding_idx = embed_tokens.padding_idx
|
518 |
+
self.max_target_positions = config.max_position_embeddings
|
519 |
+
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
520 |
+
self.embed_tokens = embed_tokens
|
521 |
+
if config.static_position_embeddings:
|
522 |
+
self.embed_positions = SinusoidalPositionalEmbedding(
|
523 |
+
config.max_position_embeddings, config.d_model, config.pad_token_id
|
524 |
+
)
|
525 |
+
else:
|
526 |
+
self.embed_positions = LearnedPositionalEmbedding(
|
527 |
+
config.max_position_embeddings,
|
528 |
+
config.d_model,
|
529 |
+
self.padding_idx,
|
530 |
+
config.extra_pos_embeddings,
|
531 |
+
)
|
532 |
+
self.layers = nn.ModuleList(
|
533 |
+
[DecoderLayer(config) for _ in range(config.decoder_layers)]
|
534 |
+
) # type: List[DecoderLayer]
|
535 |
+
self.layernorm_embedding = LayerNorm(config.d_model) if config.normalize_embedding else nn.Identity()
|
536 |
+
self.layer_norm = LayerNorm(config.d_model) if config.add_final_layer_norm else None
|
537 |
+
|
538 |
+
def forward(
|
539 |
+
self,
|
540 |
+
input_ids,
|
541 |
+
encoder_hidden_states,
|
542 |
+
encoder_padding_mask,
|
543 |
+
decoder_padding_mask,
|
544 |
+
decoder_causal_mask,
|
545 |
+
past_key_values=None,
|
546 |
+
use_cache=False,
|
547 |
+
output_attentions=False,
|
548 |
+
output_hidden_states=False,
|
549 |
+
return_dict=False,
|
550 |
+
**unused,
|
551 |
+
):
|
552 |
+
"""
|
553 |
+
Includes several features from "Jointly Learning to Align and
|
554 |
+
Translate with Transformer Models" (Garg et al., EMNLP 2019).
|
555 |
+
|
556 |
+
Args:
|
557 |
+
input_ids (LongTensor): previous decoder outputs of shape
|
558 |
+
`(batch, tgt_len)`, for teacher forcing
|
559 |
+
encoder_hidden_states: output from the encoder, used for
|
560 |
+
encoder-side attention
|
561 |
+
encoder_padding_mask: for ignoring pad tokens
|
562 |
+
past_key_values (dict or None): dictionary used for storing state during generation
|
563 |
+
|
564 |
+
Returns:
|
565 |
+
BaseModelOutputWithPast or tuple:
|
566 |
+
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
|
567 |
+
- the cache
|
568 |
+
- hidden states
|
569 |
+
- attentions
|
570 |
+
"""
|
571 |
+
if "decoder_cached_states" in unused:
|
572 |
+
warnings.warn(
|
573 |
+
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
574 |
+
FutureWarning,
|
575 |
+
)
|
576 |
+
past_key_values = unused.pop("decoder_cached_states")
|
577 |
+
if "decoder_past_key_values" in unused:
|
578 |
+
warnings.warn(
|
579 |
+
"The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
580 |
+
FutureWarning,
|
581 |
+
)
|
582 |
+
past_key_values = unused.pop("decoder_past_key_values")
|
583 |
+
|
584 |
+
# check attention mask and invert
|
585 |
+
if encoder_padding_mask is not None:
|
586 |
+
encoder_padding_mask = invert_mask(encoder_padding_mask)
|
587 |
+
|
588 |
+
# embed positions
|
589 |
+
positions = self.embed_positions(input_ids, use_cache=use_cache)
|
590 |
+
|
591 |
+
if use_cache:
|
592 |
+
input_ids = input_ids[:, -1:]
|
593 |
+
positions = positions[:, -1:] # happens after we embed them
|
594 |
+
# assert input_ids.ne(self.padding_idx).any()
|
595 |
+
|
596 |
+
x = self.embed_tokens(input_ids) * self.embed_scale
|
597 |
+
x += positions
|
598 |
+
x = self.layernorm_embedding(x)
|
599 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
600 |
+
|
601 |
+
# Convert to Bart output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
|
602 |
+
x = x.transpose(0, 1)
|
603 |
+
encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
|
604 |
+
|
605 |
+
# decoder layers
|
606 |
+
all_hidden_states = () if output_hidden_states else None
|
607 |
+
all_self_attns = () if output_attentions else None
|
608 |
+
next_decoder_cache = []
|
609 |
+
for idx, decoder_layer in enumerate(self.layers):
|
610 |
+
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
611 |
+
if output_hidden_states:
|
612 |
+
all_hidden_states += (x,)
|
613 |
+
dropout_probability = random.uniform(0, 1)
|
614 |
+
if self.training and (dropout_probability < self.layerdrop):
|
615 |
+
continue
|
616 |
+
|
617 |
+
layer_state = past_key_values[idx] if past_key_values is not None else None
|
618 |
+
|
619 |
+
x, layer_self_attn, layer_past = decoder_layer(
|
620 |
+
x,
|
621 |
+
encoder_hidden_states,
|
622 |
+
encoder_attn_mask=encoder_padding_mask,
|
623 |
+
decoder_padding_mask=decoder_padding_mask,
|
624 |
+
layer_state=layer_state,
|
625 |
+
causal_mask=decoder_causal_mask,
|
626 |
+
output_attentions=output_attentions,
|
627 |
+
)
|
628 |
+
|
629 |
+
if use_cache:
|
630 |
+
next_decoder_cache.append(layer_past.copy())
|
631 |
+
|
632 |
+
if self.layer_norm and (idx == len(self.layers) - 1): # if config.add_final_layer_norm (mBART)
|
633 |
+
x = self.layer_norm(x)
|
634 |
+
if output_attentions:
|
635 |
+
all_self_attns += (layer_self_attn,)
|
636 |
+
|
637 |
+
# Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
|
638 |
+
if output_hidden_states:
|
639 |
+
all_hidden_states = tuple(hidden_state.transpose(0, 1) for hidden_state in all_hidden_states)
|
640 |
+
x = x.transpose(0, 1)
|
641 |
+
encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
|
642 |
+
|
643 |
+
next_cache = next_decoder_cache if use_cache else None
|
644 |
+
|
645 |
+
if not return_dict:
|
646 |
+
return tuple(v for v in [x, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
647 |
+
return BaseModelOutputWithPast(
|
648 |
+
last_hidden_state=x, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns
|
649 |
+
)
|
650 |
+
|
651 |
+
|
652 |
+
def _reorder_buffer(attn_cache, new_order):
|
653 |
+
for k, input_buffer_k in attn_cache.items():
|
654 |
+
if input_buffer_k is not None:
|
655 |
+
attn_cache[k] = input_buffer_k.index_select(0, new_order)
|
656 |
+
return attn_cache
|
657 |
+
|
658 |
+
|
659 |
+
class Attention(nn.Module):
|
660 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
661 |
+
|
662 |
+
def __init__(
|
663 |
+
self,
|
664 |
+
embed_dim,
|
665 |
+
num_heads,
|
666 |
+
dropout=0.0,
|
667 |
+
bias=True,
|
668 |
+
encoder_decoder_attention=False, # otherwise self_attention
|
669 |
+
):
|
670 |
+
super().__init__()
|
671 |
+
self.embed_dim = embed_dim
|
672 |
+
self.num_heads = num_heads
|
673 |
+
self.dropout = dropout
|
674 |
+
self.head_dim = embed_dim // num_heads
|
675 |
+
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
676 |
+
self.scaling = self.head_dim ** -0.5
|
677 |
+
|
678 |
+
self.encoder_decoder_attention = encoder_decoder_attention
|
679 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
680 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
681 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
682 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
683 |
+
self.cache_key = "encoder_decoder" if self.encoder_decoder_attention else "self"
|
684 |
+
|
685 |
+
def _shape(self, tensor, seq_len, bsz):
|
686 |
+
return tensor.contiguous().view(seq_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
687 |
+
|
688 |
+
def forward(
|
689 |
+
self,
|
690 |
+
query,
|
691 |
+
key: Optional[Tensor],
|
692 |
+
key_padding_mask: Optional[Tensor] = None,
|
693 |
+
layer_state: Optional[Dict[str, Optional[Tensor]]] = None,
|
694 |
+
attn_mask: Optional[Tensor] = None,
|
695 |
+
output_attentions=False,
|
696 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
697 |
+
"""Input shape: Time(SeqLen) x Batch x Channel"""
|
698 |
+
static_kv: bool = self.encoder_decoder_attention
|
699 |
+
tgt_len, bsz, embed_dim = query.size()
|
700 |
+
assert embed_dim == self.embed_dim
|
701 |
+
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
702 |
+
# get here for encoder decoder cause of static_kv
|
703 |
+
if layer_state is not None: # reuse k,v and encoder_padding_mask
|
704 |
+
saved_state = layer_state.get(self.cache_key, {})
|
705 |
+
if "prev_key" in saved_state and static_kv:
|
706 |
+
# previous time steps are cached - no need to recompute key and value if they are static
|
707 |
+
key = None
|
708 |
+
else:
|
709 |
+
saved_state = None
|
710 |
+
layer_state = {}
|
711 |
+
|
712 |
+
q = self.q_proj(query) * self.scaling
|
713 |
+
if static_kv:
|
714 |
+
if key is None:
|
715 |
+
k = v = None
|
716 |
+
else:
|
717 |
+
k = self.k_proj(key)
|
718 |
+
v = self.v_proj(key)
|
719 |
+
else:
|
720 |
+
k = self.k_proj(query)
|
721 |
+
v = self.v_proj(query)
|
722 |
+
|
723 |
+
q = self._shape(q, tgt_len, bsz)
|
724 |
+
if k is not None:
|
725 |
+
k = self._shape(k, -1, bsz)
|
726 |
+
if v is not None:
|
727 |
+
v = self._shape(v, -1, bsz)
|
728 |
+
|
729 |
+
if saved_state is not None:
|
730 |
+
k, v, key_padding_mask = self._use_saved_state(k, v, saved_state, key_padding_mask, static_kv, bsz)
|
731 |
+
|
732 |
+
# Update cache
|
733 |
+
layer_state[self.cache_key] = {
|
734 |
+
"prev_key": k.view(bsz, self.num_heads, -1, self.head_dim),
|
735 |
+
"prev_value": v.view(bsz, self.num_heads, -1, self.head_dim),
|
736 |
+
"prev_key_padding_mask": key_padding_mask if not static_kv else None,
|
737 |
+
}
|
738 |
+
|
739 |
+
assert k is not None
|
740 |
+
src_len = k.size(1)
|
741 |
+
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
742 |
+
assert attn_weights.size() == (bsz * self.num_heads, tgt_len, src_len)
|
743 |
+
|
744 |
+
if attn_mask is not None:
|
745 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_mask
|
746 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
747 |
+
|
748 |
+
# This is part of a workaround to get around fork/join parallelism not supporting Optional types.
|
749 |
+
if key_padding_mask is not None and key_padding_mask.dim() == 0:
|
750 |
+
key_padding_mask = None
|
751 |
+
assert key_padding_mask is None or key_padding_mask.size()[:2] == (
|
752 |
+
bsz,
|
753 |
+
src_len,
|
754 |
+
)
|
755 |
+
|
756 |
+
if key_padding_mask is not None: # don't attend to padding symbols
|
757 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
758 |
+
reshaped = key_padding_mask.unsqueeze(1).unsqueeze(2)
|
759 |
+
attn_weights = attn_weights.masked_fill(reshaped, float("-inf"))
|
760 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
761 |
+
attn_weights = F.softmax(attn_weights, dim=-1)
|
762 |
+
attn_probs = F.dropout(
|
763 |
+
attn_weights,
|
764 |
+
p=self.dropout,
|
765 |
+
training=self.training,
|
766 |
+
)
|
767 |
+
|
768 |
+
assert v is not None
|
769 |
+
attn_output = torch.bmm(attn_probs, v)
|
770 |
+
assert attn_output.size() == (bsz * self.num_heads, tgt_len, self.head_dim)
|
771 |
+
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
772 |
+
attn_output = self.out_proj(attn_output)
|
773 |
+
if output_attentions:
|
774 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
775 |
+
else:
|
776 |
+
attn_weights = None
|
777 |
+
return attn_output, attn_weights
|
778 |
+
|
779 |
+
def _use_saved_state(self, k, v, saved_state, key_padding_mask, static_kv, bsz):
|
780 |
+
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
781 |
+
if "prev_key" in saved_state:
|
782 |
+
_prev_key = saved_state["prev_key"]
|
783 |
+
assert _prev_key is not None
|
784 |
+
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
|
785 |
+
if static_kv:
|
786 |
+
k = prev_key
|
787 |
+
else:
|
788 |
+
assert k is not None
|
789 |
+
k = torch.cat([prev_key, k], dim=1)
|
790 |
+
if "prev_value" in saved_state:
|
791 |
+
_prev_value = saved_state["prev_value"]
|
792 |
+
assert _prev_value is not None
|
793 |
+
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
|
794 |
+
if static_kv:
|
795 |
+
v = prev_value
|
796 |
+
else:
|
797 |
+
assert v is not None
|
798 |
+
v = torch.cat([prev_value, v], dim=1)
|
799 |
+
assert k is not None and v is not None
|
800 |
+
prev_key_padding_mask: Optional[Tensor] = saved_state.get("prev_key_padding_mask", None)
|
801 |
+
if prev_key_padding_mask is not None:
|
802 |
+
if static_kv:
|
803 |
+
new_key_padding_mask = prev_key_padding_mask
|
804 |
+
else:
|
805 |
+
new_key_padding_mask = torch.cat([prev_key_padding_mask, key_padding_mask], dim=1)
|
806 |
+
else:
|
807 |
+
new_key_padding_mask = key_padding_mask
|
808 |
+
return k, v, new_key_padding_mask
|
809 |
+
|
810 |
+
|
811 |
+
class BartClassificationHead(nn.Module):
|
812 |
+
"""Head for sentence-level classification tasks."""
|
813 |
+
|
814 |
+
# This can trivially be shared with RobertaClassificationHead
|
815 |
+
|
816 |
+
def __init__(
|
817 |
+
self,
|
818 |
+
input_dim,
|
819 |
+
inner_dim,
|
820 |
+
num_classes,
|
821 |
+
pooler_dropout,
|
822 |
+
):
|
823 |
+
super().__init__()
|
824 |
+
self.dense = nn.Linear(input_dim, inner_dim)
|
825 |
+
self.dropout = nn.Dropout(p=pooler_dropout)
|
826 |
+
self.out_proj = nn.Linear(inner_dim, num_classes)
|
827 |
+
|
828 |
+
def forward(self, x):
|
829 |
+
x = self.dropout(x)
|
830 |
+
x = self.dense(x)
|
831 |
+
x = torch.tanh(x)
|
832 |
+
x = self.dropout(x)
|
833 |
+
x = self.out_proj(x)
|
834 |
+
return x
|
835 |
+
|
836 |
+
|
837 |
+
class LearnedPositionalEmbedding(nn.Embedding):
|
838 |
+
"""
|
839 |
+
This module learns positional embeddings up to a fixed maximum size.
|
840 |
+
Padding ids are ignored by either offsetting based on padding_idx
|
841 |
+
or by setting padding_idx to None and ensuring that the appropriate
|
842 |
+
position ids are passed to the forward function.
|
843 |
+
"""
|
844 |
+
|
845 |
+
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, offset):
|
846 |
+
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
|
847 |
+
# and adjust num_embeddings appropriately. Other models dont have this hack
|
848 |
+
self.offset = offset
|
849 |
+
assert padding_idx is not None
|
850 |
+
num_embeddings += offset
|
851 |
+
super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx)
|
852 |
+
|
853 |
+
def forward(self, input_ids, use_cache=False):
|
854 |
+
"""Input is expected to be of size [bsz x seqlen]."""
|
855 |
+
bsz, seq_len = input_ids.shape[:2]
|
856 |
+
if use_cache:
|
857 |
+
positions = input_ids.data.new(1, 1).fill_(seq_len - 1) # called before slicing
|
858 |
+
else:
|
859 |
+
# starts at 0, ends at 1-seq_len
|
860 |
+
positions = torch.arange(seq_len, dtype=torch.long, device=self.weight.device)
|
861 |
+
return super().forward(positions + self.offset)
|
862 |
+
|
863 |
+
|
864 |
+
def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True):
|
865 |
+
if torch.cuda.is_available():
|
866 |
+
try:
|
867 |
+
from apex.normalization import FusedLayerNorm
|
868 |
+
|
869 |
+
return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
|
870 |
+
except ImportError:
|
871 |
+
pass
|
872 |
+
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
|
873 |
+
|
874 |
+
|
875 |
+
def fill_with_neg_inf(t):
|
876 |
+
"""FP16-compatible function that fills a input_ids with -inf."""
|
877 |
+
return t.float().fill_(float("-inf")).type_as(t)
|
878 |
+
|
879 |
+
|
880 |
+
# Public API
|
881 |
+
def _get_shape(t):
|
882 |
+
return getattr(t, "shape", None)
|
883 |
+
|
884 |
+
|
885 |
+
@add_start_docstrings(
|
886 |
+
"The bare BART Model outputting raw hidden-states without any specific head on top.",
|
887 |
+
BART_START_DOCSTRING,
|
888 |
+
)
|
889 |
+
class BartModel(PretrainedBartModel):
|
890 |
+
def __init__(self, config: BartConfig):
|
891 |
+
super().__init__(config)
|
892 |
+
|
893 |
+
padding_idx, vocab_size = config.pad_token_id, config.vocab_size
|
894 |
+
self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
|
895 |
+
|
896 |
+
self.encoder = BartEncoder(config, self.shared)
|
897 |
+
self.decoder = BartDecoder(config, self.shared)
|
898 |
+
|
899 |
+
self.init_weights()
|
900 |
+
for param in self.parameters():
|
901 |
+
param.requires_grad = False
|
902 |
+
for name, sub_module in self.named_modules():
|
903 |
+
if isinstance(sub_module, (Adapter, torch.nn.LayerNorm,
|
904 |
+
)):
|
905 |
+
for param_name, param in sub_module.named_parameters():
|
906 |
+
param.requires_grad = True
|
907 |
+
|
908 |
+
@add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
|
909 |
+
@add_code_sample_docstrings(
|
910 |
+
tokenizer_class=_TOKENIZER_FOR_DOC,
|
911 |
+
checkpoint="facebook/bart-large",
|
912 |
+
output_type=BaseModelOutputWithPast,
|
913 |
+
config_class=_CONFIG_FOR_DOC,
|
914 |
+
)
|
915 |
+
def forward(
|
916 |
+
self,
|
917 |
+
input_ids,
|
918 |
+
attention_mask=None,
|
919 |
+
decoder_input_ids=None,
|
920 |
+
encoder_outputs: Optional[Tuple] = None,
|
921 |
+
decoder_attention_mask=None,
|
922 |
+
past_key_values=None,
|
923 |
+
use_cache=None,
|
924 |
+
output_attentions=None,
|
925 |
+
output_hidden_states=None,
|
926 |
+
return_dict=None,
|
927 |
+
**kwargs,
|
928 |
+
):
|
929 |
+
if "decoder_past_key_values" in kwargs:
|
930 |
+
warnings.warn(
|
931 |
+
"The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
932 |
+
FutureWarning,
|
933 |
+
)
|
934 |
+
past_key_values = kwargs.pop("decoder_past_key_values")
|
935 |
+
|
936 |
+
if decoder_input_ids is None:
|
937 |
+
use_cache = False
|
938 |
+
|
939 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
940 |
+
output_hidden_states = (
|
941 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
942 |
+
)
|
943 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
944 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
945 |
+
|
946 |
+
# make masks if user doesn't supply
|
947 |
+
if not use_cache:
|
948 |
+
decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_bart_decoder_inputs(
|
949 |
+
self.config,
|
950 |
+
input_ids,
|
951 |
+
decoder_input_ids=decoder_input_ids,
|
952 |
+
decoder_padding_mask=decoder_attention_mask,
|
953 |
+
causal_mask_dtype=self.shared.weight.dtype,
|
954 |
+
)
|
955 |
+
else:
|
956 |
+
decoder_padding_mask, causal_mask = None, None
|
957 |
+
|
958 |
+
assert decoder_input_ids is not None
|
959 |
+
|
960 |
+
if encoder_outputs is None:
|
961 |
+
encoder_outputs = self.encoder(
|
962 |
+
input_ids=input_ids,
|
963 |
+
attention_mask=attention_mask,
|
964 |
+
output_attentions=output_attentions,
|
965 |
+
output_hidden_states=output_hidden_states,
|
966 |
+
return_dict=return_dict,
|
967 |
+
)
|
968 |
+
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOuput when return_dict=False
|
969 |
+
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
|
970 |
+
encoder_outputs = BaseModelOutput(
|
971 |
+
last_hidden_state=encoder_outputs[0],
|
972 |
+
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
973 |
+
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
974 |
+
)
|
975 |
+
|
976 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
977 |
+
decoder_outputs = self.decoder(
|
978 |
+
decoder_input_ids,
|
979 |
+
encoder_outputs[0],
|
980 |
+
attention_mask,
|
981 |
+
decoder_padding_mask,
|
982 |
+
decoder_causal_mask=causal_mask,
|
983 |
+
past_key_values=past_key_values,
|
984 |
+
use_cache=use_cache,
|
985 |
+
output_attentions=output_attentions,
|
986 |
+
output_hidden_states=output_hidden_states,
|
987 |
+
return_dict=return_dict,
|
988 |
+
)
|
989 |
+
|
990 |
+
if not return_dict:
|
991 |
+
return decoder_outputs + encoder_outputs
|
992 |
+
|
993 |
+
return Seq2SeqModelOutput(
|
994 |
+
last_hidden_state=decoder_outputs.last_hidden_state,
|
995 |
+
past_key_values=decoder_outputs.past_key_values,
|
996 |
+
decoder_hidden_states=decoder_outputs.hidden_states,
|
997 |
+
decoder_attentions=decoder_outputs.attentions,
|
998 |
+
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
999 |
+
encoder_hidden_states=encoder_outputs.hidden_states,
|
1000 |
+
encoder_attentions=encoder_outputs.attentions,
|
1001 |
+
)
|
1002 |
+
|
1003 |
+
def get_input_embeddings(self):
|
1004 |
+
return self.shared
|
1005 |
+
|
1006 |
+
def set_input_embeddings(self, value):
|
1007 |
+
self.shared = value
|
1008 |
+
self.encoder.embed_tokens = self.shared
|
1009 |
+
self.decoder.embed_tokens = self.shared
|
1010 |
+
|
1011 |
+
def get_output_embeddings(self):
|
1012 |
+
return _make_linear_from_emb(self.shared) # make it on the fly
|
1013 |
+
|
1014 |
+
|
1015 |
+
@add_start_docstrings(
|
1016 |
+
"The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING
|
1017 |
+
)
|
1018 |
+
class BartAdapterForConditionalGeneration(PretrainedBartModel):
|
1019 |
+
base_model_prefix = "model"
|
1020 |
+
authorized_missing_keys = [r"final_logits_bias", r"encoder\.version", r"decoder\.version"]
|
1021 |
+
|
1022 |
+
def __init__(self, config: BartConfig):
|
1023 |
+
super().__init__(config)
|
1024 |
+
base_model = BartModel(config)
|
1025 |
+
self.model = base_model
|
1026 |
+
self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
|
1027 |
+
|
1028 |
+
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
|
1029 |
+
old_num_tokens = self.model.shared.num_embeddings
|
1030 |
+
new_embeddings = super().resize_token_embeddings(new_num_tokens)
|
1031 |
+
self.model.shared = new_embeddings
|
1032 |
+
self._resize_final_logits_bias(new_num_tokens, old_num_tokens)
|
1033 |
+
return new_embeddings
|
1034 |
+
|
1035 |
+
def _resize_final_logits_bias(self, new_num_tokens: int, old_num_tokens: int) -> None:
|
1036 |
+
if new_num_tokens <= old_num_tokens:
|
1037 |
+
new_bias = self.final_logits_bias[:, :new_num_tokens]
|
1038 |
+
else:
|
1039 |
+
extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
|
1040 |
+
new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
|
1041 |
+
self.register_buffer("final_logits_bias", new_bias)
|
1042 |
+
|
1043 |
+
@add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
|
1044 |
+
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
1045 |
+
@add_end_docstrings(BART_GENERATION_EXAMPLE)
|
1046 |
+
def forward(
|
1047 |
+
self,
|
1048 |
+
input_ids,
|
1049 |
+
attention_mask=None,
|
1050 |
+
encoder_outputs=None,
|
1051 |
+
decoder_input_ids=None,
|
1052 |
+
decoder_attention_mask=None,
|
1053 |
+
past_key_values=None,
|
1054 |
+
labels=None,
|
1055 |
+
use_cache=None,
|
1056 |
+
output_attentions=None,
|
1057 |
+
output_hidden_states=None,
|
1058 |
+
return_dict=None,
|
1059 |
+
**unused,
|
1060 |
+
):
|
1061 |
+
r"""
|
1062 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
1063 |
+
Labels for computing the masked language modeling loss.
|
1064 |
+
Indices should either be in ``[0, ..., config.vocab_size]`` or -100 (see ``input_ids`` docstring).
|
1065 |
+
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens
|
1066 |
+
with labels in ``[0, ..., config.vocab_size]``.
|
1067 |
+
|
1068 |
+
Returns:
|
1069 |
+
|
1070 |
+
Conditional generation example::
|
1071 |
+
|
1072 |
+
# Mask filling only works for bart-large
|
1073 |
+
from transformers import BartTokenizer, BartForConditionalGeneration
|
1074 |
+
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
|
1075 |
+
TXT = "My friends are <mask> but they eat too many carbs."
|
1076 |
+
|
1077 |
+
model = BartForConditionalGeneration.from_pretrained('facebook/bart-large')
|
1078 |
+
input_ids = tokenizer([TXT], return_tensors='pt')['input_ids']
|
1079 |
+
logits = model(input_ids).logits
|
1080 |
+
|
1081 |
+
masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
|
1082 |
+
probs = logits[0, masked_index].softmax(dim=0)
|
1083 |
+
values, predictions = probs.topk(5)
|
1084 |
+
|
1085 |
+
tokenizer.decode(predictions).split()
|
1086 |
+
# ['good', 'great', 'all', 'really', 'very']
|
1087 |
+
"""
|
1088 |
+
if "lm_labels" in unused:
|
1089 |
+
warnings.warn(
|
1090 |
+
"The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
|
1091 |
+
FutureWarning,
|
1092 |
+
)
|
1093 |
+
labels = unused.pop("lm_labels")
|
1094 |
+
if "decoder_cached_states" in unused:
|
1095 |
+
warnings.warn(
|
1096 |
+
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
1097 |
+
FutureWarning,
|
1098 |
+
)
|
1099 |
+
past_key_values = unused.pop("decoder_cached_states")
|
1100 |
+
if "decoder_past_key_values" in unused:
|
1101 |
+
warnings.warn(
|
1102 |
+
"The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
1103 |
+
FutureWarning,
|
1104 |
+
)
|
1105 |
+
past_key_values = unused.pop("decoder_past_key_values")
|
1106 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1107 |
+
|
1108 |
+
if labels is not None:
|
1109 |
+
use_cache = False
|
1110 |
+
if decoder_input_ids is None:
|
1111 |
+
decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id)
|
1112 |
+
|
1113 |
+
outputs = self.model(
|
1114 |
+
input_ids,
|
1115 |
+
attention_mask=attention_mask,
|
1116 |
+
decoder_input_ids=decoder_input_ids,
|
1117 |
+
encoder_outputs=encoder_outputs,
|
1118 |
+
decoder_attention_mask=decoder_attention_mask,
|
1119 |
+
past_key_values=past_key_values,
|
1120 |
+
use_cache=use_cache,
|
1121 |
+
output_attentions=output_attentions,
|
1122 |
+
output_hidden_states=output_hidden_states,
|
1123 |
+
return_dict=return_dict,
|
1124 |
+
)
|
1125 |
+
lm_logits = F.linear(outputs[0], self.model.shared.weight, bias=self.final_logits_bias)
|
1126 |
+
|
1127 |
+
masked_lm_loss = None
|
1128 |
+
if labels is not None:
|
1129 |
+
loss_fct = CrossEntropyLoss()
|
1130 |
+
# TODO(SS): do we need to ignore pad tokens in labels?
|
1131 |
+
masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
|
1132 |
+
|
1133 |
+
if not return_dict:
|
1134 |
+
output = (lm_logits,) + outputs[1:]
|
1135 |
+
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
1136 |
+
|
1137 |
+
return Seq2SeqLMOutput(
|
1138 |
+
loss=masked_lm_loss,
|
1139 |
+
logits=lm_logits,
|
1140 |
+
past_key_values=outputs.past_key_values,
|
1141 |
+
decoder_hidden_states=outputs.decoder_hidden_states,
|
1142 |
+
decoder_attentions=outputs.decoder_attentions,
|
1143 |
+
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
1144 |
+
encoder_hidden_states=outputs.encoder_hidden_states,
|
1145 |
+
encoder_attentions=outputs.encoder_attentions,
|
1146 |
+
)
|
1147 |
+
|
1148 |
+
def prepare_inputs_for_generation(
|
1149 |
+
self, decoder_input_ids, past, attention_mask, use_cache, encoder_outputs, **kwargs
|
1150 |
+
):
|
1151 |
+
return {
|
1152 |
+
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
1153 |
+
"encoder_outputs": encoder_outputs,
|
1154 |
+
"past_key_values": past,
|
1155 |
+
"decoder_input_ids": decoder_input_ids,
|
1156 |
+
"attention_mask": attention_mask,
|
1157 |
+
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
1158 |
+
}
|
1159 |
+
|
1160 |
+
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
1161 |
+
if cur_len == 1 and self.config.force_bos_token_to_be_generated:
|
1162 |
+
self._force_token_ids_generation(logits, self.config.bos_token_id)
|
1163 |
+
elif cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
1164 |
+
self._force_token_ids_generation(logits, self.config.eos_token_id)
|
1165 |
+
return logits
|
1166 |
+
|
1167 |
+
def _force_token_ids_generation(self, scores, token_id) -> None:
|
1168 |
+
"""force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))"""
|
1169 |
+
scores[:, [x for x in range(self.config.vocab_size) if x != token_id]] = -float("inf")
|
1170 |
+
|
1171 |
+
@staticmethod
|
1172 |
+
def _reorder_cache(past, beam_idx):
|
1173 |
+
reordered_past = []
|
1174 |
+
for layer_past in past:
|
1175 |
+
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
|
1176 |
+
layer_past_new = {
|
1177 |
+
attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items()
|
1178 |
+
}
|
1179 |
+
reordered_past.append(layer_past_new)
|
1180 |
+
return reordered_past
|
1181 |
+
|
1182 |
+
def get_encoder(self):
|
1183 |
+
return self.model.encoder
|
1184 |
+
|
1185 |
+
def get_output_embeddings(self):
|
1186 |
+
return _make_linear_from_emb(self.model.shared) # make it on the fly
|
1187 |
+
|
1188 |
+
|
1189 |
+
@add_start_docstrings(
|
1190 |
+
"""Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """,
|
1191 |
+
BART_START_DOCSTRING,
|
1192 |
+
)
|
1193 |
+
class BartForSequenceClassification(PretrainedBartModel):
|
1194 |
+
def __init__(self, config: BartConfig, **kwargs):
|
1195 |
+
super().__init__(config, **kwargs)
|
1196 |
+
self.model = BartModel(config)
|
1197 |
+
self.classification_head = BartClassificationHead(
|
1198 |
+
config.d_model,
|
1199 |
+
config.d_model,
|
1200 |
+
config.num_labels,
|
1201 |
+
config.classif_dropout,
|
1202 |
+
)
|
1203 |
+
self.model._init_weights(self.classification_head.dense)
|
1204 |
+
self.model._init_weights(self.classification_head.out_proj)
|
1205 |
+
|
1206 |
+
@add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
|
1207 |
+
@add_code_sample_docstrings(
|
1208 |
+
tokenizer_class=_TOKENIZER_FOR_DOC,
|
1209 |
+
checkpoint="facebook/bart-large",
|
1210 |
+
output_type=Seq2SeqSequenceClassifierOutput,
|
1211 |
+
config_class=_CONFIG_FOR_DOC,
|
1212 |
+
)
|
1213 |
+
def forward(
|
1214 |
+
self,
|
1215 |
+
input_ids,
|
1216 |
+
attention_mask=None,
|
1217 |
+
encoder_outputs=None,
|
1218 |
+
decoder_input_ids=None,
|
1219 |
+
decoder_attention_mask=None,
|
1220 |
+
labels=None,
|
1221 |
+
use_cache=None,
|
1222 |
+
output_attentions=None,
|
1223 |
+
output_hidden_states=None,
|
1224 |
+
return_dict=None,
|
1225 |
+
):
|
1226 |
+
r"""
|
1227 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
1228 |
+
Labels for computing the sequence classification/regression loss.
|
1229 |
+
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
|
1230 |
+
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
1231 |
+
"""
|
1232 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1233 |
+
if labels is not None:
|
1234 |
+
use_cache = False
|
1235 |
+
|
1236 |
+
outputs = self.model(
|
1237 |
+
input_ids,
|
1238 |
+
attention_mask=attention_mask,
|
1239 |
+
decoder_input_ids=decoder_input_ids,
|
1240 |
+
decoder_attention_mask=decoder_attention_mask,
|
1241 |
+
encoder_outputs=encoder_outputs,
|
1242 |
+
use_cache=use_cache,
|
1243 |
+
output_attentions=output_attentions,
|
1244 |
+
output_hidden_states=output_hidden_states,
|
1245 |
+
return_dict=return_dict,
|
1246 |
+
)
|
1247 |
+
x = outputs[0] # last hidden state
|
1248 |
+
eos_mask = input_ids.eq(self.config.eos_token_id)
|
1249 |
+
if len(torch.unique(eos_mask.sum(1))) > 1:
|
1250 |
+
raise ValueError("All examples must have the same number of <eos> tokens.")
|
1251 |
+
sentence_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :]
|
1252 |
+
logits = self.classification_head(sentence_representation)
|
1253 |
+
|
1254 |
+
loss = None
|
1255 |
+
if labels is not None:
|
1256 |
+
loss_fct = CrossEntropyLoss()
|
1257 |
+
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
|
1258 |
+
|
1259 |
+
if not return_dict:
|
1260 |
+
output = (logits,) + outputs[1:]
|
1261 |
+
return ((loss,) + output) if loss is not None else output
|
1262 |
+
|
1263 |
+
return Seq2SeqSequenceClassifierOutput(
|
1264 |
+
loss=loss,
|
1265 |
+
logits=logits,
|
1266 |
+
past_key_values=outputs.past_key_values,
|
1267 |
+
decoder_hidden_states=outputs.decoder_hidden_states,
|
1268 |
+
decoder_attentions=outputs.decoder_attentions,
|
1269 |
+
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
1270 |
+
encoder_hidden_states=outputs.encoder_hidden_states,
|
1271 |
+
encoder_attentions=outputs.encoder_attentions,
|
1272 |
+
)
|
1273 |
+
|
1274 |
+
|
1275 |
+
@add_start_docstrings(
|
1276 |
+
"""BART Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layer on top of
|
1277 |
+
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
1278 |
+
BART_START_DOCSTRING,
|
1279 |
+
)
|
1280 |
+
class BartForQuestionAnswering(PretrainedBartModel):
|
1281 |
+
def __init__(self, config):
|
1282 |
+
super().__init__(config)
|
1283 |
+
|
1284 |
+
config.num_labels = 2
|
1285 |
+
self.num_labels = config.num_labels
|
1286 |
+
|
1287 |
+
self.model = BartModel(config)
|
1288 |
+
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
1289 |
+
|
1290 |
+
self.model._init_weights(self.qa_outputs)
|
1291 |
+
|
1292 |
+
@add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
|
1293 |
+
@add_code_sample_docstrings(
|
1294 |
+
tokenizer_class=_TOKENIZER_FOR_DOC,
|
1295 |
+
checkpoint="facebook/bart-large",
|
1296 |
+
output_type=Seq2SeqQuestionAnsweringModelOutput,
|
1297 |
+
config_class=_CONFIG_FOR_DOC,
|
1298 |
+
)
|
1299 |
+
def forward(
|
1300 |
+
self,
|
1301 |
+
input_ids,
|
1302 |
+
attention_mask=None,
|
1303 |
+
encoder_outputs=None,
|
1304 |
+
decoder_input_ids=None,
|
1305 |
+
decoder_attention_mask=None,
|
1306 |
+
start_positions=None,
|
1307 |
+
end_positions=None,
|
1308 |
+
use_cache=None,
|
1309 |
+
output_attentions=None,
|
1310 |
+
output_hidden_states=None,
|
1311 |
+
return_dict=None,
|
1312 |
+
):
|
1313 |
+
r"""
|
1314 |
+
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
1315 |
+
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
1316 |
+
Positions are clamped to the length of the sequence (`sequence_length`).
|
1317 |
+
Position outside of the sequence are not taken into account for computing the loss.
|
1318 |
+
end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
1319 |
+
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
1320 |
+
Positions are clamped to the length of the sequence (`sequence_length`).
|
1321 |
+
Position outside of the sequence are not taken into account for computing the loss.
|
1322 |
+
"""
|
1323 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1324 |
+
if start_positions is not None and end_positions is not None:
|
1325 |
+
use_cache = False
|
1326 |
+
|
1327 |
+
outputs = self.model(
|
1328 |
+
input_ids,
|
1329 |
+
attention_mask=attention_mask,
|
1330 |
+
decoder_input_ids=decoder_input_ids,
|
1331 |
+
decoder_attention_mask=decoder_attention_mask,
|
1332 |
+
encoder_outputs=encoder_outputs,
|
1333 |
+
use_cache=use_cache,
|
1334 |
+
output_attentions=output_attentions,
|
1335 |
+
output_hidden_states=output_hidden_states,
|
1336 |
+
return_dict=return_dict,
|
1337 |
+
)
|
1338 |
+
|
1339 |
+
sequence_output = outputs[0]
|
1340 |
+
|
1341 |
+
logits = self.qa_outputs(sequence_output)
|
1342 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
1343 |
+
start_logits = start_logits.squeeze(-1)
|
1344 |
+
end_logits = end_logits.squeeze(-1)
|
1345 |
+
|
1346 |
+
total_loss = None
|
1347 |
+
if start_positions is not None and end_positions is not None:
|
1348 |
+
# If we are on multi-GPU, split add a dimension
|
1349 |
+
if len(start_positions.size()) > 1:
|
1350 |
+
start_positions = start_positions.squeeze(-1)
|
1351 |
+
if len(end_positions.size()) > 1:
|
1352 |
+
end_positions = end_positions.squeeze(-1)
|
1353 |
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
1354 |
+
ignored_index = start_logits.size(1)
|
1355 |
+
start_positions.clamp_(0, ignored_index)
|
1356 |
+
end_positions.clamp_(0, ignored_index)
|
1357 |
+
|
1358 |
+
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
1359 |
+
start_loss = loss_fct(start_logits, start_positions)
|
1360 |
+
end_loss = loss_fct(end_logits, end_positions)
|
1361 |
+
total_loss = (start_loss + end_loss) / 2
|
1362 |
+
|
1363 |
+
if not return_dict:
|
1364 |
+
output = (
|
1365 |
+
start_logits,
|
1366 |
+
end_logits,
|
1367 |
+
) + outputs[1:]
|
1368 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
1369 |
+
|
1370 |
+
return Seq2SeqQuestionAnsweringModelOutput(
|
1371 |
+
loss=total_loss,
|
1372 |
+
start_logits=start_logits,
|
1373 |
+
end_logits=end_logits,
|
1374 |
+
past_key_values=outputs.past_key_values,
|
1375 |
+
decoder_hidden_states=outputs.decoder_hidden_states,
|
1376 |
+
decoder_attentions=outputs.decoder_attentions,
|
1377 |
+
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
1378 |
+
encoder_hidden_states=outputs.encoder_hidden_states,
|
1379 |
+
encoder_attentions=outputs.encoder_attentions,
|
1380 |
+
)
|
1381 |
+
|
1382 |
+
|
1383 |
+
class SinusoidalPositionalEmbedding(nn.Embedding):
|
1384 |
+
"""This module produces sinusoidal positional embeddings of any length."""
|
1385 |
+
|
1386 |
+
def __init__(self, num_positions, embedding_dim, padding_idx=None):
|
1387 |
+
super().__init__(num_positions, embedding_dim)
|
1388 |
+
if embedding_dim % 2 != 0:
|
1389 |
+
raise NotImplementedError(f"odd embedding_dim {embedding_dim} not supported")
|
1390 |
+
self.weight = self._init_weight(self.weight)
|
1391 |
+
|
1392 |
+
@staticmethod
|
1393 |
+
def _init_weight(out: nn.Parameter):
|
1394 |
+
"""Identical to the XLM create_sinusoidal_embeddings except features are not interleaved.
|
1395 |
+
The cos features are in the 2nd half of the vector. [dim // 2:]
|
1396 |
+
"""
|
1397 |
+
n_pos, dim = out.shape
|
1398 |
+
position_enc = np.array(
|
1399 |
+
[[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
|
1400 |
+
)
|
1401 |
+
out[:, 0 : dim // 2] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) # This line breaks for odd n_pos
|
1402 |
+
out[:, dim // 2 :] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
|
1403 |
+
out.detach_()
|
1404 |
+
out.requires_grad = False
|
1405 |
+
return out
|
1406 |
+
|
1407 |
+
@torch.no_grad()
|
1408 |
+
def forward(self, input_ids, use_cache=False):
|
1409 |
+
"""Input is expected to be of size [bsz x seqlen]."""
|
1410 |
+
bsz, seq_len = input_ids.shape[:2]
|
1411 |
+
if use_cache:
|
1412 |
+
positions = input_ids.data.new(1, 1).fill_(seq_len - 1) # called before slicing
|
1413 |
+
else:
|
1414 |
+
# starts at 0, ends at 1-seq_len
|
1415 |
+
positions = torch.arange(seq_len, dtype=torch.long, device=self.weight.device)
|
1416 |
+
return super().forward(positions)
|
datasets/LogiQA2.0/logiqa/multi-choice-prompt.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import time
|
3 |
+
import openai
|
4 |
+
import sklearn
|
5 |
+
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
|
6 |
+
openai.api_key = ''
|
7 |
+
|
8 |
+
incontext = "Input\nWrite a multi-choice question for the following article:\nArticle: David knows Mr. Zhang's friend Jack, and Jack knows David's friend Ms. Lin. Everyone of them who knows Jack has a master's degree, and everyone of them who knows Ms. Lin is from Shanghai.\nQuestion: \nWho is from Shanghai and has a master's degree?\nOptions:\nA David\nB Jack\nC Mr Zhang\nD Ms. Lin\nAnswer:\nA\nInput\nWrite a multi-choice question for the following article:\nArticle: Jimmy asked Hank to go to the mall the next day. Hank said, If it doesn't rain tomorrow, I'll go climbing. The next day, there was a drizzle. Jimmy thought that Hank would not go climbing, so he went to pick up Henry to the mall. Nevertheless, Hank went climbing the mountain. When the two met again, Jimmy blamed Hank for not keeping his word.\nQuestion: \nWhich of the following comments is appropriate?\nOptions:\nA This argument between Jimmy and Hank is meaningless\nB Jimmy's reasoning is illogical\nC Two people have different understandings of a drizzle\nD Hank broke his promise and caused the debate\nAnswer:\nB\nInput\nWrite a multi-choice question for the following article:\nArticle: Only if the government reinforce basic education can we improve our nation's education to a new stage. In order to stand out among other nations, we need to have a strong educational enterprise.\nQuestion: \nWhich can be inferred from the statement above?\nOptions:\nA The whole society should be focused on education\nB In order to stand out among nations, we should reinforce basic education\nC In order to improve our education to a new stage, it is necessary to increase the salary of college teachers\nD In order to reinforce basic education, all primary school teachers must have a bachelor degree or above.\nAnswer:\nB\nInput\nWrite a multi-choice question for the following article:\nArticle: Last night, Mark either went to play in the gym or visited his teacher Tony. If Mark drove last night, he didn't go to play in the gym. Mark would go visit his teacher Tony only if he and his teacher had an appointment. In fact, Mark had no appointment with his teacher Tony in advance.\nQuestion: \nWhich is true based on the above statement?\nOptions:\nA Mark went to the gym with his teacher Tony last night\nB Mark visited his teacher Tony last night\nC Mark didn't drive last night\nD Mark didn't go to the gym last night.\nAnswer:\nC\nInput\nWrite a multi-choice question for the following article:\nArticle: The coach of a national football team found that the best cooperative arrangement of the players U, V, W, X, Y, and Z during the training are: (1) V and X cannot be on the field at the same time, and neither can be off the field the same time. (2) V is not on the field only if U is not on the field. (3) If W is on the field, then X is on the field. (4) If Y and Z are on the field, then W must be on the field. This arrangement can yield the best performance.\nQuestion: \nIf U and Z are both on the field, for best performance, which of the following arrangement is appropriate?\nOptions:\nA X is on the eld and Y is not on the field\nB V is on the eld and Y is not on the field\nC V and W are both on the field\nD V and Y are not on the field\nAnswer:\nB\n"
|
9 |
+
label_map = {0: "A", 1: "B", 2: "C", 3: "D"}
|
10 |
+
|
11 |
+
def gpt3_api(prompt):
|
12 |
+
response = openai.Completion.create(
|
13 |
+
model="text-davinci-002",
|
14 |
+
prompt=incontext + prompt,
|
15 |
+
temperature=0,
|
16 |
+
max_tokens=60,
|
17 |
+
top_p=1.0,
|
18 |
+
frequency_penalty=0.0,
|
19 |
+
presence_penalty=0.0
|
20 |
+
)
|
21 |
+
return response
|
22 |
+
|
23 |
+
with open('test.json') as f:
|
24 |
+
y_true = []
|
25 |
+
y_pred = []
|
26 |
+
lines = f.readlines()
|
27 |
+
for i, line in enumerate(lines):
|
28 |
+
line_dict = json.loads(line)
|
29 |
+
article = line_dict['text']
|
30 |
+
answer = line_dict['answer']
|
31 |
+
label = label_map[answer]
|
32 |
+
y_true.append(label)
|
33 |
+
question = line_dict['question']
|
34 |
+
options_old = line_dict['options']
|
35 |
+
options = ""
|
36 |
+
for j, option in enumerate(options_old):
|
37 |
+
options += label_map[j] + " " + option + "\n"
|
38 |
+
prompt_input = "Write a multi-choice question for the following article:\nArticle: " + article + "\nQuestion: " + question + "\nOptions: " + options + "\nAnswer: \n"
|
39 |
+
prompt = prompt_input
|
40 |
+
output = gpt3_api(prompt)
|
41 |
+
time.sleep(5)
|
42 |
+
pred = output.choices[0].text
|
43 |
+
y_pred.append(pred)
|
44 |
+
|
45 |
+
print(y_true)
|
46 |
+
print(y_pred)
|
47 |
+
|
48 |
+
f_score = f1_score(y_true, y_pred, average='binary')
|
49 |
+
p_score = precision_score(y_true, y_pred, average='binary')
|
50 |
+
r_score = recall_score(y_true, y_pred, average='binary')
|
51 |
+
acc = accuracy_score(y_true, y_pred)
|
52 |
+
|
53 |
+
print(f_score)
|
54 |
+
print(p_score)
|
55 |
+
print(r_score)
|
56 |
+
print(acc)
|
datasets/LogiQA2.0/logiqa/run_mrc.py
ADDED
@@ -0,0 +1,552 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""
|
17 |
+
This Script is Modified for Multi-Choice Reading Comprehension Fine-tuning.
|
18 |
+
All the datasets can be downloaded from this repo.
|
19 |
+
"""
|
20 |
+
|
21 |
+
from __future__ import absolute_import, division, print_function
|
22 |
+
|
23 |
+
import argparse
|
24 |
+
import glob
|
25 |
+
import logging
|
26 |
+
import os
|
27 |
+
import random
|
28 |
+
|
29 |
+
import numpy as np
|
30 |
+
import torch
|
31 |
+
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
32 |
+
TensorDataset)
|
33 |
+
from torch.utils.data.distributed import DistributedSampler
|
34 |
+
|
35 |
+
try:
|
36 |
+
from torch.utils.tensorboard import SummaryWriter
|
37 |
+
except:
|
38 |
+
from tensorboardX import SummaryWriter
|
39 |
+
|
40 |
+
from tqdm import tqdm, trange
|
41 |
+
|
42 |
+
from transformers import (WEIGHTS_NAME, BertConfig,
|
43 |
+
BertForMultipleChoice, BertTokenizer,
|
44 |
+
RobertaConfig,
|
45 |
+
RobertaForMultipleChoice,
|
46 |
+
RobertaTokenizer,
|
47 |
+
XLNetConfig,
|
48 |
+
XLNetForMultipleChoice,
|
49 |
+
XLNetTokenizer,
|
50 |
+
)
|
51 |
+
|
52 |
+
from transformers import AdamW, get_linear_schedule_with_warmup
|
53 |
+
|
54 |
+
from utils_mrc import compute_metrics
|
55 |
+
from utils_mrc import output_modes
|
56 |
+
from utils_mrc import processors
|
57 |
+
from utils_mrc import convert_examples_to_features
|
58 |
+
|
59 |
+
|
60 |
+
logger = logging.getLogger(__name__)
|
61 |
+
|
62 |
+
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig,
|
63 |
+
RobertaConfig,)), ())
|
64 |
+
|
65 |
+
MODEL_CLASSES = {
|
66 |
+
'bert': (BertConfig, BertForMultipleChoice, BertTokenizer),
|
67 |
+
'xlnet': (XLNetConfig, XLNetForMultipleChoice, XLNetTokenizer),
|
68 |
+
'roberta': (RobertaConfig, RobertaForMultipleChoice, RobertaTokenizer),
|
69 |
+
}
|
70 |
+
|
71 |
+
def select_field(features, field):
|
72 |
+
return [[choice[field] for choice in feature.choices_features] for feature in features]
|
73 |
+
|
74 |
+
def set_seed(args):
|
75 |
+
random.seed(args.seed)
|
76 |
+
np.random.seed(args.seed)
|
77 |
+
torch.manual_seed(args.seed)
|
78 |
+
if args.n_gpu > 0:
|
79 |
+
torch.cuda.manual_seed_all(args.seed)
|
80 |
+
|
81 |
+
|
82 |
+
def train(args, train_dataset, model, tokenizer):
|
83 |
+
""" Train the model """
|
84 |
+
if args.local_rank in [-1, 0]:
|
85 |
+
tb_writer = SummaryWriter()
|
86 |
+
|
87 |
+
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
88 |
+
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
89 |
+
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
|
90 |
+
|
91 |
+
if args.max_steps > 0:
|
92 |
+
t_total = args.max_steps
|
93 |
+
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
|
94 |
+
else:
|
95 |
+
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
96 |
+
|
97 |
+
# Prepare optimizer and schedule (linear warmup and decay)
|
98 |
+
no_decay = ['bias', 'LayerNorm.weight']
|
99 |
+
optimizer_grouped_parameters = [
|
100 |
+
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
101 |
+
'weight_decay': args.weight_decay},
|
102 |
+
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
103 |
+
]
|
104 |
+
|
105 |
+
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
106 |
+
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps,
|
107 |
+
num_training_steps=t_total)
|
108 |
+
if args.fp16:
|
109 |
+
try:
|
110 |
+
from apex import amp
|
111 |
+
except ImportError:
|
112 |
+
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
113 |
+
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
|
114 |
+
|
115 |
+
# multi-gpu training (should be after apex fp16 initialization)
|
116 |
+
if args.n_gpu > 1:
|
117 |
+
model = torch.nn.DataParallel(model)
|
118 |
+
|
119 |
+
# Distributed training (should be after apex fp16 initialization)
|
120 |
+
if args.local_rank != -1:
|
121 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
|
122 |
+
output_device=args.local_rank,
|
123 |
+
find_unused_parameters=True)
|
124 |
+
|
125 |
+
# Train!
|
126 |
+
logger.info("***** Running training *****")
|
127 |
+
logger.info(" Num examples = %d", len(train_dataset))
|
128 |
+
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
129 |
+
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
130 |
+
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
131 |
+
args.train_batch_size * args.gradient_accumulation_steps * (
|
132 |
+
torch.distributed.get_world_size() if args.local_rank != -1 else 1))
|
133 |
+
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
134 |
+
logger.info(" Total optimization steps = %d", t_total)
|
135 |
+
|
136 |
+
global_step = 0
|
137 |
+
tr_loss, logging_loss = 0.0, 0.0
|
138 |
+
model.zero_grad()
|
139 |
+
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
|
140 |
+
set_seed(args) # Added here for reproductibility (even between python 2 and 3)
|
141 |
+
for _ in train_iterator:
|
142 |
+
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
143 |
+
for step, batch in enumerate(epoch_iterator):
|
144 |
+
model.train()
|
145 |
+
batch = tuple(t.to(args.device) for t in batch)
|
146 |
+
inputs = {'input_ids': batch[0],
|
147 |
+
'attention_mask': batch[1],
|
148 |
+
'labels': batch[3]}
|
149 |
+
if args.model_type != 'distilbert':
|
150 |
+
inputs['token_type_ids'] = batch[2] if args.model_type in ['bert',
|
151 |
+
'xlnet'] else None # XLM, DistilBERT and RoBERTa don't use segment_ids
|
152 |
+
outputs = model(**inputs)
|
153 |
+
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
|
154 |
+
|
155 |
+
if args.n_gpu > 1:
|
156 |
+
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
157 |
+
if args.gradient_accumulation_steps > 1:
|
158 |
+
loss = loss / args.gradient_accumulation_steps
|
159 |
+
|
160 |
+
if args.fp16:
|
161 |
+
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
162 |
+
scaled_loss.backward()
|
163 |
+
else:
|
164 |
+
loss.backward()
|
165 |
+
|
166 |
+
tr_loss += loss.item()
|
167 |
+
if (step + 1) % args.gradient_accumulation_steps == 0:
|
168 |
+
if args.fp16:
|
169 |
+
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
170 |
+
else:
|
171 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
172 |
+
|
173 |
+
optimizer.step()
|
174 |
+
scheduler.step() # Update learning rate schedule
|
175 |
+
model.zero_grad()
|
176 |
+
global_step += 1
|
177 |
+
|
178 |
+
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
179 |
+
# Log metrics
|
180 |
+
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
|
181 |
+
results = evaluate(args, model, tokenizer)
|
182 |
+
for key, value in results.items():
|
183 |
+
tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
|
184 |
+
tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
|
185 |
+
tb_writer.add_scalar('loss', (tr_loss - logging_loss) / args.logging_steps, global_step)
|
186 |
+
logging_loss = tr_loss
|
187 |
+
|
188 |
+
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
189 |
+
# Save model checkpoint
|
190 |
+
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
|
191 |
+
if not os.path.exists(output_dir):
|
192 |
+
os.makedirs(output_dir)
|
193 |
+
model_to_save = model.module if hasattr(model,
|
194 |
+
'module') else model # Take care of distributed/parallel training
|
195 |
+
model_to_save.save_pretrained(output_dir)
|
196 |
+
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
|
197 |
+
logger.info("Saving model checkpoint to %s", output_dir)
|
198 |
+
|
199 |
+
if args.max_steps > 0 and global_step > args.max_steps:
|
200 |
+
epoch_iterator.close()
|
201 |
+
break
|
202 |
+
if args.max_steps > 0 and global_step > args.max_steps:
|
203 |
+
train_iterator.close()
|
204 |
+
break
|
205 |
+
|
206 |
+
if args.local_rank in [-1, 0]:
|
207 |
+
tb_writer.close()
|
208 |
+
|
209 |
+
return global_step, tr_loss / global_step
|
210 |
+
|
211 |
+
|
212 |
+
def evaluate(args, model, tokenizer, prefix=""):
|
213 |
+
# Loop to handle MNLI double evaluation (matched, mis-matched)
|
214 |
+
eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,)
|
215 |
+
eval_outputs_dirs = (args.output_dir, args.output_dir + '-MM') if args.task_name == "mnli" else (args.output_dir,)
|
216 |
+
|
217 |
+
results = {}
|
218 |
+
for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
|
219 |
+
eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True)
|
220 |
+
|
221 |
+
if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
|
222 |
+
os.makedirs(eval_output_dir)
|
223 |
+
|
224 |
+
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
225 |
+
# Note that DistributedSampler samples randomly
|
226 |
+
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
|
227 |
+
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
228 |
+
|
229 |
+
# multi-gpu eval
|
230 |
+
if args.n_gpu > 1:
|
231 |
+
model = torch.nn.DataParallel(model)
|
232 |
+
|
233 |
+
# Eval!
|
234 |
+
logger.info("***** Running evaluation {} *****".format(prefix))
|
235 |
+
logger.info(" Num examples = %d", len(eval_dataset))
|
236 |
+
logger.info(" Batch size = %d", args.eval_batch_size)
|
237 |
+
eval_loss = 0.0
|
238 |
+
nb_eval_steps = 0
|
239 |
+
preds = None
|
240 |
+
out_label_ids = None
|
241 |
+
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
242 |
+
model.eval()
|
243 |
+
batch = tuple(t.to(args.device) for t in batch)
|
244 |
+
|
245 |
+
with torch.no_grad():
|
246 |
+
inputs = {'input_ids': batch[0],
|
247 |
+
'attention_mask': batch[1],
|
248 |
+
'labels': batch[3]}
|
249 |
+
if args.model_type != 'distilbert':
|
250 |
+
inputs['token_type_ids'] = batch[2] if args.model_type in ['bert',
|
251 |
+
'xlnet'] else None # XLM, DistilBERT and RoBERTa don't use segment_ids
|
252 |
+
outputs = model(**inputs)
|
253 |
+
tmp_eval_loss, logits = outputs[:2]
|
254 |
+
|
255 |
+
eval_loss += tmp_eval_loss.mean().item()
|
256 |
+
nb_eval_steps += 1
|
257 |
+
if preds is None:
|
258 |
+
preds = logits.detach().cpu().numpy()
|
259 |
+
out_label_ids = inputs['labels'].detach().cpu().numpy()
|
260 |
+
else:
|
261 |
+
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
|
262 |
+
out_label_ids = np.append(out_label_ids, inputs['labels'].detach().cpu().numpy(), axis=0)
|
263 |
+
|
264 |
+
eval_loss = eval_loss / nb_eval_steps
|
265 |
+
if args.output_mode == "classification":
|
266 |
+
preds = np.argmax(preds, axis=1)
|
267 |
+
elif args.output_mode == "regression":
|
268 |
+
preds = np.squeeze(preds)
|
269 |
+
result = {"eval": compute_metrics(eval_task, preds, out_label_ids), "loss": eval_loss}
|
270 |
+
results.update(result)
|
271 |
+
|
272 |
+
output_pred_file = os.path.join(eval_output_dir, prefix, "pred_results.txt")
|
273 |
+
with open(output_pred_file, "a") as writer:
|
274 |
+
for pred in preds:
|
275 |
+
writer.write(str(pred)+"\n")
|
276 |
+
|
277 |
+
output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
|
278 |
+
with open(output_eval_file, "w") as writer:
|
279 |
+
logger.info("***** Eval results {} *****".format(prefix))
|
280 |
+
for key in sorted(result.keys()):
|
281 |
+
logger.info(" %s = %s", key, str(result[key]))
|
282 |
+
writer.write("%s = %s\n" % (key, str(result[key])))
|
283 |
+
|
284 |
+
return results
|
285 |
+
|
286 |
+
|
287 |
+
def load_and_cache_examples(args, task, tokenizer, evaluate=False, test=False):
|
288 |
+
if args.local_rank not in [-1, 0]:
|
289 |
+
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
290 |
+
|
291 |
+
processor = processors[task]()
|
292 |
+
# Load data features from cache or dataset file
|
293 |
+
if evaluate:
|
294 |
+
cached_mode = "dev"
|
295 |
+
elif test:
|
296 |
+
cached_mode = "test"
|
297 |
+
else:
|
298 |
+
cached_mode = "train"
|
299 |
+
assert not (evaluate and test)
|
300 |
+
cached_features_file = os.path.join(
|
301 |
+
args.data_dir,
|
302 |
+
"cached_{}_{}_{}_{}".format(
|
303 |
+
cached_mode,
|
304 |
+
list(filter(None, args.model_name_or_path.split("/"))).pop(),
|
305 |
+
str(args.max_seq_length),
|
306 |
+
str(task),
|
307 |
+
),
|
308 |
+
)
|
309 |
+
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
310 |
+
logger.info("Loading features from cached file %s", cached_features_file)
|
311 |
+
features = torch.load(cached_features_file)
|
312 |
+
else:
|
313 |
+
logger.info("Creating features from dataset file at %s", args.data_dir)
|
314 |
+
label_list = processor.get_labels()
|
315 |
+
if evaluate:
|
316 |
+
examples = processor.get_dev_examples(args.data_dir)
|
317 |
+
elif test:
|
318 |
+
examples = processor.get_test_examples(args.data_dir)
|
319 |
+
else:
|
320 |
+
examples = processor.get_train_examples(args.data_dir)
|
321 |
+
logger.info("Training number: %s", str(len(examples)))
|
322 |
+
features = convert_examples_to_features(
|
323 |
+
examples,
|
324 |
+
label_list,
|
325 |
+
args.max_seq_length,
|
326 |
+
tokenizer,
|
327 |
+
pad_on_left=bool(args.model_type in ["xlnet"]), # pad on the left for xlnet
|
328 |
+
pad_token_segment_id=4 if args.model_type in ["xlnet"] else 0,
|
329 |
+
)
|
330 |
+
if args.local_rank in [-1, 0]:
|
331 |
+
logger.info("Saving features into cached file %s", cached_features_file)
|
332 |
+
torch.save(features, cached_features_file)
|
333 |
+
|
334 |
+
if args.local_rank == 0:
|
335 |
+
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
336 |
+
|
337 |
+
# Convert to Tensors and build dataset
|
338 |
+
all_input_ids = torch.tensor(select_field(features, "input_ids"), dtype=torch.long)
|
339 |
+
all_input_mask = torch.tensor(select_field(features, "input_mask"), dtype=torch.long)
|
340 |
+
all_segment_ids = torch.tensor(select_field(features, "segment_ids"), dtype=torch.long)
|
341 |
+
all_label_ids = torch.tensor([f.label for f in features], dtype=torch.long)
|
342 |
+
|
343 |
+
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
|
344 |
+
return dataset
|
345 |
+
|
346 |
+
|
347 |
+
|
348 |
+
def main():
|
349 |
+
parser = argparse.ArgumentParser()
|
350 |
+
|
351 |
+
## Required parameters
|
352 |
+
parser.add_argument("--data_dir", default=None, type=str, required=True,
|
353 |
+
help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
|
354 |
+
parser.add_argument("--model_type", default=None, type=str, required=True,
|
355 |
+
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
356 |
+
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
|
357 |
+
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(
|
358 |
+
ALL_MODELS))
|
359 |
+
parser.add_argument("--task_name", default=None, type=str, required=True,
|
360 |
+
help="The name of the task to train selected in the list: " + ", ".join(processors.keys()))
|
361 |
+
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
362 |
+
help="The output directory where the model predictions and checkpoints will be written.")
|
363 |
+
|
364 |
+
## Other parameters
|
365 |
+
parser.add_argument("--config_name", default="", type=str,
|
366 |
+
help="Pretrained config name or path if not the same as model_name")
|
367 |
+
parser.add_argument("--tokenizer_name", default="", type=str,
|
368 |
+
help="Pretrained tokenizer name or path if not the same as model_name")
|
369 |
+
parser.add_argument("--cache_dir", default="", type=str,
|
370 |
+
help="Where do you want to store the pre-trained models downloaded from s3")
|
371 |
+
parser.add_argument("--max_seq_length", default=128, type=int,
|
372 |
+
help="The maximum total input sequence length after tokenization. Sequences longer "
|
373 |
+
"than this will be truncated, sequences shorter will be padded.")
|
374 |
+
parser.add_argument("--do_train", action='store_true',
|
375 |
+
help="Whether to run training.")
|
376 |
+
parser.add_argument("--do_eval", action='store_true',
|
377 |
+
help="Whether to run eval on the dev set.")
|
378 |
+
parser.add_argument("--evaluate_during_training", action='store_true',
|
379 |
+
help="Rul evaluation during training at each logging step.")
|
380 |
+
parser.add_argument("--do_lower_case", action='store_true',
|
381 |
+
help="Set this flag if you are using an uncased model.")
|
382 |
+
|
383 |
+
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
|
384 |
+
help="Batch size per GPU/CPU for training.")
|
385 |
+
parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int,
|
386 |
+
help="Batch size per GPU/CPU for evaluation.")
|
387 |
+
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
388 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
389 |
+
parser.add_argument("--learning_rate", default=5e-5, type=float,
|
390 |
+
help="The initial learning rate for Adam.")
|
391 |
+
parser.add_argument("--weight_decay", default=0.0, type=float,
|
392 |
+
help="Weight deay if we apply some.")
|
393 |
+
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
|
394 |
+
help="Epsilon for Adam optimizer.")
|
395 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
396 |
+
help="Max gradient norm.")
|
397 |
+
parser.add_argument("--num_train_epochs", default=3.0, type=float,
|
398 |
+
help="Total number of training epochs to perform.")
|
399 |
+
parser.add_argument("--max_steps", default=-1, type=int,
|
400 |
+
help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
|
401 |
+
parser.add_argument("--warmup_steps", default=0, type=int,
|
402 |
+
help="Linear warmup over warmup_steps.")
|
403 |
+
|
404 |
+
parser.add_argument('--logging_steps', type=int, default=50,
|
405 |
+
help="Log every X updates steps.")
|
406 |
+
parser.add_argument('--save_steps', type=int, default=50,
|
407 |
+
help="Save checkpoint every X updates steps.")
|
408 |
+
parser.add_argument("--eval_all_checkpoints", action='store_true',
|
409 |
+
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
|
410 |
+
parser.add_argument("--no_cuda", action='store_true',
|
411 |
+
help="Avoid using CUDA when available")
|
412 |
+
parser.add_argument('--overwrite_output_dir', action='store_true',
|
413 |
+
help="Overwrite the content of the output directory")
|
414 |
+
parser.add_argument('--overwrite_cache', action='store_true',
|
415 |
+
help="Overwrite the cached training and evaluation sets")
|
416 |
+
parser.add_argument('--seed', type=int, default=42,
|
417 |
+
help="random seed for initialization")
|
418 |
+
|
419 |
+
parser.add_argument('--fp16', action='store_true',
|
420 |
+
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
421 |
+
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
422 |
+
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
423 |
+
"See details at https://nvidia.github.io/apex/amp.html")
|
424 |
+
parser.add_argument("--local_rank", type=int, default=-1,
|
425 |
+
help="For distributed training: local_rank")
|
426 |
+
parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
|
427 |
+
parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
|
428 |
+
args = parser.parse_args()
|
429 |
+
|
430 |
+
if os.path.exists(args.output_dir) and os.listdir(
|
431 |
+
args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
432 |
+
raise ValueError(
|
433 |
+
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
|
434 |
+
args.output_dir))
|
435 |
+
|
436 |
+
# Setup distant debugging if needed
|
437 |
+
if args.server_ip and args.server_port:
|
438 |
+
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
439 |
+
import ptvsd
|
440 |
+
print("Waiting for debugger attach")
|
441 |
+
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
442 |
+
ptvsd.wait_for_attach()
|
443 |
+
|
444 |
+
# Setup CUDA, GPU & distributed training
|
445 |
+
if args.local_rank == -1 or args.no_cuda:
|
446 |
+
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
447 |
+
args.n_gpu = torch.cuda.device_count()
|
448 |
+
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
449 |
+
torch.cuda.set_device(args.local_rank)
|
450 |
+
device = torch.device("cuda", args.local_rank)
|
451 |
+
torch.distributed.init_process_group(backend='nccl')
|
452 |
+
args.n_gpu = 1
|
453 |
+
args.device = device
|
454 |
+
|
455 |
+
# Setup logging
|
456 |
+
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
457 |
+
datefmt='%m/%d/%Y %H:%M:%S',
|
458 |
+
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
459 |
+
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
460 |
+
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
|
461 |
+
|
462 |
+
# Set seed
|
463 |
+
set_seed(args)
|
464 |
+
|
465 |
+
# Prepare GLUE task
|
466 |
+
args.task_name = args.task_name.lower()
|
467 |
+
print(processors)
|
468 |
+
if args.task_name not in processors:
|
469 |
+
raise ValueError("Task not found: %s" % (args.task_name))
|
470 |
+
processor = processors[args.task_name]()
|
471 |
+
args.output_mode = output_modes[args.task_name]
|
472 |
+
label_list = processor.get_labels()
|
473 |
+
num_labels = len(label_list)
|
474 |
+
|
475 |
+
# Load pretrained model and tokenizer
|
476 |
+
if args.local_rank not in [-1, 0]:
|
477 |
+
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
478 |
+
|
479 |
+
args.model_type = args.model_type.lower()
|
480 |
+
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
481 |
+
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path,
|
482 |
+
num_labels=num_labels,
|
483 |
+
finetuning_task=args.task_name,
|
484 |
+
cache_dir=args.cache_dir if args.cache_dir else None)
|
485 |
+
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
486 |
+
do_lower_case=args.do_lower_case,
|
487 |
+
cache_dir=args.cache_dir if args.cache_dir else None)
|
488 |
+
model = model_class.from_pretrained(args.model_name_or_path,
|
489 |
+
from_tf=bool('.ckpt' in args.model_name_or_path),
|
490 |
+
config=config,
|
491 |
+
cache_dir=args.cache_dir if args.cache_dir else None)
|
492 |
+
|
493 |
+
if args.local_rank == 0:
|
494 |
+
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
495 |
+
|
496 |
+
model.to(args.device)
|
497 |
+
|
498 |
+
logger.info("Training/evaluation parameters %s", args)
|
499 |
+
|
500 |
+
# Training
|
501 |
+
if args.do_train:
|
502 |
+
train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
|
503 |
+
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
|
504 |
+
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
505 |
+
|
506 |
+
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
|
507 |
+
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
508 |
+
# Create output directory if needed
|
509 |
+
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
|
510 |
+
os.makedirs(args.output_dir)
|
511 |
+
|
512 |
+
logger.info("Saving model checkpoint to %s", args.output_dir)
|
513 |
+
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
514 |
+
# They can then be reloaded using `from_pretrained()`
|
515 |
+
model_to_save = model.module if hasattr(model,
|
516 |
+
'module') else model # Take care of distributed/parallel training
|
517 |
+
model_to_save.save_pretrained(args.output_dir)
|
518 |
+
tokenizer.save_pretrained(args.output_dir)
|
519 |
+
|
520 |
+
# Good practice: save your training arguments together with the trained model
|
521 |
+
torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
|
522 |
+
|
523 |
+
# Load a trained model and vocabulary that you have fine-tuned
|
524 |
+
model = model_class.from_pretrained(args.output_dir)
|
525 |
+
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
|
526 |
+
model.to(args.device)
|
527 |
+
|
528 |
+
# Evaluation
|
529 |
+
results = {}
|
530 |
+
if args.do_eval and args.local_rank in [-1, 0]:
|
531 |
+
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
532 |
+
checkpoints = [args.output_dir]
|
533 |
+
if args.eval_all_checkpoints:
|
534 |
+
checkpoints = list(
|
535 |
+
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
|
536 |
+
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
|
537 |
+
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
538 |
+
for checkpoint in checkpoints:
|
539 |
+
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
|
540 |
+
prefix = checkpoint.split('/')[-1] if checkpoint.find('checkpoint') != -1 else ""
|
541 |
+
|
542 |
+
model = model_class.from_pretrained(checkpoint)
|
543 |
+
model.to(args.device)
|
544 |
+
result = evaluate(args, model, tokenizer, prefix=prefix)
|
545 |
+
result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
|
546 |
+
results.update(result)
|
547 |
+
|
548 |
+
return results
|
549 |
+
|
550 |
+
|
551 |
+
if __name__ == "__main__":
|
552 |
+
main()
|
datasets/LogiQA2.0/logiqa/utils_mrc.py
ADDED
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""
|
17 |
+
This Script is Modified for Natural Language Inference Datasets fine-tuning.
|
18 |
+
All the datasets can be downloaded from this repo.
|
19 |
+
"""
|
20 |
+
|
21 |
+
import logging
|
22 |
+
import os
|
23 |
+
import sys
|
24 |
+
import json
|
25 |
+
from typing import List
|
26 |
+
|
27 |
+
import tqdm
|
28 |
+
|
29 |
+
from transformers import PreTrainedTokenizer
|
30 |
+
from transformers.file_utils import is_tf_available
|
31 |
+
|
32 |
+
if is_tf_available():
|
33 |
+
import tensorflow as tf
|
34 |
+
|
35 |
+
logger = logging.getLogger(__name__)
|
36 |
+
|
37 |
+
|
38 |
+
class InputExample(object):
|
39 |
+
"""A single training/test example for multiple choice"""
|
40 |
+
|
41 |
+
def __init__(self, example_id, question, contexts, endings, label=None):
|
42 |
+
"""Constructs a InputExample.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
example_id: Unique id for the example.
|
46 |
+
contexts: list of str. The untokenized text of the first sequence (context of corresponding question).
|
47 |
+
question: string. The untokenized text of the second sequence (question).
|
48 |
+
endings: list of str. multiple choice's options. Its length must be equal to contexts' length.
|
49 |
+
label: (Optional) string. The label of the example. This should be
|
50 |
+
specified for train and dev examples, but not for test examples.
|
51 |
+
"""
|
52 |
+
self.example_id = example_id
|
53 |
+
self.question = question
|
54 |
+
self.contexts = contexts
|
55 |
+
self.endings = endings
|
56 |
+
self.label = label
|
57 |
+
|
58 |
+
|
59 |
+
class InputFeatures(object):
|
60 |
+
def __init__(self, example_id, choices_features, label):
|
61 |
+
self.example_id = example_id
|
62 |
+
self.choices_features = [
|
63 |
+
{"input_ids": input_ids, "input_mask": input_mask, "segment_ids": segment_ids}
|
64 |
+
for input_ids, input_mask, segment_ids in choices_features
|
65 |
+
]
|
66 |
+
self.label = label
|
67 |
+
|
68 |
+
class DataProcessor(object):
|
69 |
+
"""Base class for data converters for multiple choice data sets."""
|
70 |
+
|
71 |
+
def get_train_examples(self, data_dir):
|
72 |
+
"""Gets a collection of `InputExample`s for the train set."""
|
73 |
+
raise NotImplementedError()
|
74 |
+
|
75 |
+
def get_dev_examples(self, data_dir):
|
76 |
+
"""Gets a collection of `InputExample`s for the dev set."""
|
77 |
+
raise NotImplementedError()
|
78 |
+
|
79 |
+
def get_test_examples(self, data_dir):
|
80 |
+
"""Gets a collection of `InputExample`s for the test set."""
|
81 |
+
raise NotImplementedError()
|
82 |
+
|
83 |
+
def get_labels(self):
|
84 |
+
"""Gets the list of labels for this data set."""
|
85 |
+
raise NotImplementedError()
|
86 |
+
|
87 |
+
|
88 |
+
|
89 |
+
def convert_examples_to_features(
|
90 |
+
examples: List[InputExample],
|
91 |
+
label_list: List[str],
|
92 |
+
max_length: int,
|
93 |
+
tokenizer: PreTrainedTokenizer,
|
94 |
+
pad_token_segment_id=0,
|
95 |
+
pad_on_left=False,
|
96 |
+
pad_token=0,
|
97 |
+
mask_padding_with_zero=True,
|
98 |
+
) -> List[InputFeatures]:
|
99 |
+
"""
|
100 |
+
Loads a data file into a list of `InputFeatures`
|
101 |
+
"""
|
102 |
+
|
103 |
+
label_map = {label: i for i, label in enumerate(label_list)}
|
104 |
+
|
105 |
+
features = []
|
106 |
+
for (ex_index, example) in tqdm.tqdm(enumerate(examples), desc="convert examples to features"):
|
107 |
+
if ex_index % 10000 == 0:
|
108 |
+
logger.info("Writing example %d of %d" % (ex_index, len(examples)))
|
109 |
+
choices_features = []
|
110 |
+
for ending_idx, (context, ending) in enumerate(zip(example.contexts, example.endings)):
|
111 |
+
text_a = context
|
112 |
+
if example.question.find("_") != -1:
|
113 |
+
# this is for cloze question
|
114 |
+
text_b = example.question.replace("_", ending)
|
115 |
+
else:
|
116 |
+
text_b = example.question + " " + ending
|
117 |
+
|
118 |
+
inputs = tokenizer.encode_plus(text_a, text_b, add_special_tokens=True, max_length=max_length, return_token_type_ids=True)
|
119 |
+
if "num_truncated_tokens" in inputs and inputs["num_truncated_tokens"] > 0:
|
120 |
+
logger.info(
|
121 |
+
"Attention! you are cropping tokens (swag task is ok). "
|
122 |
+
"If you are training ARC and RACE and you are poping question + options,"
|
123 |
+
"you need to try to use a bigger max seq length!"
|
124 |
+
)
|
125 |
+
|
126 |
+
input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]
|
127 |
+
|
128 |
+
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
129 |
+
# tokens are attended to.
|
130 |
+
attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
|
131 |
+
|
132 |
+
# Zero-pad up to the sequence length.
|
133 |
+
padding_length = max_length - len(input_ids)
|
134 |
+
if pad_on_left:
|
135 |
+
input_ids = ([pad_token] * padding_length) + input_ids
|
136 |
+
attention_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + attention_mask
|
137 |
+
token_type_ids = ([pad_token_segment_id] * padding_length) + token_type_ids
|
138 |
+
else:
|
139 |
+
input_ids = input_ids + ([pad_token] * padding_length)
|
140 |
+
attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
|
141 |
+
token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)
|
142 |
+
|
143 |
+
assert len(input_ids) == max_length
|
144 |
+
assert len(attention_mask) == max_length
|
145 |
+
assert len(token_type_ids) == max_length
|
146 |
+
choices_features.append((input_ids, attention_mask, token_type_ids))
|
147 |
+
|
148 |
+
label = label_map[example.label]
|
149 |
+
|
150 |
+
if ex_index < 2:
|
151 |
+
logger.info("*** Example ***")
|
152 |
+
logger.info("race_id: {}".format(example.example_id))
|
153 |
+
for choice_idx, (input_ids, attention_mask, token_type_ids) in enumerate(choices_features):
|
154 |
+
logger.info("choice: {}".format(choice_idx))
|
155 |
+
logger.info("input_ids: {}".format(" ".join(map(str, input_ids))))
|
156 |
+
logger.info("attention_mask: {}".format(" ".join(map(str, attention_mask))))
|
157 |
+
logger.info("token_type_ids: {}".format(" ".join(map(str, token_type_ids))))
|
158 |
+
logger.info("label: {}".format(label))
|
159 |
+
|
160 |
+
features.append(InputFeatures(example_id=example.example_id, choices_features=choices_features, label=label,))
|
161 |
+
|
162 |
+
return features
|
163 |
+
|
164 |
+
|
165 |
+
|
166 |
+
|
167 |
+
class LogiProcessor(DataProcessor):
|
168 |
+
"""Processor for the ReClor data set."""
|
169 |
+
|
170 |
+
def get_train_examples(self, data_dir):
|
171 |
+
"""See base class."""
|
172 |
+
logger.info("LOOKING AT {} train".format(data_dir))
|
173 |
+
return self._create_examples(self._read_json(os.path.join(data_dir, "train.txt")), "train")
|
174 |
+
|
175 |
+
def get_dev_examples(self, data_dir):
|
176 |
+
"""See base class."""
|
177 |
+
logger.info("LOOKING AT {} dev".format(data_dir))
|
178 |
+
return self._create_examples(self._read_json(os.path.join(data_dir, "dev.txt")), "dev")
|
179 |
+
|
180 |
+
def get_test_examples(self, data_dir):
|
181 |
+
logger.info("LOOKING AT {} test".format(data_dir))
|
182 |
+
return self._create_examples(self._read_json(os.path.join(data_dir, "test.txt")), "test")
|
183 |
+
|
184 |
+
def get_labels(self):
|
185 |
+
"""See base class."""
|
186 |
+
return [0, 1, 2, 3]
|
187 |
+
|
188 |
+
def _read_json(self, input_file):
|
189 |
+
with open(input_file, 'r') as f:
|
190 |
+
lines = []
|
191 |
+
file = f.readlines()
|
192 |
+
for line in file:
|
193 |
+
line = json.loads(line)
|
194 |
+
lines.append(line)
|
195 |
+
return lines
|
196 |
+
|
197 |
+
# def _read_json(self, input_file):
|
198 |
+
# with open(input_file, "r") as f:
|
199 |
+
# lines = json.load(f)
|
200 |
+
# return lines
|
201 |
+
|
202 |
+
def _create_examples(self, lines, type):
|
203 |
+
"""Creates examples for the training and dev sets."""
|
204 |
+
examples = []
|
205 |
+
for d in lines:
|
206 |
+
context = d['text']
|
207 |
+
question = d['question']
|
208 |
+
answers = d['options']
|
209 |
+
label = 0 if type == "test" else d['answer'] # for test set, there is no label. Just use 0 for convenience.
|
210 |
+
id_string = d['id']
|
211 |
+
examples.append(
|
212 |
+
InputExample(
|
213 |
+
example_id = id_string,
|
214 |
+
question = question,
|
215 |
+
contexts=[context, context, context, context], # this is not efficient but convenient
|
216 |
+
endings=[answers[0], answers[1], answers[2], answers[3]],
|
217 |
+
label = label
|
218 |
+
)
|
219 |
+
)
|
220 |
+
return examples
|
221 |
+
|
222 |
+
|
223 |
+
try:
|
224 |
+
from scipy.stats import pearsonr, spearmanr
|
225 |
+
from sklearn.metrics import matthews_corrcoef, f1_score, confusion_matrix
|
226 |
+
|
227 |
+
_has_sklearn = True
|
228 |
+
except (AttributeError, ImportError):
|
229 |
+
_has_sklearn = False
|
230 |
+
|
231 |
+
|
232 |
+
def is_sklearn_available():
|
233 |
+
return _has_sklearn
|
234 |
+
|
235 |
+
|
236 |
+
if _has_sklearn:
|
237 |
+
|
238 |
+
def simple_accuracy(preds, labels):
|
239 |
+
return (preds == labels).mean()
|
240 |
+
|
241 |
+
def acc_and_f1(preds, labels):
|
242 |
+
acc = simple_accuracy(preds, labels)
|
243 |
+
f1 = f1_score(y_true=labels, y_pred=preds)
|
244 |
+
return {
|
245 |
+
"acc": acc,
|
246 |
+
"f1": f1,
|
247 |
+
"acc_and_f1": (acc + f1) / 2,
|
248 |
+
}
|
249 |
+
|
250 |
+
def pearson_and_spearman(preds, labels):
|
251 |
+
pearson_corr = pearsonr(preds, labels)[0]
|
252 |
+
spearman_corr = spearmanr(preds, labels)[0]
|
253 |
+
return {
|
254 |
+
"pearson": pearson_corr,
|
255 |
+
"spearmanr": spearman_corr,
|
256 |
+
"corr": (pearson_corr + spearman_corr) / 2,
|
257 |
+
}
|
258 |
+
|
259 |
+
def compute_metrics(task_name, preds, labels):
|
260 |
+
assert len(preds) == len(labels)
|
261 |
+
if task_name == "logiqa":
|
262 |
+
return {"acc": simple_accuracy(labels, preds)}
|
263 |
+
else:
|
264 |
+
raise KeyError(task_name)
|
265 |
+
|
266 |
+
|
267 |
+
tasks_num_labels = {
|
268 |
+
"logiqa": 4,
|
269 |
+
|
270 |
+
}
|
271 |
+
|
272 |
+
processors = {
|
273 |
+
"logiqa": LogiProcessor,
|
274 |
+
|
275 |
+
}
|
276 |
+
|
277 |
+
output_modes = {
|
278 |
+
"logiqa": "classification",
|
279 |
+
|
280 |
+
}
|
datasets/LogiQA2.0/logiqa2nli/DATA/QA2NLI/ dev_new.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c1542bf081f8020b99efade38d67552f3687e47bc47636e5a4ea90f439c508b3
|
3 |
+
size 2978446
|
datasets/LogiQA2.0/logiqa2nli/DATA/QA2NLI/dev.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:924ac0ea84135e8ebcc2332e093dc84bf85be722e67df77ee20ec508ecf12f36
|
3 |
+
size 2453230
|
datasets/LogiQA2.0/logiqa2nli/DATA/QA2NLI/readme.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# LogiQA 2.0 NLI version
|
datasets/LogiQA2.0/logiqa2nli/DATA/QA2NLI/stat.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
with open('test.txt', 'r') as f:
|
4 |
+
file = f.readlines()
|
5 |
+
n = 1
|
6 |
+
l = 0
|
7 |
+
for line in file:
|
8 |
+
line = json.loads(line)
|
9 |
+
text1 = line['major_premise']
|
10 |
+
text2 = line['minor_premise']
|
11 |
+
if type(text1) == str:
|
12 |
+
l = l + len(text1.split(" "))
|
13 |
+
else:
|
14 |
+
for text in text1:
|
15 |
+
l = l + len(text.split(" "))
|
16 |
+
if type(text2) == str:
|
17 |
+
l = l + len(text2.split(" "))
|
18 |
+
else:
|
19 |
+
for text in text2:
|
20 |
+
l = l + len(text.split(" "))
|
21 |
+
|
22 |
+
n += 1
|
23 |
+
|
24 |
+
result = l/n
|
25 |
+
print(result)
|
datasets/LogiQA2.0/logiqa2nli/DATA/QA2NLI/test.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cf5716861d7611d7d2aeed831c4651b1da9e717d318fcee043153bba34af8cfb
|
3 |
+
size 2458217
|
datasets/LogiQA2.0/logiqa2nli/DATA/QA2NLI/test_new.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a1d1a4db8ea93c979f7209dfa5db03598c1b645a66dfc7db89421461e85013ec
|
3 |
+
size 2484632
|
datasets/LogiQA2.0/logiqa2nli/DATA/QA2NLI/train.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2ade56a6c7a4ee9ff68a7193b07e0630df0955f305412ef561c50cb5b0a601f2
|
3 |
+
size 19684892
|
datasets/LogiQA2.0/logiqa2nli/DATA/QA2NLI/train_new.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d6669ad14ec82772471807f1231ae2b244d2be3983642c09cf00601e5c29d522
|
3 |
+
size 5325690
|
datasets/LogiQA2.0/logiqa2nli/nli-prompt.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import time
|
3 |
+
import openai
|
4 |
+
import sklearn
|
5 |
+
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
|
6 |
+
openai.api_key = ''
|
7 |
+
|
8 |
+
incontext = "Given the fact: All Cantonese are southerners. Some Cantonese don't like chili. Does it follow that: Some southerners don't like chili. Yes or no? yes\nGiven the fact: It is difficult for cactus to survive in humid climates; citrus is difficult to grow in cold climates. In most parts of a province, at least one species is not difficult to survive and grow between cactus and citrus. Does it follow that: Half of the province is humid and cold. Yes or no? no\nGiven the fact: It is difficult for cactus to survive in humid climates; citrus is difficult to grow in cold climates. In most parts of a province, at least one species is not difficult to survive and grow between cactus and citrus. Does it follow that: Most of the province is hot. Yes or no? no\nGiven the fact: It is difficult for cactus to survive in humid climates; citrus is difficult to grow in cold climates. In most parts of a province, at least one species is not difficult to survive and grow between cactus and citrus. Does it follow that: Most of the province is either dry or warm. Yes or no? yes\n"
|
9 |
+
def gpt3_api(prompt):
|
10 |
+
response = openai.Completion.create(
|
11 |
+
model="text-davinci-002",
|
12 |
+
prompt=incontext + prompt,
|
13 |
+
temperature=0,
|
14 |
+
max_tokens=60,
|
15 |
+
top_p=1.0,
|
16 |
+
frequency_penalty=0.0,
|
17 |
+
presence_penalty=0.0
|
18 |
+
)
|
19 |
+
return response
|
20 |
+
|
21 |
+
with open('test1.txt') as f:
|
22 |
+
c = 0
|
23 |
+
y_true = []
|
24 |
+
y_pred = []
|
25 |
+
lines = f.readlines()
|
26 |
+
for i, line in enumerate(lines):
|
27 |
+
line_dict = json.loads(line)
|
28 |
+
|
29 |
+
label = 0 if line_dict['label']=="not entailed" else 1
|
30 |
+
maj_premise = ' '.join(line_dict['major_premise'])
|
31 |
+
min_premise = ' '.join(line_dict['minor_premise'])
|
32 |
+
hypo = line_dict['conclusion']
|
33 |
+
prompt_input = "Given the fact: " + maj_premise + ' ' + min_premise + " Does it follow that: " + hypo + " Yes or no?"
|
34 |
+
|
35 |
+
y_true.append(label)
|
36 |
+
prompt = prompt_input
|
37 |
+
output = gpt3_api(prompt)
|
38 |
+
time.sleep(5)
|
39 |
+
pred = output.choices[0].text.lower()
|
40 |
+
y_pred.append(pred)
|
41 |
+
|
42 |
+
print(y_true)
|
43 |
+
print(y_pred)
|
44 |
+
f_score = f1_score(y_true, y_pred, average='binary')
|
45 |
+
p_score = precision_score(y_true, y_pred, average='binary')
|
46 |
+
r_score = recall_score(y_true, y_pred, average='binary')
|
47 |
+
acc = accuracy_score(y_true, y_pred)
|
48 |
+
print(f_score)
|
49 |
+
print(p_score)
|
50 |
+
print(r_score)
|
51 |
+
print(acc)
|
datasets/LogiQA2.0/logiqa2nli/qa2nli.sh
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export DATA_DIR=./DATA
|
2 |
+
export TASK_NAME=QA2NLI
|
3 |
+
|
4 |
+
CUDA_VISIBLE_DEVICES=0,1 python run_nli.py \
|
5 |
+
--model_type bert \
|
6 |
+
--model_name_or_path bert-base-uncased \
|
7 |
+
--task_name $TASK_NAME \
|
8 |
+
--do_eval \
|
9 |
+
--do_lower_case \
|
10 |
+
--data_dir $DATA_DIR/$TASK_NAME \
|
11 |
+
--max_seq_length 128 \
|
12 |
+
--per_gpu_eval_batch_size=64 \
|
13 |
+
--per_gpu_train_batch_size=64 \
|
14 |
+
--gradient_accumulation_steps 2\
|
15 |
+
--learning_rate 1e-5 \
|
16 |
+
--num_train_epochs 10.0 \
|
17 |
+
--logging_steps 5000 \
|
18 |
+
--save_steps 5000 \
|
19 |
+
--output_dir ./tmp/$TASK_NAME/ \
|
20 |
+
#--overwrite_output_dir \
|
datasets/LogiQA2.0/logiqa2nli/run_nli.py
ADDED
@@ -0,0 +1,549 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
""" Finetuning the library models for sequence classification on GLUE (Bert, XLM, XLNet, RoBERTa)."""
|
17 |
+
|
18 |
+
from __future__ import absolute_import, division, print_function
|
19 |
+
|
20 |
+
import argparse
|
21 |
+
import glob
|
22 |
+
import logging
|
23 |
+
import os
|
24 |
+
import random
|
25 |
+
|
26 |
+
import numpy as np
|
27 |
+
import torch
|
28 |
+
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
29 |
+
TensorDataset)
|
30 |
+
from torch.utils.data.distributed import DistributedSampler
|
31 |
+
|
32 |
+
try:
|
33 |
+
from torch.utils.tensorboard import SummaryWriter
|
34 |
+
except:
|
35 |
+
from tensorboardX import SummaryWriter
|
36 |
+
|
37 |
+
from tqdm import tqdm, trange
|
38 |
+
|
39 |
+
from transformers import (WEIGHTS_NAME, BertConfig,
|
40 |
+
BertForSequenceClassification, BertTokenizer,
|
41 |
+
RobertaConfig,
|
42 |
+
RobertaForSequenceClassification,
|
43 |
+
RobertaTokenizer,
|
44 |
+
XLMConfig, XLMForSequenceClassification,
|
45 |
+
XLMTokenizer, XLNetConfig,
|
46 |
+
XLNetForSequenceClassification,
|
47 |
+
XLNetTokenizer,
|
48 |
+
DistilBertConfig,
|
49 |
+
DistilBertForSequenceClassification,
|
50 |
+
DistilBertTokenizer,
|
51 |
+
AlbertConfig,
|
52 |
+
AlbertForSequenceClassification,
|
53 |
+
AlbertTokenizer,
|
54 |
+
)
|
55 |
+
|
56 |
+
from transformers import AdamW, get_linear_schedule_with_warmup
|
57 |
+
|
58 |
+
from utils_nli import compute_metrics
|
59 |
+
from utils_nli import output_modes
|
60 |
+
from utils_nli import processors
|
61 |
+
from utils_nli import convert_examples_to_features
|
62 |
+
|
63 |
+
|
64 |
+
logger = logging.getLogger(__name__)
|
65 |
+
|
66 |
+
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig,
|
67 |
+
RobertaConfig, DistilBertConfig)), ())
|
68 |
+
|
69 |
+
MODEL_CLASSES = {
|
70 |
+
'bert': (BertConfig, BertForSequenceClassification, BertTokenizer),
|
71 |
+
'xlnet': (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer),
|
72 |
+
'xlm': (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
|
73 |
+
'roberta': (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer),
|
74 |
+
'distilbert': (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer),
|
75 |
+
'albert': (AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer)
|
76 |
+
}
|
77 |
+
|
78 |
+
|
79 |
+
def set_seed(args):
|
80 |
+
random.seed(args.seed)
|
81 |
+
np.random.seed(args.seed)
|
82 |
+
torch.manual_seed(args.seed)
|
83 |
+
if args.n_gpu > 0:
|
84 |
+
torch.cuda.manual_seed_all(args.seed)
|
85 |
+
|
86 |
+
|
87 |
+
def train(args, train_dataset, model, tokenizer):
|
88 |
+
""" Train the model """
|
89 |
+
if args.local_rank in [-1, 0]:
|
90 |
+
tb_writer = SummaryWriter()
|
91 |
+
|
92 |
+
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
93 |
+
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
94 |
+
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
|
95 |
+
|
96 |
+
if args.max_steps > 0:
|
97 |
+
t_total = args.max_steps
|
98 |
+
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
|
99 |
+
else:
|
100 |
+
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
101 |
+
|
102 |
+
# Prepare optimizer and schedule (linear warmup and decay)
|
103 |
+
no_decay = ['bias', 'LayerNorm.weight']
|
104 |
+
optimizer_grouped_parameters = [
|
105 |
+
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
106 |
+
'weight_decay': args.weight_decay},
|
107 |
+
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
108 |
+
]
|
109 |
+
|
110 |
+
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
111 |
+
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps,
|
112 |
+
num_training_steps=t_total)
|
113 |
+
if args.fp16:
|
114 |
+
try:
|
115 |
+
from apex import amp
|
116 |
+
except ImportError:
|
117 |
+
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
118 |
+
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
|
119 |
+
|
120 |
+
# multi-gpu training (should be after apex fp16 initialization)
|
121 |
+
if args.n_gpu > 1:
|
122 |
+
model = torch.nn.DataParallel(model)
|
123 |
+
|
124 |
+
# Distributed training (should be after apex fp16 initialization)
|
125 |
+
if args.local_rank != -1:
|
126 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
|
127 |
+
output_device=args.local_rank,
|
128 |
+
find_unused_parameters=True)
|
129 |
+
|
130 |
+
# Train!
|
131 |
+
logger.info("***** Running training *****")
|
132 |
+
logger.info(" Num examples = %d", len(train_dataset))
|
133 |
+
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
134 |
+
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
135 |
+
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
136 |
+
args.train_batch_size * args.gradient_accumulation_steps * (
|
137 |
+
torch.distributed.get_world_size() if args.local_rank != -1 else 1))
|
138 |
+
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
139 |
+
logger.info(" Total optimization steps = %d", t_total)
|
140 |
+
|
141 |
+
global_step = 0
|
142 |
+
tr_loss, logging_loss = 0.0, 0.0
|
143 |
+
model.zero_grad()
|
144 |
+
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
|
145 |
+
set_seed(args) # Added here for reproductibility (even between python 2 and 3)
|
146 |
+
for _ in train_iterator:
|
147 |
+
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
148 |
+
for step, batch in enumerate(epoch_iterator):
|
149 |
+
model.train()
|
150 |
+
batch = tuple(t.to(args.device) for t in batch)
|
151 |
+
inputs = {'input_ids': batch[0],
|
152 |
+
'attention_mask': batch[1],
|
153 |
+
'labels': batch[3]}
|
154 |
+
if args.model_type != 'distilbert':
|
155 |
+
inputs['token_type_ids'] = batch[2] if args.model_type in ['bert',
|
156 |
+
'xlnet'] else None # XLM, DistilBERT and RoBERTa don't use segment_ids
|
157 |
+
outputs = model(**inputs)
|
158 |
+
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
|
159 |
+
|
160 |
+
if args.n_gpu > 1:
|
161 |
+
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
162 |
+
if args.gradient_accumulation_steps > 1:
|
163 |
+
loss = loss / args.gradient_accumulation_steps
|
164 |
+
|
165 |
+
if args.fp16:
|
166 |
+
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
167 |
+
scaled_loss.backward()
|
168 |
+
else:
|
169 |
+
loss.backward()
|
170 |
+
|
171 |
+
tr_loss += loss.item()
|
172 |
+
if (step + 1) % args.gradient_accumulation_steps == 0:
|
173 |
+
if args.fp16:
|
174 |
+
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
175 |
+
else:
|
176 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
177 |
+
|
178 |
+
optimizer.step()
|
179 |
+
scheduler.step() # Update learning rate schedule
|
180 |
+
model.zero_grad()
|
181 |
+
global_step += 1
|
182 |
+
|
183 |
+
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
184 |
+
# Log metrics
|
185 |
+
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
|
186 |
+
results = evaluate(args, model, tokenizer)
|
187 |
+
for key, value in results.items():
|
188 |
+
tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
|
189 |
+
tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
|
190 |
+
tb_writer.add_scalar('loss', (tr_loss - logging_loss) / args.logging_steps, global_step)
|
191 |
+
logging_loss = tr_loss
|
192 |
+
|
193 |
+
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
194 |
+
# Save model checkpoint
|
195 |
+
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
|
196 |
+
if not os.path.exists(output_dir):
|
197 |
+
os.makedirs(output_dir)
|
198 |
+
model_to_save = model.module if hasattr(model,
|
199 |
+
'module') else model # Take care of distributed/parallel training
|
200 |
+
model_to_save.save_pretrained(output_dir)
|
201 |
+
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
|
202 |
+
logger.info("Saving model checkpoint to %s", output_dir)
|
203 |
+
|
204 |
+
if args.max_steps > 0 and global_step > args.max_steps:
|
205 |
+
epoch_iterator.close()
|
206 |
+
break
|
207 |
+
if args.max_steps > 0 and global_step > args.max_steps:
|
208 |
+
train_iterator.close()
|
209 |
+
break
|
210 |
+
|
211 |
+
if args.local_rank in [-1, 0]:
|
212 |
+
tb_writer.close()
|
213 |
+
|
214 |
+
return global_step, tr_loss / global_step
|
215 |
+
|
216 |
+
|
217 |
+
def evaluate(args, model, tokenizer, prefix=""):
|
218 |
+
# Loop to handle MNLI double evaluation (matched, mis-matched)
|
219 |
+
eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,)
|
220 |
+
eval_outputs_dirs = (args.output_dir, args.output_dir + '-MM') if args.task_name == "mnli" else (args.output_dir,)
|
221 |
+
|
222 |
+
results = {}
|
223 |
+
for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
|
224 |
+
eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True)
|
225 |
+
|
226 |
+
if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
|
227 |
+
os.makedirs(eval_output_dir)
|
228 |
+
|
229 |
+
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
230 |
+
# Note that DistributedSampler samples randomly
|
231 |
+
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
|
232 |
+
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
233 |
+
|
234 |
+
# multi-gpu eval
|
235 |
+
if args.n_gpu > 1:
|
236 |
+
model = torch.nn.DataParallel(model)
|
237 |
+
|
238 |
+
# Eval!
|
239 |
+
logger.info("***** Running evaluation {} *****".format(prefix))
|
240 |
+
logger.info(" Num examples = %d", len(eval_dataset))
|
241 |
+
logger.info(" Batch size = %d", args.eval_batch_size)
|
242 |
+
eval_loss = 0.0
|
243 |
+
nb_eval_steps = 0
|
244 |
+
preds = None
|
245 |
+
out_label_ids = None
|
246 |
+
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
247 |
+
model.eval()
|
248 |
+
batch = tuple(t.to(args.device) for t in batch)
|
249 |
+
|
250 |
+
with torch.no_grad():
|
251 |
+
inputs = {'input_ids': batch[0],
|
252 |
+
'attention_mask': batch[1],
|
253 |
+
'labels': batch[3]}
|
254 |
+
if args.model_type != 'distilbert':
|
255 |
+
inputs['token_type_ids'] = batch[2] if args.model_type in ['bert',
|
256 |
+
'xlnet'] else None # XLM, DistilBERT and RoBERTa don't use segment_ids
|
257 |
+
outputs = model(**inputs)
|
258 |
+
tmp_eval_loss, logits = outputs[:2]
|
259 |
+
|
260 |
+
eval_loss += tmp_eval_loss.mean().item()
|
261 |
+
nb_eval_steps += 1
|
262 |
+
if preds is None:
|
263 |
+
preds = logits.detach().cpu().numpy()
|
264 |
+
out_label_ids = inputs['labels'].detach().cpu().numpy()
|
265 |
+
else:
|
266 |
+
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
|
267 |
+
out_label_ids = np.append(out_label_ids, inputs['labels'].detach().cpu().numpy(), axis=0)
|
268 |
+
|
269 |
+
eval_loss = eval_loss / nb_eval_steps
|
270 |
+
if args.output_mode == "classification":
|
271 |
+
preds = np.argmax(preds, axis=1)
|
272 |
+
elif args.output_mode == "regression":
|
273 |
+
preds = np.squeeze(preds)
|
274 |
+
result = {"eval": compute_metrics(eval_task, preds, out_label_ids), "loss": eval_loss}
|
275 |
+
results.update(result)
|
276 |
+
|
277 |
+
output_pred_file = os.path.join(eval_output_dir, prefix, "pred_results.txt")
|
278 |
+
with open(output_pred_file, "a") as writer:
|
279 |
+
for pred in preds:
|
280 |
+
writer.write(str(pred)+"\n")
|
281 |
+
|
282 |
+
output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
|
283 |
+
with open(output_eval_file, "w") as writer:
|
284 |
+
logger.info("***** Eval results {} *****".format(prefix))
|
285 |
+
for key in sorted(result.keys()):
|
286 |
+
logger.info(" %s = %s", key, str(result[key]))
|
287 |
+
writer.write("%s = %s\n" % (key, str(result[key])))
|
288 |
+
|
289 |
+
return results
|
290 |
+
|
291 |
+
|
292 |
+
def load_and_cache_examples(args, task, tokenizer, evaluate=False):
|
293 |
+
if args.local_rank not in [-1, 0] and not evaluate:
|
294 |
+
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
295 |
+
|
296 |
+
processor = processors[task]()
|
297 |
+
output_mode = output_modes[task]
|
298 |
+
# Load data features from cache or dataset file
|
299 |
+
cached_features_file = os.path.join(args.data_dir, 'cached_{}_{}_{}_{}'.format(
|
300 |
+
'dev' if evaluate else 'train',
|
301 |
+
list(filter(None, args.model_name_or_path.split('/'))).pop(),
|
302 |
+
str(args.max_seq_length),
|
303 |
+
str(task)))
|
304 |
+
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
305 |
+
logger.info("Loading features from cached file %s", cached_features_file)
|
306 |
+
features = torch.load(cached_features_file)
|
307 |
+
else:
|
308 |
+
logger.info("Creating features from dataset file at %s", args.data_dir)
|
309 |
+
label_list = processor.get_labels()
|
310 |
+
if task in ['mnli', 'mnli-mm'] and args.model_type in ['roberta']:
|
311 |
+
# HACK(label indices are swapped in RoBERTa pretrained model)
|
312 |
+
label_list[1], label_list[2] = label_list[2], label_list[1]
|
313 |
+
examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(
|
314 |
+
args.data_dir)
|
315 |
+
features = convert_examples_to_features(examples,
|
316 |
+
tokenizer,
|
317 |
+
label_list=label_list,
|
318 |
+
max_length=args.max_seq_length,
|
319 |
+
output_mode=output_mode,
|
320 |
+
pad_on_left=bool(args.model_type in ['xlnet']),
|
321 |
+
# pad on the left for xlnet
|
322 |
+
pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
|
323 |
+
pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0,
|
324 |
+
)
|
325 |
+
if args.local_rank in [-1, 0]:
|
326 |
+
logger.info("Saving features into cached file %s", cached_features_file)
|
327 |
+
torch.save(features, cached_features_file)
|
328 |
+
|
329 |
+
if args.local_rank == 0 and not evaluate:
|
330 |
+
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
331 |
+
|
332 |
+
# Convert to Tensors and build dataset
|
333 |
+
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
334 |
+
all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
|
335 |
+
all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
|
336 |
+
if output_mode == "classification":
|
337 |
+
all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
|
338 |
+
elif output_mode == "regression":
|
339 |
+
all_labels = torch.tensor([f.label for f in features], dtype=torch.float)
|
340 |
+
|
341 |
+
dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels)
|
342 |
+
return dataset
|
343 |
+
|
344 |
+
|
345 |
+
def main():
|
346 |
+
parser = argparse.ArgumentParser()
|
347 |
+
|
348 |
+
## Required parameters
|
349 |
+
parser.add_argument("--data_dir", default=None, type=str, required=True,
|
350 |
+
help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
|
351 |
+
parser.add_argument("--model_type", default=None, type=str, required=True,
|
352 |
+
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
353 |
+
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
|
354 |
+
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(
|
355 |
+
ALL_MODELS))
|
356 |
+
parser.add_argument("--task_name", default=None, type=str, required=True,
|
357 |
+
help="The name of the task to train selected in the list: " + ", ".join(processors.keys()))
|
358 |
+
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
359 |
+
help="The output directory where the model predictions and checkpoints will be written.")
|
360 |
+
|
361 |
+
## Other parameters
|
362 |
+
parser.add_argument("--config_name", default="", type=str,
|
363 |
+
help="Pretrained config name or path if not the same as model_name")
|
364 |
+
parser.add_argument("--tokenizer_name", default="", type=str,
|
365 |
+
help="Pretrained tokenizer name or path if not the same as model_name")
|
366 |
+
parser.add_argument("--cache_dir", default="", type=str,
|
367 |
+
help="Where do you want to store the pre-trained models downloaded from s3")
|
368 |
+
parser.add_argument("--max_seq_length", default=128, type=int,
|
369 |
+
help="The maximum total input sequence length after tokenization. Sequences longer "
|
370 |
+
"than this will be truncated, sequences shorter will be padded.")
|
371 |
+
parser.add_argument("--do_train", action='store_true',
|
372 |
+
help="Whether to run training.")
|
373 |
+
parser.add_argument("--do_eval", action='store_true',
|
374 |
+
help="Whether to run eval on the dev set.")
|
375 |
+
parser.add_argument("--evaluate_during_training", action='store_true',
|
376 |
+
help="Rul evaluation during training at each logging step.")
|
377 |
+
parser.add_argument("--do_lower_case", action='store_true',
|
378 |
+
help="Set this flag if you are using an uncased model.")
|
379 |
+
|
380 |
+
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
|
381 |
+
help="Batch size per GPU/CPU for training.")
|
382 |
+
parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int,
|
383 |
+
help="Batch size per GPU/CPU for evaluation.")
|
384 |
+
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
385 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
386 |
+
parser.add_argument("--learning_rate", default=5e-5, type=float,
|
387 |
+
help="The initial learning rate for Adam.")
|
388 |
+
parser.add_argument("--weight_decay", default=0.0, type=float,
|
389 |
+
help="Weight deay if we apply some.")
|
390 |
+
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
|
391 |
+
help="Epsilon for Adam optimizer.")
|
392 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
393 |
+
help="Max gradient norm.")
|
394 |
+
parser.add_argument("--num_train_epochs", default=3.0, type=float,
|
395 |
+
help="Total number of training epochs to perform.")
|
396 |
+
parser.add_argument("--max_steps", default=-1, type=int,
|
397 |
+
help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
|
398 |
+
parser.add_argument("--warmup_steps", default=0, type=int,
|
399 |
+
help="Linear warmup over warmup_steps.")
|
400 |
+
|
401 |
+
parser.add_argument('--logging_steps', type=int, default=50,
|
402 |
+
help="Log every X updates steps.")
|
403 |
+
parser.add_argument('--save_steps', type=int, default=50,
|
404 |
+
help="Save checkpoint every X updates steps.")
|
405 |
+
parser.add_argument("--eval_all_checkpoints", action='store_true',
|
406 |
+
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
|
407 |
+
parser.add_argument("--no_cuda", action='store_true',
|
408 |
+
help="Avoid using CUDA when available")
|
409 |
+
parser.add_argument('--overwrite_output_dir', action='store_true',
|
410 |
+
help="Overwrite the content of the output directory")
|
411 |
+
parser.add_argument('--overwrite_cache', action='store_true',
|
412 |
+
help="Overwrite the cached training and evaluation sets")
|
413 |
+
parser.add_argument('--seed', type=int, default=42,
|
414 |
+
help="random seed for initialization")
|
415 |
+
|
416 |
+
parser.add_argument('--fp16', action='store_true',
|
417 |
+
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
418 |
+
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
419 |
+
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
420 |
+
"See details at https://nvidia.github.io/apex/amp.html")
|
421 |
+
parser.add_argument("--local_rank", type=int, default=-1,
|
422 |
+
help="For distributed training: local_rank")
|
423 |
+
parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
|
424 |
+
parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
|
425 |
+
args = parser.parse_args()
|
426 |
+
|
427 |
+
if os.path.exists(args.output_dir) and os.listdir(
|
428 |
+
args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
429 |
+
raise ValueError(
|
430 |
+
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
|
431 |
+
args.output_dir))
|
432 |
+
|
433 |
+
# Setup distant debugging if needed
|
434 |
+
if args.server_ip and args.server_port:
|
435 |
+
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
436 |
+
import ptvsd
|
437 |
+
print("Waiting for debugger attach")
|
438 |
+
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
439 |
+
ptvsd.wait_for_attach()
|
440 |
+
|
441 |
+
# Setup CUDA, GPU & distributed training
|
442 |
+
if args.local_rank == -1 or args.no_cuda:
|
443 |
+
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
444 |
+
args.n_gpu = torch.cuda.device_count()
|
445 |
+
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
446 |
+
torch.cuda.set_device(args.local_rank)
|
447 |
+
device = torch.device("cuda", args.local_rank)
|
448 |
+
torch.distributed.init_process_group(backend='nccl')
|
449 |
+
args.n_gpu = 1
|
450 |
+
args.device = device
|
451 |
+
|
452 |
+
# Setup logging
|
453 |
+
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
454 |
+
datefmt='%m/%d/%Y %H:%M:%S',
|
455 |
+
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
456 |
+
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
457 |
+
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
|
458 |
+
|
459 |
+
# Set seed
|
460 |
+
set_seed(args)
|
461 |
+
|
462 |
+
# Prepare GLUE task
|
463 |
+
args.task_name = args.task_name.lower()
|
464 |
+
print(processors)
|
465 |
+
if args.task_name not in processors:
|
466 |
+
raise ValueError("Task not found: %s" % (args.task_name))
|
467 |
+
processor = processors[args.task_name]()
|
468 |
+
args.output_mode = output_modes[args.task_name]
|
469 |
+
label_list = processor.get_labels()
|
470 |
+
num_labels = len(label_list)
|
471 |
+
|
472 |
+
# Load pretrained model and tokenizer
|
473 |
+
if args.local_rank not in [-1, 0]:
|
474 |
+
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
475 |
+
|
476 |
+
args.model_type = args.model_type.lower()
|
477 |
+
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
478 |
+
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path,
|
479 |
+
num_labels=num_labels,
|
480 |
+
finetuning_task=args.task_name,
|
481 |
+
cache_dir=args.cache_dir if args.cache_dir else None)
|
482 |
+
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
483 |
+
do_lower_case=args.do_lower_case,
|
484 |
+
cache_dir=args.cache_dir if args.cache_dir else None)
|
485 |
+
model = model_class.from_pretrained(args.model_name_or_path,
|
486 |
+
from_tf=bool('.ckpt' in args.model_name_or_path),
|
487 |
+
config=config,
|
488 |
+
cache_dir=args.cache_dir if args.cache_dir else None)
|
489 |
+
|
490 |
+
if args.local_rank == 0:
|
491 |
+
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
492 |
+
|
493 |
+
model.to(args.device)
|
494 |
+
|
495 |
+
logger.info("Training/evaluation parameters %s", args)
|
496 |
+
|
497 |
+
# Training
|
498 |
+
if args.do_train:
|
499 |
+
train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
|
500 |
+
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
|
501 |
+
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
502 |
+
|
503 |
+
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
|
504 |
+
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
505 |
+
# Create output directory if needed
|
506 |
+
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
|
507 |
+
os.makedirs(args.output_dir)
|
508 |
+
|
509 |
+
logger.info("Saving model checkpoint to %s", args.output_dir)
|
510 |
+
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
511 |
+
# They can then be reloaded using `from_pretrained()`
|
512 |
+
model_to_save = model.module if hasattr(model,
|
513 |
+
'module') else model # Take care of distributed/parallel training
|
514 |
+
model_to_save.save_pretrained(args.output_dir)
|
515 |
+
tokenizer.save_pretrained(args.output_dir)
|
516 |
+
|
517 |
+
# Good practice: save your training arguments together with the trained model
|
518 |
+
torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
|
519 |
+
|
520 |
+
# Load a trained model and vocabulary that you have fine-tuned
|
521 |
+
model = model_class.from_pretrained(args.output_dir)
|
522 |
+
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
|
523 |
+
model.to(args.device)
|
524 |
+
|
525 |
+
# Evaluation
|
526 |
+
results = {}
|
527 |
+
if args.do_eval and args.local_rank in [-1, 0]:
|
528 |
+
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
529 |
+
checkpoints = [args.output_dir]
|
530 |
+
if args.eval_all_checkpoints:
|
531 |
+
checkpoints = list(
|
532 |
+
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
|
533 |
+
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
|
534 |
+
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
535 |
+
for checkpoint in checkpoints:
|
536 |
+
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
|
537 |
+
prefix = checkpoint.split('/')[-1] if checkpoint.find('checkpoint') != -1 else ""
|
538 |
+
|
539 |
+
model = model_class.from_pretrained(checkpoint)
|
540 |
+
model.to(args.device)
|
541 |
+
result = evaluate(args, model, tokenizer, prefix=prefix)
|
542 |
+
result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
|
543 |
+
results.update(result)
|
544 |
+
|
545 |
+
return results
|
546 |
+
|
547 |
+
|
548 |
+
if __name__ == "__main__":
|
549 |
+
main()
|
datasets/LogiQA2.0/logiqa2nli/scripts/anli.sh
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export DATA_DIR=../DATA
|
2 |
+
export TASK_NAME=ANLI
|
3 |
+
|
4 |
+
CUDA_VISIBLE_DEVICES=0,1 python ../run_nli.py \
|
5 |
+
--model_type bert \
|
6 |
+
--model_name_or_path bert-base-uncased \
|
7 |
+
--task_name $TASK_NAME \
|
8 |
+
--do_train \
|
9 |
+
--do_eval \
|
10 |
+
--do_lower_case \
|
11 |
+
--data_dir $DATA_DIR/$TASK_NAME \
|
12 |
+
--max_seq_length 128 \
|
13 |
+
--per_gpu_eval_batch_size=64 \
|
14 |
+
--per_gpu_train_batch_size=64 \
|
15 |
+
--gradient_accumulation_steps 2\
|
16 |
+
--learning_rate 1e-5 \
|
17 |
+
--num_train_epochs 10.0 \
|
18 |
+
--logging_steps 5000 \
|
19 |
+
--save_steps 5000 \
|
20 |
+
--output_dir ./tmp/$TASK_NAME/ \
|
21 |
+
#--overwrite_output_dir \
|
datasets/LogiQA2.0/logiqa2nli/scripts/cood.sh
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export DATA_DIR=../DATA
|
2 |
+
export TASK_NAME=COOD
|
3 |
+
|
4 |
+
python ../run_nli.py --model_type bert --model_name_or_path bert-base-uncased --task_name $TASK_NAME --do_train --do_eval --do_lower_case --data_dir $DATA_DIR/$TASK_NAME --max_seq_length 128 --per_gpu_eval_batch_size=16 --per_gpu_train_batch_size=16 --gradient_accumulation_steps 2 --logging_steps 1000 --save_steps 1000 --learning_rate 2e-5 --eval_all_checkpoints --num_train_epochs 10.0 --output_dir ./tmp/$TASK_NAME/bert-base/
|
datasets/LogiQA2.0/logiqa2nli/scripts/mnli.sh
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export DATA_DIR=../DATA
|
2 |
+
export TASK_NAME=MNLI
|
3 |
+
|
4 |
+
CUDA_VISIBLE_DEVICES=0,1 python ../run_nli.py \
|
5 |
+
--model_type bert \
|
6 |
+
--model_name_or_path bert-base-uncased \
|
7 |
+
--task_name $TASK_NAME \
|
8 |
+
--do_train \
|
9 |
+
--do_eval \
|
10 |
+
--do_lower_case \
|
11 |
+
--data_dir $DATA_DIR/$TASK_NAME \
|
12 |
+
--max_seq_length 128 \
|
13 |
+
--per_gpu_eval_batch_size=64 \
|
14 |
+
--per_gpu_train_batch_size=64 \
|
15 |
+
--gradient_accumulation_steps 2\
|
16 |
+
--learning_rate 2e-5 \
|
17 |
+
--num_train_epochs 2.0 \
|
18 |
+
--logging_steps 5000 \
|
19 |
+
--save_steps 5000 \
|
20 |
+
--output_dir ./tmp/$TASK_NAME/ \
|
21 |
+
#--overwrite_output_dir \
|
datasets/LogiQA2.0/logiqa2nli/scripts/multirun.sh
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export DATA_DIR=../DATA
|
2 |
+
export TASK_NAME=QA2NLI
|
3 |
+
python ../run_nli.py --model_type bert --model_name_or_path bert-base-uncased --task_name $TASK_NAME --do_train --eval_all_checkpoints --do_eval --do_lower_case --data_dir $DATA_DIR/$TASK_NAME --max_seq_length 128 --per_gpu_eval_batch_size=16 --per_gpu_train_batch_size=8 --gradient_accumulation_steps 3 --logging_steps 5000 --save_steps 5000 --eval_all_checkpoints --learning_rate 2e-5 --num_train_epochs 2.0 --output_dir ./tmp/$TASK_NAME/bert-base/
|
4 |
+
#python ../run_nli.py --model_type bert --model_name_or_path bert-large-uncased --task_name $TASK_NAME --do_train --evaluate_during_training --do_eval --do_lower_case --data_dir $DATA_DIR/$TASK_NAME --max_seq_length 128 --per_gpu_eval_batch_size=16 --per_gpu_train_batch_size=8 --gradient_accumulation_steps 1 --learning_rate 2e-5 --save_steps 200 --adam_epsilon 1e-6 --no_clip_grad_norm --warmup_proportion 0.1 --num_train_epochs 5.0 --output_dir ./tmp/$TASK_NAME/bertlarge/
|
5 |
+
python ../run_nli.py --model_type roberta --model_name_or_path roberta-base --task_name $TASK_NAME --do_train --do_eval --eval_all_checkpoints --do_lower_case --data_dir $DATA_DIR/$TASK_NAME --max_seq_length 256 --per_gpu_eval_batch_size=16 --per_gpu_train_batch_size=8 --gradient_accumulation_steps 3 --logging_steps 5000 --save_steps 5000 --eval_all_checkpoints --learning_rate 1e-5 --num_train_epochs 2.0 --output_dir ./tmp/$TASK_NAME/roberta/
|
6 |
+
#python ../run_nli.py --model_type roberta --model_name_or_path /home/bimu/PycharmProjects/liu_nli/tmp-1/QNLI/roberta/ --task_name $TASK_NAME --do_train --do_eval --do_lower_case --data_dir $DATA_DIR/$TASK_NAME --max_seq_length 128 --per_gpu_eval_batch_size=16 --per_gpu_train_batch_size=16 --gradient_accumulation_steps 2 --learning_rate 2e-5 --num_train_epochs 5.0 --output_dir ./tmp/$TASK_NAME/roberta/
|
7 |
+
#python ../run_nli.py --model_type xlnet --model_name_or_path xlnet-base-cased --task_name $TASK_NAME --do_train --do_eval --eval_all_checkpoints --do_lower_case --data_dir $DATA_DIR/$TASK_NAME --max_seq_length 128 --per_gpu_eval_batch_size=16 --per_gpu_train_batch_size=8 --gradient_accumulation_steps 3 --logging_steps 500 --save_steps 500 --eval_all_checkpoints --learning_rate 2e-5 --adam_epsilon 1e-6 --num_train_epochs 5.0 --output_dir ./tmp/$TASK_NAME/xlnet/
|
8 |
+
#python ../run_nli.py --model_type bert --model_name_or_path bert-base-uncased --task_name $TASK_NAME --do_train --do_eval --do_lower_case --data_dir $DATA_DIR/$TASK_NAME --max_seq_length 128 --per_gpu_eval_batch_size=16 --per_gpu_train_batch_size=16 --gradient_accumulation_steps 2 --logging_steps 500 --save_steps 500 --eval_all_checkpoints --learning_rate 2e-5 --num_train_epochs 5.0 --output_dir ./tmp/$TASK_NAME/bert-base/
|
datasets/LogiQA2.0/logiqa2nli/scripts/pnli.sh
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export DATA_DIR=../DATA
|
2 |
+
export TASK_NAME=PNLI
|
3 |
+
|
4 |
+
CUDA_VISIBLE_DEVICES=0,1 python ../run_nli.py \
|
5 |
+
--model_type bert \
|
6 |
+
--model_name_or_path bert-base-uncased \
|
7 |
+
--task_name $TASK_NAME \
|
8 |
+
--do_train \
|
9 |
+
--do_eval \
|
10 |
+
--do_lower_case \
|
11 |
+
--data_dir $DATA_DIR/$TASK_NAME \
|
12 |
+
--max_seq_length 512 \
|
13 |
+
--per_gpu_eval_batch_size=8 \
|
14 |
+
--per_gpu_train_batch_size=8 \
|
15 |
+
--gradient_accumulation_steps 2\
|
16 |
+
--learning_rate 1e-5 \
|
17 |
+
--num_train_epochs 10.0 \
|
18 |
+
--logging_steps 5000 \
|
19 |
+
--save_steps 5000 \
|
20 |
+
--output_dir ./tmp/$TASK_NAME/ \
|
21 |
+
#--overwrite_output_dir \
|
datasets/LogiQA2.0/logiqa2nli/scripts/qa2nli.sh
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export DATA_DIR=../DATA
|
2 |
+
export TASK_NAME=QA2NLI
|
3 |
+
|
4 |
+
CUDA_VISIBLE_DEVICES=0,1 python ../run_nli.py \
|
5 |
+
--model_type bert \
|
6 |
+
--model_name_or_path bert-base-uncased \
|
7 |
+
--task_name $TASK_NAME \
|
8 |
+
--do_train \
|
9 |
+
--do_eval \
|
10 |
+
--do_lower_case \
|
11 |
+
--data_dir $DATA_DIR/$TASK_NAME \
|
12 |
+
--max_seq_length 128 \
|
13 |
+
--per_gpu_eval_batch_size=64 \
|
14 |
+
--per_gpu_train_batch_size=64 \
|
15 |
+
--gradient_accumulation_steps 2\
|
16 |
+
--learning_rate 1e-5 \
|
17 |
+
--num_train_epochs 10.0 \
|
18 |
+
--logging_steps 5000 \
|
19 |
+
--save_steps 5000 \
|
20 |
+
--output_dir ./tmp/$TASK_NAME/ \
|
21 |
+
#--overwrite_output_dir \
|
datasets/LogiQA2.0/logiqa2nli/scripts/qnli.sh
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export DATA_DIR=../DATA
|
2 |
+
export TASK_NAME=QNLI
|
3 |
+
|
4 |
+
CUDA_VISIBLE_DEVICES=0,1 python ../run_nli.py \
|
5 |
+
--model_type bert \
|
6 |
+
--model_name_or_path bert-base-uncased \
|
7 |
+
--task_name $TASK_NAME \
|
8 |
+
--do_train \
|
9 |
+
--do_eval \
|
10 |
+
--do_lower_case \
|
11 |
+
--data_dir $DATA_DIR/$TASK_NAME \
|
12 |
+
--max_seq_length 128 \
|
13 |
+
--per_gpu_eval_batch_size=64 \
|
14 |
+
--per_gpu_train_batch_size=64 \
|
15 |
+
--gradient_accumulation_steps 2\
|
16 |
+
--learning_rate 1e-5 \
|
17 |
+
--num_train_epochs 10.0 \
|
18 |
+
--logging_steps 5000 \
|
19 |
+
--save_steps 5000 \
|
20 |
+
--output_dir ./tmp/$TASK_NAME/ \
|
21 |
+
#--overwrite_output_dir \
|
datasets/LogiQA2.0/logiqa2nli/scripts/qood.sh
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export DATA_DIR=../DATA
|
2 |
+
export TASK_NAME=QOOD
|
3 |
+
|
4 |
+
CUDA_VISIBLE_DEVICES=0,1 python run_nli.py \
|
5 |
+
--model_type bert \
|
6 |
+
--model_name_or_path bert-base-uncased \
|
7 |
+
--task_name $TASK_NAME \
|
8 |
+
--do_eval \
|
9 |
+
--do_lower_case \
|
10 |
+
--data_dir $DATA_DIR/$TASK_NAME \
|
11 |
+
--max_seq_length 128 \
|
12 |
+
--per_gpu_eval_batch_size=64 \
|
13 |
+
--per_gpu_train_batch_size=64 \
|
14 |
+
--gradient_accumulation_steps 2\
|
15 |
+
--learning_rate 1e-5 \
|
16 |
+
--num_train_epochs 10.0 \
|
17 |
+
--logging_steps 5000 \
|
18 |
+
--save_steps 5000 \
|
19 |
+
--output_dir ./tmp/$TASK_NAME/ \
|
20 |
+
#--overwrite_output_dir \
|
datasets/LogiQA2.0/logiqa2nli/scripts/rte.sh
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export DATA_DIR=../DATA
|
2 |
+
export TASK_NAME=RTE
|
3 |
+
|
4 |
+
CUDA_VISIBLE_DEVICES=0,1 python ../run_DATA.py \
|
5 |
+
--model_type bert \
|
6 |
+
--model_name_or_path bert-base-uncased \
|
7 |
+
--task_name $TASK_NAME \
|
8 |
+
--do_eval \
|
9 |
+
--do_lower_case \
|
10 |
+
--data_dir $DATA_DIR/$TASK_NAME \
|
11 |
+
--max_seq_length 256 \
|
12 |
+
--per_gpu_eval_batch_size=16 \
|
13 |
+
--per_gpu_train_batch_size=16 \
|
14 |
+
--gradient_accumulation_steps 2\
|
15 |
+
--learning_rate 1e-5 \
|
16 |
+
--num_train_epochs 10.0 \
|
17 |
+
--logging_steps 5000 \
|
18 |
+
--save_steps 5000 \
|
19 |
+
--output_dir ./tmp/$TASK_NAME/ \
|
20 |
+
#--overwrite_output_dir \
|
datasets/LogiQA2.0/logiqa2nli/scripts/scitail.sh
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export DATA_DIR=../DATA
|
2 |
+
export TASK_NAME=SCITAIL
|
3 |
+
|
4 |
+
CUDA_VISIBLE_DEVICES=0,1 python ../run_nli.py \
|
5 |
+
--model_type bert \
|
6 |
+
--model_name_or_path bert-base-uncased \
|
7 |
+
--task_name $TASK_NAME \
|
8 |
+
--do_train \
|
9 |
+
--do_eval \
|
10 |
+
--do_lower_case \
|
11 |
+
--data_dir $DATA_DIR/$TASK_NAME \
|
12 |
+
--max_seq_length 128 \
|
13 |
+
--per_gpu_eval_batch_size=64 \
|
14 |
+
--per_gpu_train_batch_size=64 \
|
15 |
+
--gradient_accumulation_steps 2\
|
16 |
+
--evaluate_during_training \
|
17 |
+
--learning_rate 2e-5 \
|
18 |
+
--num_train_epochs 10.0 \
|
19 |
+
--logging_steps 5000 \
|
20 |
+
--save_steps 5000 \
|
21 |
+
--output_dir ./tmp/$TASK_NAME/ \
|
22 |
+
--overwrite_output_dir \
|
datasets/LogiQA2.0/logiqa2nli/scripts/snli.sh
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export DATA_DIR=./DATA
|
2 |
+
export TASK_NAME=SNLI
|
3 |
+
|
4 |
+
CUDA_VISIBLE_DEVICES=0,1 python ../run_nli.py \
|
5 |
+
--model_type bert \
|
6 |
+
--model_name_or_path bert-base-uncased \
|
7 |
+
--task_name $TASK_NAME \
|
8 |
+
--do_train \
|
9 |
+
--do_eval \
|
10 |
+
--do_lower_case \
|
11 |
+
--data_dir $DATA_DIR/$TASK_NAME \
|
12 |
+
--max_seq_length 128 \
|
13 |
+
--per_gpu_eval_batch_size=64 \
|
14 |
+
--per_gpu_train_batch_size=64 \
|
15 |
+
--gradient_accumulation_steps 2\
|
16 |
+
--learning_rate 2e-5 \
|
17 |
+
--num_train_epochs 2.0 \
|
18 |
+
--logging_steps 5000 \
|
19 |
+
--save_steps 5000 \
|
20 |
+
--output_dir ./tmp/$TASK_NAME/ \
|
21 |
+
#--overwrite_output_dir \
|
datasets/LogiQA2.0/logiqa2nli/scripts/wnli.sh
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export DATA_DIR=../DATA
|
2 |
+
export TASK_NAME=WNLI
|
3 |
+
|
4 |
+
CUDA_VISIBLE_DEVICES=0,1 python ../run_nli.py \
|
5 |
+
--model_type bert \
|
6 |
+
--model_name_or_path bert-base-uncased \
|
7 |
+
--task_name $TASK_NAME \
|
8 |
+
--do_train \
|
9 |
+
--do_eval \
|
10 |
+
--do_lower_case \
|
11 |
+
--data_dir $DATA_DIR/$TASK_NAME \
|
12 |
+
--max_seq_length 128 \
|
13 |
+
--per_gpu_eval_batch_size=64 \
|
14 |
+
--per_gpu_train_batch_size=64 \
|
15 |
+
--gradient_accumulation_steps 2\
|
16 |
+
--learning_rate 1e-5 \
|
17 |
+
--num_train_epochs 10.0 \
|
18 |
+
--logging_steps 5000 \
|
19 |
+
--save_steps 5000 \
|
20 |
+
--output_dir ./tmp/$TASK_NAME/ \
|
21 |
+
#--overwrite_output_dir \
|
datasets/LogiQA2.0/logiqa2nli/utils_nli.py
ADDED
@@ -0,0 +1,1002 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""
|
17 |
+
This Script is Modified for Natural Language Inference Datasets fine-tuning.
|
18 |
+
All the datasets can be downloaded from this repo.
|
19 |
+
"""
|
20 |
+
|
21 |
+
import logging
|
22 |
+
import os
|
23 |
+
import sys
|
24 |
+
import json
|
25 |
+
|
26 |
+
from transformers.data.processors.utils import DataProcessor, InputExample, InputFeatures
|
27 |
+
from transformers.file_utils import is_tf_available
|
28 |
+
|
29 |
+
if is_tf_available():
|
30 |
+
import tensorflow as tf
|
31 |
+
|
32 |
+
logger = logging.getLogger(__name__)
|
33 |
+
|
34 |
+
|
35 |
+
def convert_examples_to_features(examples, tokenizer,
|
36 |
+
max_length=512,
|
37 |
+
task=None,
|
38 |
+
label_list=None,
|
39 |
+
output_mode=None,
|
40 |
+
pad_on_left=False,
|
41 |
+
pad_token=0,
|
42 |
+
pad_token_segment_id=0,
|
43 |
+
mask_padding_with_zero=True):
|
44 |
+
|
45 |
+
is_tf_dataset = False
|
46 |
+
if is_tf_available() and isinstance(examples, tf.data.Dataset):
|
47 |
+
is_tf_dataset = True
|
48 |
+
|
49 |
+
if task is not None:
|
50 |
+
processor = glue_processors[task]()
|
51 |
+
if label_list is None:
|
52 |
+
label_list = processor.get_labels()
|
53 |
+
logger.info("Using label list %s for task %s" % (label_list, task))
|
54 |
+
if output_mode is None:
|
55 |
+
output_mode = glue_output_modes[task]
|
56 |
+
logger.info("Using output mode %s for task %s" % (output_mode, task))
|
57 |
+
|
58 |
+
label_map = {label: i for i, label in enumerate(label_list)}
|
59 |
+
|
60 |
+
features = []
|
61 |
+
for (ex_index, example) in enumerate(examples):
|
62 |
+
if ex_index % 10000 == 0:
|
63 |
+
logger.info("Writing example %d" % (ex_index))
|
64 |
+
if is_tf_dataset:
|
65 |
+
example = processor.get_example_from_tensor_dict(example)
|
66 |
+
example = processor.tfds_map(example)
|
67 |
+
|
68 |
+
inputs = tokenizer.encode_plus(
|
69 |
+
example.text_a,
|
70 |
+
example.text_b,
|
71 |
+
add_special_tokens=True,
|
72 |
+
max_length=max_length,
|
73 |
+
)
|
74 |
+
input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]
|
75 |
+
|
76 |
+
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
77 |
+
# tokens are attended to.
|
78 |
+
attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
|
79 |
+
|
80 |
+
# Zero-pad up to the sequence length.
|
81 |
+
padding_length = max_length - len(input_ids)
|
82 |
+
if pad_on_left:
|
83 |
+
input_ids = ([pad_token] * padding_length) + input_ids
|
84 |
+
attention_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + attention_mask
|
85 |
+
token_type_ids = ([pad_token_segment_id] * padding_length) + token_type_ids
|
86 |
+
else:
|
87 |
+
input_ids = input_ids + ([pad_token] * padding_length)
|
88 |
+
attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
|
89 |
+
token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)
|
90 |
+
|
91 |
+
assert len(input_ids) == max_length, "Error with input length {} vs {}".format(len(input_ids), max_length)
|
92 |
+
assert len(attention_mask) == max_length, "Error with input length {} vs {}".format(len(attention_mask), max_length)
|
93 |
+
assert len(token_type_ids) == max_length, "Error with input length {} vs {}".format(len(token_type_ids), max_length)
|
94 |
+
|
95 |
+
if output_mode == "classification":
|
96 |
+
label = label_map[example.label]
|
97 |
+
elif output_mode == "regression":
|
98 |
+
label = float(example.label)
|
99 |
+
else:
|
100 |
+
raise KeyError(output_mode)
|
101 |
+
|
102 |
+
if ex_index < 5:
|
103 |
+
logger.info("*** Example ***")
|
104 |
+
logger.info("guid: %s" % (example.guid))
|
105 |
+
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
|
106 |
+
logger.info("attention_mask: %s" % " ".join([str(x) for x in attention_mask]))
|
107 |
+
logger.info("token_type_ids: %s" % " ".join([str(x) for x in token_type_ids]))
|
108 |
+
logger.info("label: %s (id = %d)" % (example.label, label))
|
109 |
+
|
110 |
+
features.append(
|
111 |
+
InputFeatures(input_ids=input_ids,
|
112 |
+
attention_mask=attention_mask,
|
113 |
+
token_type_ids=token_type_ids,
|
114 |
+
label=label))
|
115 |
+
|
116 |
+
if is_tf_available() and is_tf_dataset:
|
117 |
+
def gen():
|
118 |
+
for ex in features:
|
119 |
+
yield ({'input_ids': ex.input_ids,
|
120 |
+
'attention_mask': ex.attention_mask,
|
121 |
+
'token_type_ids': ex.token_type_ids},
|
122 |
+
ex.label)
|
123 |
+
|
124 |
+
return tf.data.Dataset.from_generator(gen,
|
125 |
+
({'input_ids': tf.int32,
|
126 |
+
'attention_mask': tf.int32,
|
127 |
+
'token_type_ids': tf.int32},
|
128 |
+
tf.int64),
|
129 |
+
({'input_ids': tf.TensorShape([None]),
|
130 |
+
'attention_mask': tf.TensorShape([None]),
|
131 |
+
'token_type_ids': tf.TensorShape([None])},
|
132 |
+
tf.TensorShape([])))
|
133 |
+
|
134 |
+
return features
|
135 |
+
|
136 |
+
|
137 |
+
class SnliProcessor(DataProcessor):
|
138 |
+
"""Processor for the SNLI dataset (converted)."""
|
139 |
+
|
140 |
+
def get_example_from_tensor_dict(self, tensor_dict):
|
141 |
+
"""See base class."""
|
142 |
+
return InputExample(tensor_dict['idx'].numpy(),
|
143 |
+
tensor_dict['premise'].numpy().decode('utf-8'),
|
144 |
+
tensor_dict['hypothesis'].numpy().decode('utf-8'),
|
145 |
+
str(tensor_dict['label'].numpy()))
|
146 |
+
|
147 |
+
def get_train_examples(self, data_dir):
|
148 |
+
"""See base class."""
|
149 |
+
return self._create_examples(
|
150 |
+
self._read_txt(os.path.join(data_dir, "train.jsonl")), "train")
|
151 |
+
|
152 |
+
def get_dev_examples(self, data_dir):
|
153 |
+
"""See base class."""
|
154 |
+
return self._create_examples(
|
155 |
+
self._read_txt(os.path.join(data_dir, "dev.jsonl")), "dev")
|
156 |
+
|
157 |
+
def get_labels(self):
|
158 |
+
"""See base class."""
|
159 |
+
return ["e", "n", "c"]
|
160 |
+
|
161 |
+
def _read_txt(self, dir):
|
162 |
+
with open(dir, "r", encoding="utf-8") as f:
|
163 |
+
lines = []
|
164 |
+
for line in f.readlines():
|
165 |
+
if sys.version_info[0] == 2:
|
166 |
+
line = list(unicode(cell, 'utf-8') for cell in line)
|
167 |
+
lines.append(line)
|
168 |
+
return lines
|
169 |
+
|
170 |
+
def _create_examples(self, lines, set_type):
|
171 |
+
"""Creates examples for the training and dev sets."""
|
172 |
+
examples = []
|
173 |
+
for (i, line) in enumerate(lines):
|
174 |
+
dict_line = json.loads(line)
|
175 |
+
guid = "%s-%s" % (set_type, i)
|
176 |
+
label = dict_line['label']
|
177 |
+
text_a = dict_line['premise'].strip()
|
178 |
+
text_b = dict_line['hypothesis'].strip()
|
179 |
+
examples.append(
|
180 |
+
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)
|
181 |
+
)
|
182 |
+
return examples
|
183 |
+
|
184 |
+
|
185 |
+
class MnliProcessor(DataProcessor):
|
186 |
+
"""Processor for the MultiNLI data set (GLUE version)."""
|
187 |
+
|
188 |
+
def get_example_from_tensor_dict(self, tensor_dict):
|
189 |
+
"""See base class."""
|
190 |
+
return InputExample(tensor_dict['idx'].numpy(),
|
191 |
+
tensor_dict['premise'].numpy().decode('utf-8'),
|
192 |
+
tensor_dict['hypothesis'].numpy().decode('utf-8'),
|
193 |
+
str(tensor_dict['label'].numpy()))
|
194 |
+
|
195 |
+
def get_train_examples(self, data_dir):
|
196 |
+
"""See base class."""
|
197 |
+
return self._create_examples(
|
198 |
+
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
199 |
+
|
200 |
+
def get_dev_examples(self, data_dir):
|
201 |
+
"""See base class."""
|
202 |
+
return self._create_examples(
|
203 |
+
self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")),
|
204 |
+
"dev_matched")
|
205 |
+
|
206 |
+
def get_labels(self):
|
207 |
+
"""See base class."""
|
208 |
+
return ["contradiction", "entailment", "neutral"]
|
209 |
+
|
210 |
+
def _create_examples(self, lines, set_type):
|
211 |
+
"""Creates examples for the training and dev sets."""
|
212 |
+
examples = []
|
213 |
+
for (i, line) in enumerate(lines):
|
214 |
+
if i == 0:
|
215 |
+
continue
|
216 |
+
guid = "%s-%s" % (set_type, line[0])
|
217 |
+
text_a = line[8]
|
218 |
+
text_b = line[9]
|
219 |
+
label = line[-1]
|
220 |
+
examples.append(
|
221 |
+
InputExample(guid=guid, text_a=text_b, text_b=text_a, label=label))
|
222 |
+
return examples
|
223 |
+
|
224 |
+
|
225 |
+
class MnliMismatchedProcessor(MnliProcessor):
|
226 |
+
"""Processor for the MultiNLI Mismatched data set (GLUE version)."""
|
227 |
+
|
228 |
+
def get_dev_examples(self, data_dir):
|
229 |
+
"""See base class."""
|
230 |
+
return self._create_examples(
|
231 |
+
self._read_tsv(os.path.join(data_dir, "short/dev_mismatched.tsv")),
|
232 |
+
"dev_matched")
|
233 |
+
|
234 |
+
|
235 |
+
class ColaProcessor(DataProcessor):
|
236 |
+
"""Processor for the CoLA data set (GLUE version). <Linguistic Acceptability>"""
|
237 |
+
|
238 |
+
def get_example_from_tensor_dict(self, tensor_dict):
|
239 |
+
"""See base class."""
|
240 |
+
return InputExample(tensor_dict['idx'].numpy(),
|
241 |
+
tensor_dict['sentence'].numpy().decode('utf-8'),
|
242 |
+
None,
|
243 |
+
str(tensor_dict['label'].numpy()))
|
244 |
+
|
245 |
+
def get_train_examples(self, data_dir):
|
246 |
+
"""See base class."""
|
247 |
+
return self._create_examples(
|
248 |
+
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
249 |
+
|
250 |
+
def get_dev_examples(self, data_dir):
|
251 |
+
"""See base class."""
|
252 |
+
return self._create_examples(
|
253 |
+
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
254 |
+
|
255 |
+
def get_labels(self):
|
256 |
+
"""See base class."""
|
257 |
+
return ["0", "1"]
|
258 |
+
|
259 |
+
def _create_examples(self, lines, set_type):
|
260 |
+
"""Creates examples for the training and dev sets."""
|
261 |
+
examples = []
|
262 |
+
for (i, line) in enumerate(lines):
|
263 |
+
guid = "%s-%s" % (set_type, i)
|
264 |
+
text_a = line[3]
|
265 |
+
label = line[1]
|
266 |
+
examples.append(
|
267 |
+
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
|
268 |
+
return examples
|
269 |
+
|
270 |
+
class CoodProcessor(DataProcessor):
|
271 |
+
"""Processor for the CoLA-ood data set. <Linguistic Acceptability>"""
|
272 |
+
|
273 |
+
def get_example_from_tensor_dict(self, tensor_dict):
|
274 |
+
"""See base class."""
|
275 |
+
return InputExample(tensor_dict['idx'].numpy(),
|
276 |
+
tensor_dict['sentence'].numpy().decode('utf-8'),
|
277 |
+
None,
|
278 |
+
str(tensor_dict['label'].numpy()))
|
279 |
+
|
280 |
+
def _read_txt(self, dir):
|
281 |
+
with open(dir, "r", encoding="utf-8") as f:
|
282 |
+
lines = []
|
283 |
+
for line in f.readlines():
|
284 |
+
if sys.version_info[0] == 2:
|
285 |
+
line = list(unicode(cell, 'utf-8') for cell in line)
|
286 |
+
lines.append(line)
|
287 |
+
return lines
|
288 |
+
|
289 |
+
def get_train_examples(self, data_dir):
|
290 |
+
"""See base class."""
|
291 |
+
return self._create_examples(
|
292 |
+
self._read_txt(os.path.join(data_dir, "binary_train.txt")), "train")
|
293 |
+
|
294 |
+
def get_dev_examples(self, data_dir):
|
295 |
+
"""See base class."""
|
296 |
+
return self._create_examples(
|
297 |
+
self._read_txt(os.path.join(data_dir, "binary_dev.txt")), "dev")
|
298 |
+
|
299 |
+
def get_labels(self):
|
300 |
+
"""See base class."""
|
301 |
+
return [0, 1]
|
302 |
+
|
303 |
+
def _create_examples(self, lines, set_type):
|
304 |
+
"""Creates examples for the training and dev sets."""
|
305 |
+
examples = []
|
306 |
+
for (i, line) in enumerate(lines):
|
307 |
+
guid = "%s-%s" % (set_type, i)
|
308 |
+
dict_line = eval(line)
|
309 |
+
print(i)
|
310 |
+
text_a = dict_line['text']
|
311 |
+
label = dict_line['label']
|
312 |
+
examples.append(
|
313 |
+
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
|
314 |
+
return examples
|
315 |
+
|
316 |
+
class Sst2Processor(DataProcessor):
|
317 |
+
"""Processor for the SST-2 data set (GLUE version). <Sentiment Analysis>"""
|
318 |
+
|
319 |
+
def get_example_from_tensor_dict(self, tensor_dict):
|
320 |
+
"""See base class."""
|
321 |
+
return InputExample(tensor_dict['idx'].numpy(),
|
322 |
+
tensor_dict['sentence'].numpy().decode('utf-8'),
|
323 |
+
None,
|
324 |
+
str(tensor_dict['label'].numpy()))
|
325 |
+
|
326 |
+
def get_train_examples(self, data_dir):
|
327 |
+
"""See base class."""
|
328 |
+
return self._create_examples(
|
329 |
+
self._read_tsv(os.path.join(data_dir, "short/train.tsv")), "train")
|
330 |
+
|
331 |
+
def get_dev_examples(self, data_dir):
|
332 |
+
"""See base class."""
|
333 |
+
return self._create_examples(
|
334 |
+
self._read_tsv(os.path.join(data_dir, "short/dev.tsv")), "dev")
|
335 |
+
|
336 |
+
def get_labels(self):
|
337 |
+
"""See base class."""
|
338 |
+
return ["0", "1"]
|
339 |
+
|
340 |
+
def _create_examples(self, lines, set_type):
|
341 |
+
"""Creates examples for the training and dev sets."""
|
342 |
+
examples = []
|
343 |
+
for (i, line) in enumerate(lines):
|
344 |
+
if i == 0:
|
345 |
+
continue
|
346 |
+
guid = "%s-%s" % (set_type, i)
|
347 |
+
text_a = line[0]
|
348 |
+
label = line[1]
|
349 |
+
examples.append(
|
350 |
+
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
|
351 |
+
return examples
|
352 |
+
|
353 |
+
|
354 |
+
class StsbProcessor(DataProcessor):
|
355 |
+
"""Processor for the STS-B data set (GLUE version). <Text Similarity>"""
|
356 |
+
|
357 |
+
def get_example_from_tensor_dict(self, tensor_dict):
|
358 |
+
"""See base class."""
|
359 |
+
return InputExample(tensor_dict['idx'].numpy(),
|
360 |
+
tensor_dict['sentence1'].numpy().decode('utf-8'),
|
361 |
+
tensor_dict['sentence2'].numpy().decode('utf-8'),
|
362 |
+
str(tensor_dict['label'].numpy()))
|
363 |
+
|
364 |
+
def get_train_examples(self, data_dir):
|
365 |
+
"""See base class."""
|
366 |
+
return self._create_examples(
|
367 |
+
self._read_tsv(os.path.join(data_dir, "short/train.tsv")), "train")
|
368 |
+
|
369 |
+
def get_dev_examples(self, data_dir):
|
370 |
+
"""See base class."""
|
371 |
+
return self._create_examples(
|
372 |
+
self._read_tsv(os.path.join(data_dir, "short/dev.tsv")), "dev")
|
373 |
+
|
374 |
+
def get_labels(self):
|
375 |
+
"""See base class."""
|
376 |
+
return [None]
|
377 |
+
|
378 |
+
def _create_examples(self, lines, set_type):
|
379 |
+
"""Creates examples for the training and dev sets."""
|
380 |
+
examples = []
|
381 |
+
for (i, line) in enumerate(lines):
|
382 |
+
if i == 0:
|
383 |
+
continue
|
384 |
+
guid = "%s-%s" % (set_type, line[0])
|
385 |
+
text_a = line[7]
|
386 |
+
text_b = line[8]
|
387 |
+
label = line[-1]
|
388 |
+
examples.append(
|
389 |
+
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
390 |
+
return examples
|
391 |
+
|
392 |
+
|
393 |
+
class QqpProcessor(DataProcessor):
|
394 |
+
"""Processor for the QQP data set (GLUE version). <Paraphrase>"""
|
395 |
+
|
396 |
+
def get_example_from_tensor_dict(self, tensor_dict):
|
397 |
+
"""See base class."""
|
398 |
+
return InputExample(tensor_dict['idx'].numpy(),
|
399 |
+
tensor_dict['question1'].numpy().decode('utf-8'),
|
400 |
+
tensor_dict['question2'].numpy().decode('utf-8'),
|
401 |
+
str(tensor_dict['label'].numpy()))
|
402 |
+
|
403 |
+
def get_train_examples(self, data_dir):
|
404 |
+
"""See base class."""
|
405 |
+
return self._create_examples(
|
406 |
+
self._read_tsv(os.path.join(data_dir, "short/train.tsv")), "train")
|
407 |
+
|
408 |
+
def get_dev_examples(self, data_dir):
|
409 |
+
"""See base class."""
|
410 |
+
return self._create_examples(
|
411 |
+
self._read_tsv(os.path.join(data_dir, "short/dev.tsv")), "dev")
|
412 |
+
|
413 |
+
def get_labels(self):
|
414 |
+
"""See base class."""
|
415 |
+
return ["0", "1"]
|
416 |
+
|
417 |
+
def _create_examples(self, lines, set_type):
|
418 |
+
"""Creates examples for the training and dev sets."""
|
419 |
+
examples = []
|
420 |
+
for (i, line) in enumerate(lines):
|
421 |
+
if i == 0:
|
422 |
+
continue
|
423 |
+
guid = "%s-%s" % (set_type, line[0])
|
424 |
+
try:
|
425 |
+
text_a = line[3]
|
426 |
+
text_b = line[4]
|
427 |
+
label = line[5]
|
428 |
+
except IndexError:
|
429 |
+
continue
|
430 |
+
examples.append(
|
431 |
+
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
432 |
+
return examples
|
433 |
+
|
434 |
+
|
435 |
+
class QnliProcessor(DataProcessor):
|
436 |
+
"""Processor for the QNLI data set (GLUE version). <Question>"""
|
437 |
+
|
438 |
+
def get_example_from_tensor_dict(self, tensor_dict):
|
439 |
+
"""See base class."""
|
440 |
+
return InputExample(tensor_dict['idx'].numpy(),
|
441 |
+
tensor_dict['question'].numpy().decode('utf-8'),
|
442 |
+
tensor_dict['sentence'].numpy().decode('utf-8'),
|
443 |
+
str(tensor_dict['label'].numpy()))
|
444 |
+
|
445 |
+
def get_train_examples(self, data_dir):
|
446 |
+
"""See base class."""
|
447 |
+
return self._create_examples(
|
448 |
+
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
449 |
+
|
450 |
+
def get_dev_examples(self, data_dir):
|
451 |
+
"""See base class."""
|
452 |
+
return self._create_examples(
|
453 |
+
self._read_tsv(os.path.join(data_dir, "dev.tsv")),
|
454 |
+
"dev_matched")
|
455 |
+
|
456 |
+
def get_labels(self):
|
457 |
+
"""See base class."""
|
458 |
+
return ["entailment", "not_entailment"]
|
459 |
+
|
460 |
+
def _create_examples(self, lines, set_type):
|
461 |
+
"""Creates examples for the training and dev sets."""
|
462 |
+
examples = []
|
463 |
+
for (i, line) in enumerate(lines):
|
464 |
+
if i == 0:
|
465 |
+
continue
|
466 |
+
guid = "%s-%s" % (set_type, line[0])
|
467 |
+
text_a = line[1]
|
468 |
+
text_b = line[2]
|
469 |
+
label = line[-1]
|
470 |
+
examples.append(
|
471 |
+
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
472 |
+
return examples
|
473 |
+
|
474 |
+
|
475 |
+
class RteProcessor(DataProcessor):
|
476 |
+
"""Processor for the RTE data set (GLUE version)."""
|
477 |
+
|
478 |
+
def get_example_from_tensor_dict(self, tensor_dict):
|
479 |
+
"""See base class."""
|
480 |
+
return InputExample(tensor_dict['idx'].numpy(),
|
481 |
+
tensor_dict['sentence1'].numpy().decode('utf-8'),
|
482 |
+
tensor_dict['sentence2'].numpy().decode('utf-8'),
|
483 |
+
str(tensor_dict['label'].numpy()))
|
484 |
+
|
485 |
+
def get_train_examples(self, data_dir):
|
486 |
+
"""See base class."""
|
487 |
+
return self._create_examples(
|
488 |
+
self._read_tsv(os.path.join(data_dir, "short/train.tsv")), "train")
|
489 |
+
|
490 |
+
def get_dev_examples(self, data_dir):
|
491 |
+
"""See base class."""
|
492 |
+
return self._create_examples(
|
493 |
+
self._read_tsv(os.path.join(data_dir, "short/dev.tsv")), "dev")
|
494 |
+
|
495 |
+
def get_labels(self):
|
496 |
+
"""See base class."""
|
497 |
+
return ["entailment", "not_entailment"]
|
498 |
+
|
499 |
+
def _create_examples(self, lines, set_type):
|
500 |
+
"""Creates examples for the training and dev sets."""
|
501 |
+
examples = []
|
502 |
+
for (i, line) in enumerate(lines):
|
503 |
+
if i == 0:
|
504 |
+
continue
|
505 |
+
guid = "%s-%s" % (set_type, line[0])
|
506 |
+
text_a = line[1]
|
507 |
+
text_b = line[2]
|
508 |
+
label = line[-1]
|
509 |
+
examples.append(
|
510 |
+
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
511 |
+
return examples
|
512 |
+
|
513 |
+
|
514 |
+
class WnliProcessor(DataProcessor):
|
515 |
+
"""Processor for the WNLI data set (GLUE version)."""
|
516 |
+
|
517 |
+
def get_example_from_tensor_dict(self, tensor_dict):
|
518 |
+
"""See base class."""
|
519 |
+
return InputExample(tensor_dict['idx'].numpy(),
|
520 |
+
tensor_dict['sentence1'].numpy().decode('utf-8'),
|
521 |
+
tensor_dict['sentence2'].numpy().decode('utf-8'),
|
522 |
+
str(tensor_dict['label'].numpy()))
|
523 |
+
|
524 |
+
def get_train_examples(self, data_dir):
|
525 |
+
"""See base class."""
|
526 |
+
return self._create_examples(
|
527 |
+
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
528 |
+
|
529 |
+
def get_dev_examples(self, data_dir):
|
530 |
+
"""See base class."""
|
531 |
+
return self._create_examples(
|
532 |
+
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
533 |
+
|
534 |
+
def get_labels(self):
|
535 |
+
"""See base class."""
|
536 |
+
return ["0", "1"]
|
537 |
+
|
538 |
+
def _create_examples(self, lines, set_type):
|
539 |
+
"""Creates examples for the training and dev sets."""
|
540 |
+
examples = []
|
541 |
+
for (i, line) in enumerate(lines):
|
542 |
+
if i == 0:
|
543 |
+
continue
|
544 |
+
guid = "%s-%s" % (set_type, line[0])
|
545 |
+
text_a = line[1]
|
546 |
+
text_b = line[2]
|
547 |
+
label = line[-1]
|
548 |
+
examples.append(
|
549 |
+
InputExample(guid=guid, text_a='', text_b=text_a, label=label))
|
550 |
+
return examples
|
551 |
+
|
552 |
+
class PnliProcessor(DataProcessor):
|
553 |
+
"""Processor for the ConTRoL dataset (multi-sentence/paragraph/passage level). """
|
554 |
+
|
555 |
+
def get_example_from_tensor_dict(self, tensor_dict):
|
556 |
+
"""See base class."""
|
557 |
+
return InputExample(tensor_dict['context'].numpy().decode('utf-8'),
|
558 |
+
tensor_dict['hypothesis'].numpy().decode('utf-8'),
|
559 |
+
str(tensor_dict['label'].numpy()))
|
560 |
+
|
561 |
+
def get_train_examples(self, data_dir):
|
562 |
+
"""See base class."""
|
563 |
+
return self._create_examples(
|
564 |
+
self._read_txt(os.path.join(data_dir, "train.jsonl")), "train")
|
565 |
+
|
566 |
+
def get_dev_examples(self, data_dir):
|
567 |
+
"""See base class."""
|
568 |
+
return self._create_examples(
|
569 |
+
self._read_txt(os.path.join(data_dir, "dev.jsonl")), "dev")
|
570 |
+
|
571 |
+
def get_labels(self):
|
572 |
+
"""See base class."""
|
573 |
+
return ["c", "e", "n"]
|
574 |
+
|
575 |
+
def _read_txt(self, dir):
|
576 |
+
with open(dir, "r", encoding="utf-8") as f:
|
577 |
+
lines = []
|
578 |
+
for line in f.readlines():
|
579 |
+
if sys.version_info[0] == 2:
|
580 |
+
line = list(unicode(cell, 'utf-8') for cell in line)
|
581 |
+
lines.append(line)
|
582 |
+
return lines
|
583 |
+
|
584 |
+
def _create_examples(self, lines, set_type):
|
585 |
+
"""Creates examples for the training and dev sets."""
|
586 |
+
examples = []
|
587 |
+
for (i, line) in enumerate(lines):
|
588 |
+
dict_line = json.loads(line)
|
589 |
+
guid = "%s-%s" % (set_type, i)
|
590 |
+
label = dict_line['label']
|
591 |
+
text_a = dict_line['premise'].strip()
|
592 |
+
text_b = dict_line['hypothesis'].strip()
|
593 |
+
examples.append(
|
594 |
+
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)
|
595 |
+
)
|
596 |
+
return examples
|
597 |
+
"""Below is the data reader for long/short segmentation of the ConTRoL data"""
|
598 |
+
# def get_train_examples(self, data_dir):
|
599 |
+
# """See base class."""
|
600 |
+
# return self._create_examples(
|
601 |
+
# self._read_tsv(os.path.join(data_dir, "short/train.tsv")), "train")
|
602 |
+
|
603 |
+
# def get_dev_examples(self, data_dir):
|
604 |
+
# """See base class."""
|
605 |
+
# return self._create_examples(
|
606 |
+
# self._read_tsv(os.path.join(data_dir, "short/dev.tsv")), "dev")
|
607 |
+
|
608 |
+
# def get_labels(self):
|
609 |
+
# """See base class."""
|
610 |
+
# return ["c", "e", "n"]
|
611 |
+
# def _create_examples(self, lines, set_type):
|
612 |
+
# """Creates examples for the training and dev sets."""
|
613 |
+
# examples = []
|
614 |
+
# for (i, line) in enumerate(lines):
|
615 |
+
# if i == 0:
|
616 |
+
# continue
|
617 |
+
# if len(line) == 3:
|
618 |
+
# guid = "%s-%s" % (set_type, line[0])
|
619 |
+
# text_a = line[0]
|
620 |
+
# text_b = line[1]
|
621 |
+
# label = line[-1][-1].lower()
|
622 |
+
|
623 |
+
# examples.append(
|
624 |
+
# InputExample(guid=guid, text_a=text_b, text_b=text_a, label=label))
|
625 |
+
# return examples
|
626 |
+
|
627 |
+
class Qa2nliProcessor(DataProcessor):
|
628 |
+
"""Processor for the logiqa2nli data set."""
|
629 |
+
|
630 |
+
def get_example_from_tensor_dict(self, tensor_dict):
|
631 |
+
"""See base class."""
|
632 |
+
return InputExample(tensor_dict['idx'].numpy(),
|
633 |
+
tensor_dict['premise_par_new'].numpy().decode('utf-8'),
|
634 |
+
tensor_dict['hypothesis'].numpy().decode('utf-8'),
|
635 |
+
str(tensor_dict['label'].numpy()))
|
636 |
+
|
637 |
+
def get_train_examples(self, data_dir):
|
638 |
+
"""See base class."""
|
639 |
+
return self._create_examples(
|
640 |
+
self._read_txt(os.path.join(data_dir, "train.txt")), "train")
|
641 |
+
|
642 |
+
def get_dev_examples(self, data_dir):
|
643 |
+
"""See base class."""
|
644 |
+
return self._create_examples(
|
645 |
+
self._read_txt(os.path.join(data_dir, "dev.txt")), "dev")
|
646 |
+
|
647 |
+
def get_labels(self):
|
648 |
+
"""See base class."""
|
649 |
+
return ['entailed', 'not entailed']
|
650 |
+
|
651 |
+
def _read_txt(self, dir):
|
652 |
+
with open(dir, "r", encoding="utf-8") as f:
|
653 |
+
lines = []
|
654 |
+
for line in f.readlines():
|
655 |
+
if sys.version_info[0] == 2:
|
656 |
+
line = list(unicode(cell, 'utf-8') for cell in line)
|
657 |
+
lines.append(line)
|
658 |
+
return lines
|
659 |
+
|
660 |
+
def _create_examples(self, lines, set_type):
|
661 |
+
"""Creates examples for the training and dev sets."""
|
662 |
+
examples = []
|
663 |
+
for (i, line) in enumerate(lines):
|
664 |
+
dict_line = json.loads(line)
|
665 |
+
guid = "%s-%s" % (set_type, i)
|
666 |
+
label = dict_line['label']
|
667 |
+
text_a = "".join(_ for _ in dict_line['major_premise']) + " " + "".join(_ for _ in dict_line['minor_premise'])
|
668 |
+
text_a = text_a.strip()
|
669 |
+
text_b = dict_line['conclusion'].strip()
|
670 |
+
examples.append(
|
671 |
+
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)
|
672 |
+
)
|
673 |
+
return examples
|
674 |
+
|
675 |
+
class SciProcessor(DataProcessor):
|
676 |
+
"""Processor for the SciTail data set."""
|
677 |
+
|
678 |
+
def get_example_from_tensor_dict(self, tensor_dict):
|
679 |
+
"""See base class."""
|
680 |
+
return InputExample(tensor_dict['idx'].numpy(),
|
681 |
+
tensor_dict['premise'].numpy().decode('utf-8'),
|
682 |
+
tensor_dict['hypothesis'].numpy().decode('utf-8'),
|
683 |
+
str(tensor_dict['label'].numpy()))
|
684 |
+
|
685 |
+
def get_train_examples(self, data_dir):
|
686 |
+
"""See base class."""
|
687 |
+
return self._create_examples(
|
688 |
+
self._read_txt(os.path.join(data_dir, "snli_format/train.txt")), "train")
|
689 |
+
|
690 |
+
def get_dev_examples(self, data_dir):
|
691 |
+
"""See base class."""
|
692 |
+
return self._create_examples(
|
693 |
+
self._read_txt(os.path.join(data_dir, "snli_format/dev.txt")), "dev")
|
694 |
+
|
695 |
+
def get_labels(self):
|
696 |
+
"""See base class."""
|
697 |
+
return ["entailment", "neutral"]
|
698 |
+
|
699 |
+
def _read_txt(self, dir):
|
700 |
+
with open(dir, "r", encoding="utf-8") as f:
|
701 |
+
lines = []
|
702 |
+
for line in f.readlines():
|
703 |
+
if sys.version_info[0] == 2:
|
704 |
+
line = list(unicode(cell, 'utf-8') for cell in line)
|
705 |
+
lines.append(line)
|
706 |
+
return lines
|
707 |
+
|
708 |
+
def _create_examples(self, lines, set_type):
|
709 |
+
"""Creates examples for the training and dev sets."""
|
710 |
+
examples = []
|
711 |
+
for (i, line) in enumerate(lines):
|
712 |
+
dict_line = json.loads(line)
|
713 |
+
guid = "%s-%s" % (set_type, i)
|
714 |
+
label = dict_line['gold_label']
|
715 |
+
text_a = dict_line['sentence1'].strip()
|
716 |
+
text_b = dict_line['sentence2'].strip()
|
717 |
+
examples.append(
|
718 |
+
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)
|
719 |
+
)
|
720 |
+
return examples
|
721 |
+
|
722 |
+
|
723 |
+
class AnliProcessor(DataProcessor):
|
724 |
+
"""Processor for the ANLI data set."""
|
725 |
+
|
726 |
+
def get_example_from_tensor_dict(self, tensor_dict):
|
727 |
+
"""See base class."""
|
728 |
+
return InputExample(tensor_dict['idx'].numpy(),
|
729 |
+
tensor_dict['premise'].numpy().decode('utf-8'),
|
730 |
+
tensor_dict['hypothesis'].numpy().decode('utf-8'),
|
731 |
+
str(tensor_dict['label'].numpy()))
|
732 |
+
|
733 |
+
def get_train_examples(self, data_dir):
|
734 |
+
"""See base class."""
|
735 |
+
return self._create_examples(
|
736 |
+
self._read_txt(os.path.join(data_dir, "r3/train.jsonl")), "train")
|
737 |
+
|
738 |
+
def get_dev_examples(self, data_dir):
|
739 |
+
"""See base class."""
|
740 |
+
return self._create_examples(
|
741 |
+
self._read_txt(os.path.join(data_dir, "r3/dev.jsonl")), "dev")
|
742 |
+
|
743 |
+
def get_labels(self):
|
744 |
+
"""See base class."""
|
745 |
+
return ["e", "n", "c"]
|
746 |
+
|
747 |
+
def _read_txt(self, dir):
|
748 |
+
with open(dir, "r", encoding="utf-8") as f:
|
749 |
+
lines = []
|
750 |
+
for line in f.readlines():
|
751 |
+
if sys.version_info[0] == 2:
|
752 |
+
line = list(unicode(cell, 'utf-8') for cell in line)
|
753 |
+
lines.append(line)
|
754 |
+
return lines
|
755 |
+
|
756 |
+
def _create_examples(self, lines, set_type):
|
757 |
+
"""Creates examples for the training and dev sets."""
|
758 |
+
examples = []
|
759 |
+
for (i, line) in enumerate(lines):
|
760 |
+
dict_line = json.loads(line)
|
761 |
+
guid = "%s-%s" % (set_type, i)
|
762 |
+
label = dict_line['label']
|
763 |
+
text_a = dict_line['premise'].strip()
|
764 |
+
text_b = dict_line['hypothesis'].strip()
|
765 |
+
examples.append(
|
766 |
+
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)
|
767 |
+
)
|
768 |
+
return examples
|
769 |
+
|
770 |
+
|
771 |
+
class QoodProcessor(DataProcessor):
|
772 |
+
"""Processor for the QNLI-ood data set."""
|
773 |
+
|
774 |
+
def get_example_from_tensor_dict(self, tensor_dict):
|
775 |
+
"""See base class."""
|
776 |
+
return InputExample(tensor_dict['idx'].numpy(),
|
777 |
+
tensor_dict['premise'].numpy().decode('utf-8'),
|
778 |
+
tensor_dict['hypothesis'].numpy().decode('utf-8'),
|
779 |
+
str(tensor_dict['label'].numpy()))
|
780 |
+
|
781 |
+
def get_train_examples(self, data_dir):
|
782 |
+
"""See base class."""
|
783 |
+
return self._create_examples(
|
784 |
+
self._read_txt(os.path.join(data_dir, "train.txt")), "train")
|
785 |
+
|
786 |
+
def get_dev_examples(self, data_dir):
|
787 |
+
"""See base class."""
|
788 |
+
return self._create_examples(
|
789 |
+
self._read_txt(os.path.join(data_dir, "dev.txt")), "dev")
|
790 |
+
|
791 |
+
def get_labels(self):
|
792 |
+
"""See base class."""
|
793 |
+
return ["entailment", "not_entailment"]
|
794 |
+
|
795 |
+
def _read_txt(self, dir):
|
796 |
+
with open(dir, "r", encoding="utf-8") as f:
|
797 |
+
lines = []
|
798 |
+
for line in f.readlines():
|
799 |
+
if sys.version_info[0] == 2:
|
800 |
+
line = list(unicode(cell, 'utf-8') for cell in line)
|
801 |
+
lines.append(line)
|
802 |
+
return lines
|
803 |
+
|
804 |
+
def _create_examples(self, lines, set_type):
|
805 |
+
"""Creates examples for the training and dev sets."""
|
806 |
+
examples = []
|
807 |
+
for (i, line) in enumerate(lines):
|
808 |
+
dict_line = json.loads(line)
|
809 |
+
guid = "%s-%s" % (set_type, i)
|
810 |
+
label = dict_line['label']
|
811 |
+
text_a = dict_line['question'].strip()
|
812 |
+
text_b = dict_line['sentence'].strip()
|
813 |
+
examples.append(
|
814 |
+
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)
|
815 |
+
)
|
816 |
+
return examples
|
817 |
+
|
818 |
+
class MrpcProcessor(DataProcessor):
|
819 |
+
"""Processor for the MRPC data set (GLUE version). <Paraphrase>"""
|
820 |
+
|
821 |
+
def get_example_from_tensor_dict(self, tensor_dict):
|
822 |
+
"""See base class."""
|
823 |
+
return InputExample(tensor_dict['idx'].numpy(),
|
824 |
+
tensor_dict['sentence1'].numpy().decode('utf-8'),
|
825 |
+
tensor_dict['sentence2'].numpy().decode('utf-8'),
|
826 |
+
str(tensor_dict['label'].numpy()))
|
827 |
+
|
828 |
+
def get_train_examples(self, data_dir):
|
829 |
+
"""See base class."""
|
830 |
+
logger.info("LOOKING AT {}".format(os.path.join(data_dir, "short/train.tsv")))
|
831 |
+
return self._create_examples(
|
832 |
+
self._read_tsv(os.path.join(data_dir, "short/train.tsv")), "train")
|
833 |
+
|
834 |
+
def get_dev_examples(self, data_dir):
|
835 |
+
"""See base class."""
|
836 |
+
return self._create_examples(
|
837 |
+
self._read_tsv(os.path.join(data_dir, "short/dev.tsv")), "dev")
|
838 |
+
|
839 |
+
def get_labels(self):
|
840 |
+
"""See base class."""
|
841 |
+
return ["0", "1"]
|
842 |
+
|
843 |
+
def _create_examples(self, lines, set_type):
|
844 |
+
"""Creates examples for the training and dev sets."""
|
845 |
+
examples = []
|
846 |
+
for (i, line) in enumerate(lines):
|
847 |
+
if i == 0:
|
848 |
+
continue
|
849 |
+
guid = "%s-%s" % (set_type, i)
|
850 |
+
text_a = line[3]
|
851 |
+
text_b = line[4]
|
852 |
+
label = line[0]
|
853 |
+
examples.append(
|
854 |
+
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
855 |
+
return examples
|
856 |
+
|
857 |
+
|
858 |
+
|
859 |
+
try:
|
860 |
+
from scipy.stats import pearsonr, spearmanr
|
861 |
+
from sklearn.metrics import matthews_corrcoef, f1_score, confusion_matrix
|
862 |
+
|
863 |
+
_has_sklearn = True
|
864 |
+
except (AttributeError, ImportError):
|
865 |
+
_has_sklearn = False
|
866 |
+
|
867 |
+
|
868 |
+
def is_sklearn_available():
|
869 |
+
return _has_sklearn
|
870 |
+
|
871 |
+
|
872 |
+
#if _has_sklearn:
|
873 |
+
|
874 |
+
def simple_accuracy(preds, labels):
|
875 |
+
return (preds == labels).mean()
|
876 |
+
|
877 |
+
def acc_and_f1(preds, labels):
|
878 |
+
acc = simple_accuracy(preds, labels)
|
879 |
+
f1 = f1_score(y_true=labels, y_pred=preds)
|
880 |
+
return {
|
881 |
+
"acc": acc,
|
882 |
+
"f1": f1,
|
883 |
+
"acc_and_f1": (acc + f1) / 2,
|
884 |
+
}
|
885 |
+
|
886 |
+
def pearson_and_spearman(preds, labels):
|
887 |
+
pearson_corr = pearsonr(preds, labels)[0]
|
888 |
+
spearman_corr = spearmanr(preds, labels)[0]
|
889 |
+
return {
|
890 |
+
"pearson": pearson_corr,
|
891 |
+
"spearmanr": spearman_corr,
|
892 |
+
"corr": (pearson_corr + spearman_corr) / 2,
|
893 |
+
}
|
894 |
+
|
895 |
+
def compute_metrics(task_name, preds, labels):
|
896 |
+
assert len(preds) == len(labels)
|
897 |
+
if task_name == "cola":
|
898 |
+
return {"mcc": matthews_corrcoef(labels, preds)}
|
899 |
+
elif task_name == "cood":
|
900 |
+
return {"confusion matrix": confusion_matrix(preds, labels), "mcc": matthews_corrcoef(labels, preds), "f1 score": acc_and_f1(preds, labels)}
|
901 |
+
elif task_name == "sst-2":
|
902 |
+
return {"acc": simple_accuracy(preds, labels)}
|
903 |
+
elif task_name == "mrpc":
|
904 |
+
return acc_and_f1(preds, labels)
|
905 |
+
elif task_name == "sts-b":
|
906 |
+
return pearson_and_spearman(preds, labels)
|
907 |
+
elif task_name == "qqp":
|
908 |
+
return acc_and_f1(preds, labels)
|
909 |
+
elif task_name == "mnli":
|
910 |
+
return {"acc": simple_accuracy(preds, labels)}
|
911 |
+
elif task_name == "mnli-mm":
|
912 |
+
return {"acc": simple_accuracy(preds, labels)}
|
913 |
+
elif task_name == "qnli":
|
914 |
+
return {"acc": simple_accuracy(preds, labels)}
|
915 |
+
elif task_name == "rte":
|
916 |
+
return {"acc": simple_accuracy(preds, labels)}
|
917 |
+
elif task_name == "wnli":
|
918 |
+
return {"acc": simple_accuracy(preds, labels)}
|
919 |
+
elif task_name == "hans":
|
920 |
+
return {"acc": simple_accuracy(preds, labels)}
|
921 |
+
elif task_name == "scitail":
|
922 |
+
return {"acc": simple_accuracy(preds, labels)}
|
923 |
+
elif task_name == "snli":
|
924 |
+
return {"acc": simple_accuracy(preds, labels)}
|
925 |
+
elif task_name == "qa2nli":
|
926 |
+
return {"confusion matrix": confusion_matrix(preds, labels), "mcc": matthews_corrcoef(labels, preds), "f1 score": acc_and_f1(preds, labels)}
|
927 |
+
elif task_name == "anli":
|
928 |
+
return {"acc": simple_accuracy(preds, labels)}
|
929 |
+
elif task_name == "pnli":
|
930 |
+
return {"acc": simple_accuracy(preds, labels)}
|
931 |
+
elif task_name == "qood":
|
932 |
+
return {"acc": simple_accuracy(preds, labels)}
|
933 |
+
else:
|
934 |
+
raise KeyError(task_name)
|
935 |
+
|
936 |
+
def xnli_compute_metrics(task_name, preds, labels):
|
937 |
+
assert len(preds) == len(labels)
|
938 |
+
if task_name == "xnli":
|
939 |
+
return {"acc": simple_accuracy(preds, labels)}
|
940 |
+
else:
|
941 |
+
raise KeyError(task_name)
|
942 |
+
|
943 |
+
|
944 |
+
|
945 |
+
tasks_num_labels = {
|
946 |
+
"pnli": 3,
|
947 |
+
"cola": 2,
|
948 |
+
"cood": 2,
|
949 |
+
"snli": 3,
|
950 |
+
"mnli": 3,
|
951 |
+
"mrpc": 2,
|
952 |
+
"sst-2": 2,
|
953 |
+
"sts-b": 1,
|
954 |
+
"qqp": 2,
|
955 |
+
"qnli": 2,
|
956 |
+
"rte": 2,
|
957 |
+
"wnli": 2,
|
958 |
+
"qa2nli": 2,
|
959 |
+
"scitail": 2,
|
960 |
+
"anli": 3,
|
961 |
+
"qood": 2,
|
962 |
+
}
|
963 |
+
|
964 |
+
processors = {
|
965 |
+
"cola": ColaProcessor,
|
966 |
+
"cood": CoodProcessor,
|
967 |
+
"snli": SnliProcessor,
|
968 |
+
"mnli": MnliProcessor,
|
969 |
+
"mnli-mm": MnliMismatchedProcessor,
|
970 |
+
"mrpc": MrpcProcessor,
|
971 |
+
"sst-2": Sst2Processor,
|
972 |
+
"sts-b": StsbProcessor,
|
973 |
+
"qqp": QqpProcessor,
|
974 |
+
"qnli": QnliProcessor,
|
975 |
+
"rte": RteProcessor,
|
976 |
+
"wnli": WnliProcessor,
|
977 |
+
"pnli": PnliProcessor,
|
978 |
+
"qa2nli": Qa2nliProcessor,
|
979 |
+
"scitail": SciProcessor,
|
980 |
+
"anli": AnliProcessor,
|
981 |
+
"qood": QoodProcessor,
|
982 |
+
}
|
983 |
+
|
984 |
+
output_modes = {
|
985 |
+
"cola": "classification",
|
986 |
+
"cood": "classification",
|
987 |
+
"mnli": "classification",
|
988 |
+
"mnli-mm": "classification",
|
989 |
+
"mrpc": "classification",
|
990 |
+
"sst-2": "classification",
|
991 |
+
"sts-b": "regression",
|
992 |
+
"qqp": "classification",
|
993 |
+
"qnli": "classification",
|
994 |
+
"rte": "classification",
|
995 |
+
"wnli": "classification",
|
996 |
+
"pnli": "classification",
|
997 |
+
"qa2nli": "classification",
|
998 |
+
"scitail": "classification",
|
999 |
+
"snli": "classification",
|
1000 |
+
"anli": "classification",
|
1001 |
+
"qood": "classification",
|
1002 |
+
}
|
datasets/LogiQA2.0/requirements.yml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: logiqa
|
2 |
+
dependencies:
|
3 |
+
- nvidia::cudatoolkit=10.2.89
|
4 |
+
- numpy
|
5 |
+
- pillow
|
6 |
+
- pip
|
7 |
+
- python=3.6
|
8 |
+
- pytorch::pytorch=1.10.2=py3.6_cuda11.7_cudnn8.0.5_0
|
9 |
+
- scipy
|
10 |
+
- tqdm
|
11 |
+
- scikit-learn
|
12 |
+
- tensorboard
|
13 |
+
- tensorboardX
|
14 |
+
- pip:
|
15 |
+
- transformers==2.4.1
|
16 |
+
- nltk
|
17 |
+
- wandb
|