Create model.py
Browse files
model.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.nn.functional import softmax
|
2 |
+
from transformers import MT5ForConditionalGeneration, MT5Tokenizer
|
3 |
+
|
4 |
+
def process_nli(premise: str, hypothesis: str):
|
5 |
+
""" process to required xnli format with task prefix """
|
6 |
+
return "".join(['xnli: premise: ', premise, ' hypothesis: ', hypothesis])
|
7 |
+
|
8 |
+
def setModel(model_name):
|
9 |
+
tokenizer = MT5Tokenizer.from_pretrained(model_name)
|
10 |
+
model = MT5ForConditionalGeneration.from_pretrained(model_name)
|
11 |
+
model.eval()
|
12 |
+
return model, tokenizer
|
13 |
+
|
14 |
+
def runModel(model_name, sequence_to_classify, candidate_labels, hypothesis_template):
|
15 |
+
ENTAILS_LABEL = "▁0"
|
16 |
+
NEUTRAL_LABEL = "▁1"
|
17 |
+
CONTRADICTS_LABEL = "▁2"
|
18 |
+
|
19 |
+
label_inds = tokenizer.convert_tokens_to_ids([ENTAILS_LABEL, NEUTRAL_LABEL, CONTRADICTS_LABEL])
|
20 |
+
|
21 |
+
# construct sequence of premise, hypothesis pairs
|
22 |
+
pairs = [(sequence_to_classify, hypothesis_template.format(label)) for label in candidate_labels]
|
23 |
+
# format for mt5 xnli task
|
24 |
+
seqs = [process_nli(premise=premise, hypothesis=hypothesis) for premise, hypothesis in pairs]
|
25 |
+
|
26 |
+
model, tokenizer = setModel(model_name)
|
27 |
+
inputs = tokenizer.batch_encode_plus(seqs, return_tensors="pt", padding=True)
|
28 |
+
out = model.generate(**inputs, output_scores=True, return_dict_in_generate=True, num_beams=1)
|
29 |
+
|
30 |
+
# sanity check that our sequences are expected length (1 + start token + end token = 3)
|
31 |
+
for i, seq in enumerate(out.sequences):
|
32 |
+
assert len(seq) == 3
|
33 |
+
|
34 |
+
# get the scores for our only token of interest
|
35 |
+
# we'll now treat these like the output logits of a `*ForSequenceClassification` model
|
36 |
+
scores = out.scores[0]
|
37 |
+
|
38 |
+
# scores has a size of the model's vocab.
|
39 |
+
# However, for this task we have a fixed set of labels
|
40 |
+
# sanity check that these labels are always the top 3 scoring
|
41 |
+
for i, sequence_scores in enumerate(scores):
|
42 |
+
top_scores = sequence_scores.argsort()[-3:]
|
43 |
+
assert set(top_scores.tolist()) == set(label_inds)
|
44 |
+
|
45 |
+
# cut down scores to our task labels
|
46 |
+
scores = scores[:, label_inds]
|
47 |
+
|
48 |
+
# new indices of entailment and contradiction in scores
|
49 |
+
entailment_ind = 0
|
50 |
+
contradiction_ind = 2
|
51 |
+
|
52 |
+
# we can show, per item, the entailment vs contradiction probas
|
53 |
+
entail_vs_contra_scores = scores[:, [entailment_ind, contradiction_ind]]
|
54 |
+
entail_vs_contra_probas = softmax(entail_vs_contra_scores, dim=1)
|
55 |
+
|
56 |
+
# or we can show probas similar to `ZeroShotClassificationPipeline`
|
57 |
+
# this gives a zero-shot classification style output across labels
|
58 |
+
entail_scores = scores[:, entailment_ind]
|
59 |
+
entail_probas = softmax(entail_scores, dim=0)
|
60 |
+
|
61 |
+
dd = dict(zip(candidate_labels, entail_probas.tolist()))
|
62 |
+
ddd = dict(sorted(dd.items(), key = lambda x: x[1], reverse = True))
|
63 |
+
return ddd
|