File size: 4,578 Bytes
fae5c57
2d6d7a5
 
c9518d3
 
 
2d6d7a5
 
 
 
0aa55e8
 
2d6d7a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9518d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d6d7a5
 
 
 
c9518d3
2d6d7a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9518d3
 
 
 
 
 
 
 
 
fae5c57
2d6d7a5
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import os
import string
import re
import json
import requests
from typing import List, Optional

import torch
from transformers import EncoderDecoderModel, BertTokenizerFast

import gradio as gr


def str2title(str):

    str = string.capwords(str)
    str = str.replace(' - - - ', ' — ')
    str = str.replace(' - - ', ' – ')
    str = str.replace('( ', '(')
    str = str.replace(' )', ')')
    str = re.sub(r'(\w)\s+-\s+(\w)', r'\1-\2', str)
    str = re.sub(r'(\w|")\s+:', r'\1:', str)
    str = re.sub(r'"\s+([^"]+)\s+"', r'"\1"', str)
    return str


class Predictor:

    def __init__(
        self,
        model: EncoderDecoderModel,
        tokenizer: BertTokenizerFast,
        device: torch.device,
        num_titles: int,
        encoder_max_length: int = 512,
        decoder_max_length: int = 32,
    ) -> None:
        super().__init__()
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.num_titles = num_titles
        self.encoder_max_length = encoder_max_length
        self.decoder_max_length = decoder_max_length


    def __call__(self, abstract: str, temperature: float) -> List[str]:

        temperature = max(1.0, float(temperature))
        input_token_ids = self.tokenizer(abstract, truncation=True, max_length=self.encoder_max_length, return_tensors='pt').input_ids.to(self.device)
        pred = self.model.generate(
            input_token_ids,
            decoder_start_token_id=self.tokenizer.cls_token_id, eos_token_id=self.tokenizer.sep_token_id, pad_token_id=self.tokenizer.pad_token_id,
            do_sample=(temperature > 1),
            num_beams=10,
            max_length=self.decoder_max_length,
            no_repeat_ngram_size=2,
            temperature=temperature,
            top_k=50,
            num_return_sequences=self.num_titles
        )
        titles = [str2title(title) for title in tokenizer.batch_decode(pred, True)]
        return titles


class HostedInference:

    def __init__(self, model: str, num_titles: int, api_key: Optional[str] = None) -> None:
        super().__init__()
        self.model = model
        self.num_titles = num_titles
        self.api_key = api_key

    def __call__(self, abstract: str, temperature: float) -> List[str]:
        temperature = max(1.0, float(temperature))
        data = json.dumps({
            'inputs' : abstract,
            'parameters' : {
                'do_sample': (temperature > 1),
                'num_beams': 10,
                'temperature': temperature,
                'top_k': 50,
                'no_repeat_ngram_size': 2,
                'num_return_sequences': self.num_titles,
            },
            'options' : { 'use_cache' : False, 'wait_for_model' : True }
        })
        api_url = "https://api-inference.huggingface.co/models/" + self.model
        headers = { "Authorization": f"Bearer {self.api_key}" } if self.api_key is not None else {}
        response = requests.request("POST", api_url, headers=headers, data=data)
        response = json.loads(response.content.decode("utf-8"))

        if isinstance(response, dict) and ('error' in response):
            raise RuntimeError(response['error'])

        titles = [str2title(title['summary_text']) for title in response]
        return titles


def create_gradio_ui(predictor):

    inputs = [
        gr.Textbox(label="Paper Abstract", lines=10),
        gr.Slider(label="Creativity", minimum=1.0, maximum=2.5, step=0.1, value=1.5),
    ]
    outputs = ["text"] * predictor.num_titles

    description = (
        "<center>"
        "Bert2Bert model trained on computer science papers from arXiv to generate "
        "paper tiles from abstracts."
        "</center>"
    )

    ui = gr.Interface(
        fn=predictor,
        inputs=inputs,
        outputs=outputs,
        title="Paper Title Generator",
        description=description,
    )
    return ui


if __name__ == '__main__':
    model_path = "Callidior/bert2bert-base-arxiv-titlegen"

    if torch.cuda.is_available():
        print('Loading model...')
        tokenizer = BertTokenizerFast.from_pretrained(model_path)
        model = EncoderDecoderModel.from_pretrained(model_path).cuda()
        predictor = Predictor(model, tokenizer, device="cuda", num_titles=5)
        print(f'Ready - running on GPU.')
    else:
        print(f'No GPU available - using hosted inference API.')
        predictor = HostedInference(model_path, num_titles=5, api_key=os.environ.get("HF_TOKEN"))
    
    interface = create_gradio_ui(predictor)
    interface.launch()