Spaces:
Running
Running
anakin87
commited on
Commit
•
a147158
1
Parent(s):
55e565f
little code improvements
Browse files- Rock_fact_checker.py +1 -4
- app_utils/backend_utils.py +4 -6
- app_utils/frontend_utils.py +17 -8
- pages/Info.py +3 -2
Rock_fact_checker.py
CHANGED
@@ -12,15 +12,13 @@ from app_utils.frontend_utils import (
|
|
12 |
entailment_html_messages,
|
13 |
create_df_for_relevant_snippets,
|
14 |
create_ternary_plot,
|
15 |
-
build_sidebar
|
16 |
)
|
17 |
from app_utils.config import RETRIEVER_TOP_K
|
18 |
|
19 |
|
20 |
def main():
|
21 |
-
|
22 |
statements = load_statements()
|
23 |
-
|
24 |
build_sidebar()
|
25 |
|
26 |
# Persistent state
|
@@ -120,7 +118,6 @@ def main():
|
|
120 |
st.markdown(f"###### Most Relevant snippets:")
|
121 |
df, urls = create_df_for_relevant_snippets(docs)
|
122 |
st.dataframe(df)
|
123 |
-
|
124 |
str_wiki_pages = "Wikipedia source pages: "
|
125 |
for doc, url in urls.items():
|
126 |
str_wiki_pages += f"[{doc}]({url}) "
|
|
|
12 |
entailment_html_messages,
|
13 |
create_df_for_relevant_snippets,
|
14 |
create_ternary_plot,
|
15 |
+
build_sidebar,
|
16 |
)
|
17 |
from app_utils.config import RETRIEVER_TOP_K
|
18 |
|
19 |
|
20 |
def main():
|
|
|
21 |
statements = load_statements()
|
|
|
22 |
build_sidebar()
|
23 |
|
24 |
# Persistent state
|
|
|
118 |
st.markdown(f"###### Most Relevant snippets:")
|
119 |
df, urls = create_df_for_relevant_snippets(docs)
|
120 |
st.dataframe(df)
|
|
|
121 |
str_wiki_pages = "Wikipedia source pages: "
|
122 |
for doc, url in urls.items():
|
123 |
str_wiki_pages += f"[{doc}]({url}) "
|
app_utils/backend_utils.py
CHANGED
@@ -31,7 +31,7 @@ def load_statements():
|
|
31 |
)
|
32 |
def start_haystack():
|
33 |
"""
|
34 |
-
load document store, retriever,
|
35 |
"""
|
36 |
shutil.copy(f"{INDEX_DIR}/faiss_document_store.db", ".")
|
37 |
document_store = FAISSDocumentStore(
|
@@ -39,13 +39,11 @@ def start_haystack():
|
|
39 |
faiss_config_path=f"{INDEX_DIR}/my_faiss_index.json",
|
40 |
)
|
41 |
print(f"Index size: {document_store.get_document_count()}")
|
42 |
-
|
43 |
retriever = EmbeddingRetriever(
|
44 |
document_store=document_store,
|
45 |
embedding_model=RETRIEVER_MODEL,
|
46 |
model_format=RETRIEVER_MODEL_FORMAT,
|
47 |
)
|
48 |
-
|
49 |
entailment_checker = EntailmentChecker(model_name_or_path=NLI_MODEL, use_gpu=False)
|
50 |
|
51 |
pipe = Pipeline()
|
@@ -84,8 +82,8 @@ def query(statement: str, retriever_top_k: int = 5):
|
|
84 |
break
|
85 |
|
86 |
results["agg_entailment_info"] = {
|
87 |
-
"contradiction":
|
88 |
-
"neutral":
|
89 |
-
"entailment":
|
90 |
}
|
91 |
return results
|
|
|
31 |
)
|
32 |
def start_haystack():
|
33 |
"""
|
34 |
+
load document store, retriever, entailment checker and create pipeline
|
35 |
"""
|
36 |
shutil.copy(f"{INDEX_DIR}/faiss_document_store.db", ".")
|
37 |
document_store = FAISSDocumentStore(
|
|
|
39 |
faiss_config_path=f"{INDEX_DIR}/my_faiss_index.json",
|
40 |
)
|
41 |
print(f"Index size: {document_store.get_document_count()}")
|
|
|
42 |
retriever = EmbeddingRetriever(
|
43 |
document_store=document_store,
|
44 |
embedding_model=RETRIEVER_MODEL,
|
45 |
model_format=RETRIEVER_MODEL_FORMAT,
|
46 |
)
|
|
|
47 |
entailment_checker = EntailmentChecker(model_name_or_path=NLI_MODEL, use_gpu=False)
|
48 |
|
49 |
pipe = Pipeline()
|
|
|
82 |
break
|
83 |
|
84 |
results["agg_entailment_info"] = {
|
85 |
+
"contradiction": round(agg_con / scores, 2),
|
86 |
+
"neutral": round(agg_neu / scores, 2),
|
87 |
+
"entailment": round(agg_ent / scores, 2),
|
88 |
}
|
89 |
return results
|
app_utils/frontend_utils.py
CHANGED
@@ -9,8 +9,9 @@ entailment_html_messages = {
|
|
9 |
"neutral": 'The knowledge base is <span style="color:darkgray">neutral</span> about your statement',
|
10 |
}
|
11 |
|
|
|
12 |
def build_sidebar():
|
13 |
-
sidebar="""
|
14 |
<h1 style='text-align: center'>Fact Checking 🎸 Rocks!</h1>
|
15 |
<div style='text-align: center'>
|
16 |
<i>Fact checking baseline combining dense retrieval and textual entailment</i>
|
@@ -20,6 +21,7 @@ def build_sidebar():
|
|
20 |
"""
|
21 |
st.sidebar.markdown(sidebar, unsafe_allow_html=True)
|
22 |
|
|
|
23 |
def set_state_if_absent(key, value):
|
24 |
if key not in st.session_state:
|
25 |
st.session_state[key] = value
|
@@ -33,6 +35,9 @@ def reset_results(*args):
|
|
33 |
|
34 |
|
35 |
def create_ternary_plot(entailment_data):
|
|
|
|
|
|
|
36 |
hover_text = ""
|
37 |
for label, value in entailment_data.items():
|
38 |
hover_text += f"{label}: {value}<br>"
|
@@ -83,14 +88,11 @@ def makeAxis(title, tickangle):
|
|
83 |
}
|
84 |
|
85 |
|
86 |
-
def highlight_cols(s):
|
87 |
-
coldict = {"con": "#FFA07A", "neu": "#E5E4E2", "ent": "#a9d39e"}
|
88 |
-
if s.name in coldict.keys():
|
89 |
-
return ["background-color: {}".format(coldict[s.name])] * len(s)
|
90 |
-
return [""] * len(s)
|
91 |
-
|
92 |
-
|
93 |
def create_df_for_relevant_snippets(docs):
|
|
|
|
|
|
|
|
|
94 |
rows = []
|
95 |
urls = {}
|
96 |
for doc in docs:
|
@@ -106,3 +108,10 @@ def create_df_for_relevant_snippets(docs):
|
|
106 |
rows.append(row)
|
107 |
df = pd.DataFrame(rows).style.apply(highlight_cols)
|
108 |
return df, urls
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
"neutral": 'The knowledge base is <span style="color:darkgray">neutral</span> about your statement',
|
10 |
}
|
11 |
|
12 |
+
|
13 |
def build_sidebar():
|
14 |
+
sidebar = """
|
15 |
<h1 style='text-align: center'>Fact Checking 🎸 Rocks!</h1>
|
16 |
<div style='text-align: center'>
|
17 |
<i>Fact checking baseline combining dense retrieval and textual entailment</i>
|
|
|
21 |
"""
|
22 |
st.sidebar.markdown(sidebar, unsafe_allow_html=True)
|
23 |
|
24 |
+
|
25 |
def set_state_if_absent(key, value):
|
26 |
if key not in st.session_state:
|
27 |
st.session_state[key] = value
|
|
|
35 |
|
36 |
|
37 |
def create_ternary_plot(entailment_data):
|
38 |
+
"""
|
39 |
+
Create a Plotly ternary plot for the given entailment dict.
|
40 |
+
"""
|
41 |
hover_text = ""
|
42 |
for label, value in entailment_data.items():
|
43 |
hover_text += f"{label}: {value}<br>"
|
|
|
88 |
}
|
89 |
|
90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
def create_df_for_relevant_snippets(docs):
|
92 |
+
"""
|
93 |
+
Create a dataframe that contains all relevant snippets.
|
94 |
+
Also returns the URLs
|
95 |
+
"""
|
96 |
rows = []
|
97 |
urls = {}
|
98 |
for doc in docs:
|
|
|
108 |
rows.append(row)
|
109 |
df = pd.DataFrame(rows).style.apply(highlight_cols)
|
110 |
return df, urls
|
111 |
+
|
112 |
+
|
113 |
+
def highlight_cols(s):
|
114 |
+
coldict = {"con": "#FFA07A", "neu": "#E5E4E2", "ent": "#a9d39e"}
|
115 |
+
if s.name in coldict.keys():
|
116 |
+
return ["background-color: {}".format(coldict[s.name])] * len(s)
|
117 |
+
return [""] * len(s)
|
pages/Info.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1 |
import streamlit as st
|
|
|
2 |
from app_utils.frontend_utils import build_sidebar
|
3 |
|
4 |
build_sidebar()
|
5 |
|
6 |
-
with open(
|
7 |
-
readme = fin.read().rpartition(
|
8 |
|
9 |
st.markdown(readme, unsafe_allow_html=True)
|
|
|
1 |
import streamlit as st
|
2 |
+
|
3 |
from app_utils.frontend_utils import build_sidebar
|
4 |
|
5 |
build_sidebar()
|
6 |
|
7 |
+
with open("README.md", "r") as fin:
|
8 |
+
readme = fin.read().rpartition("---")[-1]
|
9 |
|
10 |
st.markdown(readme, unsafe_allow_html=True)
|