arxiv-titlegen / app.py
Callidior's picture
Update app.py
2d6d7a5
raw
history blame
No virus
3.03 kB
import string
import re
from typing import List
import torch
from transformers import EncoderDecoderModel, BertTokenizerFast
import gradio as gr
def str2title(str):
str = string.capwords(str)
str = str.replace(' - - - ', ' — ')
str = str.replace(' - - ', ' – ')
str = str.replace('( ', '(')
str = str.replace(' )', ')')
str = re.sub(r'(\w)\s+-\s+(\w)', r'\1-\2', str)
str = re.sub(r'(\w|")\s+:', r'\1:', str)
str = re.sub(r'"\s+([^"]+)\s+"', r'"\1"', str)
return str
class Predictor:
def __init__(
self,
model: EncoderDecoderModel,
tokenizer: BertTokenizerFast,
device: torch.device,
num_titles: int,
encoder_max_length: int = 512,
decoder_max_length: int = 32,
) -> None:
super().__init__()
self.model = model
self.tokenizer = tokenizer
self.device = device
self.num_titles = num_titles
self.encoder_max_length = encoder_max_length
self.decoder_max_length = decoder_max_length
def __call__(self, abstract: str, temperature: float) -> List[str]:
temperature = max(1.0, float(temperature))
input_token_ids = self.tokenizer(abstract, truncation=True, max_length=self.encoder_max_length, return_tensors='pt').input_ids.to(self.device)
pred = self.model.generate(
input_token_ids,
decoder_start_token_id=self.tokenizer.cls_token_id, eos_token_id=self.tokenizer.sep_token_id, pad_token_id=self.tokenizer.pad_token_id,
do_sample=(temperature > 1),
num_beams=10,
max_length=self.decoder_max_length,
no_repeat_ngram_size=2,
temperature=temperature,
top_k=50,
num_return_sequences=self.num_titles
)
titles = [str2title(title) for title in tokenizer.batch_decode(pred, True)]
return titles
def create_gradio_ui(predictor):
inputs = [
gr.Textbox(label="Paper Abstract", lines=10),
gr.Slider(label="Creativity", minimum=1.0, maximum=2.0, step=0.1, value=1.5),
]
outputs = ["text"] * predictor.num_titles
description = (
"<center>"
"Bert2Bert model trained on computer science papers from arXiv to generate "
"paper tiles from abstracts."
"</center>"
)
ui = gr.Interface(
fn=predictor,
inputs=inputs,
outputs=outputs,
title="Paper Title Generator",
description=description,
)
return ui
if __name__ == '__main__':
print('Loading model...')
model_path = "Callidior/bert2bert-base-arxiv-titlegen"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = BertTokenizerFast.from_pretrained(model_path)
model = EncoderDecoderModel.from_pretrained(model_path).to(device)
print(f'Ready - running on {device}.')
predictor = Predictor(model, tokenizer, device=device, num_titles=5)
interface = create_gradio_ui(predictor)
interface.launch()