import string import re from typing import List import torch from transformers import EncoderDecoderModel, BertTokenizerFast import gradio as gr def str2title(str): str = string.capwords(str) str = str.replace(' - - - ', ' — ') str = str.replace(' - - ', ' – ') str = str.replace('( ', '(') str = str.replace(' )', ')') str = re.sub(r'(\w)\s+-\s+(\w)', r'\1-\2', str) str = re.sub(r'(\w|")\s+:', r'\1:', str) str = re.sub(r'"\s+([^"]+)\s+"', r'"\1"', str) return str class Predictor: def __init__( self, model: EncoderDecoderModel, tokenizer: BertTokenizerFast, device: torch.device, num_titles: int, encoder_max_length: int = 512, decoder_max_length: int = 32, ) -> None: super().__init__() self.model = model self.tokenizer = tokenizer self.device = device self.num_titles = num_titles self.encoder_max_length = encoder_max_length self.decoder_max_length = decoder_max_length def __call__(self, abstract: str, temperature: float) -> List[str]: temperature = max(1.0, float(temperature)) input_token_ids = self.tokenizer(abstract, truncation=True, max_length=self.encoder_max_length, return_tensors='pt').input_ids.to(self.device) pred = self.model.generate( input_token_ids, decoder_start_token_id=self.tokenizer.cls_token_id, eos_token_id=self.tokenizer.sep_token_id, pad_token_id=self.tokenizer.pad_token_id, do_sample=(temperature > 1), num_beams=10, max_length=self.decoder_max_length, no_repeat_ngram_size=2, temperature=temperature, top_k=50, num_return_sequences=self.num_titles ) titles = [str2title(title) for title in tokenizer.batch_decode(pred, True)] return titles def create_gradio_ui(predictor): inputs = [ gr.Textbox(label="Paper Abstract", lines=10), gr.Slider(label="Creativity", minimum=1.0, maximum=2.0, step=0.1, value=1.5), ] outputs = ["text"] * predictor.num_titles description = ( "