Spaces:
Runtime error
Runtime error
import streamlit as st | |
import GPTHelper | |
from sentence_transformers import CrossEncoder | |
from pymed import PubMed | |
import pandas as pd | |
import plotly.express as px | |
import logging | |
from langdetect import detect | |
from typing import Dict, List | |
if 'valid_inputs_received' not in st.session_state: | |
st.session_state['valid_inputs_received'] = False | |
def get_articles(query, fetcher) -> Dict[List[str], List[str]]: | |
# Fetches articles using pymed. Increasing max_results results in longer loading times. | |
results = fetcher.query(query, max_results=50) | |
conclusions = [] | |
titles = [] | |
links = [] | |
for article in results: | |
article_id = 0 # If PubMed search fails to return anything | |
try: | |
article_id = article.pubmed_id[:8] # Sometimes pymed wrongly returns a long list of ids. Use only the first | |
# [] can cause the cross-encoder to misinterpret string as a list | |
title = article.title.replace('[', '(').replace(']', ')') | |
conclusion = article.conclusions | |
abstract = article.abstract | |
article_url = f'https://pubmed.ncbi.nlm.nih.gov/{article_id}/' | |
article_link = f'<a href="{article_url}" style="color: black; font-size: 16px; ' \ | |
f'text-decoration: underline;">PubMed ID: {article_id}</a>' # Injects a link to plotly | |
if conclusion: | |
# Not all articles come with the provided conclusions. Abstract is used alternatively. | |
conclusion = conclusion.replace('[', '(').replace(']', ')') | |
conclusions.append(title+'\n'+conclusion) | |
titles.append(title) # Title is added to the conclusion to improve relevance ranking. | |
links.append(article_link) | |
elif abstract: | |
abstract = abstract.replace('[', '(').replace(']', ')') | |
conclusions.append(title + '\n' + abstract) | |
titles.append(title) | |
links.append(article_link) | |
except Exception as e: | |
logging.warning(f'Error reading article: {article_id}: ', exc_info=e) | |
return { | |
'Conclusions': conclusions, | |
'Links': links | |
} | |
def load_cross_encoder(): | |
# The pretrained cross-encoder model used for reranking. Can be substituted with a different one. | |
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') | |
return cross_encoder | |
def load_pubmed_fetcher(): | |
pubmed = PubMed(tool='PubmedFactChecker', email='[email protected]') | |
return pubmed | |
def run_ui(): | |
# This function controls the whole app flow. | |
st.set_page_config(page_title='PUBMED FACT-CHECKER', page_icon='📖') | |
sidebar = st.sidebar | |
sidebar.title('ABOUT') | |
sidebar.write(''' | |
The PubMed fact-checker app enables users to verify biomedical claims by comparing them against | |
research papers available on PubMed. \n | |
As the number of self-proclaimed experts continues to rise, | |
so does the risk of harmful misinformation. This app showcases the potential of Large Language Models | |
to provide accurate and valuable information to people. | |
''') | |
sidebar.title('EXAMPLES') | |
sidebar.write('Try one of the below examples to see PubMed fact-checker in action.') | |
st.title('PubMed FACT CHECKER') | |
with st.form(key='fact_form'): | |
fact = st.text_input('Fact:', placeholder='Enter your fact', key='form_input') | |
submitted = st.form_submit_button('Fact-Check') | |
if sidebar.button('Mediterranean diet helps with weight loss.', use_container_width=250): | |
submitted = True | |
fact = 'Mediterranean diet helps with weight loss.' | |
if sidebar.button('Low Carb High Fat diet is healthy in long term.', use_container_width=250): | |
submitted = True | |
fact = 'Low Carb High Fat diet is healthy in long term.' | |
if sidebar.button('Vaccines are a cause of autism.', use_container_width=250): | |
submitted = True | |
fact = 'Vaccines are a cause of autism.' | |
sidebar.title('HOW IT WORKS') | |
sidebar.write('Source code and an in-depth app description available at:') | |
sidebar.info('**GitHub: [@jacinthes](https://github.com/jacinthes/PubMed-fact-checker)**', icon="💻") | |
sidebar.title('DISCLAIMER') | |
sidebar.write('This project is meant for educational and research purposes. \n' | |
'PubMed fact-checker may provide inaccurate information.') | |
if not submitted and not st.session_state.valid_inputs_received: | |
st.stop() | |
elif submitted and not fact: | |
st.warning('Please enter your fact before fact-checking.') | |
st.session_state.valid_inputs_received = False | |
st.stop() | |
elif submitted and not detect(fact) == 'en': | |
st.warning('Please enter valid text in English. For short inputs, language detection is sometimes inaccurate.' | |
' Try making the fact more verbose.') | |
st.session_state.valid_inputs_received = False | |
st.stop() | |
elif submitted and not len(fact) < 75: | |
st.warning('To ensure accurate searching, please keep your fact under 75 characters.') | |
st.session_state.valid_inputs_received = False | |
st.stop() | |
elif submitted and '?' in fact: | |
st.warning('Please state a fact. PubMed Fact Checker is good at verifying facts, ' | |
'it is not meant to answer questions.') | |
st.session_state.valid_inputs_received = False | |
st.stop() | |
elif submitted or st.session_state.valid_inputs_received: | |
pubmed_query = GPTHelper.gpt35_rephrase(fact) # Call gpt3.5 to rephrase the fact as a PubMed query. | |
pubmed = load_pubmed_fetcher() | |
with st.spinner('Fetching articles...'): | |
articles = get_articles(pubmed_query, pubmed) | |
article_conclusions = articles['Conclusions'] | |
article_links = articles['Links'] | |
if len(article_conclusions) == 0: | |
# If nothing is returned by pymed, inform user. | |
st.info( | |
"Unfortunately, I couldn't find anything for your search.\n" | |
"Don't let that discourage you, I have over 35 million citations in my database.\n" | |
"I am sure your next search will be more successful." | |
) | |
st.stop() | |
cross_inp = [[fact, conclusions] for conclusions in article_conclusions] | |
with st.spinner('Assessing article relevancy...'): | |
cross_encoder = load_cross_encoder() | |
cross_scores = cross_encoder.predict(cross_inp) # Calculate relevancy using the defined cross-encoder. | |
df = pd.DataFrame({ | |
'Link': article_links, | |
'Conclusion': article_conclusions, | |
'Score': cross_scores | |
}) | |
df.sort_values(by=['Score'], ascending=False, inplace=True) | |
df = df[df['Score'] > 0] # Only keep articles with relevancy score above 0. | |
if df.shape[0] == 0: # If no relevant article is found, inform the user. | |
st.info( | |
"Unfortunately, I couldn't find anything for your search.\n" | |
"Don't let that discourage you, I have over 35 million citations in my database.\n" | |
"I am sure your next search will be more successful." | |
) | |
st.stop() | |
df = df.head(10) # Keep only 10 most relevant articles. This is done to control OpenAI costs and load time. | |
progress_text = 'Assessing the validity of the fact based on relevant research papers.' | |
fact_checking_bar = st.progress(0, text=progress_text) | |
step = 100/df.shape[0] | |
percent_complete = 0 | |
predictions = [] | |
for index, row in df.iterrows(): | |
prediction = GPTHelper.gpt35_check_fact(row['Conclusion'], fact) # Prompt to GPT3.5 to fact-check | |
# For output purposes I use True, False and Undetermined as labels. | |
if prediction == 'Entails': | |
predictions.append('True') | |
elif prediction == 'Contradicts': | |
predictions.append('False') | |
elif prediction == 'Undetermined': | |
predictions.append(prediction) | |
else: | |
# If GPT3.5 returns an invalid response. Has not happened during testing. | |
predictions.append('Invalid') | |
logging.warning(f'Unexpected prediction: {prediction}') | |
percent_complete += step/100 | |
fact_checking_bar.progress(round(percent_complete, 2), text=progress_text) | |
fact_checking_bar.empty() | |
df['Prediction'] = predictions | |
df = df[df.Prediction != 'Invalid'] # Drop rows with invalid prediction. | |
# Prepare DataFrame for plotly sunburst chart. | |
totals = df.groupby('Prediction').size().to_dict() | |
df['Total'] = df['Prediction'].map(totals) | |
fig = px.sunburst(df, path=['Prediction', 'Link'], values='Total', height=600, width=600, color='Prediction', | |
color_discrete_map={ | |
'False': "#FF8384", | |
'True': "#A5D46A", | |
'Undetermined': "#FFDF80" | |
} | |
) | |
fig.update_layout( | |
margin=dict(l=20, r=20, t=20, b=20), | |
font_size=32, | |
font_color='#000000' | |
) | |
st.write(f'According to PubMed "{fact}" is:') | |
st.plotly_chart(fig, use_container_width=True) | |
if __name__ == "__main__": | |
run_ui() | |