titipata's picture
Update app.py
5f5d98b verified
import gradio as gr
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
def get_matches(query, db_name="miread_contrastive"):
"""
Wrapper to call the similarity search on the required index
"""
matches = vecdbs[index_names.index(db_name)].similarity_search_with_score(query, k=60)
return matches
def inference(query, model="miread_contrastive"):
"""
This function processes information retrieved by the get_matches() function
Returns - Gradio update commands for the authors, abstracts and journals tablular output
"""
matches = get_matches(query, model)
auth_counts = {}
journal_bucket = {}
author_table = [] # Author table
abstract_table = [] # Abstract table
# Calculate normalized scores
scores = [round(match[1].item(), 3) for match in matches]
min_score, max_score = min(scores), max(scores)
normaliser = lambda x: round(1 - (x-min_score)/max_score, 3)
for i, (doc, score) in enumerate(matches):
norm_score = round(normaliser(round(score.item(), 3)), 3)
metadata = doc.metadata
# Extract metadata
title = metadata['title']
author = metadata['authors'][0].title()
date = metadata.get('date', 'None')
link = metadata.get('link', 'None')
submitter = metadata.get('submitter', 'None')
journal = metadata['journal'].strip() if metadata['journal'] else 'None'
# Update journal scores
if journal != 'None':
j_bucket[journal] = j_bucket.get(journal, 0) + norm_score
# Build author table (limit 2 entries per author)
if auth_counts.get(author, 0) < 2:
author_table.append([i+1, norm_score, author, title, link, date])
auth_counts[author] = auth_counts.get(author, 0) + 1
# Build abstract table
abstract_table.append([i+1, title, author, submitter, journal, date, link, norm_score])
# Build journal table
del j_bucket['None']
journal_table = [[i+1, j, s] for i, (j, s) in enumerate(
sorted(j_bucket.items(), key=lambda x: x[1], reverse=True)
)]
return [
gr.Dataframe.update(value=abstract_table, visible=True),
gr.Dataframe.update(value=journal_table, visible=True),
gr.Dataframe.update(value=author_table, visible=True)
]
index_names = ["miread_large", "miread_contrastive", "scibert_contrastive"]
model_names = [
"biodatlab/MIReAD-Neuro-Large",
"biodatlab/MIReAD-Neuro-Contrastive",
"biodatlab/SciBERT-Neuro-Contrastive",
]
model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': False}
faiss_embedders = [HuggingFaceEmbeddings(
model_name=name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs) for name in model_names]
vecdbs = [
FAISS.load_local(index_name, faiss_embedder)
for index_name, faiss_embedder in zip(index_names, faiss_embedders)
]
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# NBDT Recommendation Engine for Editors")
gr.Markdown("NBDT Recommendation Engine for Editors is a tool for neuroscience authors/abstracts/journalsrecommendation built for NBDT journal editors. \
It aims to help an editor to find similar reviewers, abstracts, and journals to a given submitted abstract.\
To find a recommendation, paste a `title[SEP]abstract` or `abstract` in the text box below and click on the appropriate \"Find Matches\" button.\
Then, you can hover to authors/abstracts/journals tab to find a suggested list.\
The data in our current demo includes authors associated with the NBDT Journal. We will update the data monthly for an up-to-date publications.")
abst = gr.Textbox(label="Abstract", lines=10)
action_btn1 = gr.Button(value="Find Matches with MIReAD-Neuro-Large")
action_btn2 = gr.Button(value="Find Matches with MIReAD-Neuro-Contrastive")
action_btn3 = gr.Button(
value="Find Matches with SciBERT-Neuro-Contrastive")
with gr.Tab("Authors"):
n_output = gr.Dataframe(
headers=['No.', 'Score', 'Name', 'Title', 'Link', 'Date'],
datatype=['number', 'number', 'str', 'str', 'str', 'str'],
col_count=(6, "fixed"),
wrap=True,
visible=False
)
with gr.Tab("Abstracts"):
a_output = gr.Dataframe(
headers=['No.', 'Title', 'Author', 'Corresponding Author',
'Journal', 'Date', 'Link', 'Score'],
datatype=['number', 'str', 'str', 'str',
'str', 'str', 'str', 'number'],
col_count=(8, "fixed"),
wrap=True,
visible=False
)
with gr.Tab("Journals"):
j_output = gr.Dataframe(
headers=['No.', 'Name', 'Score'],
datatype=['number', 'str', 'number'],
col_count=(3, "fixed"),
wrap=True,
visible=False
)
action_btn1.click(
fn=lambda x: inference(x, index_names[0]),
inputs=[abst],
outputs=[a_output, j_output, n_output],
api_name="neurojane"
)
action_btn2.click(
fn=lambda x: inference(x, index_names[1]),
inputs=[abst],
outputs=[a_output, j_output, n_output],
api_name="neurojane")
action_btn3.click(
fn=lambda x: inference(x, index_names[2]),
inputs=[abst,],
outputs=[a_output, j_output, n_output],
api_name="neurojane")
demo.launch(debug=True)