import os import string import re import json import requests from typing import List, Optional import torch from transformers import EncoderDecoderModel, BertTokenizerFast import gradio as gr def str2title(str): str = string.capwords(str) str = str.replace(' - - - ', ' — ') str = str.replace(' - - ', ' – ') str = str.replace('( ', '(') str = str.replace(' )', ')') str = re.sub(r'(\w)\s+-\s+(\w)', r'\1-\2', str) str = re.sub(r'(\w|")\s+:', r'\1:', str) str = re.sub(r'"\s+([^"]+)\s+"', r'"\1"', str) return str class Predictor: def __init__( self, model: EncoderDecoderModel, tokenizer: BertTokenizerFast, device: torch.device, num_titles: int, encoder_max_length: int = 512, decoder_max_length: int = 32, ) -> None: super().__init__() self.model = model self.tokenizer = tokenizer self.device = device self.num_titles = num_titles self.encoder_max_length = encoder_max_length self.decoder_max_length = decoder_max_length def __call__(self, abstract: str, temperature: float) -> List[str]: temperature = max(1.0, float(temperature)) input_token_ids = self.tokenizer(abstract, truncation=True, max_length=self.encoder_max_length, return_tensors='pt').input_ids.to(self.device) pred = self.model.generate( input_token_ids, decoder_start_token_id=self.tokenizer.cls_token_id, eos_token_id=self.tokenizer.sep_token_id, pad_token_id=self.tokenizer.pad_token_id, do_sample=(temperature > 1), num_beams=10, max_length=self.decoder_max_length, no_repeat_ngram_size=2, temperature=temperature, top_k=50, num_return_sequences=self.num_titles ) titles = [str2title(title) for title in tokenizer.batch_decode(pred, True)] return titles class HostedInference: def __init__(self, model: str, num_titles: int, api_key: Optional[str] = None) -> None: super().__init__() self.model = model self.num_titles = num_titles self.api_key = api_key def __call__(self, abstract: str, temperature: float) -> List[str]: temperature = max(1.0, float(temperature)) data = json.dumps({ 'inputs' : abstract, 'parameters' : { 'do_sample': (temperature > 1), 'num_beams': 10, 'temperature': temperature, 'top_k': 50, 'no_repeat_ngram_size': 2, 'num_return_sequences': self.num_titles, }, 'options' : { 'use_cache' : False, 'wait_for_model' : True } }) api_url = "https://api-inference.huggingface.co/models/" + self.model headers = { "Authorization": f"Bearer {self.api_key}" } if self.api_key is not None else {} response = requests.request("POST", api_url, headers=headers, data=data) response = json.loads(response.content.decode("utf-8")) if isinstance(response, dict) and ('error' in response): raise RuntimeError(response['error']) titles = [str2title(title['summary_text']) for title in response] return titles def create_gradio_ui(predictor): inputs = [ gr.Textbox(label="Paper Abstract", lines=10), gr.Slider(label="Creativity", minimum=1.0, maximum=2.5, step=0.1, value=1.5), ] outputs = ["text"] * predictor.num_titles description = ( "