File size: 3,028 Bytes
2d6d7a5
 
 
 
 
 
 
0aa55e8
 
2d6d7a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import string
import re
from typing import List

import torch
from transformers import EncoderDecoderModel, BertTokenizerFast

import gradio as gr


def str2title(str):

    str = string.capwords(str)
    str = str.replace(' - - - ', ' — ')
    str = str.replace(' - - ', ' – ')
    str = str.replace('( ', '(')
    str = str.replace(' )', ')')
    str = re.sub(r'(\w)\s+-\s+(\w)', r'\1-\2', str)
    str = re.sub(r'(\w|")\s+:', r'\1:', str)
    str = re.sub(r'"\s+([^"]+)\s+"', r'"\1"', str)
    return str


class Predictor:

    def __init__(
        self,
        model: EncoderDecoderModel,
        tokenizer: BertTokenizerFast,
        device: torch.device,
        num_titles: int,
        encoder_max_length: int = 512,
        decoder_max_length: int = 32,
    ) -> None:
        super().__init__()
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.num_titles = num_titles
        self.encoder_max_length = encoder_max_length
        self.decoder_max_length = decoder_max_length


    def __call__(self, abstract: str, temperature: float) -> List[str]:

        temperature = max(1.0, float(temperature))
        input_token_ids = self.tokenizer(abstract, truncation=True, max_length=self.encoder_max_length, return_tensors='pt').input_ids.to(self.device)
        pred = self.model.generate(
            input_token_ids,
            decoder_start_token_id=self.tokenizer.cls_token_id, eos_token_id=self.tokenizer.sep_token_id, pad_token_id=self.tokenizer.pad_token_id,
            do_sample=(temperature > 1),
            num_beams=10,
            max_length=self.decoder_max_length,
            no_repeat_ngram_size=2,
            temperature=temperature,
            top_k=50,
            num_return_sequences=self.num_titles
        )
        titles = [str2title(title) for title in tokenizer.batch_decode(pred, True)]
        return titles


def create_gradio_ui(predictor):

    inputs = [
        gr.Textbox(label="Paper Abstract", lines=10),
        gr.Slider(label="Creativity", minimum=1.0, maximum=2.0, step=0.1, value=1.5),
    ]
    outputs = ["text"] * predictor.num_titles

    description = (
        "<center>"
        "Bert2Bert model trained on computer science papers from arXiv to generate "
        "paper tiles from abstracts."
        "</center>"
    )

    ui = gr.Interface(
        fn=predictor,
        inputs=inputs,
        outputs=outputs,
        title="Paper Title Generator",
        description=description,
    )
    return ui


if __name__ == '__main__':
    print('Loading model...')
    model_path = "Callidior/bert2bert-base-arxiv-titlegen"
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    tokenizer = BertTokenizerFast.from_pretrained(model_path)
    model = EncoderDecoderModel.from_pretrained(model_path).to(device)
    print(f'Ready - running on {device}.')
    
    predictor = Predictor(model, tokenizer, device=device, num_titles=5)
    interface = create_gradio_ui(predictor)
    interface.launch()