PANH commited on
Commit
ffca110
1 Parent(s): 2d8296f

Upload 15 files

Browse files
alignscore/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 yuh-zha
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
alignscore/README.md ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AlignScore
2
+ This is the repository for AlignScore, a metric for automatic factual consistency evaluation of text pairs introduced in \
3
+ [AlignScore: Evaluating Factual Consistency with a Unified Alignment Function](https://arxiv.org/abs/2305.16739) \
4
+ Yuheng Zha, Yichi Yang, Ruichen Li and Zhiting Hu \
5
+ ACL 2023
6
+
7
+ **Factual consistency evaluation** is to evaluate whether all the information in **b** is contained in **a** (**b** does not contradict **a**). For example, this is a factual inconsistent case:
8
+
9
+ * **a**: Children smiling and waving at camera.
10
+ * **b**: The kids are frowning.
11
+
12
+ And this is a factual consistent case:
13
+
14
+ * **a**: The NBA season of 1975 -- 76 was the 30th season of the National Basketball Association.
15
+ * **b**: The 1975 -- 76 season of the National Basketball Association was the 30th season of the NBA.
16
+
17
+ Factual consistency evaluation can be applied to many tasks like Summarization, Paraphrase and Dialog. For example, large language models often generate hallucinations when summarizing documents. We wonder if the generated text is factual consistent to its original context.
18
+
19
+ # Leaderboards
20
+ We introduce two leaderboards that compare AlignScore with similar-sized metrics and LLM-based metrics, respectively.
21
+ ## Leaderboard --- compare with similar-sized metrics
22
+
23
+ We list the performance of AlignScore as well as other metrics on the SummaC (includes 6 datasets) and TRUE (includes 11 datasets) benchmarks, as well as other popular factual consistency datasets (include 6 datasets).
24
+
25
+ | Rank | Metrics | SummaC* | TRUE** | Other Datasets*** | Average**** | Paper | Code |
26
+ | ---- | :--------------- | :-----: | :----: | :------------: | :-----: | :---: | :--: |
27
+ | 1 | **AlignScore-large** | 88.6 | 83.8 | 49.3 | 73.9 | [:page\_facing\_up:(Zha et al. 2023)](https://arxiv.org/pdf/2305.16739.pdf) | [:octocat:](https://github.com/yuh-zha/AlignScore) |
28
+ | 2 | **AlignScore-base** | 87.4 | 82.5 | 44.9 | 71.6 | [:page\_facing\_up:(Zha et al. 2023)](https://arxiv.org/pdf/2305.16739.pdf) | [:octocat:](https://github.com/yuh-zha/AlignScore) |
29
+ | 3 | QAFactEval | 83.8 | 79.4 | 42.4 | 68.5 | [:page\_facing\_up:(Fabbri et al. 2022)](https://arxiv.org/abs/2112.08542) | [:octocat:](https://github.com/salesforce/QAFactEval) |
30
+ | 4 | UniEval | 84.6 | 78.0 | 41.5 | 68.0 | [:page\_facing\_up:(Zhong et al. 2022)](https://arxiv.org/abs/2210.07197) | [:octocat:](https://github.com/maszhongming/UniEval) |
31
+ | 5 | SummaC-CONV | 81.0 | 78.7 | 34.2 | 64.6 | [:page\_facing\_up:(Laban et al. 2022)](https://arxiv.org/abs/2111.09525) | [:octocat:](https://github.com/tingofurro/summac) |
32
+ | 6 | BARTScore | 80.9 | 73.4 | 34.8 | 63.0 | [:page\_facing\_up:(Yuan et al. 2022)](https://arxiv.org/abs/2106.11520) | [:octocat:](https://github.com/neulab/BARTScore) |
33
+ | 7 | CTC | 81.2 | 72.4 | 35.3 | 63.0 | [:page\_facing\_up:(Deng et al. 2022)](https://arxiv.org/abs/2109.06379) | [:octocat:](https://github.com/tanyuqian/ctc-gen-eval) |
34
+ | 8 | SummaC-ZS | 79.0 | 78.2 | 30.4 | 62.5 | [:page\_facing\_up:(Laban et al. 2022)](https://arxiv.org/abs/2111.09525) | [:octocat:](https://github.com/tingofurro/summac) |
35
+ | 9 | ROUGE-2 | 78.1 | 72.4 | 27.9 | 59.5 | [:page\_facing\_up:(Lin 2004)](https://aclanthology.org/W04-1013/) | [:octocat:](https://github.com/pltrdy/rouge) |
36
+ | 10 | ROUGE-1 | 77.4 | 72.0 | 28.6 | 59.3 | [:page\_facing\_up:(Lin 2004)](https://aclanthology.org/W04-1013/) | [:octocat:](https://github.com/pltrdy/rouge) |
37
+ | 11 | ROUGE-L | 77.3 | 71.8 | 28.3 | 59.1 | [:page\_facing\_up:(Lin 2004)](https://aclanthology.org/W04-1013/) | [:octocat:](https://github.com/pltrdy/rouge) |
38
+ | 12 | QuestEval | 72.5 | 71.4 | 25.0 | 56.3 | [:page\_facing\_up:(Scialom et al. 2021)](https://arxiv.org/abs/2103.12693) | [:octocat:](https://github.com/ThomasScialom/QuestEval) |
39
+ | 13 | BLEU | 76.3 | 67.3 | 24.6 | 56.1 | [:page\_facing\_up:(Papineni et al. 2002)](https://aclanthology.org/P02-1040/) | [:octocat:](https://www.nltk.org/_modules/nltk/translate/bleu_score.html) |
40
+ | 14 | DAE | 66.8 | 65.7 | 35.1 | 55.8 | [:page\_facing\_up:(Goyal and Durrett 2020)](https://aclanthology.org/2020.findings-emnlp.322/) | [:octocat:](https://github.com/tagoyal/dae-factuality) |
41
+ | 15 | BLEURT | 69.2 | 71.9 | 24.9 | 55.4 | [:page\_facing\_up:(Sellam et al. 2020)](https://arxiv.org/abs/2004.04696) | [:octocat:](https://github.com/google-research/bleurt) |
42
+ | 16 | BERTScore | 72.1 | 68.6 | 21.9 | 54.2 | [:page\_facing\_up:(Zhang et al. 2020)](https://arxiv.org/abs/1904.09675) | [:octocat:](https://github.com/Tiiiger/bert_score) |
43
+ | 17 | SimCSE | 67.4 | 70.3 | 23.8 | 53.8 | [:page\_facing\_up:(Gao et al. 2021)](https://arxiv.org/abs/2104.08821) | [:octocat:](https://github.com/princeton-nlp/SimCSE) |
44
+ | 18 | FactCC | 68.8 | 62.7 | 21.2 | 50.9 | [:page\_facing\_up:(Kryscinski et al. 2020)](https://arxiv.org/abs/1910.12840) | [:octocat:](https://github.com/salesforce/factCC) |
45
+ | 19 | BLANC | 65.1 | 64.0 | 14.4 | 47.8 | [:page\_facing\_up:(Vasilyev et al. 2020)](https://arxiv.org/abs/2002.09836) | [:octocat:](https://github.com/PrimerAI/blanc) |
46
+ | 20 | NER-Overlap | 60.4 | 59.3 | 18.9 | 46.2 | [:page\_facing\_up:(Laban et al. 2022)](https://arxiv.org/abs/2111.09525) | [:octocat:](https://github.com/tingofurro/summac) |
47
+ | 21 | MNLI | 47.9 | 60.4 | 3.1 | 37.2 | [:page\_facing\_up:(Williams et al. 2018)](https://arxiv.org/abs/1704.05426) | [:octocat:](https://github.com/nyu-mll/multiNLI) |
48
+ | 22 | FEQA | 48.3 | 52.2 | -1.9 | 32.9 | [:page\_facing\_up:(Durmus et al. 2020)](https://arxiv.org/abs/2005.03754) | [:octocat:](https://github.com/esdurmus/feqa) |
49
+
50
+ \* SummaC Benchmark: [\[Paper\]](https://arxiv.org/abs/2111.09525) \| [\[Github\]](https://github.com/tingofurro/summac). We report AUC ROC on the SummaC benchmark.
51
+
52
+ ** TRUE Benchmark: [\[Paper\]](https://arxiv.org/abs/2204.04991) \| [\[Github\]](https://github.com/google-research/true). We report AUC ROC on the TRUE benchmark.
53
+
54
+ *** Besides the SummaC and TRUE benchmarks, we also include other popular factual consistency evaluation datasets: [XSumFaith](https://doi.org/10.18653/v1/2020.acl-main.173), [SummEval](https://doi.org/10.1162/tacl_a_00373), [QAGS-XSum](https://doi.org/10.18653/v1/2020.acl-main.450), [QAGS-CNNDM](https://doi.org/10.18653/v1/2020.acl-main.450), [FRANK-XSum](https://doi.org/10.18653/v1/2021.naacl-main.383), [FRANK-CNNDM](https://doi.org/10.18653/v1/2021.naacl-main.383) and [SamSum](https://doi.org/10.18653/v1/D19-5409). We compute the Spearman Correlation coefficients between the human annotated score and the metric predicted score, following common practice.
55
+
56
+ **** To rank these metrics, we simply compute the average performance of SummaC, TRUE and Other Datasets.
57
+
58
+ ## Leaderboard --- compare with LLM-based metrics
59
+
60
+ We also show the performance comparison with large-language-model based metrics below. The rank is based on the average Spearman Correlation coefficients on SummEval, QAGS-XSum and QAGS-CNNDM datasets.*
61
+
62
+ | Rank | Metrics | Base Model | SummEval | QAGS-XSUM | QAGS-CNNDM | Average | Paper | Code |
63
+ | :--- | :-------------------- | :----------------------------------------------------------- | :------: | :-------: | :--------: | :--: | :----------------------------------------------------------: | :----------------------------------------------------------: |
64
+ | 1 | **AlignScore-large** | RoBERTa-l (355M) | 46.6 | 57.2 | 73.9 | 59.3 | [:page\_facing\_up:(Zha et al. 2023)](https://arxiv.org/pdf/2305.16739.pdf) | [:octocat:](https://github.com/yuh-zha/AlignScore) |
65
+ | 2 | G-EVAL-4 | GPT4 | 50.7 | 53.7 | 68.5 | 57.6 | [:page\_facing\_up:(Liu et al. 2023)](https://arxiv.org/pdf/2303.16634.pdf) | [:octocat:](https://github.com/nlpyang/geval) |
66
+ | 3 | **AlignScore-base** | RoBERTa-b (125M) | 43.4 | 51.9 | 69.0 | 54.8 | [:page\_facing\_up:(Zha et al. 2023)](https://arxiv.org/pdf/2305.16739.pdf) | [:octocat:](https://github.com/yuh-zha/AlignScore) |
67
+ | 4 | FActScore (modified)** | GPT3.5-d03 + GPT3.5-turbo | 52.6 | 51.2 | 57.6 | 53.8 | [:page\_facing\_up:(Min et al. 2023)](https://arxiv.org/pdf/2305.14251.pdf) | [:octocat:](https://github.com/shmsw25/FActScore)* |
68
+ | 5 | ChatGPT (Chen et al. 2023) | GPT3.5-turbo | 42.7 | 53.3 | 52.7 | 49.6 | [:page\_facing\_up:(Yi Chen et al. 2023)](https://arxiv.org/pdf/2305.14069.pdf) | [:octocat:](https://github.com/SJTU-LIT/llmeval_sum_factual) |
69
+ | 6 | GPTScore | GPT3.5-d03 | 45.9 | 22.7 | 64.4 | 44.3 | [:page\_facing\_up:(Fu et al. 2023)](https://arxiv.org/pdf/2302.04166.pdf) | [:octocat:](https://github.com/jinlanfu/GPTScore) |
70
+ | 7 | GPTScore | GPT3-d01 | 46.1 | 22.3 | 63.9 | 44.1 | [:page\_facing\_up:(Fu et al. 2023)](https://arxiv.org/pdf/2302.04166.pdf) | [:octocat:](https://github.com/jinlanfu/GPTScore) |
71
+ | 8 | G-EVAL-3.5 | GPT3.5-d03 | 38.6 | 40.6 | 51.6 | 43.6 | [:page\_facing\_up:(Liu et al. 2023)](https://arxiv.org/pdf/2303.16634.pdf) | [:octocat:](https://github.com/nlpyang/geval) |
72
+ | 9 | ChatGPT (Gao et al. 2023) | GPT3.5-turbo | 41.6 | 30.4 | 48.9 | 40.3 | [:page\_facing\_up:(Gao et al. 2023)](https://arxiv.org/pdf/2304.02554.pdf) | - |
73
+
74
+ \* We notice that evaluating factual consistency using GPT-based models is expensive and slow. And we need human labor to interpret the response (generally text) to numerical scores. Therefore, we only benchmark on 3 popular factual consistency evaluation datasets: SummEval, QAGS-XSum and QAGS-CNNDM.
75
+
76
+ *\* We use a modified version of FActScore `retrieval+ChatGPT` where we skip the retrieval stage and use the context documents in SummEval, QAGS-XSUM, and QAGS-CNNDM directly. As samples in theses datasets do not have "topics", we make a small modification to the original FActScore prompt and do not mention `topic` when not available. See [our fork of FActScore](https://github.com/yichi-yang/FActScore) for more details.
77
+
78
+ # Introduction
79
+
80
+ The AlignScore metric is an automatic factual consistency evaluation metric built with the following parts:
81
+
82
+ * Unified information alignment function between two arbitrary text pieces: It is trained on 4.7 million training examples from 7 well-established tasks (NLI, QA, paraphrasing, fact verification, information retrieval, semantic textual similarity and summarization)
83
+
84
+ * The chunk-sentence splitting method: The input context is splitted into chunks (contains roughly 350 tokens each) and the input claim is splitted into sentences. With the help of the alignment function, it's possible to know the alignment score between chunks and sentences. We pick the maximum alignment score for each sentence and then average these scores to get the example-level factual consistency score (AlignScore).
85
+
86
+ <div align=center>
87
+ <img src="./alignscore_fig.png" alt="alignscore_fig" width="500px" />
88
+ </div>
89
+
90
+ We assume there are two inputs to the metric, namely `context` and `claim`. And the metric evaluates whether the `claim` is factual consistent with the `context`. The output of AlignScore is a single numerical value, which shows the degree of the factual consistency.
91
+ # Installation
92
+
93
+ Our models are trained and evaluated using PyTorch 1.12.1. We recommend using this version to reproduce the results.
94
+
95
+ 1. Please first install the right version of PyTorch before installing `alignscore`.
96
+ 2. You can install `alignscore` by cloning this repository and `pip install .`.
97
+ 3. After installing `alignscore`, please use `python -m spacy download en_core_web_sm` to install the required spaCy model (we use `spaCy` for sentenization).
98
+
99
+ # Evaluating Factual Consistency
100
+ To evaluate the factual consistency of the `claim` w.r.t. the `context`, simply use the score method of `AlignScore`.
101
+ ```python
102
+ from alignscore import AlignScore
103
+
104
+ scorer = AlignScore(model='roberta-base', batch_size=32, device='cuda:0', ckpt_path='/path/to/checkpoint', evaluation_mode='nli_sp')
105
+ score = scorer.score(contexts=['hello world.'], claims=['hello world.'])
106
+ ```
107
+ `model`: the backbone model of the metric. Now, we only provide the metric trained on RoBERTa
108
+
109
+ `batch_size`: the batch size of the inference
110
+
111
+ `device`: which device to run the metric
112
+
113
+ `ckpt_path`: the path to the checkpoint
114
+
115
+ `evaluation_mode`: choose from `'nli_sp', 'nli', 'bin_sp', 'bin'`. `nli` and `bin` refer to the 3-way and binary classficiation head, respectively. `sp` indicates if the chunk-sentence splitting method is used. `nli_sp` is the default setting of AlignScore
116
+
117
+
118
+ # Checkpoints
119
+ We provide two versions of the AlignScore checkpoints: `AlignScore-base` and `AlignScore-large`. The `-base` model is based on RoBERTa-base and has 125M parameters. The `-large` model is based on RoBERTa-large and has 355M parameters.
120
+
121
+ **AlignScore-base**:
122
+ https://huggingface.co/yzha/AlignScore/resolve/main/AlignScore-base.ckpt
123
+
124
+ **AlignScore-large**:
125
+ https://huggingface.co/yzha/AlignScore/resolve/main/AlignScore-large.ckpt
126
+
127
+ # Training
128
+ You can use the above checkpoints directly for factual consistency evaluation. However, if you wish to train an alignment model from scratch / on your own data, use `train.py`.
129
+ ```python
130
+ python train.py --seed 2022 --batch-size 32 \
131
+ --num-epoch 3 --devices 0 1 2 3 \
132
+ --model-name roberta-large -- ckpt-save-path ./ckpt/ \
133
+ --data-path ./data/training_sets/ \
134
+ --max-samples-per-dataset 500000
135
+ ```
136
+
137
+ `--seed`: the random seed for initialization
138
+
139
+ `--batch-size`: the batch size for training
140
+
141
+ `--num-epoch`: training epochs
142
+
143
+ `--devices`: which devices to train the metric, a list of GPU ids
144
+
145
+ `--model-name`: the backbone model name of the metric, default RoBERTa-large
146
+
147
+ `--ckpt-save-path`: the path to save the checkpoint
148
+
149
+ `--training-datasets`: the names of the training datasets
150
+
151
+ `--data-path`: the path to the training datasets
152
+
153
+ `--max-samples-per-dataset`: the maximum number of samples from a dataset
154
+
155
+ # Benchmarking
156
+ Our benchmark includes the TRUE and SummaC benchmark as well as several popular factual consistency evaluation datasets.
157
+
158
+ To run the benchmark, a few additional dependencies are required and can be installed with `pip install -r requirements.txt`.
159
+ Additionally, some depedencies are not available as packages and need to be downloaded manually (please see `python benchmark.py --help` for instructions).
160
+
161
+ Note installing `summac` may cause dependency conflicts with `alignscore`. Please reinstall `alignscore` to force the correct dependency versions.
162
+
163
+ The relevant arguments for evaluating AlignScore are:
164
+
165
+ `--alignscore`: evaluation the AlignScore metric
166
+
167
+ `--alignscore-model`: the name of the backbone model (either 'roberta-base' or 'roberta-large')
168
+
169
+ `--alignscore-ckpt`: the path to the saved checkpoint
170
+
171
+ `--alignscore-eval-mode`: the evaluation mode, defaults to `nli_sp`
172
+
173
+ `--device`: which device to run the metric, defaults to `cuda:0`
174
+
175
+ `--tasks`: which tasks to benchmark, e.g., SummEval, QAGS-CNNDM, ...
176
+
177
+ For the baselines, please see `python benchmark.py --help` for details.
178
+
179
+ ## Training datasets download
180
+ Most datasets are downloadable from Huggingface (refer to [`generate_training_data.py`](https://github.com/yuh-zha/AlignScore/blob/main/generate_training_data.py)). Some datasets that needed to be imported manually are now also avaialable on Huggingface (See [Issue](https://github.com/yuh-zha/AlignScore/issues/6#issuecomment-1695448614)).
181
+
182
+ ## Evaluation datasets download
183
+
184
+ The following table shows the links to the evaluation datasets mentioned in the paper
185
+
186
+ | Benchmark/Dataset | Link |
187
+ | ----------------- | ------------------------------------------------------------ |
188
+ | SummaC | https://github.com/tingofurro/summac |
189
+ | TRUE | https://github.com/google-research/true |
190
+ | XSumFaith | https://github.com/google-research-datasets/xsum_hallucination_annotations |
191
+ | SummEval | https://github.com/tanyuqian/ctc-gen-eval/blob/master/train/data/summeval.json |
192
+ | QAGS-Xsum | https://github.com/tanyuqian/ctc-gen-eval/blob/master/train/data/qags_xsum.json |
193
+ | QAGS-CNNDM | https://github.com/tanyuqian/ctc-gen-eval/blob/master/train/data/qags_cnndm.json |
194
+ | FRANK-XSum | https://github.com/artidoro/frank |
195
+ | FRANK-CNNDM | https://github.com/artidoro/frank |
196
+ | SamSum | https://github.com/skgabriel/GoFigure/blob/main/human_eval/samsum.jsonl |
197
+
198
+ # Citation
199
+ If you find the metric and this repo helpful, please consider cite:
200
+ ```
201
+ @inproceedings{zha-etal-2023-alignscore,
202
+ title = "{A}lign{S}core: Evaluating Factual Consistency with A Unified Alignment Function",
203
+ author = "Zha, Yuheng and
204
+ Yang, Yichi and
205
+ Li, Ruichen and
206
+ Hu, Zhiting",
207
+ booktitle = "Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)",
208
+ month = jul,
209
+ year = "2023",
210
+ address = "Toronto, Canada",
211
+ publisher = "Association for Computational Linguistics",
212
+ url = "https://aclanthology.org/2023.acl-long.634",
213
+ pages = "11328--11348",
214
+ abstract = "Many text generation applications require the generated text to be factually consistent with input information. Automatic evaluation of factual consistency is challenging. Previous work has developed various metrics that often depend on specific functions, such as natural language inference (NLI) or question answering (QA), trained on limited data. Those metrics thus can hardly assess diverse factual inconsistencies (e.g., contradictions, hallucinations) that occur in varying inputs/outputs (e.g., sentences, documents) from different tasks. In this paper, we propose AlignScore, a new holistic metric that applies to a variety of factual inconsistency scenarios as above. AlignScore is based on a general function of information alignment between two arbitrary text pieces. Crucially, we develop a unified training framework of the alignment function by integrating a large diversity of data sources, resulting in 4.7M training examples from 7 well-established tasks (NLI, QA, paraphrasing, fact verification, information retrieval, semantic similarity, and summarization). We conduct extensive experiments on large-scale benchmarks including 22 evaluation datasets, where 19 of the datasets were never seen in the alignment training. AlignScore achieves substantial improvement over a wide range of previous metrics. Moreover, AlignScore (355M parameters) matches or even outperforms metrics based on ChatGPT and GPT-4 that are orders of magnitude larger.",
215
+ }
216
+ ```
alignscore/alignscore_fig.png ADDED
alignscore/baselines.py ADDED
@@ -0,0 +1,704 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from logging import warning
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ import spacy
7
+ from sklearn.metrics.pairwise import cosine_similarity
8
+ from nltk.tokenize import sent_tokenize
9
+ import json
10
+
11
+ class CTCScorer():
12
+ def __init__(self, model_type) -> None:
13
+ self.model_type = model_type
14
+ import nltk
15
+ nltk.download('stopwords')
16
+
17
+ from ctc_score import StyleTransferScorer, SummarizationScorer, DialogScorer
18
+ if model_type == 'D-cnndm':
19
+ self.scorer = SummarizationScorer(align='D-cnndm')
20
+ elif model_type =='E-roberta':
21
+ self.scorer = SummarizationScorer(align='E-roberta')
22
+ elif model_type == 'R-cnndm':
23
+ self.scorer = SummarizationScorer(align='R-cnndm')
24
+ def score(self, premise: list, hypo: list):
25
+ assert len(premise) == len(hypo), "Premise and hypothesis should have the same length"
26
+
27
+ output_scores = []
28
+ for one_pre, one_hypo in tqdm(zip(premise, hypo), total=len(premise), desc="Evaluating by ctc"):
29
+ score_for_this_example = self.scorer.score(doc=one_pre, refs=[], hypo=one_hypo, aspect='consistency')
30
+ if score_for_this_example is not None:
31
+ output_scores.append(score_for_this_example)
32
+ else:
33
+ output_scores.append(1e-8)
34
+ output = None, torch.tensor(output_scores), None
35
+
36
+ return output
37
+
38
+ class SimCSEScorer():
39
+ def __init__(self, model_type, device) -> None:
40
+ self.model_type = model_type
41
+ self.device = device
42
+ from transformers import AutoModel, AutoTokenizer
43
+
44
+ # refer to the model list on https://github.com/princeton-nlp/SimCSE for the list of models
45
+ self.tokenizer = AutoTokenizer.from_pretrained(model_type)
46
+ self.model = AutoModel.from_pretrained(model_type).to(self.device)
47
+ self.spacy = spacy.load('en_core_web_sm')
48
+
49
+ self.batch_size = 64
50
+
51
+ def score(self, premise: list, hypo: list):
52
+ assert len(premise) == len(hypo)
53
+
54
+ output_scores = []
55
+ premise_sents = []
56
+ premise_index = [0]
57
+ hypo_sents = []
58
+ hypo_index = [0]
59
+
60
+ for one_pre, one_hypo in tqdm(zip(premise, hypo), desc="Sentenizing", total=len(premise)):
61
+ premise_sent = sent_tokenize(one_pre) #[each.text for each in self.spacy(one_pre).sents]
62
+ hypo_sent = sent_tokenize(one_hypo) #[each.text for each in self.spacy(one_hypo).sents]
63
+ premise_sents.extend(premise_sent)
64
+ premise_index.append(len(premise_sents))
65
+
66
+ hypo_sents.extend(hypo_sent)
67
+ hypo_index.append(len(hypo_sents))
68
+
69
+ all_sents = premise_sents + hypo_sents
70
+ embeddings = []
71
+ with torch.no_grad():
72
+ for batch in tqdm(self.chunks(all_sents, self.batch_size), total=int(len(all_sents)/self.batch_size), desc="Evaluating by SimCSE"):
73
+ inputs = self.tokenizer(batch, padding=True, truncation=True, return_tensors="pt").to(self.device)
74
+ embeddings.append(self.model(**inputs, output_hidden_states=True, return_dict=True).pooler_output)
75
+ embeddings = torch.cat(embeddings)
76
+
77
+ assert len(premise_index) == len(hypo_index)
78
+ for i in range(len(premise_index)-1):
79
+ premise_embeddings = embeddings[premise_index[i]: premise_index[i+1]]
80
+ hypo_embeddings = embeddings[len(premise_sents)+hypo_index[i]:len(premise_sents)+hypo_index[i+1]]
81
+ cos_sim = cosine_similarity(premise_embeddings.cpu(), hypo_embeddings.cpu())
82
+ score_p = cos_sim.max(axis=0).mean()
83
+ score_r = cos_sim.max(axis=1).mean()
84
+ score_f = 2 * score_p * score_r / (score_p + score_r)
85
+ output_scores.append(score_f)
86
+
87
+ return torch.Tensor(output_scores), torch.Tensor(output_scores), None
88
+
89
+ def chunks(self, lst, n):
90
+ """Yield successive n-sized chunks from lst."""
91
+ for i in range(0, len(lst), n):
92
+ yield lst[i:i + n]
93
+
94
+ class BleurtScorer():
95
+ def __init__(self, checkpoint) -> None:
96
+ self.checkpoint = checkpoint
97
+
98
+ from bleurt import score
99
+ # BLEURT-20 can also be switched to other checkpoints to improve time
100
+ # No avaliable api to specify cuda number
101
+ self.model = score.BleurtScorer(self.checkpoint)
102
+
103
+ def scorer(self, premise:list, hypo: list):
104
+ assert len(premise) == len(hypo)
105
+
106
+ output_scores = self.model.score(references=premise, candidates=hypo, batch_size=8)
107
+ output_scores = [s for s in output_scores]
108
+ return torch.Tensor(output_scores), torch.Tensor(output_scores), torch.Tensor(output_scores)
109
+
110
+ class BertScoreScorer():
111
+ def __init__(self, model_type, metric, device, batch_size) -> None:
112
+ self.model_type = model_type
113
+ self.device = device
114
+ self.metric = metric
115
+ self.batch_size = batch_size
116
+
117
+ from bert_score import score
118
+ self.model = score
119
+
120
+ def scorer(self, premise: list, hypo: list):
121
+ assert len(premise) == len(hypo)
122
+
123
+ precision, recall, f1 = self.model(premise, hypo, model_type=self.model_type, lang='en', rescale_with_baseline=True, verbose=True, device=self.device, batch_size=self.batch_size)
124
+
125
+ f1 = [f for f in f1]
126
+ precision = [p for p in precision]
127
+ recall = [r for r in recall]
128
+
129
+ if self.metric == 'f1':
130
+ return torch.Tensor(f1), torch.Tensor(f1), None
131
+ elif self.metric == 'precision':
132
+ return torch.Tensor(precision), torch.Tensor(precision), None
133
+ elif self.metric == 'recall':
134
+ return torch.Tensor(recall), torch.Tensor(recall), None
135
+ else:
136
+ ValueError("metric type not in f1, precision or recall.")
137
+
138
+ class BartScoreScorer():
139
+ def __init__(self, checkpoint, device) -> None:
140
+ self.checkpoint = checkpoint
141
+ self.device = device
142
+ import os, sys
143
+ sys.path.append('baselines/BARTScore')
144
+ from bart_score import BARTScorer
145
+ self.model = BARTScorer(device=self.device, checkpoint=self.checkpoint)
146
+
147
+ def scorer(self, premise: list, hypo: list):
148
+ assert len(premise) == len(hypo)
149
+
150
+ output_scores = self.model.score(premise, hypo, batch_size=4)
151
+ normed_score = torch.exp(torch.Tensor(output_scores))
152
+
153
+ return normed_score, normed_score, normed_score
154
+
155
+ ### Below are baselines in SummaC
156
+ ### MNLI, NER, FactCC, DAE, FEQA, QuestEval, SummaC-ZS, SummaC-Conv
157
+ class MNLIScorer():
158
+ def __init__(self, model="roberta-large-mnli", device='cuda:0', batch_size=32) -> None:
159
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
160
+ self.tokenizer = AutoTokenizer.from_pretrained(model)
161
+ self.model = AutoModelForSequenceClassification.from_pretrained(model).to(device)
162
+ self.device = device
163
+ self.softmax = nn.Softmax(dim=-1)
164
+ self.batch_size = batch_size
165
+
166
+ def scorer(self, premise: list, hypo: list):
167
+ if isinstance(premise, str) and isinstance(hypo, str):
168
+ premise = [premise]
169
+ hypo = [hypo]
170
+
171
+ batch = self.batch_tokenize(premise, hypo)
172
+ output_score_tri = []
173
+
174
+ for mini_batch in tqdm(batch, desc="Evaluating MNLI"):
175
+ # for mini_batch in batch:
176
+ mini_batch = mini_batch.to(self.device)
177
+ with torch.no_grad():
178
+ model_output = self.model(**mini_batch)
179
+ model_output_tri = model_output.logits
180
+ model_output_tri = self.softmax(model_output_tri).cpu()
181
+
182
+ output_score_tri.append(model_output_tri[:,2])
183
+
184
+ output_score_tri = torch.cat(output_score_tri)
185
+
186
+ return output_score_tri, output_score_tri, output_score_tri
187
+
188
+ def batch_tokenize(self, premise, hypo):
189
+ """
190
+ input premise and hypos are lists
191
+ """
192
+ assert isinstance(premise, list) and isinstance(hypo, list)
193
+ assert len(premise) == len(hypo), "premise and hypo should be in the same length."
194
+
195
+ batch = []
196
+ for mini_batch_pre, mini_batch_hypo in zip(self.chunks(premise, self.batch_size), self.chunks(hypo, self.batch_size)):
197
+ try:
198
+ mini_batch = self.tokenizer(mini_batch_pre, mini_batch_hypo, truncation='only_first', padding='max_length', max_length=self.tokenizer.model_max_length, return_tensors='pt')
199
+ except:
200
+ warning('text_b too long...')
201
+ mini_batch = self.tokenizer(mini_batch_pre, mini_batch_hypo, truncation=True, padding='max_length', max_length=self.tokenizer.model_max_length, return_tensors='pt')
202
+ batch.append(mini_batch)
203
+
204
+ return batch
205
+
206
+ def chunks(self, lst, n):
207
+ """Yield successive n-sized chunks from lst."""
208
+ for i in range(0, len(lst), n):
209
+ yield lst[i:i + n]
210
+
211
+ class NERScorer():
212
+ def __init__(self) -> None:
213
+ import os, sys
214
+ sys.path.append('baselines/summac/summac')
215
+ from model_guardrails import NERInaccuracyPenalty
216
+ self.ner = NERInaccuracyPenalty()
217
+
218
+ def scorer(self, premise, hypo):
219
+ score_return = self.ner.score(premise, hypo)['scores']
220
+ oppo_score = [float(not each) for each in score_return]
221
+
222
+ tensor_score = torch.tensor(oppo_score)
223
+
224
+ return tensor_score, tensor_score, tensor_score
225
+ class UniEvalScorer():
226
+ def __init__(self, task='fact', device='cuda:0') -> None:
227
+ import os, sys
228
+ sys.path.append('baselines/UniEval')
229
+ from metric.evaluator import get_evaluator
230
+
231
+ self.evaluator = get_evaluator(task, device=device)
232
+
233
+ def scorer(self, premise, hypo):
234
+ from utils import convert_to_json
235
+ # Prepare data for pre-trained evaluators
236
+ data = convert_to_json(output_list=hypo, src_list=premise)
237
+ # Initialize evaluator for a specific task
238
+
239
+ # Get factual consistency scores
240
+ eval_scores = self.evaluator.evaluate(data, print_result=True)
241
+ score_list = [each['consistency'] for each in eval_scores]
242
+
243
+ return torch.tensor(score_list), torch.tensor(score_list), torch.tensor(score_list)
244
+
245
+ class FEQAScorer():
246
+ def __init__(self) -> None:
247
+ import os, sys
248
+ sys.path.append('baselines/feqa')
249
+ import benepar
250
+ import nltk
251
+
252
+ benepar.download('benepar_en3')
253
+ nltk.download('stopwords')
254
+
255
+ from feqa import FEQA
256
+ self.feqa_model = FEQA(squad_dir=os.path.abspath('baselines/feqa/qa_models/squad1.0'), bart_qa_dir=os.path.abspath('baselines/feqa/bart_qg/checkpoints/'), use_gpu=True)
257
+
258
+ def scorer(self, premise, hypo):
259
+ eval_score = self.feqa_model.compute_score(premise, hypo, aggregate=False)
260
+
261
+ return torch.tensor(eval_score), torch.tensor(eval_score), torch.tensor(eval_score)
262
+
263
+
264
+ class QuestEvalScorer():
265
+ def __init__(self) -> None:
266
+ import os, sys
267
+ sys.path.append('baselines/QuestEval')
268
+ from questeval.questeval_metric import QuestEval
269
+ self.questeval = QuestEval(no_cuda=False)
270
+
271
+ def scorer(self, premise, hypo):
272
+ score = self.questeval.corpus_questeval(
273
+ hypothesis=hypo,
274
+ sources=premise
275
+ )
276
+ final_score = score['ex_level_scores']
277
+
278
+ return torch.tensor(final_score), torch.tensor(final_score), torch.tensor(final_score)
279
+
280
+ class QAFactEvalScorer():
281
+ def __init__(self, model_folder, device='cuda:0') -> None:
282
+ import os, sys
283
+ sys.path.append('baselines/QAFactEval')
284
+ sys.path.append(os.path.abspath('baselines/qaeval/'))
285
+ from qafacteval import QAFactEval
286
+ kwargs = {"cuda_device": int(device.split(':')[-1]), "use_lerc_quip": True, \
287
+ "verbose": True, "generation_batch_size": 32, \
288
+ "answering_batch_size": 32, "lerc_batch_size": 8}
289
+
290
+ self.metric = QAFactEval(
291
+ lerc_quip_path=f"{model_folder}/quip-512-mocha",
292
+ generation_model_path=f"{model_folder}/generation/model.tar.gz",
293
+ answering_model_dir=f"{model_folder}/answering",
294
+ lerc_model_path=f"{model_folder}/lerc/model.tar.gz",
295
+ lerc_pretrained_model_path=f"{model_folder}/lerc/pretraining.tar.gz",
296
+ **kwargs
297
+ )
298
+ def scorer(self, premise, hypo):
299
+ results = self.metric.score_batch_qafacteval(premise, [[each] for each in hypo], return_qa_pairs=True)
300
+ score = [result[0]['qa-eval']['lerc_quip'] for result in results]
301
+ return torch.tensor(score), torch.tensor(score), torch.tensor(score)
302
+
303
+ class MoverScorer():
304
+ def __init__(self) -> None:
305
+ pass
306
+
307
+ class BERTScoreFFCIScorer():
308
+ def __init__(self) -> None:
309
+ pass
310
+
311
+ class DAEScorer():
312
+ def __init__(self, model_dir, device=0) -> None:
313
+ import os, sys
314
+ sys.path.insert(0, "baselines/factuality-datasets/")
315
+ from evaluate_generated_outputs import daefact
316
+ self.dae = daefact(model_dir, model_type='electra_dae', gpu_device=device)
317
+
318
+ def scorer(self, premise, hypo):
319
+ return_score = torch.tensor(self.dae.score_multi_doc(premise, hypo))
320
+
321
+ return return_score, return_score, return_score
322
+
323
+ class SummaCScorer():
324
+ def __init__(self, summac_type='conv', device='cuda:0') -> None:
325
+ self.summac_type = summac_type
326
+ import os, sys
327
+ sys.path.append("baselines/summac")
328
+ from summac.model_summac import SummaCZS, SummaCConv
329
+
330
+ if summac_type == 'conv':
331
+ self.model = SummaCConv(models=["vitc"], bins='percentile', granularity="sentence", nli_labels="e", device=device, start_file="default", agg="mean")
332
+ elif summac_type == 'zs':
333
+ self.model = SummaCZS(granularity="sentence", model_name="vitc", device=device) # If you have a GPU: switch to: device="cuda"
334
+
335
+ def scorer(self, premise, hypo):
336
+ assert len(premise) == len(hypo)
337
+ scores = self.model.score(premise, hypo)['scores']
338
+ return_score = torch.tensor(scores)
339
+
340
+ return return_score, return_score, return_score
341
+
342
+ class FactCCScorer():
343
+ def __init__(self, script_path, test_data_path,result_path) -> None:
344
+ self.script_path = script_path
345
+ self.result_path = result_path
346
+ self.test_data_path = test_data_path
347
+ def scorer(self, premise, hypo):
348
+ import subprocess
349
+ import pickle
350
+
351
+ self.generate_json_file(premise, hypo)
352
+ subprocess.call(f"sh {self.script_path}", shell=True)
353
+ print("Finishing FactCC")
354
+ results = pickle.load(open(self.result_path, 'rb'))
355
+ results = [-each+1 for each in results]
356
+
357
+ return torch.tensor(results), torch.tensor(results), torch.tensor(results)
358
+
359
+ def generate_json_file(self, premise, hypo):
360
+ output = []
361
+ assert len(premise) == len(hypo)
362
+ i = 0
363
+ for one_premise, one_hypo in zip(premise, hypo):
364
+ example = dict()
365
+ example['id'] = i
366
+ example['text'] = one_premise
367
+ example['claim'] = one_hypo
368
+ example['label'] = 'CORRECT'
369
+
370
+ i += 1
371
+ output.append(example)
372
+ with open(self.test_data_path, 'w', encoding='utf8') as f:
373
+ for each in output:
374
+ json.dump(each, f, ensure_ascii=False)
375
+ f.write('\n')
376
+
377
+ class BLANCScorer():
378
+ def __init__(self, device='cuda', batch_size=64) -> None:
379
+ from blanc import BlancHelp, BlancTune
380
+ self.blanc_help = BlancHelp(device=device, inference_batch_size=batch_size)
381
+
382
+
383
+ def scorer(self, premise, hypo):
384
+ score = self.blanc_help.eval_pairs(premise, hypo)
385
+
386
+ return_score = torch.tensor(score)
387
+
388
+ return return_score, return_score, return_score
389
+
390
+
391
+ class BLEUScorer():
392
+ def __init__(self, n_grams=1) -> None:
393
+ self.n_grams = n_grams
394
+ self.n_gram_map = {
395
+ 1: (1,0,0,0),
396
+ 2: (0.5,0.5,0,0),
397
+ 3: (1./3,1./3,1./3,0),
398
+ 4: (0.25,0.25,0.25,0.25)
399
+ }
400
+
401
+ def scorer(self, premise, hypo):
402
+ from nltk.translate.bleu_score import sentence_bleu
403
+ assert len(premise) == len(hypo), "premise and hypothesis should be the same length!"
404
+
405
+ output_score = []
406
+
407
+ for one_pre, one_hypo in tqdm(zip(premise, hypo), desc=f"Evaluating BLEU-{self.n_grams}", total=len(premise)):
408
+ scores = []
409
+ pre_sents = sent_tokenize(one_pre)
410
+ references = [[each for each in sent.split()] for sent in pre_sents]
411
+ for hypo_sent in sent_tokenize(one_hypo):
412
+ hypothesis = [each for each in hypo_sent.split()]
413
+ scores.append(sentence_bleu(references=references, hypothesis=hypothesis, weights=self.n_gram_map[self.n_grams]))
414
+ output_score.append(sum(scores)/len(scores) if len(scores)>0 else 0.)
415
+
416
+ return torch.tensor(output_score), torch.tensor(output_score), torch.tensor(output_score)
417
+
418
+ class ROUGEScorer():
419
+ def __init__(self, rouge_type='1') -> None:
420
+ from rouge import Rouge
421
+ self.rouge = Rouge()
422
+ self.rouge_type = rouge_type
423
+
424
+ def scorer(self, premise, hypo):
425
+
426
+ assert len(premise) == len(hypo), "premise and hypothesis should be the same length!"
427
+
428
+ output_score = []
429
+
430
+ for one_pre, one_hypo in tqdm(zip(premise, hypo), desc=f"Evaluating ROUGE-{self.rouge_type}", total=len(premise)):
431
+ scores = []
432
+ for pre_sent in sent_tokenize(one_pre):
433
+ for hypo_sent in sent_tokenize(one_hypo):
434
+ try:
435
+ scores.append(self.rouge.get_scores(pre_sent, hypo_sent)[0][f"rouge-{self.rouge_type}"]['f'])
436
+ except:
437
+ if len(pre_sent.strip()) == 0:
438
+ print('premise sent is empty')
439
+ elif len(hypo_sent.strip()) == 0:
440
+ print('hypo sent is empty')
441
+ scores.append(0.0)
442
+ scores = np.array(scores)
443
+ scores = scores.reshape((len(sent_tokenize(one_pre)), len(sent_tokenize(one_hypo))))
444
+ scores = scores.max(axis=0).mean()
445
+ output_score.append(scores.item())
446
+
447
+ return torch.tensor(output_score), torch.tensor(output_score), torch.tensor(output_score)
448
+
449
+
450
+ class GPTScoreScorer():
451
+ def __init__(self, api_key, gpt_model='davinci003') -> None:
452
+ import os, sys
453
+ sys.path.append('../BaselineForNLGEval/GPTScore')
454
+ from gpt3_score import gpt3score
455
+
456
+ self.gpt3score = gpt3score
457
+ self.api_key = api_key
458
+ self.gpt_model = gpt_model
459
+
460
+ self.consistency_prefix = "Generate factually consistent summary for the following text: "
461
+ self.consistency_suffix = " \n\nTl;dr "
462
+
463
+
464
+ def scorer(self, premise: list, hypothesis: list):
465
+ assert len(premise) == len(hypothesis)
466
+ output_score = []
467
+ for p, h in tqdm(zip(premise, hypothesis), total=len(premise), desc="Evaluating GPTScore"):
468
+ score = self.gpt3score(input=self.consistency_prefix + p + self.consistency_suffix, output=h, gpt3model=self.gpt_model, api_key=self.api_key)
469
+ output_score.append(score)
470
+
471
+ output_score = torch.tensor(output_score)
472
+
473
+ return None, output_score, None
474
+
475
+ class ChatGPTLuo2023Scorer():
476
+ def __init__(self, task, api_key, chat_model='gpt-3.5-turbo') -> None:
477
+ openai.api_key = api_key
478
+ assert isinstance(task, list) and len(task) == 1
479
+
480
+ self.task = task[0]
481
+ self.chat_model = chat_model
482
+ self.instruct = """Score the following summary given the corresponding article with respect to consistency from 1 to 10. Note that consistency measures how much information included in the summary is present in the source article. 10 points indicate the summary contains only statements that are entailed by the source document."""
483
+
484
+ def scorer(self, premise: list, hypothesis: list):
485
+ import time
486
+ assert len(premise) == len(hypothesis)
487
+ output_score = []
488
+ i = -1
489
+
490
+ for p, h in tqdm(zip(premise, hypothesis), total=len(premise), desc="Evaluating ChatGPTLuo2023"):
491
+ i += 1
492
+ if i <= -1: continue
493
+
494
+ attempt = 0
495
+ max_attempt = 5
496
+ while attempt < max_attempt:
497
+ try:
498
+ response = openai.ChatCompletion.create(
499
+ model=self.chat_model,
500
+ messages=[
501
+ # {"role": "system", "content": "You are a helpful assistant."},
502
+ {"role": "user", "content": f"""Score the following summary given the corresponding article with respect to consistency from 1 to 10. Note that consistency measures how much information included in the summary is present in the source article. 10 points indicate the summary contains only statements that are entailed by the source document.
503
+
504
+ Summary: {h}
505
+
506
+ Article: {p} """},
507
+ ],
508
+ temperature=0,
509
+ max_tokens=10
510
+ )
511
+ res_content = response['choices'][0]['message']['content']
512
+ break
513
+ except:
514
+ attempt += 1
515
+ print("openai api failed")
516
+ if max_attempt == attempt:
517
+ print("maximum failed attempts reached. exiting...")
518
+ exit()
519
+ json.dump({i: res_content}, open(f'exp_results/nlg_eval_fact/baselines/ChatGPTLuo2023-output/{self.task}.json', 'a'))
520
+ with open(f'exp_results/nlg_eval_fact/baselines/ChatGPTLuo2023-output/{self.task}.json', 'a') as f:
521
+ f.write('\n')
522
+
523
+ try:
524
+ score = int(res_content)
525
+ except:
526
+ print("unknown score")
527
+ score = 0.0
528
+ output_score.append(score)
529
+ # time.sleep(1)
530
+
531
+ output_score = torch.tensor(output_score)
532
+
533
+ return None, output_score, None
534
+
535
+ class ChatGPTGao2023Scorer():
536
+ def __init__(self, task, api_key, chat_model='gpt-3.5-turbo') -> None:
537
+ openai.api_key = api_key
538
+ assert isinstance(task, list) and len(task) == 1
539
+
540
+ self.task = task[0]
541
+ self.chat_model = chat_model
542
+
543
+ def scorer(self, premise: list, hypothesis: list):
544
+ import time
545
+ assert len(premise) == len(hypothesis)
546
+ output_score = []
547
+ i = -1
548
+
549
+ for p, h in tqdm(zip(premise, hypothesis), total=len(premise), desc="Evaluating ChatGPTGao2023"):
550
+ i += 1
551
+ if i <= -1: continue
552
+
553
+ attempt = 0
554
+ max_attempt = 5
555
+ while attempt < max_attempt:
556
+ try:
557
+ response = openai.ChatCompletion.create(
558
+ model=self.chat_model,
559
+ messages=[
560
+ # {"role": "system", "content": "You are a human annotator that rates the quality of summaries"},
561
+ # {"role": "user", "content": f"""Imagine you are a human annotator now. You will evaluate the quality of summaries written for a news article. Please follow these steps:\n\n 1. Carefully read the news article, and be aware of the information it contains.\n 2. Read the proposed summary.\n 3. Rate the summary on four dimensions: relevance, consistency, fluency, and coherence. You should rate on a scale from 1 (worst) to 5 (best).\n\n Definitions are as follows:\n Relevance: The rating measures how well the summary captures the key points of the article. Consider whether all and only the important aspects are contained in the summary.\n Consistency: The rating measures whether the facts in the summary are consistent with the facts in the original article. Consider whether the summary does reproduce all facts accurately and does not make up untrue information.\n Fluency: This rating measures the quality of individual sentences, whether they are well-written and grammatically correct. Consider the quality of individual sentences.\n Coherence: The rating measures the quality of all sentences collectively, to fit together and sound natural. Consider the quality of the summary as a whole.\n\n The article and the summary are given below:\n Article: {p}\n Summary: {h}"""},
562
+ {"role": "user", "content": f"""Evaluate the quality of summaries written for a news article. Rate each summary on four dimensions: relevance, faithfulness, fluency, and coherence. You should rate on a scale from 1 (worst) to 5 (best).\n\n Article: {p}\n Summary: {h}"""},
563
+ ],
564
+ temperature=0,
565
+ # max_tokens=10
566
+ )
567
+ res_content = response['choices'][0]['message']['content']
568
+ break
569
+ except:
570
+ attempt += 1
571
+ print("openai api failed")
572
+ if max_attempt == attempt:
573
+ print("maximum failed attempts reached. exiting...")
574
+ exit()
575
+ json.dump({i: res_content}, open(f'exp_results/nlg_eval_fact/baselines/ChatGPTGao2023-output/{self.task}.json', 'a'))
576
+ with open(f'exp_results/nlg_eval_fact/baselines/ChatGPTGao2023-output/{self.task}.json', 'a') as f:
577
+ f.write('\n')
578
+
579
+ try:
580
+ score = int(res_content)
581
+ except:
582
+ print("unknown score")
583
+ score = 0.0
584
+ output_score.append(score)
585
+ # time.sleep(1)
586
+
587
+ output_score = torch.tensor(output_score)
588
+
589
+ return None, output_score, None
590
+
591
+ class ChatGPTYiChen2023Scorer():
592
+ def __init__(self, task, api_key, chat_model='gpt-3.5-turbo') -> None:
593
+ ### Explicit score by ChatGPT
594
+ openai.api_key = api_key
595
+ assert isinstance(task, list) and len(task) == 1
596
+
597
+ self.task = task[0]
598
+ self.chat_model = chat_model
599
+
600
+ def scorer(self, premise: list, hypothesis: list):
601
+ import time
602
+ assert len(premise) == len(hypothesis)
603
+ output_score = []
604
+ i = -1
605
+
606
+ for p, h in tqdm(zip(premise, hypothesis), total=len(premise), desc="Evaluating ChatGPTYiChen2023"):
607
+ i += 1
608
+ if i <= -1: continue
609
+
610
+ attempt = 0
611
+ max_attempt = 5
612
+ while attempt < max_attempt:
613
+ try:
614
+ response = openai.ChatCompletion.create(
615
+ model=self.chat_model,
616
+ messages=[
617
+ # {"role": "system", "content": "You are a human annotator that rates the quality of summaries"},
618
+ # {"role": "user", "content": f"""Imagine you are a human annotator now. You will evaluate the quality of summaries written for a news article. Please follow these steps:\n\n 1. Carefully read the news article, and be aware of the information it contains.\n 2. Read the proposed summary.\n 3. Rate the summary on four dimensions: relevance, consistency, fluency, and coherence. You should rate on a scale from 1 (worst) to 5 (best).\n\n Definitions are as follows:\n Relevance: The rating measures how well the summary captures the key points of the article. Consider whether all and only the important aspects are contained in the summary.\n Consistency: The rating measures whether the facts in the summary are consistent with the facts in the original article. Consider whether the summary does reproduce all facts accurately and does not make up untrue information.\n Fluency: This rating measures the quality of individual sentences, whether they are well-written and grammatically correct. Consider the quality of individual sentences.\n Coherence: The rating measures the quality of all sentences collectively, to fit together and sound natural. Consider the quality of the summary as a whole.\n\n The article and the summary are given below:\n Article: {p}\n Summary: {h}"""},
619
+ {"role": "user", "content": f"""Score the following storyline given the beginning of the story on a continual scale from 0 (worst) to 100 (best), where score of 0 means "The storyline makes no sense and is totally not understandable" and score of 100 means "The storyline is perfect-written and highly consistent with the given beginning of the story". \n\n The beginning of the story: {p} \n\n Storyline: {h} \n\n Score: """},
620
+ ],
621
+ temperature=0,
622
+ # max_tokens=10
623
+ )
624
+ res_content = response['choices'][0]['message']['content']
625
+ break
626
+ except:
627
+ attempt += 1
628
+ print("openai api failed")
629
+ if max_attempt == attempt:
630
+ print("maximum failed attempts reached. exiting...")
631
+ exit()
632
+ json.dump({i: res_content}, open(f'exp_results/nlg_eval_fact/baselines/ChatGPTYiChen2023-output/{self.task}.json', 'a'))
633
+ with open(f'exp_results/nlg_eval_fact/baselines/ChatGPTYiChen2023-output/{self.task}.json', 'a') as f:
634
+ f.write('\n')
635
+
636
+ try:
637
+ score = int(res_content)
638
+ except:
639
+ print("unknown score")
640
+ score = 0.0
641
+ output_score.append(score)
642
+ # time.sleep(1)
643
+
644
+ output_score = torch.tensor(output_score)
645
+
646
+ return None, output_score, None
647
+
648
+ class ChatGPTShiqiChen2023Scorer():
649
+ def __init__(self, task, api_key, chat_model='gpt-3.5-turbo') -> None:
650
+ ### Explicit score by ChatGPT
651
+ openai.api_key = api_key
652
+ assert isinstance(task, list) and len(task) == 1
653
+
654
+ self.task = task[0]
655
+ self.chat_model = chat_model
656
+
657
+ def scorer(self, premise: list, hypothesis: list):
658
+ import time
659
+ assert len(premise) == len(hypothesis)
660
+ output_score = []
661
+ i = -1
662
+
663
+ for p, h in tqdm(zip(premise, hypothesis), total=len(premise), desc="Evaluating ChatGPTShiqiChen2023"):
664
+ i += 1
665
+ if i <= -1: continue
666
+ hypo_sents = sent_tokenize(h)
667
+ hypo_sents = ' \n '.join([f"{i+1}. "+each for i, each in enumerate(hypo_sents)])
668
+ attempt = 0
669
+ max_attempt = 5
670
+ while attempt < max_attempt:
671
+ try:
672
+ response = openai.ChatCompletion.create(
673
+ model=self.chat_model,
674
+ messages=[
675
+ # {"role": "system", "content": "You are a human annotator that rates the quality of summaries"},
676
+ # {"role": "user", "content": f"""Imagine you are a human annotator now. You will evaluate the quality of summaries written for a news article. Please follow these steps:\n\n 1. Carefully read the news article, and be aware of the information it contains.\n 2. Read the proposed summary.\n 3. Rate the summary on four dimensions: relevance, consistency, fluency, and coherence. You should rate on a scale from 1 (worst) to 5 (best).\n\n Definitions are as follows:\n Relevance: The rating measures how well the summary captures the key points of the article. Consider whether all and only the important aspects are contained in the summary.\n Consistency: The rating measures whether the facts in the summary are consistent with the facts in the original article. Consider whether the summary does reproduce all facts accurately and does not make up untrue information.\n Fluency: This rating measures the quality of individual sentences, whether they are well-written and grammatically correct. Consider the quality of individual sentences.\n Coherence: The rating measures the quality of all sentences collectively, to fit together and sound natural. Consider the quality of the summary as a whole.\n\n The article and the summary are given below:\n Article: {p}\n Summary: {h}"""},
677
+ {"role": "user", "content": f"""Source Document: \n {p} \n\n Q: Can the following statement be inferred from the above document? Yes or No?\n {hypo_sents} \n A: 1. """},
678
+ ],
679
+ temperature=0,
680
+ # max_tokens=10
681
+ )
682
+ res_content = response['choices'][0]['message']['content']
683
+ break
684
+ except:
685
+ attempt += 1
686
+ print("openai api failed")
687
+ if max_attempt == attempt:
688
+ print("maximum failed attempts reached. exiting...")
689
+ exit()
690
+ json.dump({i: res_content}, open(f'exp_results/nlg_eval_fact/baselines/ChatGPTShiqiChen2023-output/{self.task}.json', 'a'))
691
+ with open(f'exp_results/nlg_eval_fact/baselines/ChatGPTShiqiChen2023-output/{self.task}.json', 'a') as f:
692
+ f.write('\n')
693
+
694
+ try:
695
+ score = int(res_content)
696
+ except:
697
+ print("unknown score")
698
+ score = 0.0
699
+ output_score.append(score)
700
+ # time.sleep(1)
701
+
702
+ output_score = torch.tensor(output_score)
703
+
704
+ return None, output_score, None
alignscore/benchmark.py ADDED
@@ -0,0 +1,494 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from evaluate import Evaluator, ALL_TASKS
2
+ from baselines import *
3
+ from alignscore.inference import Inferencer
4
+ import time
5
+ import json
6
+ import os
7
+ from argparse import ArgumentParser
8
+
9
+ SAVE_ALL_TABLES = True
10
+ SAVE_AND_PRINT_TIMER = False
11
+
12
+ class Timer():
13
+ def __init__(self) -> None:
14
+ self.t0 = time.time()
15
+ self.save_path = 'exp_results/time.json'
16
+
17
+ def finish(self, display_name):
18
+ t1 = time.time()
19
+ time_pass = t1 - self.t0
20
+ if SAVE_AND_PRINT_TIMER:
21
+ print(f"Evalautor {display_name} finished in {time_pass} secs.")
22
+ with open(self.save_path, 'a', encoding='utf8') as f:
23
+ json.dump({display_name: time_pass}, f)
24
+ f.write('\n')
25
+
26
+
27
+ def eval_ctc(model_type, tasks=ALL_TASKS):
28
+ ctc_scorer = CTCScorer(model_type)
29
+ evaluator = Evaluator(eval_tasks=tasks, align_func=ctc_scorer.score, save_all_tables=SAVE_ALL_TABLES)
30
+ evaluator.result_save_name = f"baselines/CTC-{model_type}"
31
+
32
+ timer = Timer()
33
+ evaluator.evaluate()
34
+ timer.finish(f"CTC-{model_type}")
35
+
36
+ def eval_simcse(model_type, device, tasks=ALL_TASKS):
37
+ simcse_scorer = SimCSEScorer(model_type, device)
38
+ evaluator = Evaluator(eval_tasks=tasks, align_func=simcse_scorer.score, save_all_tables=SAVE_ALL_TABLES)
39
+ evaluator.result_save_name = f"baselines/{model_type.split('/')[-1]}_f"
40
+
41
+ timer = Timer()
42
+ evaluator.evaluate()
43
+ timer.finish(f"{model_type.split('/')[-1]}_f")
44
+
45
+ def eval_bleurt(checkpoint, tasks=ALL_TASKS):
46
+ bleurt_scorer = BleurtScorer(checkpoint)
47
+ evaluator = Evaluator(eval_tasks=tasks, align_func=bleurt_scorer.scorer, save_all_tables=SAVE_ALL_TABLES)
48
+ evaluator.result_save_name = f"baselines/BLEURT"
49
+
50
+ timer = Timer()
51
+ evaluator.evaluate()
52
+ timer.finish(f"BLEURT")
53
+
54
+ def eval_bertscore(model_type, device, batch_size, tasks=ALL_TASKS):
55
+ bertscore_scorer = BertScoreScorer(model_type=model_type, metric='f1', device=device, batch_size=batch_size)
56
+ evaluator = Evaluator(eval_tasks=tasks, align_func=bertscore_scorer.scorer, save_all_tables=SAVE_ALL_TABLES)
57
+ evaluator.result_save_name = f"baselines/bertscore_{model_type.replace('/', '-')}_f"
58
+
59
+ timer = Timer()
60
+ evaluator.evaluate()
61
+ timer.finish(f"bertscore_{model_type.replace('/', '-')}_f")
62
+
63
+ def eval_bartscore(checkpoint, device, tasks=ALL_TASKS):
64
+ bartscore_scorer = BartScoreScorer(checkpoint, device)
65
+ evaluator = Evaluator(eval_tasks=tasks, align_func=bartscore_scorer.scorer, save_all_tables=SAVE_ALL_TABLES)
66
+ evaluator.result_save_name = f"baselines/bartscore-{checkpoint.replace('/','-')}"
67
+
68
+ timer = Timer()
69
+ evaluator.evaluate()
70
+ timer.finish(f"bartscore-{checkpoint.replace('/','-')}")
71
+
72
+ ### Below are Baselines for SummaC
73
+ def eval_mnli(model="roberta-large-mnli", device='cuda:0', tasks=ALL_TASKS):
74
+ mnli_scorer = MNLIScorer(model=model, device=device)
75
+ evaluator = Evaluator(eval_tasks=tasks, align_func=mnli_scorer.scorer, save_all_tables=SAVE_ALL_TABLES)
76
+ evaluator.result_save_name = f"baselines/mnli-{model}"
77
+
78
+ timer = Timer()
79
+ evaluator.evaluate()
80
+ timer.finish(f"mnli-{model}")
81
+
82
+ def eval_ner(tasks=ALL_TASKS):
83
+ ner_scorer = NERScorer()
84
+ evaluator = Evaluator(eval_tasks=tasks, align_func=ner_scorer.scorer, save_all_tables=SAVE_ALL_TABLES)
85
+ evaluator.result_save_name = f"baselines/NER"
86
+
87
+ timer = Timer()
88
+ evaluator.evaluate()
89
+ timer.finish(f"NER")
90
+
91
+ def eval_unieval(tasks=ALL_TASKS, device='cuda:0'):
92
+ unieval = UniEvalScorer(task='fact', device=device)
93
+ evaluator = Evaluator(eval_tasks=tasks, align_func=unieval.scorer, save_all_tables=SAVE_ALL_TABLES)
94
+ evaluator.result_save_name = f"baselines/UniEval"
95
+
96
+ timer = Timer()
97
+ evaluator.evaluate()
98
+ timer.finish(f"UniEval")
99
+
100
+ def eval_feqa(tasks=ALL_TASKS):
101
+ feqa = FEQAScorer()
102
+ evaluator = Evaluator(eval_tasks=tasks, align_func=feqa.scorer, save_all_tables=SAVE_ALL_TABLES)
103
+ evaluator.result_save_name = f"baselines/FEQA"
104
+
105
+ timer = Timer()
106
+ evaluator.evaluate()
107
+ timer.finish(f"FEQA")
108
+
109
+ def eval_questeval(tasks=ALL_TASKS):
110
+ questeval = QuestEvalScorer()
111
+ evaluator = Evaluator(eval_tasks=tasks, align_func=questeval.scorer, save_all_tables=SAVE_ALL_TABLES)
112
+ evaluator.result_save_name = f"baselines/QuestEval"
113
+
114
+ timer = Timer()
115
+ evaluator.evaluate()
116
+ timer.finish(f"QuestEval")
117
+
118
+ def eval_qafacteval(tasks=ALL_TASKS, device='cuda:0'):
119
+ import os, sys
120
+ warning("using conda env qaeval!!!")
121
+ qafacteval = QAFactEvalScorer(device=device, model_folder=os.path.abspath('../BaselineForNLGEval/QAFactEval/models'))
122
+ evaluator = Evaluator(eval_tasks=tasks, align_func=qafacteval.scorer, save_all_tables=SAVE_ALL_TABLES)
123
+ evaluator.result_save_name = f"baselines/QAFactEval"
124
+ evaluator.evaluate()
125
+
126
+ def eval_dae(tasks=ALL_TASKS, model_dir=None, device=0):
127
+ dae = DAEScorer(model_dir=model_dir, device=device)
128
+ evaluator = Evaluator(eval_tasks=tasks, align_func=dae.scorer, save_all_tables=SAVE_ALL_TABLES)
129
+ evaluator.result_save_name = f"baselines/DAE"
130
+
131
+ timer = Timer()
132
+ evaluator.evaluate()
133
+ timer.finish(f"DAE")
134
+
135
+ def eval_bleu(tasks=ALL_TASKS, n_grams=1):
136
+ bleu = BLEUScorer(n_grams=n_grams)
137
+ evaluator = Evaluator(eval_tasks=tasks, align_func=bleu.scorer, save_all_tables=SAVE_ALL_TABLES)
138
+ evaluator.result_save_name = f"baselines/BLEU-{n_grams}"
139
+
140
+ timer = Timer()
141
+ evaluator.evaluate()
142
+ timer.finish(f"BLEU-{n_grams}")
143
+
144
+ def eval_rouge(tasks=ALL_TASKS, rouge_type='1'):
145
+ rouge = ROUGEScorer(rouge_type=rouge_type)
146
+ evaluator = Evaluator(eval_tasks=tasks, align_func=rouge.scorer, save_all_tables=SAVE_ALL_TABLES)
147
+ evaluator.result_save_name = f"baselines/ROUGE-{rouge_type}"
148
+
149
+ timer = Timer()
150
+ evaluator.evaluate()
151
+ timer.finish(f"ROUGE-{rouge_type}")
152
+
153
+ def eval_factcc(script_path, test_data_path,result_path, tasks=ALL_TASKS):
154
+ factcc = FactCCScorer(script_path=script_path, test_data_path=test_data_path, result_path=result_path)
155
+ evaluator = Evaluator(eval_tasks=tasks, align_func=factcc.scorer, save_all_tables=SAVE_ALL_TABLES)
156
+ evaluator.result_save_name = f"baselines/FactCC"
157
+
158
+ timer = Timer()
159
+ evaluator.evaluate()
160
+ timer.finish(f"FactCC")
161
+
162
+ def eval_blanc(tasks=ALL_TASKS, device='cuda:0', batch_size=64):
163
+ blanc = BLANCScorer(device=device, batch_size=batch_size)
164
+ evaluator = Evaluator(eval_tasks=tasks, align_func=blanc.scorer, save_all_tables=SAVE_ALL_TABLES)
165
+ evaluator.result_save_name = f"baselines/BLANC"
166
+
167
+ timer = Timer()
168
+ evaluator.evaluate()
169
+ timer.finish(f"BLANC")
170
+
171
+ def eval_summac(tasks=ALL_TASKS, summac_type='conv', device='cuda:0'):
172
+ summac = SummaCScorer(summac_type=summac_type, device=device)
173
+ evaluator = Evaluator(eval_tasks=tasks, align_func=summac.scorer, save_all_tables=SAVE_ALL_TABLES)
174
+ evaluator.result_save_name = f"baselines/SummaC-{summac_type}"
175
+
176
+ timer = Timer()
177
+ evaluator.evaluate()
178
+ timer.finish(f"SummaC-{summac_type}")
179
+
180
+ def eval_align_nlg(ckpt_path, comment='', base_model='roberta-large', batch_size=32, device='cuda:0', tasks=ALL_TASKS, nlg_eval_mode='nli_sp'):
181
+ align = Inferencer(ckpt_path=ckpt_path, model=base_model, batch_size=batch_size, device=device)
182
+ if 'smart' in nlg_eval_mode:
183
+ align.smart_type = nlg_eval_mode
184
+ else:
185
+ align.nlg_eval_mode = nlg_eval_mode
186
+
187
+ evaluator = Evaluator(eval_tasks=tasks, align_func=align.nlg_eval, save_all_tables=SAVE_ALL_TABLES)
188
+ name = f'AlignScore-{nlg_eval_mode}-{base_model}'
189
+ if comment:
190
+ name += '_' + comment
191
+ evaluator.result_save_name = f"align_eval/{name}"
192
+
193
+ timer = Timer()
194
+ evaluator.evaluate()
195
+ timer.finish(name)
196
+
197
+ def eval_gptscore(api_key, gpt_model='davinci003', tasks=ALL_TASKS):
198
+ gptscore = GPTScoreScorer(api_key=api_key, gpt_model=gpt_model)
199
+ evaluator = Evaluator(eval_tasks=tasks, align_func=gptscore.scorer, is_save_all_tables=IS_SAVE_ALL_TABLES)
200
+ evaluator.result_save_name = f"nlg_eval_fact/baselines/GPTScore-{gpt_model}"
201
+ evaluator.evaluate()
202
+
203
+ def eval_chatgptluo2023(api_key, chat_model='gpt-3.5-turbo', tasks=['qags_cnndm']):
204
+ chatgpt = ChatGPTLuo2023Scorer(task=tasks, api_key=api_key, chat_model=chat_model)
205
+ evaluator = Evaluator(eval_tasks=tasks, align_func=chatgpt.scorer, is_save_all_tables=IS_SAVE_ALL_TABLES)
206
+ evaluator.result_save_name = f"nlg_eval_fact/baselines/ChatGPTLuo2023-{chat_model}"
207
+ evaluator.evaluate()
208
+
209
+ def eval_chatgptgao2023(api_key, chat_model='gpt-3.5-turbo', tasks=['qags_cnndm']):
210
+ chatgpt = ChatGPTGao2023Scorer(task=tasks, api_key=api_key, chat_model=chat_model)
211
+ evaluator = Evaluator(eval_tasks=tasks, align_func=chatgpt.scorer, is_save_all_tables=IS_SAVE_ALL_TABLES)
212
+ evaluator.result_save_name = f"nlg_eval_fact/baselines/ChatGPTGao2023-{chat_model}"
213
+ evaluator.evaluate()
214
+
215
+ def eval_chatgptyichen2023(api_key, chat_model='gpt-3.5-turbo', tasks=['qags_cnndm']):
216
+ chatgpt = ChatGPTYiChen2023Scorer(task=tasks, api_key=api_key, chat_model=chat_model)
217
+ evaluator = Evaluator(eval_tasks=tasks, align_func=chatgpt.scorer, is_save_all_tables=IS_SAVE_ALL_TABLES)
218
+ evaluator.result_save_name = f"nlg_eval_fact/baselines/ChatGPTYiChen2023-{chat_model}"
219
+ evaluator.evaluate()
220
+
221
+ def eval_chatgptshiqichen2023(api_key, chat_model='gpt-3.5-turbo', tasks=['qags_cnndm']):
222
+ chatgpt = ChatGPTShiqiChen2023Scorer(task=tasks, api_key=api_key, chat_model=chat_model)
223
+ evaluator = Evaluator(eval_tasks=tasks, align_func=chatgpt.scorer, is_save_all_tables=IS_SAVE_ALL_TABLES)
224
+ evaluator.result_save_name = f"nlg_eval_fact/baselines/ChatGPTShiqiChen2023-{chat_model}"
225
+ evaluator.evaluate()
226
+
227
+ def run_benchmarks(args, argugment_error):
228
+ os.makedirs('exp_results/baselines', exist_ok=True)
229
+ os.makedirs('exp_results/align_eval', exist_ok=True)
230
+
231
+ if args.alignscore:
232
+ if not all((args.alignscore_model, args.alignscore_ckpt, args.alignscore_eval_mode)):
233
+ argugment_error('--alignscore-model, --alignscore-model, and --alignscore-ckpt must be specified to run AlignScore')
234
+ eval_align_nlg(
235
+ nlg_eval_mode=args.alignscore_eval_mode,
236
+ ckpt_path=args.alignscore_ckpt,
237
+ base_model=args.alignscore_model,
238
+ device=args.device, tasks=args.tasks,
239
+ comment=args.alignscore_comment
240
+ )
241
+
242
+ if args.ctc:
243
+ if not args.ctc_type:
244
+ argugment_error('--ctc-type must be specified to run CTC baseline')
245
+ for type in args.ctc_type:
246
+ eval_ctc(type, tasks=args.tasks)
247
+
248
+ if args.simcse:
249
+ if not args.simcse_ckpt:
250
+ argugment_error('--simcse-ckpt must be specified to run SimCSE baseline')
251
+ for ckpt in args.simcse_ckpt:
252
+ eval_simcse(ckpt, device=args.device, tasks=args.tasks)
253
+
254
+ if args.bleurt:
255
+ if not args.bleurt_ckpt:
256
+ argugment_error('--bleurt-ckpt must be specified to run BLEURT baseline')
257
+ eval_bleurt(args.bleurt_ckpt, tasks=args.tasks)
258
+
259
+ if args.bertscore:
260
+ if not args.bertscore_ckpt or not args.bertscore_batch_size:
261
+ argugment_error('--bertscore-ckpt and --bertscore-batch-size must be specified to run BERTScore baseline')
262
+ for ckpt in args.bertscore_ckpt:
263
+ eval_bertscore(ckpt, device=args.device, tasks=args.tasks, batch_size=args.bertscore_batch_size)
264
+
265
+ if args.bartscore:
266
+ if not args.bartscore_ckpt:
267
+ argugment_error('--bartscore-ckpt must be specified to run BARTScore baseline')
268
+ for ckpt in args.bartscore_ckpt:
269
+ eval_bartscore(ckpt, device=args.device, tasks=args.tasks)
270
+
271
+ if args.mnli:
272
+ if not args.mnli_ckpt:
273
+ argugment_error('--mnli-ckpt must be specified to run MNLI baseline')
274
+ for ckpt in args.mnli_ckpt:
275
+ eval_mnli(model=ckpt, device=args.device, tasks=args.tasks)
276
+
277
+ if args.ner:
278
+ eval_ner(tasks=args.tasks)
279
+
280
+ if args.unieval:
281
+ eval_unieval(tasks=args.tasks, device=args.device)
282
+
283
+ if args.feqa:
284
+ eval_feqa(tasks=args.tasks)
285
+
286
+ if args.questeval:
287
+ eval_questeval(tasks=args.tasks)
288
+
289
+ if args.qafacteval:
290
+ eval_qafacteval(tasks=args.tasks)
291
+
292
+ if args.bleu:
293
+ if not args.bleu_ngram:
294
+ argugment_error('--bleu-ngram must be specified to run BLEU baseline')
295
+ for n in args.bleu_ngram:
296
+ eval_bleu(tasks=args.tasks, n_grams=n)
297
+
298
+ if args.rouge:
299
+ if not args.rouge_type:
300
+ argugment_error('--rouge-type must be specified to run ROUGE baseline')
301
+ for type in args.rouge_type:
302
+ eval_rouge(tasks=args.tasks, rouge_type=type)
303
+
304
+ if args.dae:
305
+ if not args.dae_ckpt:
306
+ argugment_error('--dae-ckpt must be specified to run DAE baseline')
307
+ eval_dae(tasks=args.tasks, model_dir=os.path.abspath(args.dae_ckpt))
308
+
309
+ if args.factcc:
310
+ if not all((args.factcc_script, args.factcc_test_data, args.factcc_result_path)):
311
+ argugment_error('--factcc-script, --factcc-test-data, and --factcc-result-path must be specified to run FactCC baseline')
312
+ eval_factcc(
313
+ tasks=args.tasks,
314
+ script_path=os.path.abspath(args.factcc_script),
315
+ test_data_path=os.path.abspath(args.factcc_test_data),
316
+ result_path=os.path.abspath(args.factcc_result_path)
317
+ )
318
+
319
+ if args.blanc:
320
+ if not args.blanc_batch_size:
321
+ argugment_error('--blanc-batch-size must be specified to run BLANC baseline')
322
+ eval_blanc(tasks=args.tasks, device=args.device, batch_size=args.blanc_batch_size)
323
+
324
+ if args.summac:
325
+ if not args.summac_type:
326
+ argugment_error('--summac-type must be specified to run SummaC baseline')
327
+ for type in args.summac_type:
328
+ eval_summac(tasks=args.tasks, device=args.device, summac_type=type)
329
+
330
+
331
+ if __name__ == "__main__":
332
+ FACT_EVAL_TASKS = ['summac', 'true','xsumfaith', 'summeval', 'qags_xsum', 'qags_cnndm', 'newsroom', 'rank19', 'frank', 'samsum']
333
+
334
+ parser = ArgumentParser()
335
+ parser.add_argument('--tasks', nargs='+', type=str, default=FACT_EVAL_TASKS, choices=FACT_EVAL_TASKS)
336
+ parser.add_argument('--device', type=str, default='cuda:0')
337
+ parser.add_argument('--timer', action='store_true', help='Time all metric runs')
338
+
339
+ alignscore_parser = parser.add_argument_group('AlignScore')
340
+ alignscore_parser.add_argument('--alignscore', action='store_true', help='Run AlignScore benchmark')
341
+ alignscore_parser.add_argument('--alignscore-model', type=str, choices=['roberta-base', 'roberta-large'])
342
+ alignscore_parser.add_argument('--alignscore-ckpt', type=str)
343
+ alignscore_parser.add_argument(
344
+ '--alignscore-eval-mode',
345
+ type=str,
346
+ choices=['bin', 'bin_sp', 'nli', 'nli_sp', 'reg', 'reg_sp', 'smart-n', 'smart-l'],
347
+ default='nli_sp'
348
+ )
349
+ alignscore_parser.add_argument('--alignscore-comment', type=str, default='')
350
+
351
+ ctc_parser = parser.add_argument_group('Baseline - CTC')
352
+ ctc_parser.add_argument('--ctc', action='store_true', help='Run CTC baseline')
353
+ ctc_parser.add_argument(
354
+ '--ctc-type',
355
+ nargs='*',
356
+ type=str,
357
+ choices=['D-cnndm', 'E-roberta', 'R-cnndm'],
358
+ default=['D-cnndm']
359
+ )
360
+
361
+ simcse_parser = parser.add_argument_group('Baseline - SimCSE')
362
+ simcse_models = [
363
+ 'princeton-nlp/unsup-simcse-bert-base-uncased',
364
+ 'princeton-nlp/unsup-simcse-bert-large-uncased',
365
+ 'princeton-nlp/unsup-simcse-roberta-base',
366
+ 'princeton-nlp/unsup-simcse-roberta-large',
367
+ 'princeton-nlp/sup-simcse-bert-base-uncased',
368
+ 'princeton-nlp/sup-simcse-bert-large-uncased',
369
+ 'princeton-nlp/sup-simcse-roberta-base',
370
+ 'princeton-nlp/sup-simcse-roberta-large'
371
+ ]
372
+ simcse_parser.add_argument('--simcse', action='store_true', help='Run SimCSE baseline')
373
+ simcse_parser.add_argument(
374
+ '--simcse-ckpt',
375
+ nargs='*',
376
+ type=str,
377
+ choices=simcse_models,
378
+ default=['princeton-nlp/sup-simcse-roberta-large']
379
+ )
380
+
381
+ bleurt_parser = parser.add_argument_group('Baseline - BLEURT')
382
+ bleurt_parser.add_argument('--bleurt', action='store_true', help='Run BLEURT baseline')
383
+ bleurt_parser.add_argument('--bleurt-ckpt', type=str)
384
+
385
+ bertscore_parser = parser.add_argument_group('Baseline - BERTScore')
386
+ bertscore_parser.add_argument('--bertscore', action='store_true', help='Run BERTScore baseline')
387
+ bertscore_parser.add_argument(
388
+ '--bertscore-ckpt',
389
+ nargs='*',
390
+ type=str,
391
+ default=['microsoft/deberta-xlarge-mnli']
392
+ )
393
+ bertscore_parser.add_argument('--bertscore-batch-size', type=int, default=16)
394
+
395
+ bartscore_parser = parser.add_argument_group(
396
+ 'Baseline - BARTScore',
397
+ description='Please clone https://github.com/neulab/BARTScore to baselines/BARTScore.'
398
+ )
399
+ bartscore_parser.add_argument('--bartscore', action='store_true', help='Run BARTScore baseline')
400
+ bartscore_parser.add_argument(
401
+ '--bartscore-ckpt',
402
+ type=str,
403
+ nargs='*',
404
+ default=['facebook/bart-large-cnn']
405
+ )
406
+
407
+ mnli_parser = parser.add_argument_group('Baseline - MNLI')
408
+ mnli_parser.add_argument('--mnli', action='store_true', help='Run MNLI baseline')
409
+ mnli_parser.add_argument(
410
+ '--mnli-ckpt',
411
+ nargs='*',
412
+ type=str,
413
+ default=['roberta-large-mnli']
414
+ )
415
+
416
+ ner_parser = parser.add_argument_group(
417
+ 'Baseline - NER overlap',
418
+ description='Please clone https://github.com/tingofurro/summac to baselines/summac.'
419
+ )
420
+ ner_parser.add_argument('--ner', action='store_true', help='Run NER overlap baseline')
421
+
422
+ unieval_parser = parser.add_argument_group(
423
+ 'Baseline - UniEval',
424
+ description='Please clone https://github.com/maszhongming/UniEval to baselines/UniEval.'
425
+ )
426
+ unieval_parser.add_argument('--unieval', action='store_true', help='Run UniEval baseline')
427
+
428
+ feqa_parser = parser.add_argument_group(
429
+ 'Baseline - FEQA',
430
+ description='Please clone https://github.com/esdurmus/feqa to baselines/feqa'
431
+ )
432
+ feqa_parser.add_argument('--feqa', action='store_true', help='Run FEQA baseline')
433
+
434
+ questeval_parser = parser.add_argument_group(
435
+ 'Baseline - QuestEval',
436
+ description='Please clone https://github.com/ThomasScialom/QuestEval to baselines/QuestEval.'
437
+ )
438
+ questeval_parser.add_argument('--questeval', action='store_true', help='Run QuestEval baseline')
439
+
440
+ qafacteval_parser = parser.add_argument_group(
441
+ 'Baseline - QAFactEval',
442
+ description='Please clone https://github.com/salesforce/QAFactEval to baselines/QAFactEval.'
443
+ )
444
+ qafacteval_parser.add_argument('--qafacteval', action='store_true', help='Run QAFactEval baseline')
445
+
446
+ bleu_parser = parser.add_argument_group('Baseline - BLEU')
447
+ bleu_parser.add_argument('--bleu', action='store_true', help='Run BLEU baseline')
448
+ bleu_parser.add_argument(
449
+ '--bleu-ngram',
450
+ nargs='*',
451
+ type=int,
452
+ choices=[1, 2, 3, 4],
453
+ default=[1, 2, 3, 4]
454
+ )
455
+
456
+ rouge_parser = parser.add_argument_group('Baseline - ROUGE')
457
+ rouge_parser.add_argument('--rouge', action='store_true', help='Run ROUGE baseline')
458
+ rouge_parser.add_argument(
459
+ '--rouge-type',
460
+ nargs='*',
461
+ type=str,
462
+ choices=['1', '2', 'l'],
463
+ default=['1', '2', 'l']
464
+ )
465
+
466
+ dae_parser = parser.add_argument_group('Baseline - DAE')
467
+ dae_parser.add_argument('--dae', action='store_true', help='Run DAE baseline')
468
+ dae_parser.add_argument('--dae-ckpt', type=str)
469
+
470
+ factcc_parser = parser.add_argument_group('Baseline - FactCC')
471
+ factcc_parser.add_argument('--factcc', action='store_true', help='Run FactCC baseline')
472
+ factcc_parser.add_argument('--factcc-script', type=str)
473
+ factcc_parser.add_argument('--factcc-test-data', type=str)
474
+ factcc_parser.add_argument('--factcc-result-path', type=str)
475
+
476
+ blanc_parser = parser.add_argument_group('Baseline - BLANC')
477
+ blanc_parser.add_argument('--blanc', action='store_true', help='Run BLANC baseline')
478
+ blanc_parser.add_argument('--blanc-batch-size', type=int, default=64)
479
+
480
+ summac_parser = parser.add_argument_group(
481
+ 'Baseline - SummaC',
482
+ description='Please clone https://github.com/tingofurro/summac to baselines/summac.'
483
+ )
484
+ summac_parser.add_argument('--summac', action='store_true', help='Run SummaC baseline')
485
+ summac_parser.add_argument('--summac-type', nargs='*', type=str, choices=['conv', 'zs'], default=['conv', 'zs'])
486
+
487
+ args = parser.parse_args()
488
+ if args.timer:
489
+ SAVE_AND_PRINT_TIMER = True
490
+
491
+ def argugment_error(msg):
492
+ parser.error(msg)
493
+
494
+ run_benchmarks(args, argugment_error)
alignscore/evaluate.py ADDED
@@ -0,0 +1,1793 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from logging import warning
2
+ from datasets import load_dataset
3
+ from alignscore.inference import Inferencer
4
+ import numpy as np
5
+ from scipy.stats import pearsonr, kendalltau, spearmanr
6
+ from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, balanced_accuracy_score, matthews_corrcoef
7
+ import pandas as pd
8
+ import torch
9
+ import json
10
+ import pickle
11
+ import os
12
+
13
+ HUGGINGFACE_DATASETS = {
14
+ 'stsb': ['glue', 'stsb', 'validation'],
15
+ 'mrpc': ['glue', 'mrpc', 'test'],
16
+ 'axb': ['super_glue', 'axb', 'test'],
17
+ 'axg': ['super_glue', 'axg', 'test'],
18
+ 'cb': ['super_glue', 'cb', 'validation'],
19
+ 'rte': ['super_glue', 'rte', 'validation'],
20
+ 'wnli': ['SetFit/wnli', 'validation'],
21
+ 'paws': ['paws', 'labeled_final', 'test'],
22
+ 'mnli_matched': ['multi_nli', 'validation_matched'],
23
+ 'mnli_mismatched': ['multi_nli', 'validation_mismatched'],
24
+ 'nli_fever': ['pietrolesci/nli_fever', 'dev'],
25
+ 'doc_nli': ['saattrupdan/doc-nli', 'test'],
26
+ 'sem_eval': ['sem_eval_2014_task_1', 'test'],
27
+ 'sick': ['sick', 'default', 'test'],
28
+ 'race_m': ['race', 'middle', 'test'],
29
+ 'race_h': ['race', 'high', 'test'],
30
+ 'boolq': ['boolq', 'validation'],
31
+ 'anli_1': ['anli', 'test_r1'],
32
+ 'anli_2': ['anli', 'test_r2'],
33
+ 'anli_3': ['anli', 'test_r3'],
34
+ 'snli': ['snli', 'test'],
35
+ 'vitaminc': ['tals/vitaminc', 'test'],
36
+ 'qqp': ['glue', 'qqp', 'validation'],
37
+ # below are tasks from https://arxiv.org/pdf/2104.14690.pdf
38
+ 'sst2': ['SetFit/sst2', 'test'],
39
+ # can't find MR
40
+ 'cr': ['SetFit/SentEval-CR', 'test'],
41
+ # can't find MPQA
42
+ 'subj': ['SetFit/subj', 'test'],
43
+ # can't find OS
44
+ 'imdb': ['SetFit/imdb', 'test'], # note: I can't confirm if this is the same dataset used in that paper
45
+ # The original dataset is no longer accessiable
46
+ 'cola': ['glue', 'cola', 'validation'],
47
+ 'yelp_efl': ['SetFit/yelp_review_full', 'test'],
48
+ 'ag_news': ['SetFit/ag_news', 'test'],
49
+ 'trec': ['SetFit/TREC-QC', 'test',],
50
+ 'dream': ['dream', 'test'],
51
+ 'quartz': ['quartz', 'test'],
52
+ 'eraser_multi_rc': ['eraser_multi_rc', 'test'],
53
+ 'quail': ['quail', 'challenge'],
54
+ 'sciq': ['sciq', 'test'],
55
+ 'gap': ['gap', 'test'],
56
+ 'qnli': ['glue', 'qnli', 'validation']
57
+ }
58
+
59
+ PICKLE_DATASETS = [
60
+ 'newsroom',
61
+ 'rank19',
62
+ 'bagel',
63
+ 'sfhot',
64
+ 'sfres'
65
+ ]
66
+
67
+ ALL_TASKS = { # enumerate all possible tasks
68
+ 'stsb': 0, ### using which output: regression, binary, tri-label
69
+ 'sick': 0,
70
+ 'race_m': 1,
71
+ 'race_h': 1,
72
+ 'boolq': 1,
73
+ 'anli_1': 2,
74
+ 'anli_2': 2,
75
+ 'anli_3': 2,
76
+ 'snli': 2,
77
+ 'vitaminc': 2,
78
+ 'mrpc': 1,
79
+ 'paws': 1,
80
+ 'mnli_matched': 2,
81
+ 'mnli_mismatched': 2,
82
+ 'sem_eval': 1,
83
+ 'summeval': 1,
84
+ 'qags_xsum': 1,
85
+ 'qags_cnndm': 1,
86
+ 'frank': 1,
87
+ 'xsumfaith': 1,
88
+ 'samsum': 1,
89
+ 'yelp': 1,
90
+ 'persona_chat': 1,
91
+ 'topical_chat': 1,
92
+ 'paws_qqp': 1,
93
+ 'qqp': 1,
94
+ 'newsroom': 1,
95
+ 'rank19': 1,
96
+ 'bagel': 1,
97
+ 'sfhot': 1,
98
+ 'sfres': 1,
99
+ 'wmt17': 0,
100
+ 'wmt18': 0,
101
+ 'wmt19': 0,
102
+ 'sst2': 1,
103
+ 'cr': 1,
104
+ 'subj': 1,
105
+ 'imdb': 1,
106
+ 'cola': 1,
107
+ 'yelp_efl': 1,
108
+ 'ag_news': 1,
109
+ 'trec': 1,
110
+ 'axb': 1,
111
+ 'axg': 1,
112
+ 'cb': 2,
113
+ 'rte': 2,
114
+ 'wnli': 2,
115
+ 'dream': 1,
116
+ 'quartz': 1,
117
+ 'nli_fever': 2,
118
+ 'doc_nli': 1,
119
+ 'eraser_multi_rc': 1,
120
+ 'quail': 1,
121
+ 'sciq': 1,
122
+ 'gap': 1,
123
+ 'qnli': 1
124
+ }
125
+
126
+ FEW_SHOT_N = 8
127
+ FEW_SHOT_SEEDS = [30247, 38252, 29050, 1091, 35554, 25309, 79319, 35079, 35256, 46744]
128
+
129
+ class Evaluator():
130
+ def __init__(self, eval_tasks, align_func, save_all_tables=False, clean_data=True) -> None:
131
+ self.align_func = align_func
132
+ self.eval_tasks = eval_tasks # ['stsb', 'paws', ...]
133
+ self.result_save_name = "Default_result_name"
134
+ self.result_tables = []
135
+ self.result_dicts = []
136
+ self.clean_data = clean_data
137
+ self.init_eval_dataset()
138
+
139
+ self.should_save_all_tables = save_all_tables
140
+ warning(f"Saving the result is: {self.should_save_all_tables}")
141
+
142
+ def init_eval_dataset(self):
143
+ self.dataset = dict()
144
+ for eval_task in self.eval_tasks:
145
+ if eval_task in HUGGINGFACE_DATASETS:
146
+ if len(HUGGINGFACE_DATASETS[eval_task]) == 3:
147
+ self.dataset[eval_task] = load_dataset(HUGGINGFACE_DATASETS[eval_task][0], HUGGINGFACE_DATASETS[eval_task][1])[HUGGINGFACE_DATASETS[eval_task][2]]
148
+ elif len(HUGGINGFACE_DATASETS[eval_task]) == 2:
149
+ if isinstance(HUGGINGFACE_DATASETS[eval_task][1], tuple):
150
+ dataset = load_dataset(HUGGINGFACE_DATASETS[eval_task][0])
151
+ self.dataset[eval_task] = {split:dataset[split] for split in HUGGINGFACE_DATASETS[eval_task][1]}
152
+ else:
153
+ self.dataset[eval_task] = load_dataset(HUGGINGFACE_DATASETS[eval_task][0])[HUGGINGFACE_DATASETS[eval_task][1]]
154
+
155
+ elif eval_task == 'paws_qqp':
156
+ self.dataset[eval_task] = pd.read_csv('data/paws_qqp/output/dev_and_test.tsv', sep='\t')
157
+ elif eval_task == 'beir':
158
+ print("beir load by itself")
159
+ self.dataset[eval_task] = "BEIR Benchmark"
160
+ elif eval_task in PICKLE_DATASETS:
161
+ with open(f'data/eval/{eval_task}.pkl', 'rb') as f:
162
+ self.dataset[eval_task] = pickle.load(f)
163
+ elif 'wmt' in eval_task:
164
+ self.dataset[eval_task] = []
165
+ with open(f'data/eval/{eval_task}_eval.jsonl', 'r', encoding='utf8') as f:
166
+ for example in f:
167
+ self.dataset[eval_task].append(json.loads(example))
168
+ elif 'true' == eval_task:
169
+ for each_true_sub in os.listdir('data/eval/true'):
170
+ if 'qags' in each_true_sub:
171
+ each_true_sub_name = 'true_' + '_'.join(each_true_sub.split('_')[:2])
172
+ else:
173
+ each_true_sub_name = 'true_' + '_'.join(each_true_sub.split('_')[:1])
174
+
175
+ self.dataset[each_true_sub_name] = pd.read_csv(os.path.join('data/eval/true', each_true_sub))
176
+ elif 'summac' == eval_task:
177
+ from summac.benchmark import SummaCBenchmark
178
+ self.summac_validation_set = dict()
179
+ summac_benchmark = SummaCBenchmark(benchmark_folder="./data/eval/summac/benchmark", cut='test')
180
+ for each in summac_benchmark.datasets:
181
+ summac_dt_name = each['name']
182
+ self.dataset['summac_'+summac_dt_name] = each['dataset']
183
+
184
+ summac_benchmark_valid = SummaCBenchmark(benchmark_folder="./data/eval/summac/benchmark", cut='val')
185
+ for each in summac_benchmark_valid.datasets:
186
+ summac_dt_name = each['name']
187
+ self.summac_validation_set['summac_'+summac_dt_name] = each['dataset']
188
+ else:
189
+ f = open(f'data/eval/{eval_task}.json')
190
+ self.dataset[eval_task] = json.load(f)
191
+ f.close()
192
+
193
+ def print_result_table(self, table):
194
+ self.result_tables.append(pd.DataFrame(table).to_markdown())
195
+ self.result_dicts.append(table)
196
+ print(self.result_tables[-1])
197
+
198
+ def print_all_tables(self):
199
+ print("\n All Evaluation Results:")
200
+ for each in self.result_tables:
201
+ print(each)
202
+ print('='*100)
203
+
204
+ def save_all_tables(self):
205
+ with open(f'exp_results/{self.result_save_name}.pkl', 'wb') as f:
206
+ pickle.dump(self.result_dicts, f, protocol=pickle.HIGHEST_PROTOCOL)
207
+
208
+ def evaluate(self):
209
+ for each_task in self.dataset:
210
+ eval(f'self.evaluate_{each_task}()')
211
+
212
+ if self.should_save_all_tables:
213
+ self.save_all_tables()
214
+
215
+ def get_accuracy(self, true_score, pred_score):
216
+ return [accuracy_score(true_score, [m>0.5 for m in pred_score])]
217
+
218
+ def get_balanced_accuracy(self, true_score, pred_score, thres=0.5):
219
+ return [balanced_accuracy_score(true_score, [m>thres for m in pred_score])]
220
+
221
+ def get_f1(self, true_score, pred_score):
222
+ return [f1_score(true_score, [m>0.5 for m in pred_score])]
223
+
224
+ def get_3label_f1(self, true_score, pred_score):
225
+ return [f1_score(true_score, pred_score, average='micro')]
226
+
227
+ def get_pearson(self, true_score, pred_score):
228
+ return pearsonr(pred_score, true_score)
229
+
230
+ def get_kendalltau(self, true_score, pred_score):
231
+ return kendalltau(pred_score, true_score)
232
+
233
+ def get_spearman(self, true_score, pred_score):
234
+ return spearmanr(pred_score, true_score)
235
+
236
+ def get_matthews_corr(self, true_score, pred_score):
237
+ return [matthews_corrcoef(true_score, [s>0.5 for s in pred_score])]
238
+
239
+
240
+ def clean_text(self, context, claims):
241
+ from nltk.tokenize import sent_tokenize
242
+
243
+ if not self.clean_data:
244
+ return claims
245
+
246
+ word_cases = {token.lower():token for token in context.strip().split()}
247
+
248
+ def clean(text):
249
+ text = ' '.join(word_cases.get(token.lower(), token) for token in text.strip().split())
250
+ text = text.replace('“', '"').replace('”', '"').replace('’', '\'').replace('‘', '\'').replace('`', '\'').replace('-lrb-', '(').replace('-rrb-', ')')
251
+ text= ' '.join(each.strip()[0].capitalize()+each.strip()[1:] for each in sent_tokenize(text))
252
+ return text
253
+
254
+ if isinstance(claims, str):
255
+ return clean(claims)
256
+
257
+ return [clean(text) for text in claims]
258
+
259
+
260
+ def evaluate_newsroom(self):
261
+ true_score = []
262
+ true_score_rel = []
263
+ true_score_binary = []
264
+ sent1 = []
265
+ sent2 = []
266
+
267
+ for sample in self.dataset['newsroom'].values():
268
+ summaries, informativeness, relevance = zip(*(
269
+ (s['sys_summ'], s['scores']['informativeness'], s['scores']['relevance'])
270
+ for s in sample['sys_summs'].values()
271
+ ))
272
+ cleaned_summaries = self.clean_text(sample['src'], summaries)
273
+ for summary, inf_score, rel_score in zip(cleaned_summaries, informativeness, relevance):
274
+ sent1.append(sample['src'])
275
+ sent2.append(summary)
276
+ true_score.append(inf_score)
277
+ true_score_rel.append(rel_score)
278
+ true_score_binary.append(int(inf_score >= 4))
279
+
280
+ pred_score = self.align_func(sent1, sent2)[ALL_TASKS['newsroom']].tolist()
281
+
282
+ self.print_result_table({
283
+ 'Dataset_name': 'newsroom',
284
+ 'Pearson': self.get_pearson(true_score, pred_score),
285
+ 'Spearman': self.get_spearman(true_score, pred_score),
286
+ 'Kendall': self.get_kendalltau(true_score, pred_score),
287
+ 'AUC': roc_auc_score(true_score_binary, pred_score),
288
+ 'Pearson_rel': self.get_pearson(true_score_rel, pred_score),
289
+ 'Spearman_rel': self.get_spearman(true_score_rel, pred_score),
290
+ 'Kendall_rel': self.get_kendalltau(true_score_rel, pred_score),
291
+ })
292
+
293
+ def evaluate_rank19(self):
294
+ def chunks(lst, n):
295
+ """Yield successive n-sized chunks from lst."""
296
+ for i in range(0, len(lst), n):
297
+ yield lst[i:i + n]
298
+ true_score = []
299
+ sent1 = []
300
+ sent2 = []
301
+
302
+ for example in self.dataset['rank19']:
303
+ for example_summs in self.dataset['rank19'][example]['sys_summs']:
304
+ sent1.append(self.dataset['rank19'][example]['src'])
305
+ sent2.append(self.dataset['rank19'][example]['sys_summs'][example_summs]['sys_summ'])
306
+ true_score.append(self.dataset['rank19'][example]['sys_summs'][example_summs]['scores']['fact'])
307
+
308
+ pred_score = self.align_func(sent1, sent2)[ALL_TASKS['rank19']].tolist()
309
+ pred_score_bin = []
310
+ assert len(pred_score) % 2 == 0
311
+ for i, pair in enumerate(chunks(pred_score, 2)):
312
+ pred_score_bin.extend([0, 1] if pair[1] > pair[0] else [1, 0])
313
+
314
+ self.print_result_table({
315
+ 'Dataset_name': 'rank19',
316
+ 'Pearson': self.get_pearson(true_score, pred_score),
317
+ 'Spearman': self.get_spearman(true_score, pred_score),
318
+ 'Kendall': self.get_kendalltau(true_score, pred_score),
319
+ 'Accuracy': self.get_accuracy(true_score, pred_score_bin)[0],
320
+ 'AUC': roc_auc_score(true_score, pred_score_bin)
321
+ })
322
+
323
+ def evaluate_bagel(self):
324
+ true_score = []
325
+ true_score_binary = []
326
+ sent1 = []
327
+ sent2 = []
328
+ pred_score = []
329
+
330
+ for example in self.dataset['bagel']:
331
+ sent1.append(' '.join(self.dataset['bagel'][example]['ref_summs']))
332
+ sent2.append(self.dataset['bagel'][example]['sys_summ'])
333
+ true_score.append(self.dataset['bagel'][example]['scores']['informativeness'])
334
+
335
+ if(self.dataset['bagel'][example]['scores']['informativeness'] >= 4.0):
336
+ true_score_binary.append(1)
337
+ else:
338
+ true_score_binary.append(0)
339
+ pred_score = self.align_func(sent1, sent2)[ALL_TASKS['bagel']].tolist()
340
+
341
+ self.print_result_table({
342
+ 'Dataset_name': 'bagel',
343
+ 'Pearson': self.get_pearson(true_score, pred_score),
344
+ 'Spearman': self.get_spearman(true_score, pred_score),
345
+ 'Kendall': self.get_kendalltau(true_score, pred_score),
346
+ 'AUC': roc_auc_score(true_score_binary, pred_score)
347
+ })
348
+
349
+ def evaluate_sfhot(self):
350
+ true_score = []
351
+ sent1 = []
352
+ sent2 = []
353
+ pred_score = []
354
+
355
+ for example in self.dataset['sfhot']:
356
+ for ref in self.dataset['sfhot'][example]['ref_summs']:
357
+ sent1.append(self.dataset['sfhot'][example]['sys_summ'])
358
+ sent2.append(ref)
359
+ pred_score.append(max(self.align_func(sent1, sent2)[ALL_TASKS['sfhot']].tolist()))
360
+ sent1 = []
361
+ sent2 = []
362
+ if(self.dataset['sfhot'][example]['scores']['quality'] >= 4.0):
363
+ true_score.append(1)
364
+ else:
365
+ true_score.append(0)
366
+
367
+ self.print_result_table({
368
+ 'Dataset_name': 'sfhot',
369
+ 'Pearson': self.get_pearson(true_score, pred_score),
370
+ 'Spearman': self.get_spearman(true_score, pred_score),
371
+ 'Kendall': self.get_kendalltau(true_score, pred_score),
372
+ 'AUC': roc_auc_score(true_score, pred_score)
373
+ })
374
+
375
+ def evaluate_sfres(self):
376
+ true_score = []
377
+ sent1 = []
378
+ sent2 = []
379
+ pred_score = []
380
+
381
+ for example in self.dataset['sfres']:
382
+ for ref in self.dataset['sfres'][example]['ref_summs']:
383
+ sent1.append(self.dataset['sfres'][example]['sys_summ'])
384
+ sent2.append(ref)
385
+ pred_score.append(max(self.align_func(sent1, sent2)[ALL_TASKS['sfres']].tolist()))
386
+ sent1 = []
387
+ sent2 = []
388
+ if(self.dataset['sfres'][example]['scores']['quality'] >= 4.0):
389
+ true_score.append(1)
390
+ else:
391
+ true_score.append(0)
392
+
393
+ self.print_result_table({
394
+ 'Dataset_name': 'sfres',
395
+ 'Pearson': self.get_pearson(true_score, pred_score),
396
+ 'Spearman': self.get_spearman(true_score, pred_score),
397
+ 'Kendall': self.get_kendalltau(true_score, pred_score),
398
+ 'AUC': roc_auc_score(true_score, pred_score)
399
+ })
400
+
401
+
402
+ def evaluate_stsb(self):
403
+ true_score = []
404
+ sent1 = []
405
+ sent2 = []
406
+ for example in self.dataset['stsb']:
407
+ sent1.append(example['sentence1'])
408
+ sent2.append(example['sentence2'])
409
+ true_score.append(example['label'])
410
+
411
+ pred_score = self.align_func(sent1, sent2)[ALL_TASKS['stsb']].tolist()
412
+
413
+ self.print_result_table({
414
+ 'Dataset_name': 'stsb',
415
+ 'Pearson': self.get_pearson(true_score, pred_score),
416
+ 'Spearman': self.get_spearman(true_score, pred_score),
417
+ 'Kendall': self.get_kendalltau(true_score, pred_score)
418
+ })
419
+
420
+ def evaluate_sick(self):
421
+ true_score = []
422
+ sent1 = []
423
+ sent2 = []
424
+ for example in self.dataset['sick']:
425
+ sent1.append(example['sentence_A'])
426
+ sent2.append(example['sentence_B'])
427
+ true_score.append(example['relatedness_score'])
428
+
429
+ pred_score = self.align_func(sent1, sent2)[ALL_TASKS['sick']].tolist()
430
+
431
+ self.print_result_table({
432
+ 'Dataset_name': 'sick-r',
433
+ 'Pearson': self.get_pearson(true_score, pred_score),
434
+ 'Spearman': self.get_spearman(true_score, pred_score),
435
+ 'Kendall': self.get_kendalltau(true_score, pred_score)
436
+ })
437
+
438
+ def evaluate_race_m(self):
439
+ true_score = []
440
+ article = []
441
+ qa = []
442
+
443
+ for example in self.dataset['race_m']:
444
+ for i, option in enumerate(example['options']):
445
+ article.append(example['article'])
446
+ qa.append(example['question']+" "+option+" " if "_" not in example['question'] else ' '.join(example['question'].replace("_", " "+option+" ").split()))
447
+ if i == ord(example['answer'])-65:
448
+ true_score.append(i) # 0,1,2,3
449
+
450
+ pred_score = []
451
+ pred_score_temp = self.align_func(article, qa)[ALL_TASKS['race_m']].tolist()
452
+ for a, b, c, d in zip(*[iter(pred_score_temp)]*4):
453
+ arr = [0]*4
454
+ pred_score.append(np.argmax([a,b,c,d]))
455
+
456
+ assert len(pred_score) == len(true_score)
457
+ acc = [int(p==t) for p, t in zip(pred_score, true_score)]
458
+ acc = sum(acc) / len(acc)
459
+
460
+ self.print_result_table({
461
+ 'Dataset_name': 'race-m',
462
+ 'Accuracy': [acc],
463
+ })
464
+
465
+ def evaluate_race_h(self):
466
+ true_score = []
467
+ article = []
468
+ qa = []
469
+
470
+ for example in self.dataset['race_h']:
471
+ for i, option in enumerate(example['options']):
472
+ article.append(example['article'])
473
+ qa.append(example['question']+" "+option+" " if "_" not in example['question'] else ' '.join(example['question'].replace("_", " "+option+" ").split()))
474
+ if i == ord(example['answer'])-65:
475
+ true_score.append(i) # 0,1,2,3
476
+
477
+ pred_score = []
478
+ pred_score_temp = self.align_func(article, qa)[ALL_TASKS['race_h']].tolist()
479
+ for a, b, c, d in zip(*[iter(pred_score_temp)]*4):
480
+ pred_score.append(np.argmax([a,b,c,d]))
481
+
482
+ assert len(pred_score) == len(true_score)
483
+ acc = [int(p==t) for p, t in zip(pred_score, true_score)]
484
+ acc = sum(acc) / len(acc)
485
+
486
+ self.print_result_table({
487
+ 'Dataset_name': 'race-h',
488
+ 'Accuracy': [acc]
489
+ })
490
+
491
+ # How to combine passage, question, and single answer for boolq
492
+ def evaluate_boolq(self):
493
+ true_score = []
494
+ article = []
495
+ qa = []
496
+ for example in self.dataset['boolq']:
497
+ for i in range(2):
498
+ article.append(example['passage'])
499
+ if i == 0:
500
+ qa.append(example['question']+" "+"No.") # 0
501
+ else:
502
+ qa.append(example['question']+" "+"Yes.") # 1
503
+ true_score.append(int(example['answer']))
504
+
505
+ pred_score = []
506
+ pred_score_temp = self.align_func(article, qa)[ALL_TASKS['boolq']].tolist()
507
+ for a, b in zip(*[iter(pred_score_temp)]*2):
508
+ pred_score.append(np.argmax([a,b]))
509
+
510
+ assert len(pred_score) == len(true_score)
511
+ acc = [int(p==t) for p, t in zip(pred_score, true_score)]
512
+ acc = sum(acc) / len(acc)
513
+ self.print_result_table({
514
+ 'Dataset_name': 'boolq',
515
+ 'Accuracy': [acc]
516
+ })
517
+
518
+ def evaluate_anli_1(self):
519
+ true_score = []
520
+ sent1 = []
521
+ sent2 = []
522
+ for example in self.dataset['anli_1']:
523
+ sent1.append(example['premise'])
524
+ sent2.append(example['hypothesis'])
525
+ true_score.append(example['label'] if example['label']!=-1 else 1)
526
+
527
+ pred_score = torch.argmax(self.align_func(sent1, sent2)[ALL_TASKS['anli_1']], dim=-1).tolist()
528
+
529
+ self.print_result_table({
530
+ 'Dataset_name': 'anli-1',
531
+ 'Accuracy': [accuracy_score(true_score, pred_score)]
532
+ })
533
+
534
+ def evaluate_anli_2(self):
535
+ true_score = []
536
+ sent1 = []
537
+ sent2 = []
538
+ for example in self.dataset['anli_2']:
539
+ sent1.append(example['premise'])
540
+ sent2.append(example['hypothesis'])
541
+ true_score.append(example['label'] if example['label']!=-1 else 1)
542
+
543
+ pred_score = torch.argmax(self.align_func(sent1, sent2)[ALL_TASKS['anli_2']], dim=-1).tolist()
544
+
545
+ self.print_result_table({
546
+ 'Dataset_name': 'anli-2',
547
+ 'Accuracy': [accuracy_score(true_score, pred_score)]
548
+ })
549
+
550
+ def evaluate_anli_3(self):
551
+ true_score = []
552
+ sent1 = []
553
+ sent2 = []
554
+ for example in self.dataset['anli_3']:
555
+ sent1.append(example['premise'])
556
+ sent2.append(example['hypothesis'])
557
+ true_score.append(example['label'] if example['label']!=-1 else 1)
558
+
559
+ pred_score = torch.argmax(self.align_func(sent1, sent2)[ALL_TASKS['anli_3']], dim=-1).tolist()
560
+
561
+ self.print_result_table({
562
+ 'Dataset_name': 'anli-3',
563
+ 'Accuracy': [accuracy_score(true_score, pred_score)]
564
+ })
565
+
566
+ def evaluate_nli_fever(self):
567
+ true_score = []
568
+ sent1 = []
569
+ sent2 = []
570
+ for example in self.dataset['nli_fever']:
571
+ sent1.append(example['hypothesis']) # the original dataset flipped
572
+ sent2.append(example['premise'])
573
+ true_score.append(example['label'] if example['label']!=-1 else 1)
574
+
575
+ pred_score = torch.argmax(self.align_func(sent1, sent2)[ALL_TASKS['nli_fever']], dim=-1).tolist()
576
+
577
+ self.print_result_table({
578
+ 'Dataset_name': 'nli_fever',
579
+ 'Accuracy': [accuracy_score(true_score, pred_score)]
580
+ })
581
+
582
+ def evaluate_snli(self):
583
+ true_score = []
584
+ sent1 = []
585
+ sent2 = []
586
+ for example in self.dataset['snli']:
587
+ sent1.append(example['premise'])
588
+ sent2.append(example['hypothesis'])
589
+ true_score.append(example['label'] if example['label']!=-1 else 1)
590
+
591
+ pred_score = torch.argmax(self.align_func(sent1, sent2)[ALL_TASKS['snli']], dim=-1).tolist()
592
+
593
+ self.print_result_table({
594
+ 'Dataset_name': 'snli',
595
+ 'Accuracy': [accuracy_score(true_score, pred_score)]
596
+ })
597
+
598
+ def evaluate_axb(self):
599
+ true_score = []
600
+ sent1 = []
601
+ sent2 = []
602
+ for example in self.dataset['axb']:
603
+ sent1.append(example['sentence1'])
604
+ sent2.append(example['sentence2'])
605
+
606
+ true_score.append(1 if example['label']==0 else 0)
607
+
608
+ pred_score = self.align_func(sent1, sent2)[ALL_TASKS['axb']].tolist()
609
+
610
+ self.print_result_table({
611
+ 'Dataset_name': 'axb',
612
+ 'F1': self.get_f1(true_score, pred_score),
613
+ 'Accuracy': self.get_accuracy(true_score, pred_score),
614
+ 'AUC': [roc_auc_score(true_score, pred_score)],
615
+ 'Matthews': self.get_matthews_corr(true_score, pred_score)
616
+ })
617
+
618
+ def evaluate_axg(self):
619
+ true_score = []
620
+ sent1 = []
621
+ sent2 = []
622
+ for example in self.dataset['axg']:
623
+ sent1.append(example['premise'])
624
+ sent2.append(example['hypothesis'])
625
+
626
+ true_score.append(1 if example['label']==0 else 0)
627
+
628
+ pred_score = self.align_func(sent1, sent2)[2][:,0].tolist()
629
+
630
+ self.print_result_table({
631
+ 'Dataset_name': 'axg',
632
+ 'F1': self.get_f1(true_score, pred_score),
633
+ 'Accuracy': self.get_accuracy(true_score, pred_score),
634
+ 'AUC': [roc_auc_score(true_score, pred_score)],
635
+ })
636
+
637
+ def evaluate_cb(self):
638
+ true_score = []
639
+ sent1 = []
640
+ sent2 = []
641
+
642
+ for example in self.dataset['cb']:
643
+ sent1.append(example['premise'])
644
+ sent2.append(example['hypothesis'])
645
+
646
+ if example['label'] == 0:
647
+ label = 0
648
+ elif example['label'] == 1:
649
+ label = 2
650
+ elif example['label'] == 2:
651
+ label = 1
652
+
653
+ true_score.append(label)
654
+
655
+ pred_score = torch.argmax(self.align_func(sent1, sent2)[ALL_TASKS['cb']], dim=-1).tolist()
656
+
657
+ self.print_result_table({
658
+ 'Dataset_name': 'cb',
659
+ 'Accuracy': [accuracy_score(true_score, pred_score)],
660
+ })
661
+
662
+ def evaluate_rte(self):
663
+ true_score = []
664
+ sent1 = []
665
+ sent2 = []
666
+ for example in self.dataset['rte']:
667
+ sent1.append(example['premise'])
668
+ sent2.append(example['hypothesis'])
669
+
670
+ true_score.append(1 if example['label']==0 else 0)
671
+
672
+ pred_score = self.align_func(sent1, sent2)[2][:,0].tolist()
673
+
674
+ self.print_result_table({
675
+ 'Dataset_name': 'rte',
676
+ 'F1': self.get_f1(true_score, pred_score),
677
+ 'Accuracy': self.get_accuracy(true_score, pred_score),
678
+ 'AUC': [roc_auc_score(true_score, pred_score)],
679
+ })
680
+
681
+ def evaluate_wnli(self):
682
+ true_score = []
683
+ sent1 = []
684
+ sent2 = []
685
+ for example in self.dataset['wnli']:
686
+ sent1.append(example['text1'])
687
+ sent2.append(example['text2'])
688
+
689
+ true_score.append(example['label'])
690
+
691
+ pred_score = self.align_func(sent1, sent2)[2][:,0].tolist()
692
+
693
+ self.print_result_table({
694
+ 'Dataset_name': 'wnli',
695
+ 'F1': self.get_f1(true_score, pred_score),
696
+ 'Accuracy': self.get_accuracy(true_score, pred_score),
697
+ 'AUC': [roc_auc_score(true_score, pred_score)],
698
+ })
699
+
700
+ def evaluate_doc_nli(self):
701
+ true_score = []
702
+ sent1 = []
703
+ sent2 = []
704
+ for example in self.dataset['doc_nli']:
705
+ sent1.append(example['premise'])
706
+ sent2.append(example['hypothesis'])
707
+
708
+ true_score.append(1 if example['label'] == 'entailment' else 0)
709
+
710
+ pred_score = self.align_func(sent1, sent2)[ALL_TASKS['doc_nli']].tolist()
711
+
712
+ self.print_result_table({
713
+ 'Dataset_name': 'doc_nli',
714
+ 'F1': self.get_f1(true_score, pred_score),
715
+ 'Accuracy': self.get_accuracy(true_score, pred_score),
716
+ 'AUC': [roc_auc_score(true_score, pred_score)],
717
+ })
718
+
719
+ def evaluate_qnli(self):
720
+ true_score = []
721
+ sent1 = []
722
+ sent2 = []
723
+ for example in self.dataset['qnli']:
724
+ sent1.append(example['sentence'])
725
+ sent2.append(example['question'])
726
+
727
+ true_score.append(1 if example['label'] == 0 else 0)
728
+
729
+ pred_score = self.align_func(sent1, sent2)[ALL_TASKS['qnli']].tolist()
730
+
731
+ self.print_result_table({
732
+ 'Dataset_name': 'qnli',
733
+ 'F1': self.get_f1(true_score, pred_score),
734
+ 'Accuracy': self.get_accuracy(true_score, pred_score),
735
+ 'AUC': [roc_auc_score(true_score, pred_score)],
736
+ })
737
+
738
+ def evaluate_dream(self):
739
+ true_score = []
740
+ article = []
741
+ qa = []
742
+
743
+ for example in self.dataset['dream']:
744
+ for i, option in enumerate(example['choice']):
745
+ article.append(' '.join(example['dialogue']))
746
+ qa.append(example['question']+" "+option+" ")
747
+ if option == example['answer']:
748
+ true_score.append(i) # 0,1,2,3
749
+
750
+ pred_score = []
751
+ pred_score_temp = self.align_func(article, qa)[ALL_TASKS['dream']].tolist()
752
+ for a, b, c in zip(*[iter(pred_score_temp)]*3):
753
+ arr = [0]*3
754
+ pred_score.append(np.argmax([a,b,c]))
755
+
756
+ assert len(pred_score) == len(true_score)
757
+ acc = [int(p==t) for p, t in zip(pred_score, true_score)]
758
+ acc = sum(acc) / len(acc)
759
+
760
+ self.print_result_table({
761
+ 'Dataset_name': 'dream',
762
+ 'Accuracy': [acc],
763
+ })
764
+
765
+ def evaluate_quartz(self):
766
+ true_score = []
767
+ article = []
768
+ qa = []
769
+
770
+ for example in self.dataset['quartz']:
771
+ for i, option in enumerate(example['choices']['text']):
772
+ article.append(example['para'])
773
+ qa.append(example['question']+" "+option+" ")
774
+ if i == ord(example['answerKey'])-65:
775
+ true_score.append(i) # 0,1,2,3
776
+
777
+ pred_score = []
778
+ pred_score_temp = self.align_func(article, qa)[ALL_TASKS['quartz']].tolist()
779
+ for a, b in zip(*[iter(pred_score_temp)]*2):
780
+ arr = [0]*2
781
+ pred_score.append(np.argmax([a,b]))
782
+
783
+ assert len(pred_score) == len(true_score)
784
+ acc = [int(p==t) for p, t in zip(pred_score, true_score)]
785
+ acc = sum(acc) / len(acc)
786
+
787
+ self.print_result_table({
788
+ 'Dataset_name': 'quartz',
789
+ 'Accuracy': [acc],
790
+ })
791
+ def evaluate_eraser_multi_rc(self):
792
+ true_score = []
793
+ sent1 = []
794
+ sent2 = []
795
+ for example in self.dataset['eraser_multi_rc']:
796
+ sent1.append(example['passage'])
797
+ sent2.append(example['query_and_answer'].replace("|", ""))
798
+ true_score.append(example['label'])
799
+
800
+ pred_score = self.align_func(sent1, sent2)[ALL_TASKS['eraser_multi_rc']].tolist()
801
+
802
+ self.print_result_table({
803
+ 'Dataset_name': 'eraser_multi_rc',
804
+ 'F1': self.get_f1(true_score, pred_score),
805
+ 'Accuracy': self.get_accuracy(true_score, pred_score),
806
+ 'AUC': [roc_auc_score(true_score, pred_score)]
807
+ })
808
+
809
+ def evaluate_quail(self):
810
+ true_score = []
811
+ article = []
812
+ qa = []
813
+
814
+ for example in self.dataset['quail']:
815
+ for i, option in enumerate(example['answers']):
816
+ article.append(example['context'])
817
+ qa.append(example['question']+" "+option+" ")
818
+ if i == example['correct_answer_id']:
819
+ true_score.append(i) # 0,1,2,3
820
+
821
+ pred_score = []
822
+ pred_score_temp = self.align_func(article, qa)[ALL_TASKS['quail']].tolist()
823
+ for a, b, c, d in zip(*[iter(pred_score_temp)]*4):
824
+ arr = [0]*4
825
+ pred_score.append(np.argmax([a,b,c,d]))
826
+
827
+ assert len(pred_score) == len(true_score)
828
+ acc = [int(p==t) for p, t in zip(pred_score, true_score)]
829
+ acc = sum(acc) / len(acc)
830
+
831
+ self.print_result_table({
832
+ 'Dataset_name': 'quail',
833
+ 'Accuracy': [acc],
834
+ })
835
+
836
+ def evaluate_sciq(self):
837
+ true_score = []
838
+ article = []
839
+ qa = []
840
+
841
+ for example in self.dataset['sciq']:
842
+ options = [example['correct_answer'], example['distractor1'], example['distractor2'], example['distractor3']]
843
+ for i, option in enumerate(options):
844
+ article.append(example['support'])
845
+ qa.append(example['question']+" "+option+" ")
846
+ if i == 0:
847
+ true_score.append(i) # 0,1,2,3, always 0
848
+
849
+ pred_score = []
850
+ pred_score_temp = self.align_func(article, qa)[ALL_TASKS['sciq']].tolist()
851
+ for a, b, c, d in zip(*[iter(pred_score_temp)]*4):
852
+ arr = [0]*4
853
+ pred_score.append(np.argmax([a,b,c,d]))
854
+
855
+ assert len(pred_score) == len(true_score)
856
+ acc = [int(p==t) for p, t in zip(pred_score, true_score)]
857
+ acc = sum(acc) / len(acc)
858
+
859
+ self.print_result_table({
860
+ 'Dataset_name': 'sciq',
861
+ 'Accuracy': [acc],
862
+ })
863
+
864
+ def evaluate_gap(self):
865
+ true_score = []
866
+ article = []
867
+ qa = []
868
+
869
+ for example in self.dataset['gap']:
870
+ options = [example['Text'][:example['Pronoun-offset']]+example['A']+example['Text'][(example['Pronoun-offset']+len(example['Pronoun'])):],
871
+ example['Text'][:example['Pronoun-offset']]+example['B']+example['Text'][(example['Pronoun-offset']+len(example['Pronoun'])):]]
872
+ for i, option in enumerate(options):
873
+ article.append(example['Text'])
874
+ qa.append(option)
875
+
876
+ true_score.append(1 if example['B-coref'] else 0) # 0,1,2,3, always 0
877
+
878
+ pred_score = []
879
+ pred_score_temp = self.align_func(article, qa)[ALL_TASKS['gap']].tolist()
880
+ for a, b in zip(*[iter(pred_score_temp)]*2):
881
+ pred_score.append(np.argmax([a,b]))
882
+
883
+ assert len(pred_score) == len(true_score)
884
+ acc = [int(p==t) for p, t in zip(pred_score, true_score)]
885
+ acc = sum(acc) / len(acc)
886
+
887
+ self.print_result_table({
888
+ 'Dataset_name': 'gap',
889
+ 'Accuracy': [acc],
890
+ })
891
+
892
+ # How to group fact checking
893
+ def evaluate_vitaminc(self):
894
+ true_score = []
895
+ sent1 = []
896
+ sent2 = []
897
+ for example in self.dataset['vitaminc']:
898
+ sent1.append(example['evidence'])
899
+ sent2.append(example['claim'])
900
+ if example['label'] == 'SUPPORTS':
901
+ true_score.append(0)
902
+ elif example['label'] == 'REFUTES':
903
+ true_score.append(2)
904
+ else:
905
+ true_score.append(1)
906
+
907
+ pred_score = torch.argmax(self.align_func(sent1, sent2)[ALL_TASKS['vitaminc']], dim=-1).tolist()
908
+
909
+ self.print_result_table({
910
+ 'Dataset_name': 'vitaminc',
911
+ 'F1': self.get_3label_f1(true_score, pred_score),
912
+ 'Accuracy': [accuracy_score(true_score, pred_score)],
913
+ })
914
+
915
+ def evaluate_mrpc(self):
916
+ true_score = []
917
+ sent1 = []
918
+ sent2 = []
919
+ for example in self.dataset['mrpc']:
920
+ sent1.append(example['sentence1'])
921
+ sent2.append(example['sentence2'])
922
+ true_score.append(example['label'])
923
+
924
+ pred_score = self.align_func(sent1, sent2)[ALL_TASKS['mrpc']].tolist()
925
+
926
+ self.print_result_table({
927
+ 'Dataset_name': 'mrpc',
928
+ 'F1': self.get_f1(true_score, pred_score),
929
+ 'Accuracy': self.get_accuracy(true_score, pred_score),
930
+ 'AUC': [roc_auc_score(true_score, pred_score)]
931
+ })
932
+
933
+ def evaluate_paws(self):
934
+ true_score = []
935
+ sent1 = []
936
+ sent2 = []
937
+ for example in self.dataset['paws']:
938
+ sent1.append(example['sentence1'])
939
+ sent2.append(example['sentence2'])
940
+ true_score.append(example['label'])
941
+
942
+ pred_score = self.align_func(sent1, sent2)[ALL_TASKS['paws']].tolist()
943
+
944
+ self.print_result_table({
945
+ 'Dataset_name': 'paws',
946
+ 'F1': self.get_f1(true_score, pred_score),
947
+ 'Accuracy': self.get_accuracy(true_score, pred_score),
948
+ 'AUC': [roc_auc_score(true_score, pred_score)]
949
+ })
950
+
951
+ def evaluate_mnli_matched(self):
952
+ true_score = []
953
+ sent1 = []
954
+ sent2 = []
955
+ for example in self.dataset['mnli_matched']:
956
+ sent1.append(example['premise'])
957
+ sent2.append(example['hypothesis'])
958
+ true_score.append(example['label'] if example['label']!=-1 else 1)
959
+
960
+ pred_score = torch.argmax(self.align_func(sent1, sent2)[ALL_TASKS['mnli_matched']], dim=-1).tolist()
961
+
962
+ self.print_result_table({
963
+ 'Dataset_name': 'mnli_matched',
964
+ 'Accuracy': [accuracy_score(true_score, pred_score)]
965
+ })
966
+
967
+ def evaluate_mnli_mismatched(self):
968
+ true_score = []
969
+ sent1 = []
970
+ sent2 = []
971
+ for example in self.dataset['mnli_mismatched']:
972
+ sent1.append(example['premise'])
973
+ sent2.append(example['hypothesis'])
974
+ true_score.append(example['label'] if example['label']!=-1 else 1)
975
+
976
+ pred_score = torch.argmax(self.align_func(sent1, sent2)[ALL_TASKS['mnli_mismatched']], dim=-1).tolist()
977
+
978
+ self.print_result_table({
979
+ 'Dataset_name': 'mnli_mismatched',
980
+ 'Accuracy': [accuracy_score(true_score, pred_score)]
981
+ })
982
+
983
+ def evaluate_sem_eval(self):
984
+ print('Reached here')
985
+ true_score = []
986
+ sent1 = []
987
+ sent2 = []
988
+ for example in self.dataset['sem_eval']:
989
+ sent1.append(example['premise'])
990
+ sent2.append(example['hypothesis'])
991
+ if example['entailment_judgment'] == 1:
992
+ true_score.append(1)
993
+ else:
994
+ true_score.append(0)
995
+
996
+ pred_score = self.align_func(sent1, sent2)[ALL_TASKS['sem_eval']].tolist()
997
+
998
+ self.print_result_table({
999
+ 'Dataset_name': 'sem_eval',
1000
+ 'Accuracy': self.get_accuracy(true_score, pred_score)
1001
+ })
1002
+
1003
+ def evaluate_summeval(self):
1004
+ true_score = []
1005
+ true_score_rel = []
1006
+ true_score_binary = []
1007
+ pred_score = []
1008
+ sent1 = []
1009
+ sent2 = []
1010
+ for example in self.dataset['summeval']:
1011
+ cleaned_summary = self.clean_text(example['document'], example['summary'])
1012
+ sent1.append(example['document'])
1013
+ sent2.append(cleaned_summary)
1014
+ true_score.append(example['consistency'])
1015
+ true_score_rel.append(example['relevance'])
1016
+ true_score_binary.append(1 if example['consistency'] == 5.0 else 0)
1017
+
1018
+ pred_score = self.align_func(sent1, sent2)[ALL_TASKS['summeval']].tolist()
1019
+
1020
+ self.print_result_table({
1021
+ 'Dataset_name': 'summeval',
1022
+ 'Pearson': self.get_pearson(true_score, pred_score),
1023
+ 'Spearman': self.get_spearman(true_score, pred_score),
1024
+ 'Kendall': self.get_kendalltau(true_score, pred_score),
1025
+ 'AUC': roc_auc_score(true_score_binary, pred_score),
1026
+ 'Pearson_rel': self.get_pearson(true_score_rel, pred_score),
1027
+ 'Spearman_rel': self.get_spearman(true_score_rel, pred_score),
1028
+ 'Kendall_rel': self.get_kendalltau(true_score_rel, pred_score),
1029
+ })
1030
+
1031
+ def evaluate_qags_xsum(self):
1032
+ true_score = []
1033
+ pred_score = []
1034
+ sent1 = []
1035
+ sent2 = []
1036
+ for example in self.dataset['qags_xsum']:
1037
+ sent1.append(example['document'])
1038
+ sent2.append(example['summary'])
1039
+ true_score.append(example['consistency'])
1040
+
1041
+ pred_score = self.align_func(sent1, sent2)[ALL_TASKS['qags_xsum']].tolist()
1042
+
1043
+ self.print_result_table({
1044
+ 'Dataset_name': 'qags_xsum',
1045
+ 'Pearson': self.get_pearson(true_score, pred_score),
1046
+ 'Spearman': self.get_spearman(true_score, pred_score),
1047
+ 'Kendall': self.get_kendalltau(true_score, pred_score),
1048
+ 'AUC': roc_auc_score(true_score, pred_score)
1049
+ })
1050
+
1051
+ def evaluate_qags_cnndm(self):
1052
+ true_score = []
1053
+ pred_score = []
1054
+ sent1 = []
1055
+ sent2 = []
1056
+ true_score_binary = []
1057
+ for example in self.dataset['qags_cnndm']:
1058
+ sent1.append(example['document'])
1059
+ sent2.append(example['summary'])
1060
+ true_score.append(example['consistency'])
1061
+ true_score_binary.append(1 if example['consistency'] == 1.0 else 0)
1062
+
1063
+ pred_score = self.align_func(sent1, sent2)[ALL_TASKS['qags_cnndm']].tolist()
1064
+
1065
+ self.print_result_table({
1066
+ 'Dataset_name': 'qags_cnndm',
1067
+ 'Pearson': self.get_pearson(true_score, pred_score),
1068
+ 'Spearman': self.get_spearman(true_score, pred_score),
1069
+ 'Kendall': self.get_kendalltau(true_score, pred_score),
1070
+ 'AUC': roc_auc_score(true_score_binary, pred_score)
1071
+ })
1072
+
1073
+ def evaluate_frank(self):
1074
+ from spacy.lang.en import English
1075
+ nlp = English()
1076
+ nlp.add_pipe("sentencizer")
1077
+ for d in self.dataset['frank']:
1078
+ if d['dataset'] == 'cnndm':
1079
+ continue
1080
+ d['document'] = ' '.join([each.text for each in nlp(d['document']).sents])
1081
+
1082
+ true_score_xsum = []
1083
+ true_score_cnndm = []
1084
+ pred_score_xsum = []
1085
+ pred_score_cnndm = []
1086
+ sent1_xsum = []
1087
+ sent1_cnndm = []
1088
+ sent2_xsum = []
1089
+ sent2_cnndm = []
1090
+ true_score_binary_cnndm = []
1091
+ true_score_binary_xsum = []
1092
+ for example in self.dataset['frank']:
1093
+ if example['dataset'] == 'cnndm':
1094
+ sent1_cnndm.append(example['document'])
1095
+ sent2_cnndm.append(self.clean_text(example['document'], example['summary']))
1096
+ true_score_cnndm.append(example['score'])
1097
+ true_score_binary_cnndm.append(1 if example['score'] == 1.0 else 0)
1098
+ elif example['dataset'] == 'xsum':
1099
+ sent1_xsum.append(example['document'])
1100
+ sent2_xsum.append(self.clean_text(example['document'], example['summary']))
1101
+ true_score_xsum.append(example['score'])
1102
+ true_score_binary_xsum.append(1 if example['score'] == 1.0 else 0)
1103
+
1104
+ pred_score_xsum = self.align_func(sent1_xsum, sent2_xsum)[ALL_TASKS['frank']].tolist() #
1105
+ pred_score_cnndm = self.align_func(sent1_cnndm, sent2_cnndm)[ALL_TASKS['frank']].tolist() #
1106
+
1107
+ self.print_result_table({
1108
+ 'Dataset_name': 'frank-xsum',
1109
+ 'Pearson': self.get_pearson(true_score_xsum, pred_score_xsum),
1110
+ 'Spearman': self.get_spearman(true_score_xsum, pred_score_xsum),
1111
+ 'Kendall': self.get_kendalltau(true_score_xsum, pred_score_xsum),
1112
+ 'AUC': roc_auc_score(true_score_binary_xsum, pred_score_xsum)
1113
+ })
1114
+
1115
+ self.print_result_table({
1116
+ 'Dataset_name': 'frank-cnndm',
1117
+ 'Pearson': self.get_pearson(true_score_cnndm, pred_score_cnndm),
1118
+ 'Spearman': self.get_spearman(true_score_cnndm, pred_score_cnndm),
1119
+ 'Kendall': self.get_kendalltau(true_score_cnndm, pred_score_cnndm),
1120
+ 'AUC': roc_auc_score(true_score_binary_cnndm, pred_score_cnndm)
1121
+ })
1122
+
1123
+ self.print_result_table({
1124
+ 'Dataset_name': 'frank-all',
1125
+ 'Pearson': self.get_pearson(true_score_xsum+true_score_cnndm, pred_score_xsum+pred_score_cnndm),
1126
+ 'Spearman': self.get_spearman(true_score_xsum+true_score_cnndm, pred_score_xsum+pred_score_cnndm),
1127
+ 'Kendall': self.get_kendalltau(true_score_xsum+true_score_cnndm, pred_score_xsum+pred_score_cnndm),
1128
+ 'AUC': roc_auc_score(true_score_binary_xsum+true_score_binary_cnndm, pred_score_xsum+pred_score_cnndm)
1129
+ })
1130
+
1131
+ def evaluate_xsumfaith(self):
1132
+ dataset_name = 'xsumfaith'
1133
+
1134
+ true_score = []
1135
+ pred_score = []
1136
+ sent1 = []
1137
+ sent2 = []
1138
+ for example in self.dataset[dataset_name]:
1139
+ sent1.append(example['document'])
1140
+ sent2.append(self.clean_text(example['document'], example['claim']))
1141
+ true_score.append(example['label'])
1142
+
1143
+ pred_score = self.align_func(sent1, sent2)[ALL_TASKS[dataset_name]].tolist()
1144
+
1145
+ self.print_result_table({
1146
+ 'Dataset_name': dataset_name,
1147
+ 'Pearson': self.get_pearson(true_score, pred_score),
1148
+ 'Spearman': self.get_spearman(true_score, pred_score),
1149
+ 'Kendall': self.get_kendalltau(true_score, pred_score),
1150
+ })
1151
+
1152
+ def evaluate_samsum(self):
1153
+ dataset_name = 'samsum'
1154
+
1155
+ label_mapping = {
1156
+ 'factual': 1,
1157
+ 'factually incorrect': 0,
1158
+ 'too incoherent': 0
1159
+ }
1160
+ import string
1161
+ printable = set(string.printable)
1162
+
1163
+
1164
+ true_score = []
1165
+ pred_score = []
1166
+ sent1 = []
1167
+ sent2 = []
1168
+ for example in self.dataset[dataset_name]:
1169
+ cleaned_doc = ''.join(filter(lambda x: x in printable, example['article']))
1170
+ sent1.append(cleaned_doc)
1171
+ sent2.append(example['summary'])
1172
+ true_score.append(label_mapping[example['label']])
1173
+
1174
+ pred_score = self.align_func(sent1, sent2)[ALL_TASKS[dataset_name]].tolist()
1175
+
1176
+ self.print_result_table({
1177
+ 'Dataset_name': dataset_name,
1178
+ 'Pearson': self.get_pearson(true_score, pred_score),
1179
+ 'Spearman': self.get_spearman(true_score, pred_score),
1180
+ 'Kendall': self.get_kendalltau(true_score, pred_score),
1181
+ 'AUC': roc_auc_score(true_score, pred_score)
1182
+ })
1183
+ def evaluate_yelp(self):
1184
+ true_score = []
1185
+ sent1 = []
1186
+ sent2 = []
1187
+ for example in self.dataset['yelp']:
1188
+ sent1.append(example['input_sent'])
1189
+ sent2.append(example['output_sent'])
1190
+ true_score.append(example['preservation'])
1191
+
1192
+ pred_score = self.align_func(sent1, sent2)[ALL_TASKS['yelp']].tolist()
1193
+
1194
+ self.print_result_table({
1195
+ 'Dataset_name': 'yelp',
1196
+ 'Pearson': self.get_pearson(true_score, pred_score),
1197
+ 'Spearman': self.get_spearman(true_score, pred_score),
1198
+ 'Kendall': self.get_kendalltau(true_score, pred_score)
1199
+ })
1200
+
1201
+ def evaluate_persona_chat(self):
1202
+ true_score = []
1203
+ pred_score = []
1204
+ premise = []
1205
+ hypothesis = []
1206
+ for example in self.dataset['persona_chat']:
1207
+ premise.append(example['dialog_history']+example['fact'])
1208
+ hypothesis.append(example['response'])
1209
+ true_score.append(example['engaging'])
1210
+ pred_score = self.align_func(premise, hypothesis)[ALL_TASKS['persona_chat']].tolist()
1211
+
1212
+ self.print_result_table({
1213
+ 'Dataset_name': 'persona_chat_eng',
1214
+ 'Pearson': self.get_pearson(true_score, pred_score),
1215
+ 'Spearman': self.get_spearman(true_score, pred_score),
1216
+ 'Kendall': self.get_kendalltau(true_score, pred_score)
1217
+ })
1218
+
1219
+ true_score = []
1220
+ pred_score = []
1221
+ premise = []
1222
+ hypothesis = []
1223
+ for example in self.dataset['persona_chat']:
1224
+ premise.append(example['fact'])
1225
+ hypothesis.append(example['response'])
1226
+ true_score.append(example['uses_knowledge'])
1227
+ pred_score = self.align_func(premise, hypothesis)[ALL_TASKS['persona_chat']].tolist()
1228
+
1229
+ self.print_result_table({
1230
+ 'Dataset_name': 'persona_chat_grd',
1231
+ 'Pearson': self.get_pearson(true_score, pred_score),
1232
+ 'Spearman': self.get_spearman(true_score, pred_score),
1233
+ 'Kendall': self.get_kendalltau(true_score, pred_score)
1234
+ })
1235
+
1236
+ def evaluate_topical_chat(self):
1237
+ true_score = []
1238
+ pred_score = []
1239
+ premise = []
1240
+ hypothesis = []
1241
+ for example in self.dataset['topical_chat']:
1242
+ premise.append(example['dialog_history']+example['fact'])
1243
+ hypothesis.append(example['response'])
1244
+ true_score.append(example['engaging'])
1245
+ pred_score = self.align_func(premise, hypothesis)[ALL_TASKS['topical_chat']].tolist()
1246
+
1247
+ self.print_result_table({
1248
+ 'Dataset_name': 'topical_chat_eng',
1249
+ 'Pearson': self.get_pearson(true_score, pred_score),
1250
+ 'Spearman': self.get_spearman(true_score, pred_score),
1251
+ 'Kendall': self.get_kendalltau(true_score, pred_score)
1252
+ })
1253
+
1254
+ true_score = []
1255
+ pred_score = []
1256
+ premise = []
1257
+ hypothesis = []
1258
+ for example in self.dataset['topical_chat']:
1259
+ premise.append(example['fact'])
1260
+ hypothesis.append(example['response'])
1261
+ true_score.append(example['uses_knowledge'])
1262
+ pred_score = self.align_func(premise, hypothesis)[ALL_TASKS['topical_chat']].tolist()
1263
+
1264
+ self.print_result_table({
1265
+ 'Dataset_name': 'topical_chat_grd',
1266
+ 'Pearson': self.get_pearson(true_score, pred_score),
1267
+ 'Spearman': self.get_spearman(true_score, pred_score),
1268
+ 'Kendall': self.get_kendalltau(true_score, pred_score)
1269
+ })
1270
+
1271
+ def evaluate_paws_qqp(self):
1272
+ sent1 = []
1273
+ sent2 = []
1274
+ true_score = []
1275
+ for i in range(self.dataset['paws_qqp']['label'].size):
1276
+ sent1.append(self.dataset['paws_qqp']['sentence1'][i][2:-1])
1277
+ sent2.append(self.dataset['paws_qqp']['sentence2'][i][2:-1])
1278
+ true_score.append(self.dataset['paws_qqp']['label'][i])
1279
+
1280
+ pred_score = self.align_func(sent1, sent2)[ALL_TASKS['paws_qqp']].tolist()
1281
+ roc_auc = roc_auc_score(true_score, pred_score)
1282
+
1283
+ self.print_result_table({
1284
+ 'Dataset_name': 'paws_qqp',
1285
+ 'F1': self.get_f1(true_score, pred_score),
1286
+ 'Accuracy': self.get_accuracy(true_score, pred_score),
1287
+ 'AUC': [roc_auc]
1288
+ })
1289
+
1290
+ def evaluate_qqp(self):
1291
+ true_score = []
1292
+ sent1 = []
1293
+ sent2 = []
1294
+ for example in self.dataset['qqp']:
1295
+ sent1.append(example['question1'])
1296
+ sent2.append(example['question2'])
1297
+ true_score.append(example['label'])
1298
+
1299
+ pred_score = self.align_func(sent1, sent2)[ALL_TASKS['qqp']].tolist()
1300
+
1301
+ self.print_result_table({
1302
+ 'Dataset_name': 'qqp',
1303
+ 'F1': self.get_f1(true_score, pred_score),
1304
+ 'Accuracy': self.get_accuracy(true_score, pred_score),
1305
+ 'AUC': [roc_auc_score(true_score, pred_score)]
1306
+ })
1307
+
1308
+ def evaluate_wmt17(self):
1309
+ lang_pair = list(set([each['lang'] for each in self.dataset['wmt17']]))
1310
+
1311
+ for each_lang_pair in lang_pair:
1312
+ true_score = []
1313
+ premise = []
1314
+ hypothesis = []
1315
+ for example in self.dataset['wmt17']:
1316
+ if example['lang'] != each_lang_pair:
1317
+ continue
1318
+ premise.append(example['reference'])
1319
+ hypothesis.append(example['candidate'])
1320
+ true_score.append(example['score'])
1321
+ pred_score = self.align_func(premise, hypothesis)[ALL_TASKS['wmt17']].tolist()
1322
+
1323
+ self.print_result_table({
1324
+ 'Dataset_name': f'wmt17-{each_lang_pair}',
1325
+ 'Pearson': self.get_pearson(true_score, pred_score),
1326
+ 'Spearman': self.get_spearman(true_score, pred_score),
1327
+ 'Kendall': self.get_kendalltau(true_score, pred_score)
1328
+ })
1329
+
1330
+ def evaluate_wmt18(self):
1331
+ lang_pair = list(set([each['lang'] for each in self.dataset['wmt18']]))
1332
+
1333
+ for each_lang_pair in lang_pair:
1334
+ true_score = []
1335
+ premise = []
1336
+ hypothesis = []
1337
+ for example in self.dataset['wmt18']:
1338
+ if example['lang'] != each_lang_pair:
1339
+ continue
1340
+ premise.append(example['reference'])
1341
+ hypothesis.append(example['candidate'])
1342
+ true_score.append(example['score'])
1343
+ pred_score = self.align_func(premise, hypothesis)[ALL_TASKS['wmt18']].tolist()
1344
+
1345
+ self.print_result_table({
1346
+ 'Dataset_name': f'wmt18-{each_lang_pair}',
1347
+ 'Pearson': self.get_pearson(true_score, pred_score),
1348
+ 'Spearman': self.get_spearman(true_score, pred_score),
1349
+ 'Kendall': self.get_kendalltau(true_score, pred_score)
1350
+ })
1351
+ def evaluate_wmt19(self):
1352
+ lang_pair = list(set([each['lang'] for each in self.dataset['wmt19']]))
1353
+
1354
+ for each_lang_pair in lang_pair:
1355
+ true_score = []
1356
+ premise = []
1357
+ hypothesis = []
1358
+ for example in self.dataset['wmt19']:
1359
+ if example['lang'] != each_lang_pair:
1360
+ continue
1361
+ premise.append(example['reference'])
1362
+ hypothesis.append(example['candidate'])
1363
+ true_score.append(example['score'])
1364
+ pred_score = self.align_func(premise, hypothesis)[ALL_TASKS['wmt19']].tolist()
1365
+
1366
+ self.print_result_table({
1367
+ 'Dataset_name': f'wmt19-{each_lang_pair}',
1368
+ 'Pearson': self.get_pearson(true_score, pred_score),
1369
+ 'Spearman': self.get_spearman(true_score, pred_score),
1370
+ 'Kendall': self.get_kendalltau(true_score, pred_score)
1371
+ })
1372
+
1373
+ def evaluate_sst2(self):
1374
+ true_score = []
1375
+ sent1 = []
1376
+ sent2 = []
1377
+ for example in self.dataset['sst2']:
1378
+ sent1.append(example['text'])
1379
+ sent2.append('It was great.')
1380
+ true_score.append(int(example['label_text'] == 'positive'))
1381
+
1382
+ pred_score = self.align_func(sent1, sent2)[ALL_TASKS['sst2']].tolist()
1383
+
1384
+ self.print_result_table({
1385
+ 'Dataset_name': 'sst2',
1386
+ 'F1': self.get_f1(true_score, pred_score),
1387
+ 'Accuracy': self.get_accuracy(true_score, pred_score),
1388
+ 'AUC': roc_auc_score(true_score, pred_score)
1389
+ })
1390
+
1391
+ def evaluate_cr(self):
1392
+ true_score = []
1393
+ sent1 = []
1394
+ sent2 = []
1395
+ for example in self.dataset['cr']:
1396
+ sent1.append(example['text'])
1397
+ sent2.append('It was great.')
1398
+ true_score.append(int(example['label_text'] == 'positive'))
1399
+
1400
+ pred_score = self.align_func(sent1, sent2)[ALL_TASKS['cr']].tolist()
1401
+
1402
+ self.print_result_table({
1403
+ 'Dataset_name': 'cr',
1404
+ 'F1': self.get_f1(true_score, pred_score),
1405
+ 'Accuracy': self.get_accuracy(true_score, pred_score),
1406
+ 'AUC': roc_auc_score(true_score, pred_score)
1407
+ })
1408
+
1409
+ def evaluate_subj(self):
1410
+ true_score = []
1411
+ sent1 = []
1412
+ sent2 = []
1413
+ for example in self.dataset['subj']:
1414
+ sent1.append(example['text'])
1415
+ sent2.append('It was objective.')
1416
+ true_score.append(int(example['label_text'] == 'objective'))
1417
+
1418
+ pred_score = self.align_func(sent1, sent2)[ALL_TASKS['subj']].tolist()
1419
+
1420
+ self.print_result_table({
1421
+ 'Dataset_name': 'subj',
1422
+ 'F1': self.get_f1(true_score, pred_score),
1423
+ 'Accuracy': self.get_accuracy(true_score, pred_score),
1424
+ 'AUC': roc_auc_score(true_score, pred_score)
1425
+ })
1426
+
1427
+ def evaluate_imdb(self):
1428
+ true_score = []
1429
+ sent1 = []
1430
+ sent2 = []
1431
+ for example in self.dataset['imdb']:
1432
+ sent1.append(example['text'])
1433
+ sent2.append('It was great.')
1434
+ true_score.append(int(example['label_text'] == 'positive'))
1435
+
1436
+ pred_score = self.align_func(sent1, sent2)[ALL_TASKS['imdb']].tolist()
1437
+
1438
+ self.print_result_table({
1439
+ 'Dataset_name': 'imdb',
1440
+ 'F1': self.get_f1(true_score, pred_score),
1441
+ 'Accuracy': self.get_accuracy(true_score, pred_score),
1442
+ 'AUC': roc_auc_score(true_score, pred_score)
1443
+ })
1444
+
1445
+ def evaluate_imdb_knn(self):
1446
+ true_score = []
1447
+ sent1 = []
1448
+ sent2 = []
1449
+ for example in self.dataset['imdb']:
1450
+ sent1.append(example['text'])
1451
+ sent2.append('It was great.')
1452
+ true_score.append(int(example['label_text'] == 'positive'))
1453
+
1454
+ pred_score = self.align_func(sent1, sent2)[ALL_TASKS['imdb']].tolist()
1455
+
1456
+ self.print_result_table({
1457
+ 'Dataset_name': 'imdb',
1458
+ 'F1': self.get_f1(true_score, pred_score),
1459
+ 'Accuracy': self.get_accuracy(true_score, pred_score),
1460
+ 'AUC': roc_auc_score(true_score, pred_score)
1461
+ })
1462
+
1463
+ def evaluate_cola(self):
1464
+ true_score = []
1465
+ sent1 = []
1466
+ sent2 = []
1467
+ for example in self.dataset['cola']:
1468
+ sent1.append(example['sentence'])
1469
+ sent2.append('It was correct.')
1470
+ true_score.append(example['label'])
1471
+
1472
+ pred_score = self.align_func(sent1, sent2)[ALL_TASKS['cola']].tolist()
1473
+
1474
+ self.print_result_table({
1475
+ 'Dataset_name': 'cola',
1476
+ 'F1': self.get_f1(true_score, pred_score),
1477
+ 'Accuracy': self.get_accuracy(true_score, pred_score),
1478
+ 'AUC': roc_auc_score(true_score, pred_score)
1479
+ })
1480
+
1481
+ def evaluate_yelp_efl(self):
1482
+ sent = []
1483
+ label = []
1484
+ for example in self.dataset['yelp_efl']:
1485
+ sent.append(example['text'])
1486
+ label.append(example['label'])
1487
+ templates = [
1488
+ 'It was terrible.',
1489
+ 'It was bad.',
1490
+ 'It was ok.',
1491
+ 'It was good.',
1492
+ 'It was great.',
1493
+ ]
1494
+ template_lists = [[template] * len(sent) for template in templates]
1495
+ predictions = [
1496
+ self.align_func(sent, template_list)[ALL_TASKS['yelp_efl']]
1497
+ for template_list in template_lists
1498
+ ]
1499
+
1500
+ pred_label = torch.argmax(torch.stack(predictions), dim=0).tolist()
1501
+
1502
+ self.print_result_table({
1503
+ 'Dataset_name': 'yelp_efl',
1504
+ 'Accuracy': [accuracy_score(label, pred_label)]
1505
+ })
1506
+
1507
+ def evaluate_ag_news(self):
1508
+ sent = []
1509
+ label = []
1510
+ for example in self.dataset['ag_news']:
1511
+ sent.append(example['text'])
1512
+ label.append(example['label'])
1513
+ templates = [
1514
+ 'It is world news.',
1515
+ 'It is sports news.',
1516
+ 'It is business news.',
1517
+ 'It is science news.',
1518
+ ]
1519
+ template_lists = [[template] * len(sent) for template in templates]
1520
+ predictions = [
1521
+ self.align_func(sent, template_list)[ALL_TASKS['ag_news']]
1522
+ for template_list in template_lists
1523
+ ]
1524
+
1525
+ pred_label = torch.argmax(torch.stack(predictions), dim=0).tolist()
1526
+
1527
+ self.print_result_table({
1528
+ 'Dataset_name': 'ag_news',
1529
+ 'Accuracy': [accuracy_score(label, pred_label)]
1530
+ })
1531
+
1532
+ def evaluate_trec(self):
1533
+ sent = []
1534
+ label = []
1535
+ for example in self.dataset['trec']:
1536
+ sent.append(example['text'])
1537
+ label.append(example['label_coarse'])
1538
+ templates = [
1539
+ 'It is description.',
1540
+ 'It is entity.',
1541
+ 'It is expression.',
1542
+ 'It is human.',
1543
+ 'It is number.',
1544
+ 'It is location.',
1545
+ ]
1546
+ template_lists = [[template] * len(sent) for template in templates]
1547
+ predictions = [
1548
+ self.align_func(sent, template_list)[ALL_TASKS['trec']]
1549
+ for template_list in template_lists
1550
+ ]
1551
+
1552
+ pred_label = torch.argmax(torch.stack(predictions), dim=0).tolist()
1553
+
1554
+ self.print_result_table({
1555
+ 'Dataset_name': 'trec',
1556
+ 'Accuracy': [accuracy_score(label, pred_label)]
1557
+ })
1558
+
1559
+ def true_task_helper(self, dataset_name):
1560
+ sent1 = []
1561
+ sent2 = []
1562
+ true_score = []
1563
+ for i in range(len(self.dataset[dataset_name])):
1564
+ context = self.dataset[dataset_name].iloc[i]['grounding']
1565
+ claim = self.dataset[dataset_name].iloc[i]['generated_text']
1566
+ sent1.append(context)
1567
+ sent2.append(self.clean_text(context, claim))
1568
+ true_score.append(self.dataset[dataset_name].iloc[i]['label'])
1569
+
1570
+ pred_score = self.align_func(sent1, sent2)[1].tolist()
1571
+ roc_auc = roc_auc_score(true_score, pred_score)
1572
+
1573
+ self.print_result_table({
1574
+ 'Dataset_name': dataset_name,
1575
+ 'F1': self.get_f1(true_score, pred_score),
1576
+ 'Accuracy': self.get_accuracy(true_score, pred_score),
1577
+ 'AUC': [roc_auc]
1578
+ })
1579
+
1580
+ def evaluate_true_begin(self):
1581
+ dataset_name = 'true_begin'
1582
+ self.true_task_helper(dataset_name)
1583
+
1584
+
1585
+ def evaluate_true_dialfact(self):
1586
+ dataset_name = 'true_dialfact'
1587
+ self.true_task_helper(dataset_name)
1588
+
1589
+ def evaluate_true_fever(self):
1590
+ dataset_name = 'true_fever'
1591
+ self.true_task_helper(dataset_name)
1592
+
1593
+ def evaluate_true_frank(self):
1594
+ dataset_name = 'true_frank'
1595
+ self.true_task_helper(dataset_name)
1596
+
1597
+ def evaluate_true_mnbm(self):
1598
+ dataset_name = 'true_mnbm'
1599
+ self.true_task_helper(dataset_name)
1600
+
1601
+ def evaluate_true_paws(self):
1602
+ dataset_name = 'true_paws'
1603
+ self.true_task_helper(dataset_name)
1604
+
1605
+ def evaluate_true_q2(self):
1606
+ dataset_name = 'true_q2'
1607
+ self.true_task_helper(dataset_name)
1608
+
1609
+ def evaluate_true_qags_cnndm(self):
1610
+ dataset_name = 'true_qags_cnndm'
1611
+ self.true_task_helper(dataset_name)
1612
+
1613
+ def evaluate_true_qags_xsum(self):
1614
+ dataset_name = 'true_qags_xsum'
1615
+ self.true_task_helper(dataset_name)
1616
+
1617
+ def evaluate_true_summeval(self):
1618
+ dataset_name = 'true_summeval'
1619
+ self.true_task_helper(dataset_name)
1620
+
1621
+ def evaluate_true_vitc(self):
1622
+ dataset_name = 'true_vitc'
1623
+ self.true_task_helper(dataset_name)
1624
+
1625
+ def get_summac_thres(self, dataset_name):
1626
+ sent1 = []
1627
+ sent2 = []
1628
+ true_score = []
1629
+ for example in self.summac_validation_set[dataset_name]:
1630
+ sent1.append(example['document'])
1631
+ sent2.append(self.clean_text(example['document'], example['claim'])) #
1632
+ true_score.append(example['label'])
1633
+
1634
+ pred_score = self.align_func(sent1, sent2)[1].tolist()
1635
+
1636
+ thres_result = []
1637
+ for i in range(1001):
1638
+ thres = i / 1000
1639
+ thres_result.append((thres, balanced_accuracy_score(true_score, [p>thres for p in pred_score])))
1640
+
1641
+ best_thres = sorted(thres_result, key=lambda x: x[1], reverse=True)[0]
1642
+ print(f"best thres for {dataset_name} is {best_thres[0]} @ {best_thres[1]}")
1643
+
1644
+ return best_thres[0]
1645
+
1646
+ def summac_task_helper(self, dataset_name):
1647
+ sent1 = []
1648
+ sent2 = []
1649
+ true_score = []
1650
+ for example in self.dataset[dataset_name]:
1651
+ sent1.append(example['document'])
1652
+ sent2.append(self.clean_text(example['document'], example['claim']))
1653
+ true_score.append(example['label'])
1654
+
1655
+ pred_score = self.align_func(sent1, sent2)[1].tolist()
1656
+ roc_auc = roc_auc_score(true_score, pred_score)
1657
+
1658
+ balanced_acc_thres = self.get_summac_thres(dataset_name)
1659
+
1660
+ self.print_result_table({
1661
+ 'Dataset_name': dataset_name,
1662
+ 'F1': self.get_f1(true_score, pred_score),
1663
+ 'Accuracy': self.get_accuracy(true_score, pred_score),
1664
+ 'BalancedAcc': self.get_balanced_accuracy(true_score, pred_score, thres=balanced_acc_thres),
1665
+ 'threshold': balanced_acc_thres,
1666
+ 'AUC': [roc_auc]
1667
+ })
1668
+
1669
+ def evaluate_summac_cogensumm(self):
1670
+ dataset_name = 'summac_cogensumm'
1671
+ self.summac_task_helper(dataset_name)
1672
+
1673
+ def evaluate_summac_xsumfaith(self):
1674
+ dataset_name = 'summac_xsumfaith'
1675
+ self.summac_task_helper(dataset_name)
1676
+
1677
+ def evaluate_summac_polytope(self):
1678
+ dataset_name = 'summac_polytope'
1679
+ self.summac_task_helper(dataset_name)
1680
+
1681
+ def evaluate_summac_factcc(self):
1682
+ dataset_name = 'summac_factcc'
1683
+ self.summac_task_helper(dataset_name)
1684
+
1685
+ def evaluate_summac_summeval(self):
1686
+ dataset_name = 'summac_summeval'
1687
+ self.summac_task_helper(dataset_name)
1688
+
1689
+ def evaluate_summac_frank(self):
1690
+ dataset_name = 'summac_frank'
1691
+ self.summac_task_helper(dataset_name)
1692
+
1693
+ def evaluate_beir(self):
1694
+ from beir import util, LoggingHandler
1695
+ from beir.datasets.data_loader import GenericDataLoader
1696
+ from beir.retrieval.evaluation import EvaluateRetrieval
1697
+ from beir.retrieval.search.lexical import BM25Search as BM25
1698
+ from beir.reranking.models import CrossEncoder
1699
+ from beir.reranking import Rerank
1700
+
1701
+ import pathlib, os
1702
+ import logging
1703
+ import random
1704
+
1705
+ #### Just some code to print debug information to stdout
1706
+ logging.basicConfig(format='%(asctime)s - %(message)s',
1707
+ datefmt='%Y-%m-%d %H:%M:%S',
1708
+ level=logging.INFO,
1709
+ handlers=[LoggingHandler()])
1710
+ #### /print debug information to stdout
1711
+
1712
+ #### Download trec-covid.zip dataset and unzip the dataset
1713
+ for beir_dataset_name in ['msmarco', 'trec-covid', 'nfcorpus', 'nq', 'hotpotqa', 'fiqa',
1714
+ 'arguana', 'webis-touche2020', 'cqadupstack', 'quora',
1715
+ 'dbpedia-entity', 'scidocs', 'fever', 'climate-fever', 'scifact']:
1716
+ # for beir_dataset_name in ['fever']:
1717
+ url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(beir_dataset_name)
1718
+ # out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
1719
+ out_dir = f"./data/eval/beir/{beir_dataset_name}/"
1720
+ data_path = util.download_and_unzip(url, out_dir)
1721
+
1722
+ #### Provide the data path where trec-covid has been downloaded and unzipped to the data loader
1723
+ # data folder would contain these files:
1724
+ # (1) trec-covid/corpus.jsonl (format: jsonlines)
1725
+ # (2) trec-covid/queries.jsonl (format: jsonlines)
1726
+ # (3) trec-covid/qrels/test.tsv (format: tsv ("\t"))
1727
+
1728
+ corpus, queries, qrels = GenericDataLoader(data_path).load(split="test")
1729
+
1730
+ #########################################
1731
+ #### (1) RETRIEVE Top-100 docs using BM25
1732
+ #########################################
1733
+
1734
+ #### Provide parameters for Elasticsearch
1735
+ # print(corpus)
1736
+ hostname = "localhost" #localhost
1737
+ index_name = beir_dataset_name # trec-covid
1738
+ initialize = True # False
1739
+
1740
+ model = BM25(index_name=index_name, hostname=hostname, initialize=initialize)
1741
+ retriever = EvaluateRetrieval(model, k_values=[1,3,5,10,100,1000])
1742
+
1743
+ #### Retrieve dense results (format of results is identical to qrels)
1744
+ results = retriever.retrieve(corpus, queries)
1745
+
1746
+ # Rerank top-100 results using the reranker provided
1747
+ reranker = Rerank(self.align_func)
1748
+ rerank_results = reranker.rerank(corpus, queries, results, top_k=100)
1749
+
1750
+ #### Evaluate your retrieval using NDCG@k, MAP@K ...
1751
+ ndcg, _map, recall, precision = EvaluateRetrieval.evaluate(qrels, rerank_results, retriever.k_values)
1752
+
1753
+ self.print_result_table({
1754
+ 'Dataset_name': beir_dataset_name,
1755
+ 'ndcg': ndcg,
1756
+ 'map': _map,
1757
+ 'recall': recall,
1758
+ 'precision': precision
1759
+ })
1760
+ def evaluate_xxx(self):
1761
+ pass
1762
+
1763
+ class evaluateMultiCheckpoints:
1764
+ def __init__(self, config, device='cuda:0') -> None:
1765
+ sample_checkpoint = {
1766
+ 'backbone': 'roberta-base',
1767
+ 'task_name': 'align-wo-finetune | align-finetune | roberta-finetune-baseline | nli-wo-finetune | nli-finetune',
1768
+ 'path': 'some path',
1769
+ 'result_save_path': 'some path'
1770
+ }
1771
+ self.config = config ## a dictionary
1772
+ self.device = device
1773
+
1774
+ self.tasks = [
1775
+ 'summeval', 'qags_xsum', 'qags_cnndm', 'persona_chat', 'topical_chat',
1776
+ 'mnli_mismatched', 'mnli_matched',
1777
+ 'sick', 'yelp', 'stsb',
1778
+ 'anli_1','anli_2', 'anli_3', 'snli', 'vitaminc',
1779
+ 'mrpc', 'paws', 'sem_eval', 'paws_qqp', 'qqp',
1780
+ 'newsroom', 'rank19', 'bagel', 'race_m', 'race_h'
1781
+ ]
1782
+
1783
+ def experimentForSlide1216(self):
1784
+ for ckpt in self.config:
1785
+ self.evaluateOneCheckpoint(ckpt)
1786
+ def evaluateOneCheckpoint(self, ckpt):
1787
+ model_name = ckpt['path'].split('/')[-1].split('.ckpt')[0]
1788
+ infer = Inferencer(ckpt_path=ckpt['path'],
1789
+ model=ckpt['backbone'], batch_size=32, device=self.device)
1790
+ evaluator = Evaluator(eval_tasks=self.tasks, align_func=infer.inference, save_all_tables=True)
1791
+
1792
+ evaluator.result_save_name = f"{ckpt['result_save_path']}{model_name}"
1793
+ evaluator.evaluate()
alignscore/generate_training_data.py ADDED
@@ -0,0 +1,1519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from logging import error
2
+ from datasets import load_dataset
3
+ import transformers
4
+ from random import sample
5
+ import random
6
+ import torch
7
+ import json
8
+ from tqdm import tqdm
9
+ from nltk.translate.bleu_score import sentence_bleu
10
+ import pandas as pd
11
+ import re
12
+
13
+
14
+ '''
15
+ data format
16
+ {text_a, text_b, label:None or 0_1, }
17
+ '''
18
+ DATASET_HUGGINGFACE = {
19
+ 'cnndm': ['cnn_dailymail', '3.0.0', 'train'],
20
+ 'mnli': ['multi_nli', 'default', 'train'],
21
+ 'squad': ['squad', 'plain_text', 'train'],
22
+ 'squad_v2': ['squad_v2', 'squad_v2', 'train'],
23
+ 'paws': ['paws', 'labeled_final', 'train'],
24
+ 'vitaminc': ['tals/vitaminc', 'v1.0', 'train'],
25
+ 'xsum': ['xsum', 'default', 'train'],
26
+ 'stsb': ['glue', 'stsb', 'train'],
27
+ 'sick': ['sick', 'default', 'train'],
28
+ 'race': ['race', 'all', 'train'],
29
+ 'race_val': ['race', 'all', 'validation'],
30
+ 'anli_r1': ['anli', 'plain_text', 'train_r1'],
31
+ 'anli_r2': ['anli', 'plain_text', 'train_r2'],
32
+ 'anli_r3': ['anli', 'plain_text', 'train_r3'],
33
+ 'snli': ['snli', 'plain_text', 'train'],
34
+ 'wikihow': ['wikihow', 'all', 'train'],
35
+ 'mrpc': ['glue', 'mrpc', 'train'],
36
+ 'msmarco': ['ms_marco', 'v2.1', 'train'],
37
+ 'mrpc_val': ['glue', 'mrpc', 'validation'],
38
+ 'paws_val': ['paws', 'labeled_final', 'validation'],
39
+ 'paws_unlabeled': ['paws', 'unlabeled_final', 'train'],
40
+ 'qqp': ['glue', 'qqp', 'train'],
41
+ 'qqp_val': ['glue', 'qqp', 'validation'],
42
+ 'squad_v2_new': ['squad_v2', 'squad_v2', 'train'],
43
+ 'adversarial_qa': ['adversarial_qa', 'adversarialQA', 'train'],
44
+ 'drop': ['drop', 'train'],
45
+ 'duorc_self': ['duorc', 'SelfRC', 'train'],
46
+ 'duorc_paraphrase': ['duorc', 'ParaphraseRC', 'train'],
47
+ 'quoref': ['quoref', 'train'],
48
+ 'hotpot_qa_distractor': ['hotpot_qa', 'distractor', 'train'],
49
+ 'hotpot_qa_fullwiki': ['hotpot_qa', 'fullwiki', 'train'],
50
+ 'ropes': ['ropes', 'train'],
51
+ 'boolq': ['boolq', 'train'],
52
+ 'eraser_multi_rc': ['eraser_multi_rc', 'train'],
53
+ 'quail': ['quail', 'train'],
54
+ 'sciq': ['sciq', 'train'],
55
+ 'strategy_qa': ['metaeval/strategy-qa', 'train'],
56
+ 'gap': ['gap', 'train'],
57
+ }
58
+
59
+ DATASET_CONFIG = {
60
+ 'cnndm': {'task': 'summarization', 'text_a': 'article', 'text_b': 'highlights', 'label': None, 'huggingface': True},
61
+ 'mnli': {'task': 'nli', 'text_a': 'premise', 'text_b': 'hypothesis', 'label': 'label', 'huggingface': True},
62
+ 'nli_fever': {'task': 'fact_checking', 'text_a': 'context', 'text_b': 'query', 'label': 'label','huggingface': False, 'using_hf_api': False, 'using_pandas': False, 'using_json':True, 'data_path':'data/nli_fever/train_fitems.jsonl' },
63
+ 'doc_nli': {'task': 'bin_nli', 'text_a': 'premise', 'text_b': 'hypothesis', 'label': 'label','huggingface': False, 'using_hf_api': False, 'using_pandas': False, 'using_json':True, 'data_path':'data/DocNLI_dataset/train.json' },
64
+ 'squad': {'task': 'extractive_qa', 'text_a': 'context', 'text_b': ['question', 'answers'], 'label': None, 'huggingface': True},
65
+ 'squad_v2': {'task': 'qa', 'text_a': 'context', 'text_b': ['question', 'answers'], 'label': None, 'huggingface': True},
66
+ 'paws': {'task': 'paraphrase', 'text_a': 'sentence1', 'text_b': 'sentence2', 'label': 'label', 'huggingface': True},
67
+ 'vitaminc': {'task': 'fact_checking', 'text_a': 'evidence', 'text_b': 'claim', 'label': 'label', 'huggingface': True},
68
+ 'xsum': {'task': 'summarization', 'text_a': 'document', 'text_b': 'summary', 'label': None, 'huggingface': True, 'cliff_path': 'data/model_generated_data/cliff_summ/xsum_train.jsonl'},
69
+ 'stsb': {'task': 'sts', 'text_a': 'sentence1', 'text_b': 'sentence2', 'label': 'label', 'huggingface': True},
70
+ 'sick': {'task': 'sts', 'text_a': 'sentence_A', 'text_b': 'sentence_B', 'label': 'relatedness_score', 'huggingface': True},
71
+ 'race': {'task': 'qa', 'text_a': 'article', 'text_b': ['question', 'options'], 'label': 'answer', 'huggingface': True},
72
+ 'race_val': {'task': 'qa', 'text_a': 'article', 'text_b': ['question', 'options'], 'label': 'answer', 'huggingface': True},
73
+ 'anli_r1': {'task': 'nli', 'text_a': 'premise', 'text_b': 'hypothesis', 'label': 'label', 'huggingface': True},
74
+ 'anli_r2': {'task': 'nli', 'text_a': 'premise', 'text_b': 'hypothesis', 'label': 'label', 'huggingface': True},
75
+ 'anli_r3': {'task': 'nli', 'text_a': 'premise', 'text_b': 'hypothesis', 'label': 'label', 'huggingface': True},
76
+ 'snli': {'task': 'nli', 'text_a': 'premise', 'text_b': 'hypothesis', 'label': 'label', 'huggingface': True},
77
+ 'wikihow': {'task': 'summarization', 'text_a': 'text', 'text_b': 'headline', 'label': None, 'huggingface': False, 'using_hf_api': True, 'data_dir': 'data/wikihow_raw'},
78
+ 'mrpc': {'task': 'paraphrase', 'text_a': 'sentence1', 'text_b': 'sentence2', 'label': 'label','huggingface': True},
79
+ 'mrpc_val': {'task': 'paraphrase', 'text_a': 'sentence1', 'text_b': 'sentence2', 'label': 'label','huggingface': True},
80
+ 'paws_val': {'task': 'paraphrase', 'text_a': 'sentence1', 'text_b': 'sentence2', 'label': 'label', 'huggingface': True},
81
+ 'paws_unlabeled': {'task': 'paraphrase', 'text_a': 'sentence1', 'text_b': 'sentence2', 'label': 'label', 'huggingface': True},
82
+ 'msmarco': {'task': 'ir', 'text_a': 'passages', 'text_b': ['query', 'answers'], 'label': None,'huggingface': True},
83
+ 'paws_qqp': {'task': 'paraphrase', 'text_a': 'sentence1', 'text_b': 'sentence2', 'label': None,'huggingface': False, 'using_hf_api': False, 'using_pandas': True, 'data_path':'paws_qqp/output/train.tsv' },
84
+ 'wiki103': {'task': 'paraphrase', 'text_a': 'original_sent', 'text_b': 'paraphrase', 'label': None,'huggingface': False, 'using_hf_api': False, 'using_pandas': False, 'using_json': True, 'data_path':'data/model_generated_data/backtranslation/wiki103_single_sent_backtranslation.json'},
85
+ 'qqp': {'task': 'paraphrase', 'text_a':'question1', 'text_b':'question2', 'label': 'label', 'huggingface': True},
86
+ 'qqp_val': {'task': 'paraphrase', 'text_a':'question1', 'text_b':'question2', 'label': 'label', 'huggingface': True},
87
+ 'wmt17xxx': {'task': 'wmt', 'text_a': 'ref', 'text_b': 'mt', 'label': 'score','huggingface': False, 'using_hf_api': False, 'using_pandas': True, 'data_path':'data/wmt/wmt17/2017-da.csv' },
88
+ 'wmt15': {'task': 'wmt', 'text_a': 'ref', 'text_b': 'mt', 'label': 'score','huggingface': False, 'using_hf_api': False, 'using_pandas': False, 'using_json':True, 'data_path':'data/eval/wmt15_eval.jsonl' },
89
+ 'wmt16': {'task': 'wmt', 'text_a': 'ref', 'text_b': 'mt', 'label': 'score','huggingface': False, 'using_hf_api': False, 'using_pandas': False, 'using_json':True, 'data_path':'data/eval/wmt16_eval.jsonl' },
90
+ 'wmt17': {'task': 'wmt', 'text_a': 'ref', 'text_b': 'mt', 'label': 'score','huggingface': False, 'using_hf_api': False, 'using_pandas': False, 'using_json':True, 'data_path':'data/eval/wmt17_eval.jsonl' },
91
+ 'wmt18': {'task': 'wmt', 'text_a': 'ref', 'text_b': 'mt', 'label': 'score','huggingface': False, 'using_hf_api': False, 'using_pandas': False, 'using_json':True, 'data_path':'data/eval/wmt18_eval.jsonl' },
92
+ 'wmt19': {'task': 'wmt', 'text_a': 'ref', 'text_b': 'mt', 'label': 'score','huggingface': False, 'using_hf_api': False, 'using_pandas': False, 'using_json':True, 'data_path':'data/eval/wmt19_eval.jsonl' },
93
+ 'squad_v2_new': {'task': 'qa', 'huggingface': True},
94
+ 'adversarial_qa': {'task': 'qa', 'huggingface': True},
95
+ 'drop': {'task': 'qa', 'huggingface': True},
96
+ 'duorc_self': {'task': 'qa', 'huggingface': True},
97
+ 'duorc_paraphrase': {'task': 'qa', 'huggingface': True},
98
+ 'quoref': {'task': 'qa', 'huggingface': True},
99
+ 'hotpot_qa_distractor': {'task': 'qa', 'huggingface': True},
100
+ 'hotpot_qa_fullwiki': {'task': 'qa', 'huggingface': True},
101
+ 'newsqa': {'task': 'qa', 'using_json': True, 'raw_json': True, 'data_path': 'data/newsqa_raw/combined-newsqa-data-v1.json'},
102
+ 'ropes': {'task': 'qa', 'huggingface': True},
103
+ 'boolq': {'task': 'qa', 'huggingface': True},
104
+ 'eraser_multi_rc': {'task': 'qa', 'huggingface': True},
105
+ 'quail': {'task': 'qa', 'huggingface': True},
106
+ 'sciq': {'task': 'qa', 'huggingface': True},
107
+ 'strategy_qa': {'task': 'qa', 'huggingface': True},
108
+ 'gap': {'task': 'coreference', 'huggingface': True},
109
+ }
110
+
111
+
112
+ class QA2D():
113
+ def __init__(self, batch_size=32, device='cuda', verbose=True) -> None:
114
+ from transformers import BartTokenizer, BartForConditionalGeneration
115
+ self.tokenizer = BartTokenizer.from_pretrained("MarkS/bart-base-qa2d")
116
+ self.model = BartForConditionalGeneration.from_pretrained("MarkS/bart-base-qa2d").to(device)
117
+ self.batch_size = batch_size
118
+ self.device=device
119
+ self.verbose = verbose
120
+
121
+ def generate(self, questions: list, answers: list):
122
+ assert len(questions) == len(answers)
123
+ qa_list = []
124
+ for q, a in zip(questions, answers):
125
+ qa_list.append(f"question: {q} answer: {a}")
126
+ output = []
127
+ for qa_pairs in tqdm(
128
+ self.chunks(qa_list, self.batch_size),
129
+ desc="QA to Declarative",
130
+ total=int(len(qa_list)/self.batch_size),
131
+ disable=(not self.verbose)
132
+ ):
133
+ input_text = qa_pairs
134
+ input_token = self.tokenizer(
135
+ input_text, return_tensors='pt', padding=True, truncation=True).to(self.device)
136
+ dec_sents = self.model.generate(
137
+ input_token.input_ids, max_length=512)
138
+ result = self.tokenizer.batch_decode(
139
+ dec_sents, skip_special_tokens=True)
140
+ output.extend(result)
141
+
142
+ return output
143
+
144
+ def chunks(self, lst, n):
145
+ """Yield successive n-sized chunks from lst."""
146
+ for i in range(0, len(lst), n):
147
+ yield lst[i:i + n]
148
+
149
+
150
+ class QAnswering():
151
+ """
152
+ To answer not-answerable questions
153
+ """
154
+
155
+ def __init__(self, batch_size=32, device='cuda') -> None:
156
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
157
+ self.tokenizer = T5Tokenizer.from_pretrained(
158
+ "valhalla/t5-base-qa-qg-hl")
159
+ self.model = T5ForConditionalGeneration.from_pretrained(
160
+ "valhalla/t5-base-qa-qg-hl").to(device)
161
+ self.batch_size = batch_size
162
+ self.device = device
163
+
164
+ def generate(self, questions: list, contexts: list):
165
+ assert len(questions) == len(contexts)
166
+ answers = []
167
+ for qs, cs in tqdm(zip(self.chunks(questions, self.batch_size), self.chunks(contexts, self.batch_size)), desc="Generating Answers for not answerable", total=int(len(questions)/self.batch_size)):
168
+ qc_pairs = []
169
+ assert len(qs) == len(cs)
170
+ for one_q, one_c in zip(qs, cs):
171
+ qc_pairs.append(f"""question: {one_q} context: {one_c}""")
172
+ input_ids = self.tokenizer(
173
+ qc_pairs, padding=True, truncation=True, return_tensors='pt').to(self.device).input_ids
174
+ outputs = self.model.generate(input_ids, max_length=512)
175
+ answers.extend(self.tokenizer.batch_decode(
176
+ outputs, skip_special_tokens=True))
177
+
178
+ return answers
179
+
180
+ def chunks(self, lst, n):
181
+ """Yield successive n-sized chunks from lst."""
182
+ for i in range(0, len(lst), n):
183
+ yield lst[i:i + n]
184
+
185
+
186
+ class MLMGeneratorWithPairedData():
187
+ def __init__(self, corpra: list, device='cuda', batch_size=8, mask_percent=0.25) -> None:
188
+ self.device = device
189
+ self.tokenizer = transformers.DistilBertTokenizer.from_pretrained(
190
+ "distilbert-base-uncased")
191
+ self.model = transformers.DistilBertForMaskedLM.from_pretrained(
192
+ "distilbert-base-uncased").to(self.device)
193
+ self.mask_percent = mask_percent
194
+ self.batch_size = batch_size
195
+
196
+ self.dataset = corpra # text needs to be noised
197
+
198
+ def chunks(self, lst, n):
199
+ """Yield successive n-sized chunks from lst."""
200
+ for i in range(0, len(lst), n):
201
+ yield lst[i:i + n]
202
+
203
+ def generate(self):
204
+ sents_output = []
205
+ for examples in tqdm(self.chunks(self.dataset, self.batch_size), total=int(len(self.dataset)/self.batch_size), desc="MLM Generating"):
206
+ sents_to_be_noised = [each for each in examples]
207
+ sents_noised = self.mlm_infiller(sents_to_be_noised)
208
+
209
+ sents_output.extend(sents_noised)
210
+
211
+ return sents_output
212
+
213
+ def mlm_infiller(self, batch):
214
+ """
215
+ input a batch of sentences, list
216
+ """
217
+ masked_batch = []
218
+ masked_batch_ids = []
219
+ for each_sent in batch:
220
+ sent_tokens = self.tokenizer.tokenize(each_sent)
221
+ sent_token_ids = self.tokenizer(each_sent)['input_ids']
222
+ mask_list = sample(list(range(len(sent_tokens))), int(
223
+ self.mask_percent * len(sent_tokens)))
224
+ sent_tokens = [
225
+ each if i not in mask_list else self.tokenizer.mask_token for i, each in enumerate(sent_tokens)]
226
+ masked_batch_ids.append(
227
+ [each if i-1 not in mask_list else self.tokenizer.mask_token_id for i, each in enumerate(sent_token_ids)])
228
+ masked_batch.append(' '.join(sent_tokens))
229
+
230
+ inputs = self.tokenizer(
231
+ masked_batch, padding=True, truncation=True, return_tensors="pt").to(self.device)
232
+ with torch.no_grad():
233
+ logits = self.model(**inputs).logits
234
+
235
+ infill_tokens = []
236
+ for i in range(len(masked_batch)):
237
+ mask_token_index = (inputs.input_ids == self.tokenizer.mask_token_id)[
238
+ i].nonzero(as_tuple=True)[0]
239
+
240
+ predicted_token_id = logits[i, mask_token_index].argmax(axis=-1)
241
+ infill_tokens.append(predicted_token_id)
242
+
243
+ infilled_sent = []
244
+ for masked_sent_ids, infill_token in zip(masked_batch_ids, infill_tokens):
245
+ for infill_one_token in infill_token:
246
+ for i, each_id in enumerate(masked_sent_ids):
247
+ if each_id == self.tokenizer.mask_token_id:
248
+ masked_sent_ids[i] = infill_one_token
249
+ break
250
+ infilled_sent.append(self.tokenizer.decode(
251
+ masked_sent_ids, skip_special_tokens=True))
252
+
253
+ return infilled_sent
254
+
255
+
256
+ class ExtractiveSummarizationGenerator():
257
+ def __init__(self) -> None:
258
+ pass
259
+
260
+ def generate(self, texts):
261
+ '''
262
+ texts: list of string
263
+ '''
264
+ from summa.summarizer import summarize
265
+
266
+ summaries = []
267
+ for text in tqdm(texts, desc="Extracting Summary"):
268
+ for prop in range(1, 20):
269
+ summ = summarize(text, ratio=prop/20.)
270
+ if len(summ) > 0:
271
+ break
272
+ summaries.append(summ)
273
+
274
+ return summaries
275
+
276
+
277
+ class DataGenerator():
278
+ def __init__(self, dataset_names) -> None:
279
+ self.dataset_names = dataset_names
280
+ self.datasets = dict()
281
+ self.t5_qa = None
282
+ self.t5_tokenizer = None
283
+
284
+ self.load_dataset_from_huggingface()
285
+
286
+ def load_dataset_from_huggingface(self):
287
+ for each_dataset in self.dataset_names:
288
+ if DATASET_CONFIG[each_dataset].get('huggingface'):
289
+ self.datasets[each_dataset] = load_dataset(
290
+ *DATASET_HUGGINGFACE[each_dataset][:-1])[DATASET_HUGGINGFACE[each_dataset][-1]]
291
+ elif DATASET_CONFIG[each_dataset].get('using_hf_api'):
292
+ self.datasets[each_dataset] = load_dataset(
293
+ *DATASET_HUGGINGFACE[each_dataset][:-1], data_dir=DATASET_CONFIG[each_dataset]['data_dir'])[DATASET_HUGGINGFACE[each_dataset][-1]]
294
+ elif DATASET_CONFIG[each_dataset].get('using_pandas'):
295
+ if DATASET_CONFIG[each_dataset]['data_path'].split('.')[-1] == 'tsv':
296
+ self.datasets[each_dataset] = pd.read_csv(
297
+ DATASET_CONFIG[each_dataset]['data_path'], sep='\t')
298
+ elif DATASET_CONFIG[each_dataset]['data_path'].split('.')[-1] == 'csv':
299
+ self.datasets[each_dataset] = pd.read_csv(
300
+ DATASET_CONFIG[each_dataset]['data_path'])
301
+ elif DATASET_CONFIG[each_dataset].get('using_json'):
302
+ self.datasets[each_dataset] = []
303
+ if DATASET_CONFIG[each_dataset].get('raw_json'):
304
+ with open(DATASET_CONFIG[each_dataset]['data_path'], 'r', encoding='utf8') as f:
305
+ self.datasets[each_dataset] = json.load(f)
306
+ else:
307
+ try:
308
+ json_file = json.load(
309
+ open(DATASET_CONFIG[each_dataset]['data_path'], 'r', encoding='utf8'))
310
+ for example in json_file:
311
+ self.datasets[each_dataset].append(example)
312
+ except:
313
+ with open(DATASET_CONFIG[each_dataset]['data_path'], 'r', encoding='utf8') as f:
314
+ for example in f:
315
+ self.datasets[each_dataset].append(
316
+ json.loads(example))
317
+ else:
318
+ error('unable to locate raw dataset...')
319
+
320
+ def process_squad(self):
321
+ from rake_nltk import Rake
322
+ r = Rake()
323
+ topk = 5
324
+ threshold = 0.6
325
+
326
+ output = []
327
+ label = -1
328
+ for example in tqdm(self.datasets['squad'], desc=f'Constructing squad'):
329
+ text_a = example[DATASET_CONFIG['squad']['text_a']]
330
+ question = example[DATASET_CONFIG['squad']['text_b'][0]]
331
+ answer = example[DATASET_CONFIG['squad']
332
+ ['text_b'][1]]['text'] # a list
333
+ text_b = [question+' '+answer_ele for answer_ele in answer]
334
+ text_c = []
335
+
336
+ r.extract_keywords_from_text(text_a)
337
+ keywords_in_context = r.get_ranked_phrases()[:topk]
338
+ for each_keyword in keywords_in_context:
339
+ # then it is an incorrect answer
340
+ if sentence_bleu([answer_ele.lower().split() for answer_ele in answer], each_keyword.split(), weights=(0.33, 0.33, 0.33)) < threshold:
341
+ text_c.append(question+' '+each_keyword)
342
+
343
+ output.append({
344
+ 'text_a': text_a,
345
+ 'text_b': text_b,
346
+ 'text_c': text_c,
347
+ 'label': label
348
+ })
349
+
350
+ return output
351
+
352
+ def process_squad_v2(self):
353
+ # first collect answerable items
354
+ not_answerable_contexts = []
355
+ not_answerable_questions = []
356
+ not_answerable_answers = []
357
+
358
+ answerable_contexts = []
359
+ answerable_questions = []
360
+ answerable_answers = []
361
+
362
+ qa_generator = QAnswering(batch_size=32, device='cuda')
363
+ qa2d_generator = QA2D(batch_size=32, device='cuda')
364
+
365
+ for example in tqdm(self.datasets['squad_v2'], desc=f'Collecting (not)answerable examples'):
366
+ if len(example['answers']['text']) == 0:
367
+ not_answerable_contexts.append(example['context'])
368
+ not_answerable_questions.append(example['question'])
369
+ else:
370
+ answerable_contexts.append(example['context'])
371
+ answerable_questions.append(example['question'])
372
+ answerable_answers.append(example['answers']['text'][0])
373
+
374
+ not_answerable_answers = qa_generator.generate(
375
+ not_answerable_questions, not_answerable_contexts)
376
+ answerable_declarative_sents = qa2d_generator.generate(
377
+ answerable_questions, answerable_answers)
378
+ not_answerable_declarative_sents = qa2d_generator.generate(
379
+ not_answerable_questions, not_answerable_answers)
380
+
381
+ output = []
382
+ for i, dec_sent in enumerate(answerable_declarative_sents):
383
+ output.append({
384
+ 'text_a': answerable_contexts[i],
385
+ 'text_b': [dec_sent],
386
+ 'text_c': [],
387
+ 'label': 1
388
+ })
389
+
390
+ for i, dec_sent in enumerate(not_answerable_declarative_sents):
391
+ output.append({
392
+ 'text_a': not_answerable_contexts[i],
393
+ 'text_b': [dec_sent],
394
+ 'text_c': [],
395
+ 'label': 0
396
+ })
397
+
398
+ return output
399
+
400
+ def process_race(self):
401
+ qa2d_generator = QA2D(batch_size=32, device='cuda')
402
+ option_dict = {'A': 0, 'B': 1, 'C': 2, 'D': 3}
403
+ output = []
404
+
405
+ correct_context = []
406
+ correct_question = []
407
+ correct_answer = []
408
+
409
+ wrong_context = []
410
+ wrong_question = []
411
+ wrong_answer = []
412
+
413
+ for example in tqdm(self.datasets['race'], desc=f'Constructing race'):
414
+ text_a = example[DATASET_CONFIG['race']['text_a']]
415
+ label = -1
416
+ question = example[DATASET_CONFIG['race']['text_b'][0]]
417
+ if "_" in question:
418
+ answer_id = option_dict[example[DATASET_CONFIG['race']['label']]]
419
+ for i, options in enumerate(example[DATASET_CONFIG['race']['text_b'][1]]):
420
+ if i == answer_id:
421
+ output.append({
422
+ 'text_a': text_a,
423
+ 'text_b': [' '.join(question.replace("_", " "+options+" ").split())],
424
+ 'text_c': [],
425
+ 'label': 1
426
+ })
427
+ else:
428
+ output.append({
429
+ 'text_a': text_a,
430
+ 'text_b': [' '.join(question.replace("_", " "+options+" ").split())],
431
+ 'text_c': [],
432
+ 'label': 0
433
+ })
434
+ else:
435
+ answer_id = option_dict[example[DATASET_CONFIG['race']['label']]]
436
+ for i, options in enumerate(example[DATASET_CONFIG['race']['text_b'][1]]):
437
+ if i == answer_id:
438
+ output.append({
439
+ 'text_a': text_a,
440
+ 'text_b': [question],
441
+ 'text_c': [options],
442
+ 'label': 1
443
+ })
444
+ else:
445
+ output.append({
446
+ 'text_a': text_a,
447
+ 'text_b': [question],
448
+ 'text_c': [options],
449
+ 'label': 0
450
+ })
451
+
452
+ return output
453
+
454
+ def process_race_val(self):
455
+ qa2d_generator = QA2D(batch_size=32, device='cuda')
456
+ option_dict = {'A': 0, 'B': 1, 'C': 2, 'D': 3}
457
+ output = []
458
+
459
+ correct_context = []
460
+ correct_question = []
461
+ correct_answer = []
462
+
463
+ wrong_context = []
464
+ wrong_question = []
465
+ wrong_answer = []
466
+
467
+ for example in tqdm(self.datasets['race_val'], desc=f'Constructing race_val'):
468
+ text_a = example[DATASET_CONFIG['race_val']['text_a']]
469
+ label = -1
470
+ question = example[DATASET_CONFIG['race_val']['text_b'][0]]
471
+ if "_" in question:
472
+ answer_id = option_dict[example[DATASET_CONFIG['race_val']['label']]]
473
+ for i, options in enumerate(example[DATASET_CONFIG['race_val']['text_b'][1]]):
474
+ if i == answer_id:
475
+ output.append({
476
+ 'text_a': text_a,
477
+ 'text_b': [' '.join(question.replace("_", " "+options+" ").split())],
478
+ 'text_c': [],
479
+ 'label': 1
480
+ })
481
+ else:
482
+ output.append({
483
+ 'text_a': text_a,
484
+ 'text_b': [' '.join(question.replace("_", " "+options+" ").split())],
485
+ 'text_c': [],
486
+ 'label': 0
487
+ })
488
+ else:
489
+ answer_id = option_dict[example[DATASET_CONFIG['race_val']['label']]]
490
+ for i, options in enumerate(example[DATASET_CONFIG['race_val']['text_b'][1]]):
491
+ if i == answer_id:
492
+ correct_context.append(text_a)
493
+ correct_question.append(question)
494
+ correct_answer.append(options)
495
+ else:
496
+ wrong_context.append(text_a)
497
+ wrong_question.append(question)
498
+ wrong_answer.append(options)
499
+
500
+ correct_declarative = qa2d_generator.generate(
501
+ correct_question, correct_answer)
502
+ wrong_declarative = qa2d_generator.generate(
503
+ wrong_question, wrong_answer)
504
+ assert len(correct_context) == len(correct_declarative)
505
+ assert len(wrong_context) == len(wrong_declarative)
506
+ for context, dec in zip(correct_context, correct_declarative):
507
+ output.append({
508
+ 'text_a': context,
509
+ 'text_b': [dec],
510
+ 'text_c': [],
511
+ 'label': 1
512
+ })
513
+
514
+ for context, dec in zip(wrong_context, wrong_declarative):
515
+ output.append({
516
+ 'text_a': context,
517
+ 'text_b': [dec],
518
+ 'text_c': [],
519
+ 'label': 0
520
+ })
521
+
522
+ return output
523
+
524
+ def process_race_test(self):
525
+ option_dict = {'A': 0, 'B': 1, 'C': 2, 'D': 3}
526
+ output = []
527
+ for example in tqdm(self.datasets['race_test'], desc=f'Constructing race_test'):
528
+ text_a = example[DATASET_CONFIG['race_test']['text_a']]
529
+ text_b = [] # pos
530
+ text_c = [] # neg
531
+ label = -1
532
+ question = example[DATASET_CONFIG['race_test']['text_b'][0]]
533
+ if "_" in question:
534
+ answer_id = option_dict[example[DATASET_CONFIG['race_test']['label']]]
535
+ for i, options in enumerate(example[DATASET_CONFIG['race_test']['text_b'][1]]):
536
+ if i == answer_id:
537
+ text_b.append(' '.join(question.replace(
538
+ "_", " "+options+" ").split()))
539
+ else:
540
+ text_c.append(' '.join(question.replace(
541
+ "_", " "+options+" ").split()))
542
+ else:
543
+ answer_id = option_dict[example[DATASET_CONFIG['race_test']['label']]]
544
+ for i, options in enumerate(example[DATASET_CONFIG['race_test']['text_b'][1]]):
545
+ if i == answer_id:
546
+ text_b.append(question+" "+options+" ")
547
+ else:
548
+ text_c.append(question+" "+options+" ")
549
+
550
+ output.append({
551
+ 'text_a': text_a,
552
+ 'text_b': text_b,
553
+ 'text_c': text_c,
554
+ 'label': label
555
+ })
556
+
557
+ return output
558
+
559
+ def process_xsum(self):
560
+ '''
561
+ text_a: raw_text
562
+ text_b: raw_summary + ***extractive summ*** removed
563
+ text_c: cliff xsum + DistillBERT from raw_text_b + ***DistillBERT from extractive summ text_b***
564
+ '''
565
+ output = []
566
+
567
+ gold_summary = [example[DATASET_CONFIG['xsum']['text_b']]
568
+ for example in self.datasets['xsum']]
569
+ ext_summarizer = ExtractiveSummarizationGenerator()
570
+ extracted_summ = ext_summarizer.generate(
571
+ [example[DATASET_CONFIG['xsum']['text_a']] for example in self.datasets['xsum']])
572
+
573
+ mlm_hallucinator = MLMGeneratorWithPairedData(
574
+ corpra=gold_summary, device='cuda:0', batch_size=64, mask_percent=0.25)
575
+ gold_summary_hallucinated = mlm_hallucinator.generate()
576
+
577
+ mlm_hallucinator = MLMGeneratorWithPairedData(
578
+ corpra=extracted_summ, device='cuda:0', batch_size=64, mask_percent=0.25)
579
+ extracted_summ_hallucinated = mlm_hallucinator.generate()
580
+
581
+ assert len(self.datasets['xsum']) == len(gold_summary_hallucinated) and len(
582
+ self.datasets['xsum']) == len(extracted_summ_hallucinated)
583
+
584
+ for i, example in tqdm(enumerate(self.datasets['xsum']), desc="Constructing xsum", total=len(self.datasets['xsum'])):
585
+ text_a = example[DATASET_CONFIG['xsum']['text_a']]
586
+ text_b = [gold_summary[i], extracted_summ[i]]
587
+ text_c = [gold_summary_hallucinated[i],
588
+ extracted_summ_hallucinated[i]]
589
+ label = -1
590
+
591
+ output.append({
592
+ 'text_a': text_a,
593
+ 'text_b': text_b,
594
+ 'text_c': text_c,
595
+ 'label': label
596
+ })
597
+
598
+ return output
599
+
600
+ def process_cnndm(self):
601
+ '''
602
+ text_a: raw_text
603
+ text_b: raw_summary + ***extractive summ*** removed
604
+ text_c: DistillBERT from raw_text_b + ***DistillBERT from extractive summ text_b***
605
+ '''
606
+ # interpretation of fairseq-generate output: https://github.com/facebookresearch/fairseq/issues/3000
607
+ output = []
608
+
609
+ gold_summary = [example[DATASET_CONFIG['cnndm']['text_b']]
610
+ for example in self.datasets['cnndm']]
611
+ ext_summarizer = ExtractiveSummarizationGenerator()
612
+ extracted_summ = ext_summarizer.generate(
613
+ [example[DATASET_CONFIG['cnndm']['text_a']] for example in self.datasets['cnndm']])
614
+
615
+ mlm_hallucinator = MLMGeneratorWithPairedData(
616
+ corpra=gold_summary, device='cuda:0', batch_size=64, mask_percent=0.25)
617
+ gold_summary_hallucinated = mlm_hallucinator.generate()
618
+
619
+ mlm_hallucinator = MLMGeneratorWithPairedData(
620
+ corpra=extracted_summ, device='cuda:0', batch_size=64, mask_percent=0.25)
621
+ extracted_summ_hallucinated = mlm_hallucinator.generate()
622
+
623
+ assert len(self.datasets['cnndm']) == len(gold_summary_hallucinated) and len(
624
+ self.datasets['cnndm']) == len(extracted_summ_hallucinated)
625
+
626
+ for i, example in tqdm(enumerate(self.datasets['cnndm']), desc="Constructing cnndm", total=len(self.datasets['cnndm'])):
627
+ text_a = example[DATASET_CONFIG['cnndm']['text_a']]
628
+ text_b = [gold_summary[i], extracted_summ[i]]
629
+ text_c = [gold_summary_hallucinated[i],
630
+ extracted_summ_hallucinated[i]]
631
+ label = -1
632
+
633
+ output.append({
634
+ 'text_a': text_a,
635
+ 'text_b': text_b,
636
+ 'text_c': text_c,
637
+ 'label': label
638
+ })
639
+
640
+ return output
641
+
642
+ def process_wikihow(self):
643
+ '''
644
+ text_a: raw_text
645
+ text_b: raw_summary + ***extractive summ*** removed
646
+ text_c: DistillBERT from raw_text_b + ***DistillBERT from extractive summ text_b***
647
+ '''
648
+ # interpretation of fairseq-generate output: https://github.com/facebookresearch/fairseq/issues/3000
649
+ output = []
650
+
651
+ gold_summary = [example[DATASET_CONFIG['wikihow']['text_b']]
652
+ for example in self.datasets['wikihow']]
653
+ ext_summarizer = ExtractiveSummarizationGenerator()
654
+ extracted_summ = ext_summarizer.generate(
655
+ [example[DATASET_CONFIG['wikihow']['text_a']] for example in self.datasets['wikihow']])
656
+
657
+ mlm_hallucinator = MLMGeneratorWithPairedData(
658
+ corpra=gold_summary, device='cuda:0', batch_size=64, mask_percent=0.25)
659
+ gold_summary_hallucinated = mlm_hallucinator.generate()
660
+
661
+ mlm_hallucinator = MLMGeneratorWithPairedData(
662
+ corpra=extracted_summ, device='cuda:0', batch_size=64, mask_percent=0.25)
663
+ extracted_summ_hallucinated = mlm_hallucinator.generate()
664
+
665
+ assert len(self.datasets['wikihow']) == len(gold_summary_hallucinated) and len(
666
+ self.datasets['wikihow']) == len(extracted_summ_hallucinated)
667
+
668
+ for i, example in tqdm(enumerate(self.datasets['wikihow']), desc="Constructing wikihow", total=len(self.datasets['wikihow'])):
669
+ text_a = example[DATASET_CONFIG['wikihow']['text_a']]
670
+ text_b = [gold_summary[i], extracted_summ[i]]
671
+ text_c = [gold_summary_hallucinated[i],
672
+ extracted_summ_hallucinated[i]]
673
+ label = -1
674
+
675
+ output.append({
676
+ 'text_a': text_a,
677
+ 'text_b': text_b,
678
+ 'text_c': text_c,
679
+ 'label': label
680
+ })
681
+
682
+ return output
683
+
684
+ def process_wiki103(self):
685
+ output = []
686
+
687
+ paraphrases = [example[DATASET_CONFIG['wiki103']['text_b']]
688
+ for example in self.datasets['wiki103']]
689
+ mlm_hallucinator = MLMGeneratorWithPairedData(
690
+ corpra=paraphrases, device='cuda:3', batch_size=64, mask_percent=0.25)
691
+ paraphrase_hallucinated = mlm_hallucinator.generate()
692
+
693
+ assert len(self.datasets['wiki103']) == len(paraphrase_hallucinated)
694
+
695
+ for i, example in tqdm(enumerate(self.datasets['wiki103']), desc=f'Constructing wiki103'):
696
+ output.append({
697
+ 'text_a': example[DATASET_CONFIG['wiki103']['text_a']],
698
+ 'text_b': [example[DATASET_CONFIG['wiki103']['text_b']]],
699
+ 'text_c': [],
700
+ 'label': 1
701
+ })
702
+ output.append({
703
+ 'text_a': example[DATASET_CONFIG['wiki103']['text_a']],
704
+ 'text_b': [paraphrase_hallucinated[i]],
705
+ 'text_c': [],
706
+ 'label': 0
707
+ })
708
+
709
+ return output
710
+
711
+ def process_mnli(self):
712
+ output = []
713
+ for example in tqdm(self.datasets['mnli'], desc=f'Constructing mnli'):
714
+ text_a = example[DATASET_CONFIG['mnli']['text_a']]
715
+ text_b = [example[DATASET_CONFIG['mnli']['text_b']]]
716
+ text_c = []
717
+ label = example[DATASET_CONFIG['mnli']['label']]
718
+
719
+ output.append({
720
+ 'text_a': text_a,
721
+ 'text_b': text_b,
722
+ 'text_c': text_c,
723
+ 'label': label
724
+ })
725
+
726
+ return output
727
+
728
+ def process_nli_fever(self):
729
+ output = []
730
+ for example in tqdm(self.datasets['nli_fever'], desc=f'Constructing nli_fever'):
731
+ text_a = example[DATASET_CONFIG['nli_fever']['text_a']]
732
+ text_b = [example[DATASET_CONFIG['nli_fever']['text_b']]]
733
+ text_c = []
734
+ raw_label = example[DATASET_CONFIG['nli_fever']['label']]
735
+ if raw_label == 'SUPPORTS': # convert to nli style label
736
+ label = 0
737
+ elif raw_label == 'REFUTES':
738
+ label = 2
739
+ else:
740
+ label = 1
741
+
742
+ output.append({
743
+ 'text_a': text_a,
744
+ 'text_b': text_b,
745
+ 'text_c': text_c,
746
+ 'label': label
747
+ })
748
+
749
+ return output
750
+
751
+ def process_doc_nli(self):
752
+ output = []
753
+ for example in tqdm(self.datasets['doc_nli'], desc=f'Constructing doc_nli'):
754
+ text_a = example[DATASET_CONFIG['doc_nli']['text_a']]
755
+ text_b = [example[DATASET_CONFIG['doc_nli']['text_b']]]
756
+ text_c = []
757
+ raw_label = example[DATASET_CONFIG['doc_nli']['label']]
758
+ if raw_label == 'entailment': # convert to paraphrase style label
759
+ label = 1
760
+ else:
761
+ label = 0
762
+
763
+ output.append({
764
+ 'text_a': text_a,
765
+ 'text_b': text_b,
766
+ 'text_c': text_c,
767
+ 'label': label
768
+ })
769
+
770
+ return output
771
+
772
+ def process_anli_r1(self):
773
+ output = []
774
+ for example in tqdm(self.datasets['anli_r1'], desc=f'Constructing anli_r1'):
775
+ text_a = example[DATASET_CONFIG['anli_r1']['text_a']]
776
+ text_b = [example[DATASET_CONFIG['anli_r1']['text_b']]]
777
+ text_c = []
778
+ label = example[DATASET_CONFIG['anli_r1']['label']]
779
+
780
+ output.append({
781
+ 'text_a': text_a,
782
+ 'text_b': text_b,
783
+ 'text_c': text_c,
784
+ 'label': label
785
+ })
786
+
787
+ return output
788
+
789
+ def process_anli_r2(self):
790
+ output = []
791
+ for example in tqdm(self.datasets['anli_r2'], desc=f'Constructing anli_r2'):
792
+ text_a = example[DATASET_CONFIG['anli_r2']['text_a']]
793
+ text_b = [example[DATASET_CONFIG['anli_r2']['text_b']]]
794
+ text_c = []
795
+ label = example[DATASET_CONFIG['anli_r2']['label']]
796
+
797
+ output.append({
798
+ 'text_a': text_a,
799
+ 'text_b': text_b,
800
+ 'text_c': text_c,
801
+ 'label': label
802
+ })
803
+
804
+ return output
805
+
806
+ def process_anli_r3(self):
807
+ output = []
808
+ for example in tqdm(self.datasets['anli_r3'], desc=f'Constructing anli_r3'):
809
+ text_a = example[DATASET_CONFIG['anli_r3']['text_a']]
810
+ text_b = [example[DATASET_CONFIG['anli_r3']['text_b']]]
811
+ text_c = []
812
+ label = example[DATASET_CONFIG['anli_r3']['label']]
813
+
814
+ output.append({
815
+ 'text_a': text_a,
816
+ 'text_b': text_b,
817
+ 'text_c': text_c,
818
+ 'label': label
819
+ })
820
+
821
+ return output
822
+
823
+ def process_snli(self):
824
+ output = []
825
+ for example in tqdm(self.datasets['snli'], desc=f'Constructing snli'):
826
+ text_a = example[DATASET_CONFIG['snli']['text_a']]
827
+ text_b = [example[DATASET_CONFIG['snli']['text_b']]]
828
+ text_c = []
829
+ label = example[DATASET_CONFIG['snli']['label']]
830
+
831
+ output.append({
832
+ 'text_a': text_a,
833
+ 'text_b': text_b,
834
+ 'text_c': text_c,
835
+ 'label': label
836
+ })
837
+
838
+ return output
839
+
840
+ def process_paws(self):
841
+ output = []
842
+ for example in tqdm(self.datasets['paws'], desc=f'Constructing paws'):
843
+ text_a = example[DATASET_CONFIG['paws']['text_a']]
844
+ text_b = [example[DATASET_CONFIG['paws']['text_b']]]
845
+ text_c = []
846
+ label = example[DATASET_CONFIG['paws']['label']]
847
+
848
+ output.append({
849
+ 'text_a': text_a,
850
+ 'text_b': text_b,
851
+ 'text_c': text_c,
852
+ 'label': label
853
+ })
854
+
855
+ return output
856
+
857
+ def process_vitaminc(self):
858
+ output = []
859
+ for example in tqdm(self.datasets['vitaminc'], desc=f'Constructing vitaminc'):
860
+ text_a = example[DATASET_CONFIG['vitaminc']['text_a']]
861
+ text_b = [example[DATASET_CONFIG['vitaminc']['text_b']]]
862
+ text_c = []
863
+ raw_label = example[DATASET_CONFIG['vitaminc']['label']]
864
+ if raw_label == 'SUPPORTS': # convert to nli style label
865
+ label = 0
866
+ elif raw_label == 'REFUTES':
867
+ label = 2
868
+ else:
869
+ label = 1
870
+
871
+ output.append({
872
+ 'text_a': text_a,
873
+ 'text_b': text_b,
874
+ 'text_c': text_c,
875
+ 'label': label
876
+ })
877
+
878
+ return output
879
+
880
+ def process_stsb(self):
881
+ output = []
882
+ for example in tqdm(self.datasets['stsb'], desc=f'Constructing stsb'):
883
+ text_a = example[DATASET_CONFIG['stsb']['text_a']]
884
+ text_b = [example[DATASET_CONFIG['stsb']['text_b']]]
885
+ text_c = []
886
+ label = example[DATASET_CONFIG['stsb']['label']] / 5.0
887
+
888
+ output.append({
889
+ 'text_a': text_a,
890
+ 'text_b': text_b,
891
+ 'text_c': text_c,
892
+ 'label': label
893
+ })
894
+
895
+ return output
896
+
897
+ def process_sick(self):
898
+ output = []
899
+ for example in tqdm(self.datasets['sick'], desc=f'Constructing sick'):
900
+ text_a = example[DATASET_CONFIG['sick']['text_a']]
901
+ text_b = [example[DATASET_CONFIG['sick']['text_b']]]
902
+ text_c = []
903
+ label = example[DATASET_CONFIG['sick']['label']] / 5.0
904
+
905
+ output.append({
906
+ 'text_a': text_a,
907
+ 'text_b': text_b,
908
+ 'text_c': text_c,
909
+ 'label': label
910
+ })
911
+
912
+ return output
913
+
914
+ def process_mrpc(self):
915
+ output = []
916
+ for example in tqdm(self.datasets['mrpc'], desc=f'Constructing mrpc'):
917
+ text_a = example[DATASET_CONFIG['mrpc']['text_a']]
918
+ text_b = [example[DATASET_CONFIG['mrpc']['text_b']]]
919
+ text_c = []
920
+ label = example[DATASET_CONFIG['mrpc']['label']]
921
+
922
+ output.append({
923
+ 'text_a': text_a,
924
+ 'text_b': text_b,
925
+ 'text_c': text_c,
926
+ 'label': label
927
+ })
928
+
929
+ return output
930
+
931
+ def process_mrpc_val(self):
932
+ output = []
933
+ for example in tqdm(self.datasets['mrpc_val'], desc=f'Constructing mrpc_val'):
934
+ text_a = example[DATASET_CONFIG['mrpc_val']['text_a']]
935
+ text_b = [example[DATASET_CONFIG['mrpc_val']['text_b']]]
936
+ text_c = []
937
+ label = example[DATASET_CONFIG['mrpc_val']['label']]
938
+
939
+ output.append({
940
+ 'text_a': text_a,
941
+ 'text_b': text_b,
942
+ 'text_c': text_c,
943
+ 'label': label
944
+ })
945
+
946
+ return output
947
+
948
+ def process_paws_val(self):
949
+ output = []
950
+ for example in tqdm(self.datasets['paws_val'], desc=f'Constructing paws_val'):
951
+ text_a = example[DATASET_CONFIG['paws_val']['text_a']]
952
+ text_b = [example[DATASET_CONFIG['paws_val']['text_b']]]
953
+ text_c = []
954
+ label = example[DATASET_CONFIG['paws_val']['label']]
955
+
956
+ output.append({
957
+ 'text_a': text_a,
958
+ 'text_b': text_b,
959
+ 'text_c': text_c,
960
+ 'label': label
961
+ })
962
+
963
+ return output
964
+
965
+ def process_paws_unlabeled(self):
966
+ output = []
967
+ for example in tqdm(self.datasets['paws_unlabeled'], desc=f'Constructing paws_unlabeled'):
968
+ text_a = example[DATASET_CONFIG['paws_unlabeled']['text_a']]
969
+ text_b = [example[DATASET_CONFIG['paws_unlabeled']['text_b']]]
970
+ text_c = []
971
+ label = example[DATASET_CONFIG['paws_unlabeled']['label']]
972
+
973
+ output.append({
974
+ 'text_a': text_a,
975
+ 'text_b': text_b,
976
+ 'text_c': text_c,
977
+ 'label': label
978
+ })
979
+
980
+ return output
981
+
982
+ def process_qqp(self):
983
+ output = []
984
+ for example in tqdm(self.datasets['qqp'], desc=f'Constructing qqp'):
985
+ text_a = example[DATASET_CONFIG['qqp']['text_a']]
986
+ text_b = [example[DATASET_CONFIG['qqp']['text_b']]]
987
+ text_c = []
988
+ label = example[DATASET_CONFIG['qqp']['label']]
989
+
990
+ output.append({
991
+ 'text_a': text_a,
992
+ 'text_b': text_b,
993
+ 'text_c': text_c,
994
+ 'label': label
995
+ })
996
+
997
+ return output
998
+
999
+ def process_qqp_val(self):
1000
+ output = []
1001
+ for example in tqdm(self.datasets['qqp_val'], desc=f'Constructing qqp_val'):
1002
+ text_a = example[DATASET_CONFIG['qqp_val']['text_a']]
1003
+ text_b = [example[DATASET_CONFIG['qqp_val']['text_b']]]
1004
+ text_c = []
1005
+ label = example[DATASET_CONFIG['qqp_val']['label']]
1006
+
1007
+ output.append({
1008
+ 'text_a': text_a,
1009
+ 'text_b': text_b,
1010
+ 'text_c': text_c,
1011
+ 'label': label
1012
+ })
1013
+
1014
+ return output
1015
+
1016
+ def process_msmarco(self):
1017
+ qa2d_generator = QA2D(batch_size=32, device='cuda')
1018
+ output = []
1019
+ correct_contexts = []
1020
+ correct_questions = []
1021
+ correct_answers = []
1022
+
1023
+ wrong_contexts = []
1024
+ wrong_questions = []
1025
+ wrong_answers = []
1026
+
1027
+ filtered_examples = []
1028
+ questions = []
1029
+ answers = []
1030
+ declaratives = []
1031
+
1032
+ for example in tqdm(self.datasets['msmarco'], desc=f'Collecting msmarco'):
1033
+ if sum(example['passages']['is_selected']) > 0: # has answer
1034
+ questions.append(example['query'])
1035
+ answers.append(example['answers'][0] if len(
1036
+ example['wellFormedAnswers']) == 0 else example['wellFormedAnswers'][0])
1037
+ filtered_examples.append(example)
1038
+
1039
+ for example in filtered_examples:
1040
+ for i, is_selected in enumerate(example['passages']['is_selected']):
1041
+ if is_selected == 1:
1042
+ output.append({
1043
+ 'text_a': example['passages']['passage_text'][i],
1044
+ 'text_b': [example['query']],
1045
+ 'text_c': [],
1046
+ 'label': 1
1047
+ }
1048
+ )
1049
+ else:
1050
+ output.append({
1051
+ 'text_a': example['passages']['passage_text'][i],
1052
+ 'text_b': [example['query']],
1053
+ 'text_c': [],
1054
+ 'label': 0
1055
+ }
1056
+ )
1057
+ return output
1058
+
1059
+ def process_paws_qqp(self):
1060
+ output = []
1061
+
1062
+ for i in range(len(self.datasets['paws_qqp'])):
1063
+ text_a = self.datasets['paws_qqp'].iloc[i]['sentence1'][2:-1]
1064
+ text_b = [self.datasets['paws_qqp'].iloc[i]['sentence2'][2:-1]]
1065
+ text_c = []
1066
+ label = self.datasets['paws_qqp'].iloc[i]['label']
1067
+
1068
+ output.append({
1069
+ 'text_a': text_a,
1070
+ 'text_b': text_b,
1071
+ 'text_c': text_c,
1072
+ 'label': int(label)
1073
+ })
1074
+
1075
+ return output
1076
+
1077
+ def process_wmt15(self):
1078
+ output = []
1079
+
1080
+ for example in self.datasets['wmt15']:
1081
+ text_a = example['reference']
1082
+ text_b = [example['candidate']]
1083
+ text_c = []
1084
+ label = example['score']
1085
+
1086
+ output.append({
1087
+ 'text_a': text_a,
1088
+ 'text_b': text_b,
1089
+ 'text_c': text_c,
1090
+ 'label': label
1091
+ })
1092
+
1093
+ return output
1094
+
1095
+ def process_wmt16(self):
1096
+ output = []
1097
+
1098
+ for example in self.datasets['wmt16']:
1099
+ text_a = example['reference']
1100
+ text_b = [example['candidate']]
1101
+ text_c = []
1102
+ label = example['score']
1103
+
1104
+ output.append({
1105
+ 'text_a': text_a,
1106
+ 'text_b': text_b,
1107
+ 'text_c': text_c,
1108
+ 'label': label
1109
+ })
1110
+
1111
+ return output
1112
+
1113
+ def process_wmt17(self):
1114
+
1115
+ output = []
1116
+
1117
+ for example in self.datasets['wmt17']:
1118
+ text_a = example['reference']
1119
+ text_b = [example['candidate']]
1120
+ text_c = []
1121
+ label = example['score']
1122
+
1123
+ output.append({
1124
+ 'text_a': text_a,
1125
+ 'text_b': text_b,
1126
+ 'text_c': text_c,
1127
+ 'label': label
1128
+ })
1129
+
1130
+ return output
1131
+
1132
+ def process_wmt18(self):
1133
+ output = []
1134
+
1135
+ for example in self.datasets['wmt18']:
1136
+ text_a = example['reference']
1137
+ text_b = [example['candidate']]
1138
+ text_c = []
1139
+ label = example['score']
1140
+
1141
+ output.append({
1142
+ 'text_a': text_a,
1143
+ 'text_b': text_b,
1144
+ 'text_c': text_c,
1145
+ 'label': label
1146
+ })
1147
+
1148
+ return output
1149
+
1150
+ def process_wmt19(self):
1151
+ output = []
1152
+
1153
+ for example in self.datasets['wmt19']:
1154
+ text_a = example['reference']
1155
+ text_b = [example['candidate']]
1156
+ text_c = []
1157
+ label = example['score']
1158
+
1159
+ output.append({
1160
+ 'text_a': text_a,
1161
+ 'text_b': text_b,
1162
+ 'text_c': text_c,
1163
+ 'label': label
1164
+ })
1165
+
1166
+ return output
1167
+
1168
+ def process_boolq(self):
1169
+ output = []
1170
+
1171
+ for example in self.datasets['boolq']:
1172
+ text_a = example['passage']
1173
+ text_b = [example['question']]
1174
+ text_c = ["Yes." if example['answer'] else "No."]
1175
+ label = 1
1176
+
1177
+ output.append({
1178
+ 'text_a': text_a,
1179
+ 'text_b': text_b,
1180
+ 'text_c': text_c,
1181
+ 'label': label
1182
+ })
1183
+
1184
+ text_a = example['passage']
1185
+ text_b = [example['question']]
1186
+ text_c = ["Yes." if not example['answer'] else "No."]
1187
+ label = 0
1188
+
1189
+ output.append({
1190
+ 'text_a': text_a,
1191
+ 'text_b': text_b,
1192
+ 'text_c': text_c,
1193
+ 'label': label
1194
+ })
1195
+
1196
+ return output
1197
+
1198
+ def process_eraser_multi_rc(self):
1199
+ output = []
1200
+
1201
+ for example in self.datasets['eraser_multi_rc']:
1202
+ text_a = example['passage']
1203
+ text_b = [example['query_and_answer'].replace("|", "")]
1204
+ text_c = []
1205
+ label = int(example['label'])
1206
+
1207
+ output.append({
1208
+ 'text_a': text_a,
1209
+ 'text_b': text_b,
1210
+ 'text_c': text_c,
1211
+ 'label': label
1212
+ })
1213
+
1214
+ return output
1215
+
1216
+ def process_quail(self):
1217
+ output = []
1218
+
1219
+ for example in self.datasets['quail']:
1220
+ for i, ans in enumerate(example['answers']):
1221
+ text_a = example['context']
1222
+ text_b = [example['question']]
1223
+ text_c = [ans]
1224
+ label = 1 if i == example['correct_answer_id'] else 0
1225
+
1226
+ output.append({
1227
+ 'text_a': text_a,
1228
+ 'text_b': text_b,
1229
+ 'text_c': text_c,
1230
+ 'label': label
1231
+ })
1232
+
1233
+ return output
1234
+
1235
+ def process_sciq(self):
1236
+ output = []
1237
+
1238
+ for example in self.datasets['sciq']:
1239
+ text_a = example['support']
1240
+
1241
+ output.append({
1242
+ 'text_a': text_a,
1243
+ 'text_b': [example['question']],
1244
+ 'text_c': [example['distractor1']],
1245
+ 'label': 0
1246
+ })
1247
+ output.append({
1248
+ 'text_a': text_a,
1249
+ 'text_b': [example['question']],
1250
+ 'text_c': [example['distractor2']],
1251
+ 'label': 0
1252
+ })
1253
+ output.append({
1254
+ 'text_a': text_a,
1255
+ 'text_b': [example['question']],
1256
+ 'text_c': [example['distractor3']],
1257
+ 'label': 0
1258
+ })
1259
+ output.append({
1260
+ 'text_a': text_a,
1261
+ 'text_b': [example['question']],
1262
+ 'text_c': [example['correct_answer']],
1263
+ 'label': 1
1264
+ })
1265
+
1266
+ return output
1267
+
1268
+ def process_strategy_qa(self):
1269
+ output = []
1270
+
1271
+ for example in self.datasets['strategy_qa']:
1272
+ text_a = ' '.join(example['facts'])
1273
+ text_b = [example['question']]
1274
+ text_c = ["Yes." if example['answer'] else "No."]
1275
+ label = 1
1276
+
1277
+ output.append({
1278
+ 'text_a': text_a,
1279
+ 'text_b': text_b,
1280
+ 'text_c': text_c,
1281
+ 'label': label
1282
+ })
1283
+
1284
+ text_a = ' '.join(example['facts'])
1285
+ text_b = [example['question']]
1286
+ text_c = ["Yes." if not example['answer'] else "No."]
1287
+ label = 0
1288
+
1289
+ output.append({
1290
+ 'text_a': text_a,
1291
+ 'text_b': text_b,
1292
+ 'text_c': text_c,
1293
+ 'label': label
1294
+ })
1295
+
1296
+ return output
1297
+
1298
+ def process_gap(self):
1299
+ output = []
1300
+
1301
+ for example in self.datasets['gap']:
1302
+ text_a = example['Text']
1303
+ text_b = [example['Text'][:example['Pronoun-offset']]+example['A']+example['Text'][(example['Pronoun-offset']+len(example['Pronoun'])):]]
1304
+ text_c = []
1305
+ label = 1 if example['A-coref'] else 0
1306
+
1307
+ output.append({
1308
+ 'text_a': text_a,
1309
+ 'text_b': text_b,
1310
+ 'text_c': text_c,
1311
+ 'label': label
1312
+ })
1313
+
1314
+ text_a = example['Text']
1315
+ text_b = [example['Text'][:example['Pronoun-offset']]+example['B']+example['Text'][(example['Pronoun-offset']+len(example['Pronoun'])):]]
1316
+ text_c = []
1317
+ label = 1 if example['B-coref'] else 0
1318
+
1319
+ output.append({
1320
+ 'text_a': text_a,
1321
+ 'text_b': text_b,
1322
+ 'text_c': text_c,
1323
+ 'label': label
1324
+ })
1325
+
1326
+ return output
1327
+
1328
+ def init_qa_t5(self):
1329
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
1330
+ if self.t5_qa is None:
1331
+ self.t5_tokenizer = T5Tokenizer.from_pretrained(
1332
+ "t5-base", model_max_length=800)
1333
+ self.t5_qa = T5ForConditionalGeneration.from_pretrained("t5-base")
1334
+ self.t5_qa.to('cuda:1')
1335
+ self.t5_qa.eval()
1336
+
1337
+ @staticmethod
1338
+ def mask_answer(context, answers):
1339
+ answers = sorted(answers, key=len, reverse=True)
1340
+ for answer in answers:
1341
+ pattern = f'(?<![\w\\-\u2013]){re.escape(answer)}(?![\w\\-\u2013])'
1342
+ context = re.sub(pattern, '', context, flags=re.IGNORECASE)
1343
+ return context
1344
+
1345
+ def generate_fake_answer(self, context, question, answers):
1346
+ self.init_qa_t5()
1347
+
1348
+ context_no_answer = self.mask_answer(context, answers)
1349
+
1350
+ input_ids = self.t5_tokenizer(
1351
+ f'question: {question} context: {context_no_answer}',
1352
+ return_tensors="pt",
1353
+ truncation='only_first'
1354
+ ).input_ids.to(self.t5_qa.device)
1355
+
1356
+ outputs = self.t5_qa.generate(
1357
+ input_ids,
1358
+ max_new_tokens=40,
1359
+ remove_invalid_values=True
1360
+ )
1361
+
1362
+ return self.t5_tokenizer.decode(outputs[0], skip_special_tokens=True)
1363
+
1364
+ def negative_sample_qa(self, samples, negative_sample_no_ans_only=True):
1365
+ outputs = []
1366
+ for context, question, answers in samples:
1367
+ if answers:
1368
+ outputs.append({
1369
+ 'text_a': context,
1370
+ 'text_b': [question],
1371
+ 'text_c': answers,
1372
+ 'label': 1
1373
+ })
1374
+ if not answers or not negative_sample_no_ans_only:
1375
+ fake_answer = self.generate_fake_answer(
1376
+ context, question, answers)
1377
+ outputs.append({
1378
+ 'text_a': context,
1379
+ 'text_b': [question],
1380
+ 'text_c': [fake_answer],
1381
+ 'label': 0
1382
+ })
1383
+
1384
+ return outputs
1385
+
1386
+ def process_squad_v2_new(self):
1387
+ samples = (
1388
+ (sample['context'], sample['question'], sample['answers']['text'])
1389
+ for sample in tqdm(self.datasets['squad_v2_new'], desc=f'squad_v2_new')
1390
+ )
1391
+ return self.negative_sample_qa(samples)
1392
+
1393
+ def process_adversarial_qa(self):
1394
+ samples = (
1395
+ (sample['context'], sample['question'], sample['answers']['text'])
1396
+ for sample in tqdm(self.datasets['adversarial_qa'], desc=f'adversarial_qa')
1397
+ )
1398
+ return self.negative_sample_qa(samples, negative_sample_no_ans_only=False)
1399
+
1400
+ def process_drop(self):
1401
+ samples = (
1402
+ (sample['passage'], sample['question'],
1403
+ sample['answers_spans']['spans'])
1404
+ for sample in tqdm(self.datasets['drop'], desc=f'drop')
1405
+ )
1406
+ return self.negative_sample_qa(samples, negative_sample_no_ans_only=False)
1407
+
1408
+ def process_duorc_self(self):
1409
+ samples = (
1410
+ (sample['plot'], sample['question'],
1411
+ sample['answers'])
1412
+ for sample in tqdm(self.datasets['duorc_self'], desc=f'duorc_self')
1413
+ )
1414
+ return self.negative_sample_qa(samples, negative_sample_no_ans_only=False)
1415
+
1416
+ def process_duorc_paraphrase(self):
1417
+ samples = (
1418
+ (sample['plot'], sample['question'],
1419
+ sample['answers'])
1420
+ for sample in tqdm(self.datasets['duorc_paraphrase'], desc=f'duorc_paraphrase')
1421
+ )
1422
+ return self.negative_sample_qa(samples, negative_sample_no_ans_only=False)
1423
+
1424
+ def process_quoref(self):
1425
+ samples = (
1426
+ (sample['context'], sample['question'], sample['answers']['text'])
1427
+ for sample in tqdm(self.datasets['quoref'], desc=f'quoref')
1428
+ )
1429
+ return self.negative_sample_qa(samples, negative_sample_no_ans_only=False)
1430
+
1431
+ @staticmethod
1432
+ def prepare_hotpot_qa_samples(dateset):
1433
+ for sample in dateset:
1434
+ question = sample['question']
1435
+ answer = sample['answer']
1436
+ supporting_docs = set(sample['supporting_facts']['title'])
1437
+ irrelevant_docs = []
1438
+ context_paragraphs = []
1439
+ for title, setences in zip(sample['context']['title'], sample['context']['sentences']):
1440
+ doc = ''.join(setences)
1441
+ if title in supporting_docs:
1442
+ context_paragraphs.append(doc)
1443
+ else:
1444
+ irrelevant_docs.append(doc)
1445
+ # Add some irrelevant documents
1446
+ if irrelevant_docs and len(context_paragraphs) < 4:
1447
+ context_paragraphs.append(random.choice(irrelevant_docs))
1448
+ random.shuffle(context_paragraphs)
1449
+ yield '\n'.join(context_paragraphs), question, [answer]
1450
+
1451
+ def process_hotpot_qa_distractor(self):
1452
+ samples = self.prepare_hotpot_qa_samples(
1453
+ tqdm(self.datasets['hotpot_qa_distractor'],
1454
+ desc=f'hotpot_qa_distractor')
1455
+ )
1456
+ return self.negative_sample_qa(samples, negative_sample_no_ans_only=False)
1457
+
1458
+ def process_hotpot_qa_fullwiki(self):
1459
+ samples = self.prepare_hotpot_qa_samples(
1460
+ tqdm(self.datasets['hotpot_qa_fullwiki'],
1461
+ desc=f'hotpot_qa_fullwiki')
1462
+ )
1463
+ return self.negative_sample_qa(samples, negative_sample_no_ans_only=False)
1464
+
1465
+ def process_newsqa(self):
1466
+ def get_samples(dataset):
1467
+ for story in tqdm(dataset['data'], desc='newsqa'):
1468
+ if story['type'] != 'train':
1469
+ continue
1470
+ context = story['text']
1471
+ for question in story['questions']:
1472
+ if question.get('isQuestionBad', 0.) > 0.2:
1473
+ continue
1474
+ answers = []
1475
+ if 's' in question['consensus']:
1476
+ start = question['consensus']['s']
1477
+ end = question['consensus']['e']
1478
+ answers.append(context[start:end].strip())
1479
+ yield context, question['q'], answers
1480
+ samples = get_samples(self.datasets['newsqa'])
1481
+ return self.negative_sample_qa(samples, negative_sample_no_ans_only=False)
1482
+
1483
+ def process_ropes(self):
1484
+ samples = (
1485
+ (
1486
+ sample['situation'] + ' ' + sample['background'],
1487
+ sample['question'], sample['answers']['text']
1488
+ )
1489
+ for sample in tqdm(self.datasets['ropes'], desc=f'ropes')
1490
+ )
1491
+ return self.negative_sample_qa(samples, negative_sample_no_ans_only=False)
1492
+
1493
+ def generate(self):
1494
+ for each_dataset in self.datasets:
1495
+ with open(f'./data/training/{each_dataset}.json', 'w', encoding='utf8') as outfile:
1496
+ outfile.write("")
1497
+ for each_dataset in self.datasets:
1498
+ outputs = eval(f'self.process_{each_dataset}()')
1499
+
1500
+ for each_output in outputs:
1501
+ dict_write_to_file = {
1502
+ 'task': DATASET_CONFIG[each_dataset]['task'],
1503
+ 'text_a': each_output['text_a'], # string
1504
+ # list of positive examples
1505
+ 'text_b': each_output['text_b'],
1506
+ # list of negative examples
1507
+ 'text_c': each_output['text_c'],
1508
+ # original label, if -1 only has positive pairs and negative pairs
1509
+ 'orig_label': each_output['label']
1510
+ }
1511
+ with open(f'./data/training/{each_dataset}.json', 'a', encoding='utf8') as outfile:
1512
+ json.dump(dict_write_to_file, outfile, ensure_ascii=False)
1513
+ outfile.write('\n')
1514
+
1515
+
1516
+ if __name__ == "__main__":
1517
+ random.seed(42)
1518
+ gen = DataGenerator(list(DATASET_CONFIG.keys()))
1519
+ gen.generate()
alignscore/pyproject.toml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "alignscore"
7
+ version = "0.1.3"
8
+ authors = [
9
+ { name = "Yuheng Zha", email = "[email protected]" },
10
+ { name = "Yichi Yang", email = "[email protected]" },
11
+ { name = "Ruichen Li", email = "[email protected]" },
12
+ { name = "Zhiting Hu", email = "[email protected]" },
13
+ ]
14
+ description = "An automatic factual consistency evaluation metric based on a unifined alignment function"
15
+ readme = "README.md"
16
+ requires-python = ">=3.8"
17
+ classifiers = [
18
+ "Programming Language :: Python :: 3",
19
+ "License :: OSI Approved :: MIT License",
20
+ "Operating System :: OS Independent",
21
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
22
+ ]
23
+ dependencies = [
24
+ "spacy>=3.4.0,<4",
25
+ "nltk>=3.7,<4",
26
+ "torch>=1.12.1,<2",
27
+ "transformers>=4.20.1,<5",
28
+ "tqdm>=4.64.0,<5",
29
+ "jsonlines>=2.0.0,<3",
30
+ "numpy>=1.23.1,<2",
31
+ "datasets>=2.3.2,<3",
32
+ "scikit-learn>=1.1.2,<2",
33
+ "pytorch_lightning>=1.7.7,<2",
34
+ "scipy>=1.8.1,<2",
35
+ "tensorboard>=2.12.0,<3",
36
+ "protobuf<=3.20"
37
+ ]
38
+
39
+ [project.urls]
40
+ "Homepage" = "https://github.com/yuh-zha/AlignScore"
41
+ "Bug Tracker" = "https://github.com/yuh-zha/AlignScore/issues"
alignscore/requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ alignscore>=0.1
2
+ ctc_score==0.1.3
3
+ BLEURT @ git+https://github.com/google-research/bleurt@cebe7e6f996b40910cfaa520a63db47807e3bf5c
4
+ bert_score==0.3.11
5
+ rake_nltk==1.0.6
6
+ summa==1.2.0
7
+ benepar==0.2.0
8
+ summac==0.0.3
9
+ tabulate>=0.9.0,<1
alignscore/src/alignscore/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .alignscore import AlignScore
alignscore/src/alignscore/alignscore.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .inference import Inferencer
2
+ from typing import List
3
+
4
+ class AlignScore:
5
+ def __init__(self, model: str, batch_size: int, device: int, ckpt_path: str, evaluation_mode='nli_sp', verbose=True) -> None:
6
+ self.model = Inferencer(
7
+ ckpt_path=ckpt_path,
8
+ model=model,
9
+ batch_size=batch_size,
10
+ device=device,
11
+ verbose=verbose
12
+ )
13
+ self.model.nlg_eval_mode = evaluation_mode
14
+
15
+ def score(self, contexts: List[str], claims: List[str]) -> List[float]:
16
+ return self.model.nlg_eval(contexts, claims)[1].tolist()
alignscore/src/alignscore/dataloader.py ADDED
@@ -0,0 +1,610 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import random
4
+ from typing import Optional, Sized
5
+ import numpy as np
6
+
7
+ import torch
8
+ from pytorch_lightning import LightningDataModule
9
+ from torch.utils.data import DataLoader
10
+ from tqdm import tqdm
11
+ from transformers import (
12
+ AutoConfig,
13
+ AutoTokenizer,
14
+ )
15
+ from torch.utils.data import Dataset, Sampler
16
+ import os
17
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
18
+
19
+ class DSTDataSet(Dataset):
20
+ def __init__(self, dataset, model_name='bert-base-uncased', need_mlm=True, tokenizer_max_length=512) -> None:
21
+ super().__init__()
22
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
23
+ self.tokenizer_max_length = tokenizer_max_length
24
+ self.config = AutoConfig.from_pretrained(model_name)
25
+ self.dataset_type_dict = dict()
26
+
27
+ self.dataset = dataset
28
+
29
+ self.need_mlm = need_mlm
30
+
31
+ self.dataset_type_dict_init()
32
+
33
+ def dataset_type_dict_init(self):
34
+ for i, example in enumerate(self.dataset):
35
+ try:
36
+ self.dataset_type_dict[example['task']].append(i)
37
+ except:
38
+ self.dataset_type_dict[example['task']] = [i]
39
+ def random_word(self, tokens):
40
+ """
41
+ Masking some random tokens for Language Model task with probabilities as in the original BERT paper.
42
+ :param tokens: list of str, tokenized sentence.
43
+ :param tokenizer: Tokenizer, object used for tokenization (we need it's vocab here)
44
+ :return: (list of str, list of int), masked tokens and related labels for LM prediction
45
+ """
46
+ if not self.need_mlm: # disable masked language modeling
47
+ return tokens, [-100] * len(tokens)
48
+
49
+ output_label = []
50
+
51
+ for i, token in enumerate(tokens):
52
+ if token == self.tokenizer.pad_token_id:
53
+ output_label.append(-100) # PAD tokens ignore
54
+ continue
55
+ prob = random.random()
56
+ # mask token with 15% probability
57
+ if prob < 0.15:
58
+ prob /= 0.15
59
+
60
+ # 80% randomly change token to mask token
61
+ if prob < 0.8:
62
+ tokens[i] = self.tokenizer.mask_token_id
63
+
64
+ # 10% randomly change token to random token
65
+ elif prob < 0.9:
66
+ tokens[i] = random.choice(list(range(self.tokenizer.vocab_size)))
67
+
68
+ # -> rest 10% randomly keep current token
69
+
70
+ # append current token to output (we will predict these later)
71
+ output_label.append(token)
72
+ else:
73
+ # no masking token (will be ignored by loss function later)
74
+ output_label.append(-100)
75
+
76
+ return tokens, output_label
77
+
78
+ def process_nli(self, index):
79
+ text_a = self.dataset[index]['text_a']
80
+ text_b = self.dataset[index]['text_b'][0]
81
+ tri_label = self.dataset[index]['orig_label'] if self.dataset[index]['orig_label'] != -1 else 1
82
+
83
+ rand_self_align = random.random()
84
+ if rand_self_align > 0.95: ### random self alignment
85
+ text_b = self.dataset[index]['text_a']
86
+ tri_label = 0
87
+ elif self.dataset[index]['orig_label'] == 2 and random.random() > 0.95:
88
+ text_a = self.dataset[index]['text_b'][0]
89
+ text_b = self.dataset[index]['text_a']
90
+
91
+
92
+ try:
93
+ tokenized_pair = self.tokenizer(text_a, text_b, padding='max_length', max_length=self.tokenizer_max_length, truncation='only_first')
94
+ except:
95
+ logging.warning('text_b too long...')
96
+ tokenized_pair = self.tokenizer(text_a, text_b, padding='max_length', max_length=self.tokenizer_max_length, truncation=True)
97
+ input_ids, mlm_labels = self.random_word(tokenized_pair['input_ids'])
98
+ return (
99
+ torch.tensor(input_ids),
100
+ torch.tensor(tokenized_pair['attention_mask']),
101
+ torch.tensor(tokenized_pair['token_type_ids']) if 'token_type_ids' in tokenized_pair.keys() else None,
102
+ torch.tensor(-100), # align label, 2 class
103
+ torch.tensor(mlm_labels), # mlm label
104
+ torch.tensor(tri_label), # tri label, 3 class
105
+ torch.tensor(-100.0) # reg label, float
106
+ )
107
+
108
+ def process_paraphrase(self, index):
109
+ text_a = self.dataset[index]['text_a']
110
+ text_b = self.dataset[index]['text_b'][0]
111
+ label = self.dataset[index]['orig_label']
112
+
113
+ rand_self_align = random.random()
114
+ if rand_self_align > 0.95: ### random self alignment
115
+ text_b = self.dataset[index]['text_a']
116
+ label = 1
117
+ elif random.random() > 0.95:
118
+ text_a = self.dataset[index]['text_b'][0]
119
+ text_b = self.dataset[index]['text_a']
120
+
121
+ try:
122
+ tokenized_pair = self.tokenizer(text_a, text_b, padding='max_length', max_length=self.tokenizer_max_length, truncation='only_first')
123
+ except:
124
+ logging.warning('text_b too long...')
125
+ tokenized_pair = self.tokenizer(text_a, text_b, padding='max_length', max_length=self.tokenizer_max_length, truncation=True)
126
+ input_ids, mlm_labels = self.random_word(tokenized_pair['input_ids'])
127
+ return (
128
+ torch.tensor(input_ids),
129
+ torch.tensor(tokenized_pair['attention_mask']),
130
+ torch.tensor(tokenized_pair['token_type_ids']) if 'token_type_ids' in tokenized_pair.keys() else None,
131
+ torch.tensor(label), # align label, 2 class
132
+ torch.tensor(mlm_labels), # mlm label
133
+ torch.tensor(-100), # tri label, 3 class
134
+ torch.tensor(-100.0) # reg label, float
135
+ )
136
+
137
+ def process_qa(self, index):
138
+ text_a = self.dataset[index]['text_a']
139
+ if len(self.dataset[index]['text_c']) > 0:
140
+ text_b = self.dataset[index]['text_b'][0] + ' ' + self.dataset[index]['text_c'][0]
141
+ else:
142
+ text_b = self.dataset[index]['text_b'][0]
143
+ label = self.dataset[index]['orig_label']
144
+
145
+ try:
146
+ tokenized_pair = self.tokenizer(text_a, text_b, padding='max_length', max_length=self.tokenizer_max_length, truncation='only_first')
147
+ except:
148
+ logging.warning('text_b too long...')
149
+ tokenized_pair = self.tokenizer(text_a, text_b, padding='max_length', max_length=self.tokenizer_max_length, truncation=True)
150
+ input_ids, mlm_labels = self.random_word(tokenized_pair['input_ids'])
151
+ return (
152
+ torch.tensor(input_ids),
153
+ torch.tensor(tokenized_pair['attention_mask']),
154
+ torch.tensor(tokenized_pair['token_type_ids']) if 'token_type_ids' in tokenized_pair.keys() else None,
155
+ torch.tensor(label), # align label, 2 class
156
+ torch.tensor(mlm_labels), # mlm label
157
+ torch.tensor(-100), # tri label, 3 class
158
+ torch.tensor(-100.0) # reg label, float
159
+ )
160
+
161
+ def process_coreference(self, index):
162
+ text_a = self.dataset[index]['text_a']
163
+ if len(self.dataset[index]['text_c']) > 0:
164
+ text_b = self.dataset[index]['text_b'][0] + ' ' + self.dataset[index]['text_c'][0]
165
+ else:
166
+ text_b = self.dataset[index]['text_b'][0]
167
+ label = self.dataset[index]['orig_label']
168
+
169
+ try:
170
+ tokenized_pair = self.tokenizer(text_a, text_b, padding='max_length', max_length=self.tokenizer_max_length, truncation='only_first')
171
+ except:
172
+ logging.warning('text_b too long...')
173
+ tokenized_pair = self.tokenizer(text_a, text_b, padding='max_length', max_length=self.tokenizer_max_length, truncation=True)
174
+ input_ids, mlm_labels = self.random_word(tokenized_pair['input_ids'])
175
+ return (
176
+ torch.tensor(input_ids),
177
+ torch.tensor(tokenized_pair['attention_mask']),
178
+ torch.tensor(tokenized_pair['token_type_ids']) if 'token_type_ids' in tokenized_pair.keys() else None,
179
+ torch.tensor(label), # align label, 2 class
180
+ torch.tensor(mlm_labels), # mlm label
181
+ torch.tensor(-100), # tri label, 3 class
182
+ torch.tensor(-100.0) # reg label, float
183
+ )
184
+
185
+ def process_bin_nli(self, index):
186
+ text_a = self.dataset[index]['text_a']
187
+ text_b = self.dataset[index]['text_b'][0]
188
+ label = self.dataset[index]['orig_label']
189
+
190
+ try:
191
+ tokenized_pair = self.tokenizer(text_a, text_b, padding='max_length', max_length=self.tokenizer_max_length, truncation='only_first')
192
+ except:
193
+ logging.warning('text_b too long...')
194
+ tokenized_pair = self.tokenizer(text_a, text_b, padding='max_length', max_length=self.tokenizer_max_length, truncation=True)
195
+ input_ids, mlm_labels = self.random_word(tokenized_pair['input_ids'])
196
+ return (
197
+ torch.tensor(input_ids),
198
+ torch.tensor(tokenized_pair['attention_mask']),
199
+ torch.tensor(tokenized_pair['token_type_ids']) if 'token_type_ids' in tokenized_pair.keys() else None,
200
+ torch.tensor(label), # align label, 2 class
201
+ torch.tensor(mlm_labels), # mlm label
202
+ torch.tensor(-100), # tri label, 3 class
203
+ torch.tensor(-100.0) # reg label, float
204
+ )
205
+
206
+ def process_fact_checking(self, index):
207
+ text_a = self.dataset[index]['text_a']
208
+ text_b = self.dataset[index]['text_b'][0]
209
+ tri_label = self.dataset[index]['orig_label'] if self.dataset[index]['orig_label'] != -1 else 1
210
+
211
+ rand_self_align = random.random()
212
+ if rand_self_align > 0.95: ### random self alignment
213
+ text_b = self.dataset[index]['text_a']
214
+ tri_label = 0
215
+ elif self.dataset[index]['orig_label'] == 2 and random.random() > 0.95:
216
+ text_a = self.dataset[index]['text_b'][0]
217
+ text_b = self.dataset[index]['text_a']
218
+
219
+ try:
220
+ tokenized_pair = self.tokenizer(text_a, text_b, padding='max_length', max_length=self.tokenizer_max_length, truncation='only_first')
221
+ except:
222
+ logging.warning('text_b too long...')
223
+ tokenized_pair = self.tokenizer(text_a, text_b, padding='max_length', max_length=self.tokenizer_max_length, truncation=True)
224
+ input_ids, mlm_labels = self.random_word(tokenized_pair['input_ids'])
225
+ return (
226
+ torch.tensor(input_ids),
227
+ torch.tensor(tokenized_pair['attention_mask']),
228
+ torch.tensor(tokenized_pair['token_type_ids']) if 'token_type_ids' in tokenized_pair.keys() else None,
229
+ torch.tensor(-100), # align label, 2 class
230
+ torch.tensor(mlm_labels), # mlm label
231
+ torch.tensor(tri_label), # tri label, 3 class
232
+ torch.tensor(-100.0) # reg label, float
233
+ )
234
+
235
+ def process_summarization(self, index):
236
+ text_a = self.dataset[index]['text_a']
237
+ if random.random() > 0.5: # this will be a positive pair
238
+ random_pos_sample_id = random.randint(0, len(self.dataset[index]['text_b'])-1)
239
+ text_b = self.dataset[index]['text_b'][random_pos_sample_id]
240
+ label = 1
241
+ else: # this will be a negative pair
242
+ label = 0
243
+ if len(self.dataset[index]['text_c']) > 0:
244
+ random_neg_sample_id = random.randint(0, len(self.dataset[index]['text_c'])-1)
245
+ text_b = self.dataset[index]['text_c'][random_neg_sample_id]
246
+ else:
247
+ random_choose_from_entire_dataset_text_b = random.choice(self.dataset_type_dict['summarization'])
248
+ text_b = self.dataset[random_choose_from_entire_dataset_text_b]['text_b'][0]
249
+
250
+ try:
251
+ tokenized_pair = self.tokenizer(text_a, text_b, padding='max_length', max_length=self.tokenizer_max_length, truncation='only_first')
252
+ except:
253
+ logging.warning('text_b too long...')
254
+ tokenized_pair = self.tokenizer(text_a, text_b, padding='max_length', max_length=self.tokenizer_max_length, truncation=True)
255
+ input_ids, mlm_labels = self.random_word(tokenized_pair['input_ids'])
256
+
257
+ return (
258
+ torch.tensor(input_ids),
259
+ torch.tensor(tokenized_pair['attention_mask']),
260
+ torch.tensor(tokenized_pair['token_type_ids']) if 'token_type_ids' in tokenized_pair.keys() else None,
261
+ torch.tensor(label), # align label, 2 class
262
+ torch.tensor(mlm_labels), # mlm label
263
+ torch.tensor(-100), # tri label, 3 class
264
+ torch.tensor(-100.0) # reg label, float
265
+ )
266
+
267
+ def process_multiple_choice_qa(self, index):
268
+ text_a = self.dataset[index]['text_a']
269
+ if random.random() > 0.5: # this will be a positive pair
270
+ text_b = self.dataset[index]['text_b'][0]
271
+ label = 1
272
+ else: # this will be a negative pair
273
+ label = 0
274
+ if len(self.dataset[index]['text_c']) > 0:
275
+ random_neg_sample_id = random.randint(0, len(self.dataset[index]['text_c'])-1)
276
+ text_b = self.dataset[index]['text_c'][random_neg_sample_id]
277
+ else:
278
+ random_choose_from_entire_dataset_text_b = random.choice(self.dataset_type_dict['multiple_choice_qa'])
279
+ text_b = self.dataset[random_choose_from_entire_dataset_text_b]['text_b'][0]
280
+
281
+ try:
282
+ tokenized_pair = self.tokenizer(text_a, text_b, padding='max_length', max_length=self.tokenizer_max_length, truncation='only_first')
283
+ except:
284
+ logging.warning('text_b too long...')
285
+ tokenized_pair = self.tokenizer(text_a, text_b, padding='max_length', max_length=self.tokenizer_max_length, truncation=True)
286
+ input_ids, mlm_labels = self.random_word(tokenized_pair['input_ids'])
287
+
288
+ return (
289
+ torch.tensor(input_ids),
290
+ torch.tensor(tokenized_pair['attention_mask']),
291
+ torch.tensor(tokenized_pair['token_type_ids']) if 'token_type_ids' in tokenized_pair.keys() else None,
292
+ torch.tensor(label), # align label, 2 class
293
+ torch.tensor(mlm_labels), # mlm label
294
+ torch.tensor(-100), # tri label, 3 class
295
+ torch.tensor(-100.0) # reg label, float
296
+ )
297
+
298
+ def process_extractive_qa(self, index):
299
+ text_a = self.dataset[index]['text_a']
300
+ if random.random() > 0.5: # this will be a positive pair
301
+ random_pos_sample_id = random.randint(0, len(self.dataset[index]['text_b'])-1)
302
+ text_b = self.dataset[index]['text_b'][random_pos_sample_id]
303
+ label = 1
304
+ else: # this will be a negative pair
305
+ label = 0
306
+ if len(self.dataset[index]['text_c']) > 0:
307
+ random_neg_sample_id = random.randint(0, len(self.dataset[index]['text_c'])-1)
308
+ text_b = self.dataset[index]['text_c'][random_neg_sample_id]
309
+ else:
310
+ random_choose_from_entire_dataset_text_b = random.choice(self.dataset_type_dict['extractive_qa'])
311
+ text_b = self.dataset[random_choose_from_entire_dataset_text_b]['text_b'][0]
312
+
313
+ try:
314
+ tokenized_pair = self.tokenizer(text_a, text_b, padding='max_length', max_length=self.tokenizer_max_length, truncation='only_first')
315
+ except:
316
+ logging.warning('text_b too long...')
317
+ tokenized_pair = self.tokenizer(text_a, text_b, padding='max_length', max_length=self.tokenizer_max_length, truncation=True)
318
+ input_ids, mlm_labels = self.random_word(tokenized_pair['input_ids'])
319
+
320
+ return (
321
+ torch.tensor(input_ids),
322
+ torch.tensor(tokenized_pair['attention_mask']),
323
+ torch.tensor(tokenized_pair['token_type_ids']) if 'token_type_ids' in tokenized_pair.keys() else None,
324
+ torch.tensor(label), # align label, 2 class
325
+ torch.tensor(mlm_labels), # mlm label
326
+ torch.tensor(-100), # tri label, 3 class
327
+ torch.tensor(-100.0) # reg label, float
328
+ )
329
+
330
+ def process_ir(self, index):
331
+ text_a = self.dataset[index]['text_a']
332
+ text_b = self.dataset[index]['text_b'][random.randint(0, len(self.dataset[index]['text_b'])-1)]
333
+ label = self.dataset[index]['orig_label']
334
+
335
+ try:
336
+ tokenized_pair = self.tokenizer(text_a, text_b, padding='max_length', max_length=self.tokenizer_max_length, truncation='only_first')
337
+ except:
338
+ logging.warning('text_b too long...')
339
+ tokenized_pair = self.tokenizer(text_a, text_b, padding='max_length', max_length=self.tokenizer_max_length, truncation=True)
340
+ input_ids, mlm_labels = self.random_word(tokenized_pair['input_ids'])
341
+
342
+ return (
343
+ torch.tensor(input_ids),
344
+ torch.tensor(tokenized_pair['attention_mask']),
345
+ torch.tensor(tokenized_pair['token_type_ids']) if 'token_type_ids' in tokenized_pair.keys() else None,
346
+ torch.tensor(label), # align label, 2 class
347
+ torch.tensor(mlm_labels), # mlm label
348
+ torch.tensor(-100), # tri label, 3 class
349
+ torch.tensor(-100.0) # reg label, float
350
+ )
351
+
352
+ def process_wmt(self, index):
353
+ text_a = self.dataset[index]['text_a']
354
+ text_b = self.dataset[index]['text_b'][0]
355
+ reg_label = self.dataset[index]['orig_label']
356
+
357
+ try:
358
+ tokenized_pair = self.tokenizer(text_a, text_b, padding='max_length', max_length=self.tokenizer_max_length, truncation='only_first')
359
+ except:
360
+ logging.warning('text_b too long...')
361
+ tokenized_pair = self.tokenizer(text_a, text_b, padding='max_length', max_length=self.tokenizer_max_length, truncation=True)
362
+ input_ids, mlm_labels = self.random_word(tokenized_pair['input_ids'])
363
+
364
+ return (
365
+ torch.tensor(input_ids),
366
+ torch.tensor(tokenized_pair['attention_mask']),
367
+ torch.tensor(tokenized_pair['token_type_ids']) if 'token_type_ids' in tokenized_pair.keys() else None,
368
+ torch.tensor(-100), # align label, 2 class
369
+ torch.tensor(mlm_labels), # mlm label
370
+ torch.tensor(-100), # tri label, 3 class
371
+ torch.tensor(reg_label) # reg label, float
372
+ )
373
+
374
+ def process_sts(self, index):
375
+ text_a = self.dataset[index]['text_a']
376
+ text_b = self.dataset[index]['text_b'][0]
377
+ reg_label = self.dataset[index]['orig_label']
378
+
379
+ try:
380
+ tokenized_pair = self.tokenizer(text_a, text_b, padding='max_length', max_length=self.tokenizer_max_length, truncation='only_first')
381
+ except:
382
+ logging.warning('text_b too long...')
383
+ tokenized_pair = self.tokenizer(text_a, text_b, padding='max_length', max_length=self.tokenizer_max_length, truncation=True)
384
+ input_ids, mlm_labels = self.random_word(tokenized_pair['input_ids'])
385
+
386
+ return (
387
+ torch.tensor(input_ids),
388
+ torch.tensor(tokenized_pair['attention_mask']),
389
+ torch.tensor(tokenized_pair['token_type_ids']) if 'token_type_ids' in tokenized_pair.keys() else None,
390
+ torch.tensor(-100), # align label, 2 class
391
+ torch.tensor(mlm_labels), # mlm label
392
+ torch.tensor(-100), # tri label, 3 class
393
+ torch.tensor(reg_label) # reg label, float
394
+ )
395
+
396
+ def process_ctc(self, index):
397
+ text_a = self.dataset[index]['text_a']
398
+ text_b = self.dataset[index]['text_b'][0]
399
+ reg_label = self.dataset[index]['orig_label']
400
+
401
+ try:
402
+ tokenized_pair = self.tokenizer(text_a, text_b, padding='max_length', max_length=self.tokenizer_max_length, truncation='only_first')
403
+ except:
404
+ logging.warning('text_b too long...')
405
+ tokenized_pair = self.tokenizer(text_a, text_b, padding='max_length', max_length=self.tokenizer_max_length, truncation=True)
406
+ input_ids, mlm_labels = self.random_word(tokenized_pair['input_ids'])
407
+
408
+ return (
409
+ torch.tensor(input_ids),
410
+ torch.tensor(tokenized_pair['attention_mask']),
411
+ torch.tensor(tokenized_pair['token_type_ids']) if 'token_type_ids' in tokenized_pair.keys() else None,
412
+ torch.tensor(-100), # align label, 2 class
413
+ torch.tensor(mlm_labels), # mlm label
414
+ torch.tensor(-100), # tri label, 3 class
415
+ torch.tensor(reg_label) # reg label, float
416
+ )
417
+
418
+ def __getitem__(self, index):
419
+ if self.dataset[index]['task'] == 'nli':
420
+ input_ids, attention_mask, token_type_ids, align_label, mlm_labels, tri_label, reg_label = self.process_nli(index)
421
+
422
+ if self.dataset[index]['task'] == 'bin_nli':
423
+ input_ids, attention_mask, token_type_ids, align_label, mlm_labels, tri_label, reg_label = self.process_bin_nli(index)
424
+
425
+ if self.dataset[index]['task'] == 'paraphrase':
426
+ input_ids, attention_mask, token_type_ids, align_label, mlm_labels, tri_label, reg_label = self.process_paraphrase(index)
427
+
428
+ if self.dataset[index]['task'] == 'fact_checking':
429
+ input_ids, attention_mask, token_type_ids, align_label, mlm_labels, tri_label, reg_label = self.process_fact_checking(index)
430
+
431
+ if self.dataset[index]['task'] == 'summarization':
432
+ input_ids, attention_mask, token_type_ids, align_label, mlm_labels, tri_label, reg_label = self.process_summarization(index)
433
+
434
+ if self.dataset[index]['task'] == 'multiple_choice_qa':
435
+ input_ids, attention_mask, token_type_ids, align_label, mlm_labels, tri_label, reg_label = self.process_multiple_choice_qa(index)
436
+
437
+ if self.dataset[index]['task'] == 'extractive_qa':
438
+ input_ids, attention_mask, token_type_ids, align_label, mlm_labels, tri_label, reg_label = self.process_extractive_qa(index)
439
+
440
+ if self.dataset[index]['task'] == 'qa':
441
+ input_ids, attention_mask, token_type_ids, align_label, mlm_labels, tri_label, reg_label = self.process_qa(index)
442
+
443
+ if self.dataset[index]['task'] == 'coreference':
444
+ input_ids, attention_mask, token_type_ids, align_label, mlm_labels, tri_label, reg_label = self.process_coreference(index)
445
+
446
+ if self.dataset[index]['task'] == 'ir':
447
+ input_ids, attention_mask, token_type_ids, align_label, mlm_labels, tri_label, reg_label = self.process_ir(index)
448
+
449
+ if self.dataset[index]['task'] == 'sts':
450
+ input_ids, attention_mask, token_type_ids, align_label, mlm_labels, tri_label, reg_label = self.process_sts(index)
451
+
452
+ if self.dataset[index]['task'] == 'ctc':
453
+ input_ids, attention_mask, token_type_ids, align_label, mlm_labels, tri_label, reg_label = self.process_ctc(index)
454
+
455
+ if self.dataset[index]['task'] == 'wmt':
456
+ input_ids, attention_mask, token_type_ids, align_label, mlm_labels, tri_label, reg_label = self.process_wmt(index)
457
+
458
+ if token_type_ids is not None:
459
+ return {
460
+ 'input_ids': input_ids,
461
+ 'attention_mask': attention_mask,
462
+ 'token_type_ids': token_type_ids,
463
+ 'align_label': align_label,
464
+ 'mlm_label': mlm_labels,
465
+ 'tri_label': tri_label,
466
+ 'reg_label': reg_label
467
+ }
468
+ else:
469
+ return {
470
+ 'input_ids': input_ids,
471
+ 'attention_mask': attention_mask,
472
+ 'align_label': align_label,
473
+ 'mlm_label': mlm_labels,
474
+ 'tri_label': tri_label,
475
+ 'reg_label': reg_label
476
+ }
477
+
478
+
479
+ def __len__(self):
480
+ return len(self.dataset)
481
+
482
+ class PropSampler(Sampler[int]):
483
+ def __init__(self, data_source: Optional[Sized]) -> None:
484
+ super().__init__(data_source)
485
+ self.K = 500000
486
+ print("Initializing Prop Sampler")
487
+
488
+ self.data_positions = dict()
489
+ for i, example in tqdm(enumerate(data_source), desc="Initializing Sampler"):
490
+ if example['dataset_name'] in self.data_positions.keys():
491
+ self.data_positions[example['dataset_name']].append(i)
492
+ else:
493
+ self.data_positions[example['dataset_name']] = [i]
494
+ self.all_dataset_names = list(self.data_positions.keys())
495
+ self.dataset_lengths = {each:len(self.data_positions[each]) for each in self.data_positions}
496
+
497
+ self.dataset_props = {each: min(self.dataset_lengths[each], self.K) for each in self.dataset_lengths}
498
+ self.dataset_props_sum = sum([self.dataset_props[each] for each in self.dataset_props])
499
+
500
+
501
+
502
+ print("Finish Prop Sampler initialization.")
503
+
504
+ def __iter__(self):
505
+ iter_list = []
506
+ for each in self.dataset_props:
507
+ iter_list.extend(np.random.choice(self.data_positions[each], size=self.dataset_props[each], replace=False).tolist())
508
+
509
+ random.shuffle(iter_list)
510
+
511
+ yield from iter_list
512
+
513
+ def __len__(self):
514
+ return self.dataset_props_sum
515
+
516
+ class DSTDataLoader(LightningDataModule):
517
+ def __init__(self,dataset_config, val_dataset_config=None, sample_mode='seq', model_name='bert-base-uncased', is_finetune=False, need_mlm=True, tokenizer_max_length=512, train_batch_size=32, eval_batch_size=4, num_workers=16, train_eval_split=0.95, **kwargs):
518
+ super().__init__(**kwargs)
519
+ assert sample_mode in ['seq', 'proportion']
520
+ self.sample_mode = sample_mode
521
+ self.dataset_config = dataset_config
522
+ self.val_dataset_config = val_dataset_config
523
+ self.num_workers = num_workers
524
+ self.train_eval_split = train_eval_split
525
+ self.tokenizer_max_length = tokenizer_max_length
526
+ self.model_name = model_name
527
+
528
+ self.need_mlm = need_mlm
529
+ self.is_finetune = is_finetune
530
+
531
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
532
+ self.config = AutoConfig.from_pretrained(model_name)
533
+
534
+ self.train_bach_size = train_batch_size
535
+ self.eval_batch_size = eval_batch_size
536
+
537
+ self.dataset = None
538
+
539
+ def setup(self, stage: Optional[str] = None) -> None:
540
+ if self.dataset is not None:
541
+ print("Already Initilized LightningDataModule!")
542
+ return
543
+
544
+ self.init_training_set()
545
+
546
+ self.dataset = dict()
547
+ if not self.is_finetune:
548
+ self.dataset['train'] = DSTDataSet(dataset=self.raw_dataset[:int(self.train_eval_split*len(self.raw_dataset))], model_name=self.model_name, need_mlm=self.need_mlm)
549
+ self.dataset['test'] = DSTDataSet(dataset=self.raw_dataset[int(self.train_eval_split*len(self.raw_dataset)):], model_name=self.model_name, need_mlm=self.need_mlm)
550
+ else:
551
+ self.dataset['train'] = DSTDataSet(dataset=self.raw_dataset[:], model_name=self.model_name, need_mlm=self.need_mlm)
552
+ self.dataset['test'] = DSTDataSet(dataset=self.val_raw_dataset[:], model_name=self.model_name, need_mlm=self.need_mlm)
553
+
554
+
555
+ def init_training_set(self):
556
+ self.raw_dataset = []
557
+ if self.sample_mode == 'seq':
558
+ for each_dataset in self.dataset_config:
559
+ dataset_length = sum([1 for line in open(self.dataset_config[each_dataset]['data_path'], 'r', encoding='utf8')])
560
+ dataset_length_limit = self.dataset_config[each_dataset]['size'] if isinstance(self.dataset_config[each_dataset]['size'], int) else int(self.dataset_config[each_dataset]['size'] * dataset_length)
561
+ with open(self.dataset_config[each_dataset]['data_path'], 'r', encoding='utf8') as f:
562
+ try:
563
+ for i, example in enumerate(f):
564
+ if i >= dataset_length_limit:
565
+ break
566
+ self.raw_dataset.append(json.loads(example)) ## + dataset_name
567
+ except:
568
+ print(f"failed to load data from {each_dataset}.json, exiting...")
569
+ exit()
570
+
571
+ random.shuffle(self.raw_dataset)
572
+
573
+ elif self.sample_mode == 'proportion':
574
+ for each_dataset in tqdm(self.dataset_config, desc="Loading data from disk..."):
575
+ with open(self.dataset_config[each_dataset]['data_path'], 'r', encoding='utf8') as f:
576
+ try:
577
+ for i, example in enumerate(f):
578
+ jsonobj = json.loads(example)
579
+ jsonobj['dataset_name'] = each_dataset
580
+ self.raw_dataset.append(jsonobj) ## + dataset_name
581
+ except:
582
+ print(f"failed to load data from {each_dataset}.json, exiting...")
583
+ exit()
584
+
585
+ random.shuffle(self.raw_dataset)
586
+
587
+ if self.is_finetune:
588
+ self.val_raw_dataset = []
589
+ for each_dataset in self.val_dataset_config:
590
+ dataset_length = sum([1 for line in open(self.val_dataset_config[each_dataset]['data_path'], 'r', encoding='utf8')])
591
+ dataset_length_limit = self.val_dataset_config[each_dataset]['size'] if isinstance(self.val_dataset_config[each_dataset]['size'], int) else int(self.val_dataset_config[each_dataset]['size'] * dataset_length)
592
+ with open(self.val_dataset_config[each_dataset]['data_path'], 'r', encoding='utf8') as f:
593
+ for i, example in enumerate(f):
594
+ if i >= dataset_length_limit:
595
+ break
596
+ self.val_raw_dataset.append(json.loads(example))
597
+
598
+ random.shuffle(self.val_raw_dataset)
599
+
600
+ def prepare_data(self) -> None:
601
+ AutoTokenizer.from_pretrained(self.model_name)
602
+
603
+ def train_dataloader(self):
604
+ if self.sample_mode == 'seq':
605
+ return DataLoader(self.dataset['train'], batch_size=self.train_bach_size, shuffle=True, num_workers=self.num_workers)
606
+ elif self.sample_mode == 'proportion':
607
+ return DataLoader(self.dataset['train'], batch_size=self.train_bach_size, sampler=PropSampler(self.raw_dataset[:int(self.train_eval_split*len(self.raw_dataset))]), num_workers=self.num_workers)
608
+
609
+ def val_dataloader(self):
610
+ return DataLoader(self.dataset['test'], batch_size=self.eval_batch_size, shuffle=False, num_workers=self.num_workers)
alignscore/src/alignscore/inference.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from logging import warning
2
+ import spacy
3
+ from nltk.tokenize import sent_tokenize
4
+ import torch
5
+ from .model import BERTAlignModel
6
+ from transformers import AutoConfig, AutoTokenizer
7
+ import torch.nn as nn
8
+ from tqdm import tqdm
9
+
10
+ class Inferencer():
11
+ def __init__(self, ckpt_path, model='bert-base-uncased', batch_size=32, device='cuda', verbose=True) -> None:
12
+ self.device = device
13
+ if ckpt_path is not None:
14
+ self.model = BERTAlignModel(model=model).load_from_checkpoint(checkpoint_path=ckpt_path, strict=False).to(self.device)
15
+ else:
16
+ warning('loading UNTRAINED model!')
17
+ self.model = BERTAlignModel(model=model).to(self.device)
18
+ self.model.eval()
19
+ self.batch_size = batch_size
20
+
21
+ self.config = AutoConfig.from_pretrained(model)
22
+ self.tokenizer = AutoTokenizer.from_pretrained(model)
23
+ self.spacy = spacy.load('en_core_web_sm')
24
+
25
+ self.loss_fct = nn.CrossEntropyLoss(reduction='none')
26
+ self.softmax = nn.Softmax(dim=-1)
27
+
28
+ self.smart_type = 'smart-n'
29
+ self.smart_n_metric = 'f1'
30
+
31
+ self.disable_progress_bar_in_inference = False
32
+
33
+ self.nlg_eval_mode = None # bin, bin_sp, nli, nli_sp
34
+ self.verbose = verbose
35
+
36
+ def inference_example_batch(self, premise: list, hypo: list):
37
+ """
38
+ inference a example,
39
+ premise: list
40
+ hypo: list
41
+ using self.inference to batch the process
42
+
43
+ SummaC Style aggregation
44
+ """
45
+ self.disable_progress_bar_in_inference = True
46
+ assert len(premise) == len(hypo), "Premise must has the same length with Hypothesis!"
47
+
48
+ out_score = []
49
+ for one_pre, one_hypo in tqdm(zip(premise, hypo), desc="Evaluating", total=len(premise), disable=(not self.verbose)):
50
+ out_score.append(self.inference_per_example(one_pre, one_hypo))
51
+
52
+ return None, torch.tensor(out_score), None
53
+
54
+ def inference_per_example(self, premise:str, hypo: str):
55
+ """
56
+ inference a example,
57
+ premise: string
58
+ hypo: string
59
+ using self.inference to batch the process
60
+ """
61
+ def chunks(lst, n):
62
+ """Yield successive n-sized chunks from lst."""
63
+ for i in range(0, len(lst), n):
64
+ yield ' '.join(lst[i:i + n])
65
+
66
+ premise_sents = sent_tokenize(premise)
67
+ premise_sents = premise_sents or ['']
68
+
69
+ n_chunk = len(premise.strip().split()) // 350 + 1
70
+ n_chunk = max(len(premise_sents) // n_chunk, 1)
71
+ premise_sents = [each for each in chunks(premise_sents, n_chunk)]
72
+
73
+ hypo_sents = sent_tokenize(hypo)
74
+
75
+ premise_sent_mat = []
76
+ hypo_sents_mat = []
77
+ for i in range(len(premise_sents)):
78
+ for j in range(len(hypo_sents)):
79
+ premise_sent_mat.append(premise_sents[i])
80
+ hypo_sents_mat.append(hypo_sents[j])
81
+
82
+ if self.nlg_eval_mode is not None:
83
+ if self.nlg_eval_mode == 'nli_sp':
84
+ output_score = self.inference(premise_sent_mat, hypo_sents_mat)[2][:,0] ### use NLI head OR ALIGN head
85
+ elif self.nlg_eval_mode == 'bin_sp':
86
+ output_score = self.inference(premise_sent_mat, hypo_sents_mat)[1] ### use NLI head OR ALIGN head
87
+ elif self.nlg_eval_mode == 'reg_sp':
88
+ output_score = self.inference(premise_sent_mat, hypo_sents_mat)[0] ### use NLI head OR ALIGN head
89
+
90
+ output_score = output_score.view(len(premise_sents), len(hypo_sents)).max(dim=0).values.mean().item() ### sum or mean depends on the task/aspect
91
+ return output_score
92
+
93
+
94
+ output_score = self.inference(premise_sent_mat, hypo_sents_mat)[2][:,0] ### use NLI head OR ALIGN head
95
+ output_score = output_score.view(len(premise_sents), len(hypo_sents)).max(dim=0).values.mean().item() ### sum or mean depends on the task/aspect
96
+
97
+ return output_score
98
+
99
+
100
+ def inference(self, premise, hypo):
101
+ """
102
+ inference a list of premise and hypo
103
+
104
+ Standard aggregation
105
+ """
106
+ if isinstance(premise, str) and isinstance(hypo, str):
107
+ premise = [premise]
108
+ hypo = [hypo]
109
+
110
+ batch = self.batch_tokenize(premise, hypo)
111
+ output_score_reg = []
112
+ output_score_bin = []
113
+ output_score_tri = []
114
+
115
+ for mini_batch in tqdm(batch, desc="Evaluating", disable=not self.verbose or self.disable_progress_bar_in_inference):
116
+ mini_batch = mini_batch.to(self.device)
117
+ with torch.no_grad():
118
+ model_output = self.model(mini_batch)
119
+ model_output_reg = model_output.reg_label_logits.cpu()
120
+ model_output_bin = model_output.seq_relationship_logits # Temperature Scaling / 2.5
121
+ model_output_tri = model_output.tri_label_logits
122
+
123
+ model_output_bin = self.softmax(model_output_bin).cpu()
124
+ model_output_tri = self.softmax(model_output_tri).cpu()
125
+ output_score_reg.append(model_output_reg[:,0])
126
+ output_score_bin.append(model_output_bin[:,1])
127
+ output_score_tri.append(model_output_tri[:,:])
128
+
129
+ output_score_reg = torch.cat(output_score_reg)
130
+ output_score_bin = torch.cat(output_score_bin)
131
+ output_score_tri = torch.cat(output_score_tri)
132
+
133
+ if self.nlg_eval_mode is not None:
134
+ if self.nlg_eval_mode == 'nli':
135
+ output_score_nli = output_score_tri[:,0]
136
+ return None, output_score_nli, None
137
+ elif self.nlg_eval_mode == 'bin':
138
+ return None, output_score_bin, None
139
+ elif self.nlg_eval_mode == 'reg':
140
+ return None, output_score_reg, None
141
+ else:
142
+ ValueError("unrecognized nlg eval mode")
143
+
144
+
145
+ return output_score_reg, output_score_bin, output_score_tri
146
+
147
+ def inference_reg(self, premise, hypo):
148
+ """
149
+ inference a list of premise and hypo
150
+
151
+ Standard aggregation
152
+ """
153
+ self.model.is_reg_finetune = True
154
+ if isinstance(premise, str) and isinstance(hypo, str):
155
+ premise = [premise]
156
+ hypo = [hypo]
157
+
158
+ batch = self.batch_tokenize(premise, hypo)
159
+ output_score = []
160
+
161
+ for mini_batch in tqdm(batch, desc="Evaluating", disable=self.disable_progress_bar_in_inference):
162
+ mini_batch = mini_batch.to(self.device)
163
+ with torch.no_grad():
164
+ model_output = self.model(mini_batch).seq_relationship_logits.cpu().view(-1)
165
+ output_score.append(model_output)
166
+ output_score = torch.cat(output_score)
167
+ return output_score
168
+
169
+ def batch_tokenize(self, premise, hypo):
170
+ """
171
+ input premise and hypos are lists
172
+ """
173
+ assert isinstance(premise, list) and isinstance(hypo, list)
174
+ assert len(premise) == len(hypo), "premise and hypo should be in the same length."
175
+
176
+ batch = []
177
+ for mini_batch_pre, mini_batch_hypo in zip(self.chunks(premise, self.batch_size), self.chunks(hypo, self.batch_size)):
178
+ try:
179
+ mini_batch = self.tokenizer(mini_batch_pre, mini_batch_hypo, truncation='only_first', padding='max_length', max_length=self.tokenizer.model_max_length, return_tensors='pt')
180
+ except:
181
+ warning('text_b too long...')
182
+ mini_batch = self.tokenizer(mini_batch_pre, mini_batch_hypo, truncation=True, padding='max_length', max_length=self.tokenizer.model_max_length, return_tensors='pt')
183
+ batch.append(mini_batch)
184
+
185
+ return batch
186
+ def smart_doc(self, premise: list, hypo: list):
187
+ """
188
+ inference a example,
189
+ premise: list
190
+ hypo: list
191
+ using self.inference to batch the process
192
+
193
+ SMART Style aggregation
194
+ """
195
+ self.disable_progress_bar_in_inference = True
196
+ assert len(premise) == len(hypo), "Premise must has the same length with Hypothesis!"
197
+ assert self.smart_type in ['smart-n', 'smart-l']
198
+
199
+ out_score = []
200
+ for one_pre, one_hypo in tqdm(zip(premise, hypo), desc="Evaluating SMART", total=len(premise)):
201
+ out_score.append(self.smart_l(one_pre, one_hypo)[1] if self.smart_type == 'smart-l' else self.smart_n(one_pre, one_hypo)[1])
202
+
203
+ return None, torch.tensor(out_score), None
204
+
205
+ def smart_l(self, premise, hypo):
206
+ premise_sents = [each.text for each in self.spacy(premise).sents]
207
+ hypo_sents = [each.text for each in self.spacy(hypo).sents]
208
+
209
+ premise_sent_mat = []
210
+ hypo_sents_mat = []
211
+ for i in range(len(premise_sents)):
212
+ for j in range(len(hypo_sents)):
213
+ premise_sent_mat.append(premise_sents[i])
214
+ hypo_sents_mat.append(hypo_sents[j])
215
+
216
+ output_score = self.inference(premise_sent_mat, hypo_sents_mat)[2][:,0]
217
+ output_score = output_score.view(len(premise_sents), len(hypo_sents))
218
+
219
+ ### smart-l
220
+ lcs = [[0] * (len(hypo_sents)+1)] * (len(premise_sents)+1)
221
+ for i in range(len(premise_sents)+1):
222
+ for j in range(len(hypo_sents)+1):
223
+ if i != 0 and j != 0:
224
+ m = output_score[i-1, j-1]
225
+ lcs[i][j] = max([lcs[i-1][j-1]+m,
226
+ lcs[i-1][j]+m,
227
+ lcs[i][j-1]])
228
+
229
+ return None, lcs[-1][-1] / len(premise_sents), None
230
+
231
+ def smart_n(self, premise, hypo):
232
+ ### smart-n
233
+ n_gram = 1
234
+
235
+ premise_sents = [each.text for each in self.spacy(premise).sents]
236
+ hypo_sents = [each.text for each in self.spacy(hypo).sents]
237
+
238
+ premise_sent_mat = []
239
+ hypo_sents_mat = []
240
+ for i in range(len(premise_sents)):
241
+ for j in range(len(hypo_sents)):
242
+ premise_sent_mat.append(premise_sents[i])
243
+ hypo_sents_mat.append(hypo_sents[j])
244
+
245
+ output_score = self.inference(premise_sent_mat, hypo_sents_mat)[2][:,0]
246
+ output_score = output_score.view(len(premise_sents), len(hypo_sents))
247
+
248
+ prec = sum([max([sum([output_score[i+n, j+n]/n_gram for n in range(0, n_gram)]) for i in range(len(premise_sents)-n_gram+1)]) for j in range(len(hypo_sents)-n_gram+1)])
249
+ prec = prec / (len(hypo_sents) - n_gram + 1) if (len(hypo_sents) - n_gram + 1) > 0 else 0.
250
+
251
+
252
+ premise_sents = [each.text for each in self.spacy(hypo).sents]# simple change
253
+ hypo_sents = [each.text for each in self.spacy(premise).sents]#
254
+
255
+ premise_sent_mat = []
256
+ hypo_sents_mat = []
257
+ for i in range(len(premise_sents)):
258
+ for j in range(len(hypo_sents)):
259
+ premise_sent_mat.append(premise_sents[i])
260
+ hypo_sents_mat.append(hypo_sents[j])
261
+
262
+ output_score = self.inference(premise_sent_mat, hypo_sents_mat)[2][:,0]
263
+ output_score = output_score.view(len(premise_sents), len(hypo_sents))
264
+
265
+ recall = sum([max([sum([output_score[i+n, j+n]/n_gram for n in range(0, n_gram)]) for i in range(len(premise_sents)-n_gram+1)]) for j in range(len(hypo_sents)-n_gram+1)])
266
+ recall = prec / (len(hypo_sents) - n_gram + 1) if (len(hypo_sents) - n_gram + 1) > 0 else 0.
267
+
268
+ f1 = 2 * prec * recall / (prec + recall)
269
+
270
+ if self.smart_n_metric == 'f1':
271
+ return None, f1, None
272
+ elif self.smart_n_metric == 'precision':
273
+ return None, prec, None
274
+ elif self.smart_n_metric == 'recall':
275
+ return None, recall, None
276
+ else:
277
+ ValueError("SMART return type error")
278
+
279
+ def chunks(self, lst, n):
280
+ """Yield successive n-sized chunks from lst."""
281
+ for i in range(0, len(lst), n):
282
+ yield lst[i:i + n]
283
+
284
+ def nlg_eval(self, premise, hypo):
285
+ assert self.nlg_eval_mode is not None, "Select NLG Eval mode!"
286
+ if (self.nlg_eval_mode == 'bin') or (self.nlg_eval_mode == 'nli') or (self.nlg_eval_mode == 'reg'):
287
+ return self.inference(premise, hypo)
288
+
289
+ elif (self.nlg_eval_mode == 'bin_sp') or (self.nlg_eval_mode == 'nli_sp') or (self.nlg_eval_mode == 'reg_sp'):
290
+ return self.inference_example_batch(premise, hypo)
291
+
292
+ else:
293
+ ValueError("Unrecognized NLG Eval mode!")
alignscore/src/alignscore/model.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Tuple
3
+ from transformers import AdamW, get_linear_schedule_with_warmup, AutoConfig
4
+ from transformers import BertForPreTraining, BertModel, RobertaModel, AlbertModel, AlbertForMaskedLM, RobertaForMaskedLM
5
+ import torch
6
+ import torch.nn as nn
7
+ import pytorch_lightning as pl
8
+ from sklearn.metrics import f1_score
9
+ from dataclasses import dataclass
10
+
11
+
12
+
13
+ class BERTAlignModel(pl.LightningModule):
14
+ def __init__(self, model='bert-base-uncased', using_pretrained=True, *args, **kwargs) -> None:
15
+ super().__init__()
16
+ # Already defined in lightning: self.device
17
+ self.save_hyperparameters()
18
+ self.model = model
19
+
20
+ if 'muppet' in model:
21
+ assert using_pretrained == True, "Only support pretrained muppet!"
22
+ self.base_model = RobertaModel.from_pretrained(model)
23
+ self.mlm_head = RobertaForMaskedLM(AutoConfig.from_pretrained(model)).lm_head
24
+
25
+ elif 'roberta' in model:
26
+ if using_pretrained:
27
+ self.base_model = RobertaModel.from_pretrained(model)
28
+ self.mlm_head = RobertaForMaskedLM.from_pretrained(model).lm_head
29
+ else:
30
+ self.base_model = RobertaModel(AutoConfig.from_pretrained(model))
31
+ self.mlm_head = RobertaForMaskedLM(AutoConfig.from_pretrained(model)).lm_head
32
+
33
+ elif 'albert' in model:
34
+ if using_pretrained:
35
+ self.base_model = AlbertModel.from_pretrained(model)
36
+ self.mlm_head = AlbertForMaskedLM.from_pretrained(model).predictions
37
+ else:
38
+ self.base_model = AlbertModel(AutoConfig.from_pretrained(model))
39
+ self.mlm_head = AlbertForMaskedLM(AutoConfig.from_pretrained(model)).predictions
40
+
41
+ elif 'bert' in model:
42
+ if using_pretrained:
43
+ self.base_model = BertModel.from_pretrained(model)
44
+ self.mlm_head = BertForPreTraining.from_pretrained(model).cls.predictions
45
+ else:
46
+ self.base_model = BertModel(AutoConfig.from_pretrained(model))
47
+ self.mlm_head = BertForPreTraining(AutoConfig.from_pretrained(model)).cls.predictions
48
+
49
+ elif 'electra' in model:
50
+ self.generator = BertModel(AutoConfig.from_pretrained('prajjwal1/bert-small'))
51
+ self.generator_mlm = BertForPreTraining(AutoConfig.from_pretrained('prajjwal1/bert-small')).cls.predictions
52
+
53
+ self.base_model = BertModel(AutoConfig.from_pretrained('bert-base-uncased'))
54
+ self.discriminator_predictor = ElectraDiscriminatorPredictions(self.base_model.config)
55
+
56
+
57
+ self.bin_layer = nn.Linear(self.base_model.config.hidden_size, 2)
58
+ self.tri_layer = nn.Linear(self.base_model.config.hidden_size, 3)
59
+ self.reg_layer = nn.Linear(self.base_model.config.hidden_size, 1)
60
+
61
+ self.dropout = nn.Dropout(p=0.1)
62
+
63
+ self.need_mlm = True
64
+ self.is_finetune = False
65
+ self.mlm_loss_factor = 0.5
66
+
67
+ self.softmax = nn.Softmax(dim=-1)
68
+
69
+ def forward(self, batch):
70
+ if 'electra' in self.model:
71
+ return self.electra_forward(batch)
72
+ base_model_output = self.base_model(
73
+ input_ids = batch['input_ids'],
74
+ attention_mask = batch['attention_mask'],
75
+ token_type_ids = batch['token_type_ids'] if 'token_type_ids' in batch.keys() else None
76
+ )
77
+
78
+ prediction_scores = self.mlm_head(base_model_output.last_hidden_state) ## sequence_output for mlm
79
+ seq_relationship_score = self.bin_layer(self.dropout(base_model_output.pooler_output)) ## pooled output for classification
80
+ tri_label_score = self.tri_layer(self.dropout(base_model_output.pooler_output))
81
+ reg_label_score = self.reg_layer(base_model_output.pooler_output)
82
+
83
+ total_loss = None
84
+ if 'mlm_label' in batch.keys(): ### 'mlm_label' and 'align_label' when training
85
+ ce_loss_fct = nn.CrossEntropyLoss(reduction='sum')
86
+ masked_lm_loss = ce_loss_fct(prediction_scores.view(-1, self.base_model.config.vocab_size), batch['mlm_label'].view(-1)) #/ self.con vocabulary
87
+ next_sentence_loss = ce_loss_fct(seq_relationship_score.view(-1, 2), batch['align_label'].view(-1)) / math.log(2)
88
+ tri_label_loss = ce_loss_fct(tri_label_score.view(-1, 3), batch['tri_label'].view(-1)) / math.log(3)
89
+ reg_label_loss = self.mse_loss(reg_label_score.view(-1), batch['reg_label'].view(-1), reduction='sum')
90
+
91
+ masked_lm_loss_num = torch.sum(batch['mlm_label'].view(-1) != -100)
92
+ next_sentence_loss_num = torch.sum(batch['align_label'].view(-1) != -100)
93
+ tri_label_loss_num = torch.sum(batch['tri_label'].view(-1) != -100)
94
+ reg_label_loss_num = torch.sum(batch['reg_label'].view(-1) != -100.0)
95
+
96
+ return ModelOutput(
97
+ loss=total_loss,
98
+ all_loss=[masked_lm_loss, next_sentence_loss, tri_label_loss, reg_label_loss] if 'mlm_label' in batch.keys() else None,
99
+ loss_nums=[masked_lm_loss_num, next_sentence_loss_num, tri_label_loss_num, reg_label_loss_num] if 'mlm_label' in batch.keys() else None,
100
+ prediction_logits=prediction_scores,
101
+ seq_relationship_logits=seq_relationship_score,
102
+ tri_label_logits=tri_label_score,
103
+ reg_label_logits=reg_label_score,
104
+ hidden_states=base_model_output.hidden_states,
105
+ attentions=base_model_output.attentions
106
+ )
107
+
108
+ def electra_forward(self, batch):
109
+ if 'mlm_label' in batch.keys():
110
+ ce_loss_fct = nn.CrossEntropyLoss()
111
+ generator_output = self.generator_mlm(self.generator(
112
+ input_ids = batch['input_ids'],
113
+ attention_mask = batch['attention_mask'],
114
+ token_type_ids = batch['token_type_ids'] if 'token_type_ids' in batch.keys() else None
115
+ ).last_hidden_state)
116
+ masked_lm_loss = ce_loss_fct(generator_output.view(-1, self.generator.config.vocab_size), batch['mlm_label'].view(-1))
117
+
118
+ hallucinated_tokens = batch['input_ids'].clone()
119
+
120
+ hallucinated_tokens[batch['mlm_label']!=-100] = torch.argmax(generator_output, dim=-1)[batch['mlm_label']!=-100]
121
+ replaced_token_label = (batch['input_ids'] == hallucinated_tokens).long()#.type(torch.LongTensor) #[batch['mlm_label'] == -100] = -100
122
+ replaced_token_label[batch['mlm_label']!=-100] = (batch['mlm_label'] == hallucinated_tokens)[batch['mlm_label']!=-100].long()
123
+ replaced_token_label[batch['input_ids'] == 0] = -100 ### ignore paddings
124
+
125
+ base_model_output = self.base_model(
126
+ input_ids = hallucinated_tokens if 'mlm_label' in batch.keys() else batch['input_ids'],
127
+ attention_mask = batch['attention_mask'],
128
+ token_type_ids = batch['token_type_ids'] if 'token_type_ids' in batch.keys() else None
129
+ )
130
+ hallu_detect_score = self.discriminator_predictor(base_model_output.last_hidden_state)
131
+ seq_relationship_score = self.bin_layer(self.dropout(base_model_output.pooler_output)) ## pooled output for classification
132
+ tri_label_score = self.tri_layer(self.dropout(base_model_output.pooler_output))
133
+ reg_label_score = self.reg_layer(base_model_output.pooler_output)
134
+
135
+ total_loss = None
136
+
137
+ if 'mlm_label' in batch.keys(): ### 'mlm_label' and 'align_label' when training
138
+ total_loss = []
139
+ ce_loss_fct = nn.CrossEntropyLoss()
140
+ hallu_detect_loss = ce_loss_fct(hallu_detect_score.view(-1,2),replaced_token_label.view(-1))
141
+ next_sentence_loss = ce_loss_fct(seq_relationship_score.view(-1, 2), batch['align_label'].view(-1))
142
+ tri_label_loss = ce_loss_fct(tri_label_score.view(-1, 3), batch['tri_label'].view(-1))
143
+ reg_label_loss = self.mse_loss(reg_label_score.view(-1), batch['reg_label'].view(-1))
144
+
145
+ total_loss.append(10.0 * hallu_detect_loss if not torch.isnan(hallu_detect_loss).item() else 0.)
146
+ total_loss.append(0.2 * masked_lm_loss if (not torch.isnan(masked_lm_loss).item() and self.need_mlm) else 0.)
147
+ total_loss.append(next_sentence_loss if not torch.isnan(next_sentence_loss).item() else 0.)
148
+ total_loss.append(tri_label_loss if not torch.isnan(tri_label_loss).item() else 0.)
149
+ total_loss.append(reg_label_loss if not torch.isnan(reg_label_loss).item() else 0.)
150
+
151
+ total_loss = sum(total_loss)
152
+
153
+ return ModelOutput(
154
+ loss=total_loss,
155
+ all_loss=[masked_lm_loss, next_sentence_loss, tri_label_loss, reg_label_loss, hallu_detect_loss] if 'mlm_label' in batch.keys() else None,
156
+ prediction_logits=hallu_detect_score,
157
+ seq_relationship_logits=seq_relationship_score,
158
+ tri_label_logits=tri_label_score,
159
+ reg_label_logits=reg_label_score,
160
+ hidden_states=base_model_output.hidden_states,
161
+ attentions=base_model_output.attentions
162
+ )
163
+
164
+ def training_step(self, train_batch, batch_idx):
165
+ output = self(train_batch)
166
+
167
+ return {'losses': output.all_loss, 'loss_nums': output.loss_nums}
168
+
169
+ def training_step_end(self, step_output):
170
+ losses = step_output['losses']
171
+ loss_nums = step_output['loss_nums']
172
+ assert len(loss_nums) == len(losses), 'loss_num should be the same length as losses'
173
+
174
+ loss_mlm_num = torch.sum(loss_nums[0])
175
+ loss_bin_num = torch.sum(loss_nums[1])
176
+ loss_tri_num = torch.sum(loss_nums[2])
177
+ loss_reg_num = torch.sum(loss_nums[3])
178
+
179
+ loss_mlm = torch.sum(losses[0]) / loss_mlm_num if loss_mlm_num > 0 else 0.
180
+ loss_bin = torch.sum(losses[1]) / loss_bin_num if loss_bin_num > 0 else 0.
181
+ loss_tri = torch.sum(losses[2]) / loss_tri_num if loss_tri_num > 0 else 0.
182
+ loss_reg = torch.sum(losses[3]) / loss_reg_num if loss_reg_num > 0 else 0.
183
+
184
+ total_loss = self.mlm_loss_factor * loss_mlm + loss_bin + loss_tri + loss_reg
185
+
186
+ self.log('train_loss', total_loss)# , sync_dist=True
187
+ self.log('mlm_loss', loss_mlm)
188
+ self.log('bin_label_loss', loss_bin)
189
+ self.log('tri_label_loss', loss_tri)
190
+ self.log('reg_label_loss', loss_reg)
191
+
192
+ return total_loss
193
+
194
+ def validation_step(self, val_batch, batch_idx):
195
+ if not self.is_finetune:
196
+ with torch.no_grad():
197
+ output = self(val_batch)
198
+
199
+ return {'losses': output.all_loss, 'loss_nums': output.loss_nums}
200
+
201
+ with torch.no_grad():
202
+ output = self(val_batch)['seq_relationship_logits']
203
+ output = self.softmax(output)[:, 1].tolist()
204
+ pred = [int(align_prob>0.5) for align_prob in output]
205
+
206
+ labels = val_batch['align_label'].tolist()
207
+
208
+ return {"pred": pred, 'labels': labels}#, "preds":preds, "labels":x['labels']}
209
+
210
+ def validation_step_end(self, step_output):
211
+ losses = step_output['losses']
212
+ loss_nums = step_output['loss_nums']
213
+ assert len(loss_nums) == len(losses), 'loss_num should be the same length as losses'
214
+
215
+ loss_mlm_num = torch.sum(loss_nums[0])
216
+ loss_bin_num = torch.sum(loss_nums[1])
217
+ loss_tri_num = torch.sum(loss_nums[2])
218
+ loss_reg_num = torch.sum(loss_nums[3])
219
+
220
+ loss_mlm = torch.sum(losses[0]) / loss_mlm_num if loss_mlm_num > 0 else 0.
221
+ loss_bin = torch.sum(losses[1]) / loss_bin_num if loss_bin_num > 0 else 0.
222
+ loss_tri = torch.sum(losses[2]) / loss_tri_num if loss_tri_num > 0 else 0.
223
+ loss_reg = torch.sum(losses[3]) / loss_reg_num if loss_reg_num > 0 else 0.
224
+
225
+ total_loss = self.mlm_loss_factor * loss_mlm + loss_bin + loss_tri + loss_reg
226
+
227
+ self.log('train_loss', total_loss)# , sync_dist=True
228
+ self.log('mlm_loss', loss_mlm)
229
+ self.log('bin_label_loss', loss_bin)
230
+ self.log('tri_label_loss', loss_tri)
231
+ self.log('reg_label_loss', loss_reg)
232
+
233
+ return total_loss
234
+
235
+ def validation_epoch_end(self, outputs):
236
+ if not self.is_finetune:
237
+ total_loss = torch.stack(outputs).mean()
238
+ self.log("val_loss", total_loss, prog_bar=True, sync_dist=True)
239
+
240
+ else:
241
+ all_predictions = []
242
+ all_labels = []
243
+ for each_output in outputs:
244
+ all_predictions.extend(each_output['pred'])
245
+ all_labels.extend(each_output['labels'])
246
+
247
+ self.log("f1", f1_score(all_labels, all_predictions), prog_bar=True, sync_dist=True)
248
+
249
+ def configure_optimizers(self):
250
+ """Prepare optimizer and schedule (linear warmup and decay)"""
251
+ no_decay = ["bias", "LayerNorm.weight"]
252
+ optimizer_grouped_parameters = [
253
+ {
254
+ "params": [p for n, p in self.named_parameters() if not any(nd in n for nd in no_decay)],
255
+ "weight_decay": self.hparams.weight_decay,
256
+ },
257
+ {
258
+ "params": [p for n, p in self.named_parameters() if any(nd in n for nd in no_decay)],
259
+ "weight_decay": 0.0,
260
+ },
261
+ ]
262
+ optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
263
+
264
+ scheduler = get_linear_schedule_with_warmup(
265
+ optimizer,
266
+ num_warmup_steps=int(self.hparams.warmup_steps_portion * self.trainer.estimated_stepping_batches),
267
+ num_training_steps=self.trainer.estimated_stepping_batches,
268
+ )
269
+ scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
270
+ return [optimizer], [scheduler]
271
+
272
+ def mse_loss(self, input, target, ignored_index=-100.0, reduction='mean'):
273
+ mask = (target == ignored_index)
274
+ out = (input[~mask]-target[~mask])**2
275
+ if reduction == "mean":
276
+ return out.mean()
277
+ elif reduction == "sum":
278
+ return out.sum()
279
+
280
+ class ElectraDiscriminatorPredictions(nn.Module):
281
+ """Prediction module for the discriminator, made up of two dense layers."""
282
+
283
+ def __init__(self, config):
284
+ super().__init__()
285
+
286
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
287
+ self.dense_prediction = nn.Linear(config.hidden_size, 2)
288
+ self.config = config
289
+ self.gelu = nn.GELU()
290
+
291
+ def forward(self, discriminator_hidden_states):
292
+ hidden_states = self.dense(discriminator_hidden_states)
293
+ hidden_states = self.gelu(hidden_states)
294
+ logits = self.dense_prediction(hidden_states).squeeze(-1)
295
+
296
+ return logits
297
+
298
+ @dataclass
299
+ class ModelOutput():
300
+ loss: Optional[torch.FloatTensor] = None
301
+ all_loss: Optional[list] = None
302
+ loss_nums: Optional[list] = None
303
+ prediction_logits: torch.FloatTensor = None
304
+ seq_relationship_logits: torch.FloatTensor = None
305
+ tri_label_logits: torch.FloatTensor = None
306
+ reg_label_logits: torch.FloatTensor = None
307
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
308
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
alignscore/train.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pytorch_lightning import Trainer, seed_everything
2
+ from alignscore.dataloader import DSTDataLoader
3
+ from alignscore.model import BERTAlignModel
4
+ from pytorch_lightning.callbacks import ModelCheckpoint
5
+ from argparse import ArgumentParser
6
+ import os
7
+
8
+ def train(datasets, args):
9
+ dm = DSTDataLoader(
10
+ dataset_config=datasets,
11
+ model_name=args.model_name,
12
+ sample_mode='seq',
13
+ train_batch_size=args.batch_size,
14
+ eval_batch_size=16,
15
+ num_workers=args.num_workers,
16
+ train_eval_split=0.95,
17
+ need_mlm=args.do_mlm
18
+ )
19
+ dm.setup()
20
+
21
+ model = BERTAlignModel(model=args.model_name, using_pretrained=args.use_pretrained_model,
22
+ adam_epsilon=args.adam_epsilon,
23
+ learning_rate=args.learning_rate,
24
+ weight_decay=args.weight_decay,
25
+ warmup_steps_portion=args.warm_up_proportion
26
+ )
27
+ model.need_mlm = args.do_mlm
28
+
29
+ training_dataset_used = '_'.join(datasets.keys())
30
+ checkpoint_name = '_'.join((
31
+ f"{args.ckpt_comment}{args.model_name.replace('/', '-')}",
32
+ f"{'scratch_' if not args.use_pretrained_model else ''}{'no_mlm_' if not args.do_mlm else ''}{training_dataset_used}",
33
+ str(args.max_samples_per_dataset),
34
+ f"{args.batch_size}x{len(args.devices)}x{args.accumulate_grad_batch}"
35
+ ))
36
+
37
+ checkpoint_callback = ModelCheckpoint(
38
+ dirpath=args.ckpt_save_path,
39
+ filename=checkpoint_name + "_{epoch:02d}_{step}",
40
+ every_n_train_steps=10000,
41
+ save_top_k=1
42
+ )
43
+ trainer = Trainer(
44
+ accelerator='gpu',
45
+ max_epochs=args.num_epoch,
46
+ devices=args.devices,
47
+ strategy="dp",
48
+ precision=32,
49
+ callbacks=[checkpoint_callback],
50
+ accumulate_grad_batches=args.accumulate_grad_batch
51
+ )
52
+
53
+ trainer.fit(model, datamodule=dm)
54
+ trainer.save_checkpoint(os.path.join(args.ckpt_save_path, f"{checkpoint_name}_final.ckpt"))
55
+
56
+ print("Training is finished.")
57
+
58
+ if __name__ == "__main__":
59
+ ALL_TRAINING_DATASETS = {
60
+ ### NLI
61
+ 'mnli': {'task_type': 'nli', 'data_path': 'mnli.json'},
62
+ 'doc_nli': {'task_type': 'bin_nli', 'data_path': 'doc_nli.json'},
63
+ 'snli': {'task_type': 'nli', 'data_path': 'snli.json'},
64
+ 'anli_r1': {'task_type': 'nli', 'data_path': 'anli_r1.json'},
65
+ 'anli_r2': {'task_type': 'nli', 'data_path': 'anli_r2.json'},
66
+ 'anli_r3': {'task_type': 'nli', 'data_path': 'anli_r3.json'},
67
+
68
+ ### fact checking
69
+ 'nli_fever': {'task_type': 'fact_checking', 'data_path': 'nli_fever.json'},
70
+ 'vitaminc': {'task_type': 'fact_checking', 'data_path': 'vitaminc.json'},
71
+
72
+ ### paraphrase
73
+ 'paws': {'task_type': 'paraphrase', 'data_path': 'paws.json'},
74
+ 'paws_qqp': {'task_type': 'paraphrase', 'data_path': 'paws_qqp.json'},
75
+ 'paws_unlabeled': {'task_type': 'paraphrase', 'data_path': 'paws_unlabeled.json'},
76
+ 'qqp': {'task_type': 'paraphrase', 'data_path': 'qqp.json'},
77
+ 'wiki103': {'task_type': 'paraphrase', 'data_path': 'wiki103.json'},
78
+
79
+ ### QA
80
+ 'squad_v2': {'task_type': 'qa', 'data_path': 'squad_v2_new.json'},
81
+ 'race': {'task_type': 'qa', 'data_path': 'race.json'},
82
+ 'adversarial_qa': {'task_type': 'qa', 'data_path': 'adversarial_qa.json'},
83
+ 'drop': {'task_type': 'qa', 'data_path': 'drop.json'},
84
+ 'hotpot_qa_distractor': {'task_type': 'qa', 'data_path': 'hotpot_qa_distractor.json'},
85
+ 'hotpot_qa_fullwiki': {'task_type': 'qa', 'data_path': 'hotpot_qa_fullwiki.json'},
86
+ 'newsqa': {'task_type': 'qa', 'data_path': 'newsqa.json'},
87
+ 'quoref': {'task_type': 'qa', 'data_path': 'quoref.json'},
88
+ 'ropes': {'task_type': 'qa', 'data_path': 'ropes.json'},
89
+ 'boolq': {'task_type': 'qa', 'data_path': 'boolq.json'},
90
+ 'eraser_multi_rc': {'task_type': 'qa', 'data_path': 'eraser_multi_rc.json'},
91
+ 'quail': {'task_type': 'qa', 'data_path': 'quail.json'},
92
+ 'sciq': {'task_type': 'qa', 'data_path': 'sciq.json'},
93
+ 'strategy_qa': {'task_type': 'qa', 'data_path': 'strategy_qa.json'},
94
+
95
+ ### Coreference
96
+ 'gap': {'task_type': 'coreference', 'data_path': 'gap.json'},
97
+
98
+ ### Summarization
99
+ 'wikihow': {'task_type': 'summarization', 'data_path': 'wikihow.json'},
100
+
101
+ ### Information Retrieval
102
+ 'msmarco': {'task_type': 'ir', 'data_path': 'msmarco.json'},
103
+
104
+ ### STS
105
+ 'stsb': {'task_type': 'sts', 'data_path': 'stsb.json'},
106
+ 'sick': {'task_type': 'sts', 'data_path': 'sick.json'},
107
+ }
108
+
109
+ parser = ArgumentParser()
110
+ parser.add_argument('--seed', type=int, default=2022)
111
+ parser.add_argument('--batch-size', type=int, default=32)
112
+ parser.add_argument('--accumulate-grad-batch', type=int, default=1)
113
+ parser.add_argument('--num-epoch', type=int, default=3)
114
+ parser.add_argument('--num-workers', type=int, default=8)
115
+ parser.add_argument('--warm-up-proportion', type=float, default=0.06)
116
+ parser.add_argument('--adam-epsilon', type=float, default=1e-6)
117
+ parser.add_argument('--weight-decay', type=float, default=0.1)
118
+ parser.add_argument('--learning-rate', type=float, default=1e-5)
119
+ parser.add_argument('--val-check-interval', type=float, default=1. / 4)
120
+ parser.add_argument('--devices', nargs='+', type=int, required=True)
121
+ parser.add_argument('--model-name', type=str, default="roberta-large")
122
+ parser.add_argument('--ckpt-save-path', type=str, required=True)
123
+ parser.add_argument('--ckpt-comment', type=str, default="")
124
+ parser.add_argument('--trainin-datasets', nargs='+', type=str, default=list(ALL_TRAINING_DATASETS.keys()), choices=list(ALL_TRAINING_DATASETS.keys()))
125
+ parser.add_argument('--data-path', type=str, required=True)
126
+ parser.add_argument('--max-samples-per-dataset', type=int, default=500000)
127
+ parser.add_argument('--do-mlm', type=bool, default=False)
128
+ parser.add_argument('--use-pretrained-model', type=bool, default=True)
129
+
130
+ args = parser.parse_args()
131
+
132
+ seed_everything(args.seed)
133
+
134
+ datasets = {
135
+ name: {
136
+ **ALL_TRAINING_DATASETS[name],
137
+ "size": args.max_samples_per_dataset,
138
+ "data_path": os.path.join(args.data_path, ALL_TRAINING_DATASETS[name]['data_path'])
139
+ }
140
+ for name in args.trainin_datasets
141
+ }
142
+
143
+ train(datasets, args)
144
+