Hong commited on
Commit
1931fdb
·
1 Parent(s): f5a1b52

Upload BART_utils.py

Browse files
Files changed (1) hide show
  1. BART_utils.py +52 -0
BART_utils.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sre_parse import Tokenizer
2
+ import numpy as np
3
+ from load_data import *
4
+ import streamlit as st
5
+ import pickle
6
+ import torch
7
+
8
+ import requests
9
+ import json
10
+
11
+
12
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
13
+
14
+ from transformers import AutoTokenizer
15
+ from transformers import AutoModelForSequenceClassification
16
+
17
+ tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-mnli")
18
+ nli_model = (
19
+ AutoModelForSequenceClassification.from_pretrained(
20
+ "facebook/bart-large-mnli"
21
+ ).cuda()
22
+ if torch.cuda.is_available()
23
+ else AutoModelForSequenceClassification.from_pretrained("facebook/bart-large-mnli")
24
+ )
25
+
26
+
27
+ def get_prob(sequence, label):
28
+ premise = sequence
29
+ hypothesis = f"This example is {label}."
30
+
31
+ # run through model pre-trained on MNLI
32
+ x = tokenizer.encode(
33
+ premise, hypothesis, return_tensors="pt", truncation_strategy="only_first"
34
+ )
35
+ logits = nli_model(x.to(device))[0]
36
+
37
+ # we throw away "neutral" (dim 1) and take the probability of
38
+ # "entailment" (2) as the probability of the label being true
39
+ entail_contradiction_logits = logits[:, [0, 2]]
40
+ probs = entail_contradiction_logits.softmax(dim=1)
41
+ prob_label_is_true = probs[:, 1]
42
+ return prob_label_is_true[0].item()
43
+
44
+
45
+ def get_taggs(sequence, labels, thred=0.5):
46
+ out = []
47
+ for l in labels:
48
+ temp = get_prob(sequence, l)
49
+ if temp >= thred:
50
+ out.append((l, temp))
51
+ out = sorted(out, key=lambda x: x[1], reverse=True)
52
+ return out