Spaces:
Runtime error
Runtime error
File size: 4,578 Bytes
fae5c57 2d6d7a5 c9518d3 2d6d7a5 0aa55e8 2d6d7a5 c9518d3 2d6d7a5 c9518d3 2d6d7a5 c9518d3 fae5c57 2d6d7a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import os
import string
import re
import json
import requests
from typing import List, Optional
import torch
from transformers import EncoderDecoderModel, BertTokenizerFast
import gradio as gr
def str2title(str):
str = string.capwords(str)
str = str.replace(' - - - ', ' — ')
str = str.replace(' - - ', ' – ')
str = str.replace('( ', '(')
str = str.replace(' )', ')')
str = re.sub(r'(\w)\s+-\s+(\w)', r'\1-\2', str)
str = re.sub(r'(\w|")\s+:', r'\1:', str)
str = re.sub(r'"\s+([^"]+)\s+"', r'"\1"', str)
return str
class Predictor:
def __init__(
self,
model: EncoderDecoderModel,
tokenizer: BertTokenizerFast,
device: torch.device,
num_titles: int,
encoder_max_length: int = 512,
decoder_max_length: int = 32,
) -> None:
super().__init__()
self.model = model
self.tokenizer = tokenizer
self.device = device
self.num_titles = num_titles
self.encoder_max_length = encoder_max_length
self.decoder_max_length = decoder_max_length
def __call__(self, abstract: str, temperature: float) -> List[str]:
temperature = max(1.0, float(temperature))
input_token_ids = self.tokenizer(abstract, truncation=True, max_length=self.encoder_max_length, return_tensors='pt').input_ids.to(self.device)
pred = self.model.generate(
input_token_ids,
decoder_start_token_id=self.tokenizer.cls_token_id, eos_token_id=self.tokenizer.sep_token_id, pad_token_id=self.tokenizer.pad_token_id,
do_sample=(temperature > 1),
num_beams=10,
max_length=self.decoder_max_length,
no_repeat_ngram_size=2,
temperature=temperature,
top_k=50,
num_return_sequences=self.num_titles
)
titles = [str2title(title) for title in tokenizer.batch_decode(pred, True)]
return titles
class HostedInference:
def __init__(self, model: str, num_titles: int, api_key: Optional[str] = None) -> None:
super().__init__()
self.model = model
self.num_titles = num_titles
self.api_key = api_key
def __call__(self, abstract: str, temperature: float) -> List[str]:
temperature = max(1.0, float(temperature))
data = json.dumps({
'inputs' : abstract,
'parameters' : {
'do_sample': (temperature > 1),
'num_beams': 10,
'temperature': temperature,
'top_k': 50,
'no_repeat_ngram_size': 2,
'num_return_sequences': self.num_titles,
},
'options' : { 'use_cache' : False, 'wait_for_model' : True }
})
api_url = "https://api-inference.huggingface.co/models/" + self.model
headers = { "Authorization": f"Bearer {self.api_key}" } if self.api_key is not None else {}
response = requests.request("POST", api_url, headers=headers, data=data)
response = json.loads(response.content.decode("utf-8"))
if isinstance(response, dict) and ('error' in response):
raise RuntimeError(response['error'])
titles = [str2title(title['summary_text']) for title in response]
return titles
def create_gradio_ui(predictor):
inputs = [
gr.Textbox(label="Paper Abstract", lines=10),
gr.Slider(label="Creativity", minimum=1.0, maximum=2.5, step=0.1, value=1.5),
]
outputs = ["text"] * predictor.num_titles
description = (
"<center>"
"Bert2Bert model trained on computer science papers from arXiv to generate "
"paper tiles from abstracts."
"</center>"
)
ui = gr.Interface(
fn=predictor,
inputs=inputs,
outputs=outputs,
title="Paper Title Generator",
description=description,
)
return ui
if __name__ == '__main__':
model_path = "Callidior/bert2bert-base-arxiv-titlegen"
if torch.cuda.is_available():
print('Loading model...')
tokenizer = BertTokenizerFast.from_pretrained(model_path)
model = EncoderDecoderModel.from_pretrained(model_path).cuda()
predictor = Predictor(model, tokenizer, device="cuda", num_titles=5)
print(f'Ready - running on GPU.')
else:
print(f'No GPU available - using hosted inference API.')
predictor = HostedInference(model_path, num_titles=5, api_key=os.environ.get("HF_TOKEN"))
interface = create_gradio_ui(predictor)
interface.launch()
|