arxiv-titlegen / app.py
Callidior's picture
Update app.py
2d6d7a5
raw
history blame
3.03 kB
import string
import re
from typing import List
import torch
from transformers import EncoderDecoderModel, BertTokenizerFast
import gradio as gr
def str2title(str):
str = string.capwords(str)
str = str.replace(' - - - ', ' — ')
str = str.replace(' - - ', ' – ')
str = str.replace('( ', '(')
str = str.replace(' )', ')')
str = re.sub(r'(\w)\s+-\s+(\w)', r'\1-\2', str)
str = re.sub(r'(\w|")\s+:', r'\1:', str)
str = re.sub(r'"\s+([^"]+)\s+"', r'"\1"', str)
return str
class Predictor:
def __init__(
self,
model: EncoderDecoderModel,
tokenizer: BertTokenizerFast,
device: torch.device,
num_titles: int,
encoder_max_length: int = 512,
decoder_max_length: int = 32,
) -> None:
super().__init__()
self.model = model
self.tokenizer = tokenizer
self.device = device
self.num_titles = num_titles
self.encoder_max_length = encoder_max_length
self.decoder_max_length = decoder_max_length
def __call__(self, abstract: str, temperature: float) -> List[str]:
temperature = max(1.0, float(temperature))
input_token_ids = self.tokenizer(abstract, truncation=True, max_length=self.encoder_max_length, return_tensors='pt').input_ids.to(self.device)
pred = self.model.generate(
input_token_ids,
decoder_start_token_id=self.tokenizer.cls_token_id, eos_token_id=self.tokenizer.sep_token_id, pad_token_id=self.tokenizer.pad_token_id,
do_sample=(temperature > 1),
num_beams=10,
max_length=self.decoder_max_length,
no_repeat_ngram_size=2,
temperature=temperature,
top_k=50,
num_return_sequences=self.num_titles
)
titles = [str2title(title) for title in tokenizer.batch_decode(pred, True)]
return titles
def create_gradio_ui(predictor):
inputs = [
gr.Textbox(label="Paper Abstract", lines=10),
gr.Slider(label="Creativity", minimum=1.0, maximum=2.0, step=0.1, value=1.5),
]
outputs = ["text"] * predictor.num_titles
description = (
"<center>"
"Bert2Bert model trained on computer science papers from arXiv to generate "
"paper tiles from abstracts."
"</center>"
)
ui = gr.Interface(
fn=predictor,
inputs=inputs,
outputs=outputs,
title="Paper Title Generator",
description=description,
)
return ui
if __name__ == '__main__':
print('Loading model...')
model_path = "Callidior/bert2bert-base-arxiv-titlegen"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = BertTokenizerFast.from_pretrained(model_path)
model = EncoderDecoderModel.from_pretrained(model_path).to(device)
print(f'Ready - running on {device}.')
predictor = Predictor(model, tokenizer, device=device, num_titles=5)
interface = create_gradio_ui(predictor)
interface.launch()