paper-matching / app.py
jskim's picture
update
963bf46
raw
history blame
7.86 kB
import gradio as gr
import os
from transformers import AutoTokenizer, AutoModel
from sentence_transformers import SentenceTransformer
import pickle
import nltk
nltk.download('punkt') # tokenizer
nltk.download('averaged_perceptron_tagger') # postagger
from input_format import *
from score import *
# load document scoring model
torch.cuda.is_available = lambda : False
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pretrained_model = 'allenai/specter'
tokenizer = AutoTokenizer.from_pretrained(pretrained_model)
doc_model = AutoModel.from_pretrained(pretrained_model)
doc_model.to(device)
# load sentence model
sent_model = SentenceTransformer('sentence-transformers/gtr-t5-base')
sent_model.to(device)
def get_similar_paper(
abstract_text_input,
pdf_file_input,
author_id_input,
num_papers_show=10
):
print('retrieving similar papers')
input_sentences = sent_tokenize(abstract_text_input)
# TODO handle pdf file input
if pdf_file_input is not None:
name = None
papers = []
raise ValueError('Use submission abstract instead.')
else:
# Get author papers from id
name, papers = get_text_from_author_id(author_id_input)
# Compute Doc-level affinity scores for the Papers
print('computing scores')
titles, abstracts, doc_scores = compute_document_score(
doc_model,
tokenizer,
abstract_text_input,
papers,
batch=50
)
tmp = {
'titles': titles,
'abstracts': abstracts,
'doc_scores': doc_scores
}
pickle.dump(tmp, open('paper_info.pkl', 'wb'))
# Select top K choices of papers to show
titles = titles[:num_papers_show]
abstracts = abstracts[:num_papers_show]
doc_scores = doc_scores[:num_papers_show]
display_title = ['[ %0.3f ] %s'%(s, t) for t, s in zip(titles, doc_scores)]
print('retrieval done')
return (
gr.update(choices=display_title, interactive=True, visible=True), # set of papers
gr.update(choices=input_sentences, interactive=True), # submission sentences
gr.update(visible=True), # title row
gr.update(visible=True), # abstract row
gr.update(visible=True) # button
)
def get_highlights(
abstract_text_input,
pdf_file_input,
abstract,
K=2
):
print('obtaining highlights')
# Compute sent-level and phrase-level affinity scores for each papers
sent_ids, sent_scores, info = get_highlight_info(
sent_model,
abstract_text_input,
abstract,
K=K
)
input_sentences = sent_tokenize(abstract_text_input)
num_sents = len(input_sentences)
word_scores = dict()
# different highlights for each input sentence
for i in range(num_sents):
word_scores[str(i)] = {
"original": abstract,
"interpretation": list(zip(info['all_words'], info[i]['scores']))
} # format to feed to for Gradio Interpretation component
tmp = {
'source_sentences': input_sentences,
'highlight': word_scores
}
pickle.dump(tmp, open('highlight_info.pkl', 'wb'))
print('done')
# update the visibility of radio choices
return gr.update(visible=True)
def update_name(author_id_input):
# update the name of the author based on the id input
name, _ = get_text_from_author_id(author_id_input)
return gr.update(value=name)
def change_output_highlight(source_sent_choice):
# change the output highlight based on the sentence selected from the submission
fname = 'highlight_info.pkl'
if os.path.exists(fname):
tmp = pickle.load(open(fname, 'rb'))
source_sents = tmp['source_sentences']
highlights = tmp['highlight']
for i, s in enumerate(source_sents):
#print('changing highlight')
if source_sent_choice == s:
return highlights[str(i)]
else:
return
def change_paper(selected_papers_radio):
# change the paper to show based on the paper selected
fname = 'paper_info.pkl'
if os.path.exists(fname):
tmp = pickle.load(open(fname, 'rb'))
for title, abstract, aff_score in zip(tmp['titles'], tmp['abstracts'], tmp['doc_scores']):
display_title = '[ %0.3f ] %s'%(aff_score, title)
if display_title == selected_papers_radio:
#print('changing paper')
return title, abstract, aff_score # update title, abstract, and affinity score fields
else:
return
with gr.Blocks() as demo:
### INPUT
with gr.Row() as input_row:
with gr.Column():
abstract_text_input = gr.Textbox(label='Submission Abstract')
with gr.Column():
pdf_file_input = gr.File(label='OR upload a submission PDF File')
with gr.Column():
with gr.Row():
author_id_input = gr.Textbox(label='Reviewer ID (Semantic Scholar)')
with gr.Row():
name = gr.Textbox(label='Confirm Reviewer Name', interactive=False)
author_id_input.change(fn=update_name, inputs=author_id_input, outputs=name)
with gr.Row():
compute_btn = gr.Button('Search Similar Papers from the Reviewer')
### PAPER INFORMATION
# show multiple papers in radio check box to select from
with gr.Row():
selected_papers_radio = gr.Radio(
choices=[], # will be udpated with the button click
visible=False, # also will be updated with the button click
label='Selected Top Papers from the Reviewer'
)
# selected paper information
with gr.Row(visible=False) as title_row:
with gr.Column(scale=3):
paper_title = gr.Textbox(label='Title', interactive=False)
with gr.Column(scale=1):
affinity= gr.Number(label='Affinity', interactive=False, value=0)
with gr.Row(visibe=False) as abstract_row:
paper_abstract = gr.Textbox(label='Abstract', interactive=False, visible=False)
with gr.Row(visible=False) as explain_button_row:
explain_btn = gr.Button('Show Relevant Parts from Selected Paper')
### RELEVANT PARTS (HIGHLIGHTS)
with gr.Row():
with gr.Column(scale=2): # text from submission
source_sentences = gr.Radio(
choices=[],
visible=False,
label='Sentences from Submission Abstract',
)
with gr.Column(scale=3): # highlighted text from paper
highlight = gr.components.Interpretation(paper_abstract)
### EVENT LISTENERS
# retrieve similar papers
compute_btn.click(
fn=get_similar_paper,
inputs=[
abstract_text_input,
pdf_file_input,
author_id_input
],
outputs=[
selected_papers_radio,
source_sentences,
title_row,
paper_abstract,
explain_button_row,
]
)
# get highlights
explain_btn.click(
fn=get_highlights,
inputs=[
abstract_text_input,
pdf_file_input,
paper_abstract
],
outputs=source_sentences
)
# change highlight based on selected sentences from submission
source_sentences.change(
fn=change_output_highlight,
inputs=source_sentences,
outputs=highlight
)
# change paper to show based on selected papers
selected_papers_radio.change(
fn=change_paper,
inputs=selected_papers_radio,
outputs= [
paper_title,
paper_abstract,
affinity
]
)
if __name__ == "__main__":
demo.launch()