Callidior commited on
Commit
2d6d7a5
1 Parent(s): 0aa55e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -1
app.py CHANGED
@@ -1,3 +1,98 @@
 
 
 
 
 
 
 
1
  import gradio as gr
2
 
3
- gr.Interface.load("models/Callidior/bert2bert-base-arxiv-titlegen").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import string
2
+ import re
3
+ from typing import List
4
+
5
+ import torch
6
+ from transformers import EncoderDecoderModel, BertTokenizerFast
7
+
8
  import gradio as gr
9
 
10
+
11
+ def str2title(str):
12
+
13
+ str = string.capwords(str)
14
+ str = str.replace(' - - - ', ' — ')
15
+ str = str.replace(' - - ', ' – ')
16
+ str = str.replace('( ', '(')
17
+ str = str.replace(' )', ')')
18
+ str = re.sub(r'(\w)\s+-\s+(\w)', r'\1-\2', str)
19
+ str = re.sub(r'(\w|")\s+:', r'\1:', str)
20
+ str = re.sub(r'"\s+([^"]+)\s+"', r'"\1"', str)
21
+ return str
22
+
23
+
24
+ class Predictor:
25
+
26
+ def __init__(
27
+ self,
28
+ model: EncoderDecoderModel,
29
+ tokenizer: BertTokenizerFast,
30
+ device: torch.device,
31
+ num_titles: int,
32
+ encoder_max_length: int = 512,
33
+ decoder_max_length: int = 32,
34
+ ) -> None:
35
+ super().__init__()
36
+ self.model = model
37
+ self.tokenizer = tokenizer
38
+ self.device = device
39
+ self.num_titles = num_titles
40
+ self.encoder_max_length = encoder_max_length
41
+ self.decoder_max_length = decoder_max_length
42
+
43
+
44
+ def __call__(self, abstract: str, temperature: float) -> List[str]:
45
+
46
+ temperature = max(1.0, float(temperature))
47
+ input_token_ids = self.tokenizer(abstract, truncation=True, max_length=self.encoder_max_length, return_tensors='pt').input_ids.to(self.device)
48
+ pred = self.model.generate(
49
+ input_token_ids,
50
+ decoder_start_token_id=self.tokenizer.cls_token_id, eos_token_id=self.tokenizer.sep_token_id, pad_token_id=self.tokenizer.pad_token_id,
51
+ do_sample=(temperature > 1),
52
+ num_beams=10,
53
+ max_length=self.decoder_max_length,
54
+ no_repeat_ngram_size=2,
55
+ temperature=temperature,
56
+ top_k=50,
57
+ num_return_sequences=self.num_titles
58
+ )
59
+ titles = [str2title(title) for title in tokenizer.batch_decode(pred, True)]
60
+ return titles
61
+
62
+
63
+ def create_gradio_ui(predictor):
64
+
65
+ inputs = [
66
+ gr.Textbox(label="Paper Abstract", lines=10),
67
+ gr.Slider(label="Creativity", minimum=1.0, maximum=2.0, step=0.1, value=1.5),
68
+ ]
69
+ outputs = ["text"] * predictor.num_titles
70
+
71
+ description = (
72
+ "<center>"
73
+ "Bert2Bert model trained on computer science papers from arXiv to generate "
74
+ "paper tiles from abstracts."
75
+ "</center>"
76
+ )
77
+
78
+ ui = gr.Interface(
79
+ fn=predictor,
80
+ inputs=inputs,
81
+ outputs=outputs,
82
+ title="Paper Title Generator",
83
+ description=description,
84
+ )
85
+ return ui
86
+
87
+
88
+ if __name__ == '__main__':
89
+ print('Loading model...')
90
+ model_path = "Callidior/bert2bert-base-arxiv-titlegen"
91
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
92
+ tokenizer = BertTokenizerFast.from_pretrained(model_path)
93
+ model = EncoderDecoderModel.from_pretrained(model_path).to(device)
94
+ print(f'Ready - running on {device}.')
95
+
96
+ predictor = Predictor(model, tokenizer, device=device, num_titles=5)
97
+ interface = create_gradio_ui(predictor)
98
+ interface.launch()