Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os | |
from transformers import AutoTokenizer, AutoModel | |
from sentence_transformers import SentenceTransformer | |
import pickle | |
import nltk | |
nltk.download('punkt') # tokenizer | |
nltk.download('averaged_perceptron_tagger') # postagger | |
from input_format import * | |
from score import * | |
# load document scoring model | |
torch.cuda.is_available = lambda : False | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
pretrained_model = 'allenai/specter' | |
tokenizer = AutoTokenizer.from_pretrained(pretrained_model) | |
doc_model = AutoModel.from_pretrained(pretrained_model) | |
doc_model.to(device) | |
# load sentence model | |
sent_model = SentenceTransformer('sentence-transformers/gtr-t5-base') | |
sent_model.to(device) | |
def get_similar_paper( | |
abstract_text_input, | |
pdf_file_input, | |
author_id_input, | |
num_papers_show=10 | |
): | |
print('retrieving similar papers') | |
input_sentences = sent_tokenize(abstract_text_input) | |
# TODO handle pdf file input | |
if pdf_file_input is not None: | |
name = None | |
papers = [] | |
raise ValueError('Use submission abstract instead.') | |
else: | |
# Get author papers from id | |
name, papers = get_text_from_author_id(author_id_input) | |
# Compute Doc-level affinity scores for the Papers | |
print('computing scores') | |
titles, abstracts, doc_scores = compute_document_score( | |
doc_model, | |
tokenizer, | |
abstract_text_input, | |
papers, | |
batch=50 | |
) | |
tmp = { | |
'titles': titles, | |
'abstracts': abstracts, | |
'doc_scores': doc_scores | |
} | |
pickle.dump(tmp, open('paper_info.pkl', 'wb')) | |
# Select top K choices of papers to show | |
titles = titles[:num_papers_show] | |
abstracts = abstracts[:num_papers_show] | |
doc_scores = doc_scores[:num_papers_show] | |
display_title = ['[ %0.3f ] %s'%(s, t) for t, s in zip(titles, doc_scores)] | |
print('retrieval done') | |
return ( | |
gr.update(choices=display_title, interactive=True, visible=True), # set of papers | |
gr.update(choices=input_sentences, interactive=True), # submission sentences | |
gr.update(visible=True), # title row | |
gr.update(visible=True), # abstract row | |
gr.update(visible=True) # button | |
) | |
def get_highlights( | |
abstract_text_input, | |
pdf_file_input, | |
abstract, | |
K=2 | |
): | |
print('obtaining highlights') | |
# Compute sent-level and phrase-level affinity scores for each papers | |
sent_ids, sent_scores, info = get_highlight_info( | |
sent_model, | |
abstract_text_input, | |
abstract, | |
K=K | |
) | |
input_sentences = sent_tokenize(abstract_text_input) | |
num_sents = len(input_sentences) | |
word_scores = dict() | |
# different highlights for each input sentence | |
for i in range(num_sents): | |
word_scores[str(i)] = { | |
"original": abstract, | |
"interpretation": list(zip(info['all_words'], info[i]['scores'])) | |
} # format to feed to for Gradio Interpretation component | |
tmp = { | |
'source_sentences': input_sentences, | |
'highlight': word_scores | |
} | |
pickle.dump(tmp, open('highlight_info.pkl', 'wb')) | |
print('done') | |
# update the visibility of radio choices | |
return gr.update(visible=True) | |
def update_name(author_id_input): | |
# update the name of the author based on the id input | |
name, _ = get_text_from_author_id(author_id_input) | |
return gr.update(value=name) | |
def change_output_highlight(source_sent_choice): | |
# change the output highlight based on the sentence selected from the submission | |
fname = 'highlight_info.pkl' | |
if os.path.exists(fname): | |
tmp = pickle.load(open(fname, 'rb')) | |
source_sents = tmp['source_sentences'] | |
highlights = tmp['highlight'] | |
for i, s in enumerate(source_sents): | |
#print('changing highlight') | |
if source_sent_choice == s: | |
return highlights[str(i)] | |
else: | |
return | |
def change_paper(selected_papers_radio): | |
# change the paper to show based on the paper selected | |
fname = 'paper_info.pkl' | |
if os.path.exists(fname): | |
tmp = pickle.load(open(fname, 'rb')) | |
for title, abstract, aff_score in zip(tmp['titles'], tmp['abstracts'], tmp['doc_scores']): | |
display_title = '[ %0.3f ] %s'%(aff_score, title) | |
if display_title == selected_papers_radio: | |
#print('changing paper') | |
return title, abstract, aff_score # update title, abstract, and affinity score fields | |
else: | |
return | |
with gr.Blocks() as demo: | |
### INPUT | |
with gr.Row() as input_row: | |
with gr.Column(): | |
abstract_text_input = gr.Textbox(label='Submission Abstract') | |
with gr.Column(): | |
pdf_file_input = gr.File(label='OR upload a submission PDF File') | |
with gr.Column(): | |
with gr.Row(): | |
author_id_input = gr.Textbox(label='Reviewer ID (Semantic Scholar)') | |
with gr.Row(): | |
name = gr.Textbox(label='Confirm Reviewer Name', interactive=False) | |
author_id_input.change(fn=update_name, inputs=author_id_input, outputs=name) | |
with gr.Row(): | |
compute_btn = gr.Button('Search Similar Papers from the Reviewer') | |
### PAPER INFORMATION | |
# show multiple papers in radio check box to select from | |
with gr.Row(): | |
selected_papers_radio = gr.Radio( | |
choices=[], # will be udpated with the button click | |
visible=False, # also will be updated with the button click | |
label='Selected Top Papers from the Reviewer' | |
) | |
# selected paper information | |
with gr.Row(visible=False) as title_row: | |
with gr.Column(scale=3): | |
paper_title = gr.Textbox(label='Title', interactive=False) | |
with gr.Column(scale=1): | |
affinity= gr.Number(label='Affinity', interactive=False, value=0) | |
with gr.Row(visibe=False) as abstract_row: | |
paper_abstract = gr.Textbox(label='Abstract', interactive=False, visible=False) | |
with gr.Row(visible=False) as explain_button_row: | |
explain_btn = gr.Button('Show Relevant Parts from Selected Paper') | |
### RELEVANT PARTS (HIGHLIGHTS) | |
with gr.Row(): | |
with gr.Column(scale=2): # text from submission | |
source_sentences = gr.Radio( | |
choices=[], | |
visible=False, | |
label='Sentences from Submission Abstract', | |
) | |
with gr.Column(scale=3): # highlighted text from paper | |
highlight = gr.components.Interpretation(paper_abstract) | |
### EVENT LISTENERS | |
# retrieve similar papers | |
compute_btn.click( | |
fn=get_similar_paper, | |
inputs=[ | |
abstract_text_input, | |
pdf_file_input, | |
author_id_input | |
], | |
outputs=[ | |
selected_papers_radio, | |
source_sentences, | |
title_row, | |
paper_abstract, | |
explain_button_row, | |
] | |
) | |
# get highlights | |
explain_btn.click( | |
fn=get_highlights, | |
inputs=[ | |
abstract_text_input, | |
pdf_file_input, | |
paper_abstract | |
], | |
outputs=source_sentences | |
) | |
# change highlight based on selected sentences from submission | |
source_sentences.change( | |
fn=change_output_highlight, | |
inputs=source_sentences, | |
outputs=highlight | |
) | |
# change paper to show based on selected papers | |
selected_papers_radio.change( | |
fn=change_paper, | |
inputs=selected_papers_radio, | |
outputs= [ | |
paper_title, | |
paper_abstract, | |
affinity | |
] | |
) | |
if __name__ == "__main__": | |
demo.launch() | |