PANH commited on
Commit
75681b7
1 Parent(s): 1f76ea6

Create modelalign

Browse files
Files changed (1) hide show
  1. modelalign +308 -0
modelalign ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Tuple
3
+ from transformers import AdamW, get_linear_schedule_with_warmup, AutoConfig
4
+ from transformers import BertForPreTraining, BertModel, RobertaModel, AlbertModel, AlbertForMaskedLM, RobertaForMaskedLM
5
+ import torch
6
+ import torch.nn as nn
7
+ import pytorch_lightning as pl
8
+ from sklearn.metrics import f1_score
9
+ from dataclasses import dataclass
10
+
11
+
12
+
13
+ class BERTAlignModel(pl.LightningModule):
14
+ def __init__(self, model='bert-base-uncased', using_pretrained=True, *args, **kwargs) -> None:
15
+ super().__init__()
16
+ # Already defined in lightning: self.device
17
+ self.save_hyperparameters()
18
+ self.model = model
19
+
20
+ if 'muppet' in model:
21
+ assert using_pretrained == True, "Only support pretrained muppet!"
22
+ self.base_model = RobertaModel.from_pretrained(model)
23
+ self.mlm_head = RobertaForMaskedLM(AutoConfig.from_pretrained(model)).lm_head
24
+
25
+ elif 'roberta' in model:
26
+ if using_pretrained:
27
+ self.base_model = RobertaModel.from_pretrained(model)
28
+ self.mlm_head = RobertaForMaskedLM.from_pretrained(model).lm_head
29
+ else:
30
+ self.base_model = RobertaModel(AutoConfig.from_pretrained(model))
31
+ self.mlm_head = RobertaForMaskedLM(AutoConfig.from_pretrained(model)).lm_head
32
+
33
+ elif 'albert' in model:
34
+ if using_pretrained:
35
+ self.base_model = AlbertModel.from_pretrained(model)
36
+ self.mlm_head = AlbertForMaskedLM.from_pretrained(model).predictions
37
+ else:
38
+ self.base_model = AlbertModel(AutoConfig.from_pretrained(model))
39
+ self.mlm_head = AlbertForMaskedLM(AutoConfig.from_pretrained(model)).predictions
40
+
41
+ elif 'bert' in model:
42
+ if using_pretrained:
43
+ self.base_model = BertModel.from_pretrained(model)
44
+ self.mlm_head = BertForPreTraining.from_pretrained(model).cls.predictions
45
+ else:
46
+ self.base_model = BertModel(AutoConfig.from_pretrained(model))
47
+ self.mlm_head = BertForPreTraining(AutoConfig.from_pretrained(model)).cls.predictions
48
+
49
+ elif 'electra' in model:
50
+ self.generator = BertModel(AutoConfig.from_pretrained('prajjwal1/bert-small'))
51
+ self.generator_mlm = BertForPreTraining(AutoConfig.from_pretrained('prajjwal1/bert-small')).cls.predictions
52
+
53
+ self.base_model = BertModel(AutoConfig.from_pretrained('bert-base-uncased'))
54
+ self.discriminator_predictor = ElectraDiscriminatorPredictions(self.base_model.config)
55
+
56
+
57
+ self.bin_layer = nn.Linear(self.base_model.config.hidden_size, 2)
58
+ self.tri_layer = nn.Linear(self.base_model.config.hidden_size, 3)
59
+ self.reg_layer = nn.Linear(self.base_model.config.hidden_size, 1)
60
+
61
+ self.dropout = nn.Dropout(p=0.1)
62
+
63
+ self.need_mlm = True
64
+ self.is_finetune = False
65
+ self.mlm_loss_factor = 0.5
66
+
67
+ self.softmax = nn.Softmax(dim=-1)
68
+
69
+ def forward(self, batch):
70
+ if 'electra' in self.model:
71
+ return self.electra_forward(batch)
72
+ base_model_output = self.base_model(
73
+ input_ids = batch['input_ids'],
74
+ attention_mask = batch['attention_mask'],
75
+ token_type_ids = batch['token_type_ids'] if 'token_type_ids' in batch.keys() else None
76
+ )
77
+
78
+ prediction_scores = self.mlm_head(base_model_output.last_hidden_state) ## sequence_output for mlm
79
+ seq_relationship_score = self.bin_layer(self.dropout(base_model_output.pooler_output)) ## pooled output for classification
80
+ tri_label_score = self.tri_layer(self.dropout(base_model_output.pooler_output))
81
+ reg_label_score = self.reg_layer(base_model_output.pooler_output)
82
+
83
+ total_loss = None
84
+ if 'mlm_label' in batch.keys(): ### 'mlm_label' and 'align_label' when training
85
+ ce_loss_fct = nn.CrossEntropyLoss(reduction='sum')
86
+ masked_lm_loss = ce_loss_fct(prediction_scores.view(-1, self.base_model.config.vocab_size), batch['mlm_label'].view(-1)) #/ self.con vocabulary
87
+ next_sentence_loss = ce_loss_fct(seq_relationship_score.view(-1, 2), batch['align_label'].view(-1)) / math.log(2)
88
+ tri_label_loss = ce_loss_fct(tri_label_score.view(-1, 3), batch['tri_label'].view(-1)) / math.log(3)
89
+ reg_label_loss = self.mse_loss(reg_label_score.view(-1), batch['reg_label'].view(-1), reduction='sum')
90
+
91
+ masked_lm_loss_num = torch.sum(batch['mlm_label'].view(-1) != -100)
92
+ next_sentence_loss_num = torch.sum(batch['align_label'].view(-1) != -100)
93
+ tri_label_loss_num = torch.sum(batch['tri_label'].view(-1) != -100)
94
+ reg_label_loss_num = torch.sum(batch['reg_label'].view(-1) != -100.0)
95
+
96
+ return ModelOutput(
97
+ loss=total_loss,
98
+ all_loss=[masked_lm_loss, next_sentence_loss, tri_label_loss, reg_label_loss] if 'mlm_label' in batch.keys() else None,
99
+ loss_nums=[masked_lm_loss_num, next_sentence_loss_num, tri_label_loss_num, reg_label_loss_num] if 'mlm_label' in batch.keys() else None,
100
+ prediction_logits=prediction_scores,
101
+ seq_relationship_logits=seq_relationship_score,
102
+ tri_label_logits=tri_label_score,
103
+ reg_label_logits=reg_label_score,
104
+ hidden_states=base_model_output.hidden_states,
105
+ attentions=base_model_output.attentions
106
+ )
107
+
108
+ def electra_forward(self, batch):
109
+ if 'mlm_label' in batch.keys():
110
+ ce_loss_fct = nn.CrossEntropyLoss()
111
+ generator_output = self.generator_mlm(self.generator(
112
+ input_ids = batch['input_ids'],
113
+ attention_mask = batch['attention_mask'],
114
+ token_type_ids = batch['token_type_ids'] if 'token_type_ids' in batch.keys() else None
115
+ ).last_hidden_state)
116
+ masked_lm_loss = ce_loss_fct(generator_output.view(-1, self.generator.config.vocab_size), batch['mlm_label'].view(-1))
117
+
118
+ hallucinated_tokens = batch['input_ids'].clone()
119
+
120
+ hallucinated_tokens[batch['mlm_label']!=-100] = torch.argmax(generator_output, dim=-1)[batch['mlm_label']!=-100]
121
+ replaced_token_label = (batch['input_ids'] == hallucinated_tokens).long()#.type(torch.LongTensor) #[batch['mlm_label'] == -100] = -100
122
+ replaced_token_label[batch['mlm_label']!=-100] = (batch['mlm_label'] == hallucinated_tokens)[batch['mlm_label']!=-100].long()
123
+ replaced_token_label[batch['input_ids'] == 0] = -100 ### ignore paddings
124
+
125
+ base_model_output = self.base_model(
126
+ input_ids = hallucinated_tokens if 'mlm_label' in batch.keys() else batch['input_ids'],
127
+ attention_mask = batch['attention_mask'],
128
+ token_type_ids = batch['token_type_ids'] if 'token_type_ids' in batch.keys() else None
129
+ )
130
+ hallu_detect_score = self.discriminator_predictor(base_model_output.last_hidden_state)
131
+ seq_relationship_score = self.bin_layer(self.dropout(base_model_output.pooler_output)) ## pooled output for classification
132
+ tri_label_score = self.tri_layer(self.dropout(base_model_output.pooler_output))
133
+ reg_label_score = self.reg_layer(base_model_output.pooler_output)
134
+
135
+ total_loss = None
136
+
137
+ if 'mlm_label' in batch.keys(): ### 'mlm_label' and 'align_label' when training
138
+ total_loss = []
139
+ ce_loss_fct = nn.CrossEntropyLoss()
140
+ hallu_detect_loss = ce_loss_fct(hallu_detect_score.view(-1,2),replaced_token_label.view(-1))
141
+ next_sentence_loss = ce_loss_fct(seq_relationship_score.view(-1, 2), batch['align_label'].view(-1))
142
+ tri_label_loss = ce_loss_fct(tri_label_score.view(-1, 3), batch['tri_label'].view(-1))
143
+ reg_label_loss = self.mse_loss(reg_label_score.view(-1), batch['reg_label'].view(-1))
144
+
145
+ total_loss.append(10.0 * hallu_detect_loss if not torch.isnan(hallu_detect_loss).item() else 0.)
146
+ total_loss.append(0.2 * masked_lm_loss if (not torch.isnan(masked_lm_loss).item() and self.need_mlm) else 0.)
147
+ total_loss.append(next_sentence_loss if not torch.isnan(next_sentence_loss).item() else 0.)
148
+ total_loss.append(tri_label_loss if not torch.isnan(tri_label_loss).item() else 0.)
149
+ total_loss.append(reg_label_loss if not torch.isnan(reg_label_loss).item() else 0.)
150
+
151
+ total_loss = sum(total_loss)
152
+
153
+ return ModelOutput(
154
+ loss=total_loss,
155
+ all_loss=[masked_lm_loss, next_sentence_loss, tri_label_loss, reg_label_loss, hallu_detect_loss] if 'mlm_label' in batch.keys() else None,
156
+ prediction_logits=hallu_detect_score,
157
+ seq_relationship_logits=seq_relationship_score,
158
+ tri_label_logits=tri_label_score,
159
+ reg_label_logits=reg_label_score,
160
+ hidden_states=base_model_output.hidden_states,
161
+ attentions=base_model_output.attentions
162
+ )
163
+
164
+ def training_step(self, train_batch, batch_idx):
165
+ output = self(train_batch)
166
+
167
+ return {'losses': output.all_loss, 'loss_nums': output.loss_nums}
168
+
169
+ def training_step_end(self, step_output):
170
+ losses = step_output['losses']
171
+ loss_nums = step_output['loss_nums']
172
+ assert len(loss_nums) == len(losses), 'loss_num should be the same length as losses'
173
+
174
+ loss_mlm_num = torch.sum(loss_nums[0])
175
+ loss_bin_num = torch.sum(loss_nums[1])
176
+ loss_tri_num = torch.sum(loss_nums[2])
177
+ loss_reg_num = torch.sum(loss_nums[3])
178
+
179
+ loss_mlm = torch.sum(losses[0]) / loss_mlm_num if loss_mlm_num > 0 else 0.
180
+ loss_bin = torch.sum(losses[1]) / loss_bin_num if loss_bin_num > 0 else 0.
181
+ loss_tri = torch.sum(losses[2]) / loss_tri_num if loss_tri_num > 0 else 0.
182
+ loss_reg = torch.sum(losses[3]) / loss_reg_num if loss_reg_num > 0 else 0.
183
+
184
+ total_loss = self.mlm_loss_factor * loss_mlm + loss_bin + loss_tri + loss_reg
185
+
186
+ self.log('train_loss', total_loss)# , sync_dist=True
187
+ self.log('mlm_loss', loss_mlm)
188
+ self.log('bin_label_loss', loss_bin)
189
+ self.log('tri_label_loss', loss_tri)
190
+ self.log('reg_label_loss', loss_reg)
191
+
192
+ return total_loss
193
+
194
+ def validation_step(self, val_batch, batch_idx):
195
+ if not self.is_finetune:
196
+ with torch.no_grad():
197
+ output = self(val_batch)
198
+
199
+ return {'losses': output.all_loss, 'loss_nums': output.loss_nums}
200
+
201
+ with torch.no_grad():
202
+ output = self(val_batch)['seq_relationship_logits']
203
+ output = self.softmax(output)[:, 1].tolist()
204
+ pred = [int(align_prob>0.5) for align_prob in output]
205
+
206
+ labels = val_batch['align_label'].tolist()
207
+
208
+ return {"pred": pred, 'labels': labels}#, "preds":preds, "labels":x['labels']}
209
+
210
+ def validation_step_end(self, step_output):
211
+ losses = step_output['losses']
212
+ loss_nums = step_output['loss_nums']
213
+ assert len(loss_nums) == len(losses), 'loss_num should be the same length as losses'
214
+
215
+ loss_mlm_num = torch.sum(loss_nums[0])
216
+ loss_bin_num = torch.sum(loss_nums[1])
217
+ loss_tri_num = torch.sum(loss_nums[2])
218
+ loss_reg_num = torch.sum(loss_nums[3])
219
+
220
+ loss_mlm = torch.sum(losses[0]) / loss_mlm_num if loss_mlm_num > 0 else 0.
221
+ loss_bin = torch.sum(losses[1]) / loss_bin_num if loss_bin_num > 0 else 0.
222
+ loss_tri = torch.sum(losses[2]) / loss_tri_num if loss_tri_num > 0 else 0.
223
+ loss_reg = torch.sum(losses[3]) / loss_reg_num if loss_reg_num > 0 else 0.
224
+
225
+ total_loss = self.mlm_loss_factor * loss_mlm + loss_bin + loss_tri + loss_reg
226
+
227
+ self.log('train_loss', total_loss)# , sync_dist=True
228
+ self.log('mlm_loss', loss_mlm)
229
+ self.log('bin_label_loss', loss_bin)
230
+ self.log('tri_label_loss', loss_tri)
231
+ self.log('reg_label_loss', loss_reg)
232
+
233
+ return total_loss
234
+
235
+ def validation_epoch_end(self, outputs):
236
+ if not self.is_finetune:
237
+ total_loss = torch.stack(outputs).mean()
238
+ self.log("val_loss", total_loss, prog_bar=True, sync_dist=True)
239
+
240
+ else:
241
+ all_predictions = []
242
+ all_labels = []
243
+ for each_output in outputs:
244
+ all_predictions.extend(each_output['pred'])
245
+ all_labels.extend(each_output['labels'])
246
+
247
+ self.log("f1", f1_score(all_labels, all_predictions), prog_bar=True, sync_dist=True)
248
+
249
+ def configure_optimizers(self):
250
+ """Prepare optimizer and schedule (linear warmup and decay)"""
251
+ no_decay = ["bias", "LayerNorm.weight"]
252
+ optimizer_grouped_parameters = [
253
+ {
254
+ "params": [p for n, p in self.named_parameters() if not any(nd in n for nd in no_decay)],
255
+ "weight_decay": self.hparams.weight_decay,
256
+ },
257
+ {
258
+ "params": [p for n, p in self.named_parameters() if any(nd in n for nd in no_decay)],
259
+ "weight_decay": 0.0,
260
+ },
261
+ ]
262
+ optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
263
+
264
+ scheduler = get_linear_schedule_with_warmup(
265
+ optimizer,
266
+ num_warmup_steps=int(self.hparams.warmup_steps_portion * self.trainer.estimated_stepping_batches),
267
+ num_training_steps=self.trainer.estimated_stepping_batches,
268
+ )
269
+ scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
270
+ return [optimizer], [scheduler]
271
+
272
+ def mse_loss(self, input, target, ignored_index=-100.0, reduction='mean'):
273
+ mask = (target == ignored_index)
274
+ out = (input[~mask]-target[~mask])**2
275
+ if reduction == "mean":
276
+ return out.mean()
277
+ elif reduction == "sum":
278
+ return out.sum()
279
+
280
+ class ElectraDiscriminatorPredictions(nn.Module):
281
+ """Prediction module for the discriminator, made up of two dense layers."""
282
+
283
+ def __init__(self, config):
284
+ super().__init__()
285
+
286
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
287
+ self.dense_prediction = nn.Linear(config.hidden_size, 2)
288
+ self.config = config
289
+ self.gelu = nn.GELU()
290
+
291
+ def forward(self, discriminator_hidden_states):
292
+ hidden_states = self.dense(discriminator_hidden_states)
293
+ hidden_states = self.gelu(hidden_states)
294
+ logits = self.dense_prediction(hidden_states).squeeze(-1)
295
+
296
+ return logits
297
+
298
+ @dataclass
299
+ class ModelOutput():
300
+ loss: Optional[torch.FloatTensor] = None
301
+ all_loss: Optional[list] = None
302
+ loss_nums: Optional[list] = None
303
+ prediction_logits: torch.FloatTensor = None
304
+ seq_relationship_logits: torch.FloatTensor = None
305
+ tri_label_logits: torch.FloatTensor = None
306
+ reg_label_logits: torch.FloatTensor = None
307
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
308
+ attentions: Optional[Tuple[torch.FloatTensor]] = None