akdeniz27 commited on
Commit
3fb5cbe
·
1 Parent(s): c8f4a50

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +63 -0
model.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.nn.functional import softmax
2
+ from transformers import MT5ForConditionalGeneration, MT5Tokenizer
3
+
4
+ def process_nli(premise: str, hypothesis: str):
5
+ """ process to required xnli format with task prefix """
6
+ return "".join(['xnli: premise: ', premise, ' hypothesis: ', hypothesis])
7
+
8
+ def setModel(model_name):
9
+ tokenizer = MT5Tokenizer.from_pretrained(model_name)
10
+ model = MT5ForConditionalGeneration.from_pretrained(model_name)
11
+ model.eval()
12
+ return model, tokenizer
13
+
14
+ def runModel(model_name, sequence_to_classify, candidate_labels, hypothesis_template):
15
+ ENTAILS_LABEL = "▁0"
16
+ NEUTRAL_LABEL = "▁1"
17
+ CONTRADICTS_LABEL = "▁2"
18
+
19
+ label_inds = tokenizer.convert_tokens_to_ids([ENTAILS_LABEL, NEUTRAL_LABEL, CONTRADICTS_LABEL])
20
+
21
+ # construct sequence of premise, hypothesis pairs
22
+ pairs = [(sequence_to_classify, hypothesis_template.format(label)) for label in candidate_labels]
23
+ # format for mt5 xnli task
24
+ seqs = [process_nli(premise=premise, hypothesis=hypothesis) for premise, hypothesis in pairs]
25
+
26
+ model, tokenizer = setModel(model_name)
27
+ inputs = tokenizer.batch_encode_plus(seqs, return_tensors="pt", padding=True)
28
+ out = model.generate(**inputs, output_scores=True, return_dict_in_generate=True, num_beams=1)
29
+
30
+ # sanity check that our sequences are expected length (1 + start token + end token = 3)
31
+ for i, seq in enumerate(out.sequences):
32
+ assert len(seq) == 3
33
+
34
+ # get the scores for our only token of interest
35
+ # we'll now treat these like the output logits of a `*ForSequenceClassification` model
36
+ scores = out.scores[0]
37
+
38
+ # scores has a size of the model's vocab.
39
+ # However, for this task we have a fixed set of labels
40
+ # sanity check that these labels are always the top 3 scoring
41
+ for i, sequence_scores in enumerate(scores):
42
+ top_scores = sequence_scores.argsort()[-3:]
43
+ assert set(top_scores.tolist()) == set(label_inds)
44
+
45
+ # cut down scores to our task labels
46
+ scores = scores[:, label_inds]
47
+
48
+ # new indices of entailment and contradiction in scores
49
+ entailment_ind = 0
50
+ contradiction_ind = 2
51
+
52
+ # we can show, per item, the entailment vs contradiction probas
53
+ entail_vs_contra_scores = scores[:, [entailment_ind, contradiction_ind]]
54
+ entail_vs_contra_probas = softmax(entail_vs_contra_scores, dim=1)
55
+
56
+ # or we can show probas similar to `ZeroShotClassificationPipeline`
57
+ # this gives a zero-shot classification style output across labels
58
+ entail_scores = scores[:, entailment_ind]
59
+ entail_probas = softmax(entail_scores, dim=0)
60
+
61
+ dd = dict(zip(candidate_labels, entail_probas.tolist()))
62
+ ddd = dict(sorted(dd.items(), key = lambda x: x[1], reverse = True))
63
+ return ddd