Spaces:
Running
Running
anakin87
commited on
Commit
•
d6bdb02
1
Parent(s):
a027256
class entailment_checker
Browse files- README.md +1 -1
- app.py → Rock_fact_checker.py +0 -0
- app_utils/entailment_checker.py +66 -0
- pages/{app.py → Info.py} +0 -0
- pages/info.py +0 -3
README.md
CHANGED
@@ -5,7 +5,7 @@ colorFrom: purple
|
|
5 |
colorTo: blue
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.10.0
|
8 |
-
app_file:
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
11 |
---
|
|
|
5 |
colorTo: blue
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.10.0
|
8 |
+
app_file: rock_fact_checker.py
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
11 |
---
|
app.py → Rock_fact_checker.py
RENAMED
File without changes
|
app_utils/entailment_checker.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional
|
2 |
+
|
3 |
+
from transformers import AutoModelForSequenceClassification,AutoTokenizer,AutoConfig
|
4 |
+
import torch
|
5 |
+
from haystack.nodes.base import BaseComponent
|
6 |
+
from haystack.modeling.utils import initialize_device_settings
|
7 |
+
from haystack.schema import Document, Answer, Span
|
8 |
+
|
9 |
+
class EntailmentChecker(BaseComponent):
|
10 |
+
"""
|
11 |
+
This node checks the entailment between every document content and the query.
|
12 |
+
It enrichs the documents metadata with entailment_info
|
13 |
+
"""
|
14 |
+
|
15 |
+
outgoing_edges = 1
|
16 |
+
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
model_name_or_path: str = "roberta-large-mnli",
|
20 |
+
model_version: Optional[str] = None,
|
21 |
+
tokenizer: Optional[str] = None,
|
22 |
+
use_gpu: bool = True,
|
23 |
+
batch_size: int = 16,
|
24 |
+
):
|
25 |
+
"""
|
26 |
+
Load a Natural Language Inference model from Transformers.
|
27 |
+
|
28 |
+
:param model_name_or_path: Directory of a saved model or the name of a public model.
|
29 |
+
See https://huggingface.co/models for full list of available models.
|
30 |
+
:param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
|
31 |
+
:param tokenizer: Name of the tokenizer (usually the same as model)
|
32 |
+
:param use_gpu: Whether to use GPU (if available).
|
33 |
+
# :param batch_size: Number of Documents to be processed at a time.
|
34 |
+
"""
|
35 |
+
super().__init__()
|
36 |
+
|
37 |
+
self.devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=False)
|
38 |
+
|
39 |
+
tokenizer = tokenizer or model_name_or_path
|
40 |
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
|
41 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path=model_name_or_path,revision=model_version)
|
42 |
+
self.batch_size = batch_size
|
43 |
+
self.model.to(str(self.devices[0]))
|
44 |
+
|
45 |
+
id2label = AutoConfig.from_pretrained(model_name_or_path).id2label
|
46 |
+
self.labels= [id2label[k].lower() for k in sorted(id2label)]
|
47 |
+
if 'entailment' not in self.labels:
|
48 |
+
raise ValueError("The model config must contain entailment value in the id2label dict.")
|
49 |
+
|
50 |
+
def run(self, query: str, documents: List[Document]):
|
51 |
+
for doc in documents:
|
52 |
+
entailment_dict=self.get_entailment(premise=doc.content, hypotesis=query)
|
53 |
+
doc.meta['entailment_info']=entailment_dict
|
54 |
+
return {'documents':documents}, "output_1"
|
55 |
+
|
56 |
+
def run_batch():
|
57 |
+
pass
|
58 |
+
|
59 |
+
def get_entailment(self, premise,hypotesis):
|
60 |
+
with torch.no_grad():
|
61 |
+
inputs = self.tokenizer(f'{premise}{self.tokenizer.sep_token}{hypotesis}', return_tensors="pt").to(self.devices[0])
|
62 |
+
out = self.model(**inputs)
|
63 |
+
logits = out.logits
|
64 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)[0,:].cpu().detach().numpy()
|
65 |
+
entailment_dict={k.lower():v for k,v in zip (self.labels, probs)}
|
66 |
+
return entailment_dict
|
pages/{app.py → Info.py}
RENAMED
File without changes
|
pages/info.py
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
import streamlit as st
|
2 |
-
|
3 |
-
st.title("Test")
|
|
|
|
|
|
|
|