Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -131,36 +131,31 @@ model = torch.load("./model.pt", map_location=torch.device('cpu'))
|
|
131 |
|
132 |
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
133 |
|
134 |
-
text
|
135 |
|
136 |
-
encoding = tokenizer(text, return_tensors="pt")
|
137 |
-
encoding = {k: v.to(trainer.model.device) for k,v in encoding.items()}
|
138 |
|
139 |
-
outputs = trainer.model(**encoding)
|
140 |
|
141 |
-
logits = outputs.logits
|
142 |
-
logits.shape
|
143 |
-
|
144 |
-
|
145 |
-
# apply sigmoid + threshold
|
146 |
-
sigmoid = torch.nn.Sigmoid()
|
147 |
-
probs = sigmoid(logits.squeeze().cpu())
|
148 |
-
predictions = np.zeros(probs.shape)
|
149 |
-
predictions[np.where(probs >= 0.5)] = 1
|
150 |
-
# turn predicted id's into actual label names
|
151 |
-
predicted_labels = [id2label[idx] for idx, label in enumerate(predictions) if label == 1.0]
|
152 |
-
console.log("a")
|
153 |
-
console.log(predicted_labels)
|
154 |
-
console.log("a")
|
155 |
|
156 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
|
158 |
|
159 |
|
160 |
-
inp = [gr.Textbox(label='
|
161 |
out = gr.Textbox(label='Output')
|
162 |
text_button = gr.Button("Flip")
|
163 |
-
text_button.click(
|
164 |
|
165 |
interface = gr.Interface.load(input=inp,output=out,
|
166 |
title = title,
|
|
|
131 |
|
132 |
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
133 |
|
134 |
+
def predict(text):
|
135 |
|
136 |
+
encoding = tokenizer(text, return_tensors="pt")
|
137 |
+
encoding = {k: v.to(trainer.model.device) for k,v in encoding.items()}
|
138 |
|
139 |
+
outputs = trainer.model(**encoding)
|
140 |
|
141 |
+
logits = outputs.logits
|
142 |
+
logits.shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
|
144 |
|
145 |
+
# apply sigmoid + threshold
|
146 |
+
sigmoid = torch.nn.Sigmoid()
|
147 |
+
probs = sigmoid(logits.squeeze().cpu())
|
148 |
+
predictions = np.zeros(probs.shape)
|
149 |
+
predictions[np.where(probs >= 0.5)] = 1
|
150 |
+
# turn predicted id's into actual label names
|
151 |
+
return([id2label[idx] for idx, label in enumerate(predictions) if label == 1.0])
|
152 |
|
153 |
|
154 |
|
155 |
+
inp = [gr.Textbox(label='Text or tweet text', placeholder="Insert text")]
|
156 |
out = gr.Textbox(label='Output')
|
157 |
text_button = gr.Button("Flip")
|
158 |
+
text_button.click(predict, inputs=inp, outputs=out)
|
159 |
|
160 |
interface = gr.Interface.load(input=inp,output=out,
|
161 |
title = title,
|