Callidior commited on
Commit
c9518d3
1 Parent(s): 973eb6d

Use hosted inference when on CPU

Browse files
Files changed (1) hide show
  1. app.py +48 -8
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import string
2
  import re
3
- from typing import List
 
 
4
 
5
  import torch
6
  from transformers import EncoderDecoderModel, BertTokenizerFast
@@ -60,11 +62,45 @@ class Predictor:
60
  return titles
61
 
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  def create_gradio_ui(predictor):
64
 
65
  inputs = [
66
  gr.Textbox(label="Paper Abstract", lines=10),
67
- gr.Slider(label="Creativity", minimum=1.0, maximum=2.0, step=0.1, value=1.5),
68
  ]
69
  outputs = ["text"] * predictor.num_titles
70
 
@@ -86,13 +122,17 @@ def create_gradio_ui(predictor):
86
 
87
 
88
  if __name__ == '__main__':
89
- print('Loading model...')
90
  model_path = "Callidior/bert2bert-base-arxiv-titlegen"
91
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
92
- tokenizer = BertTokenizerFast.from_pretrained(model_path)
93
- model = EncoderDecoderModel.from_pretrained(model_path).to(device)
94
- print(f'Ready - running on {device}.')
 
 
 
 
 
 
95
 
96
- predictor = Predictor(model, tokenizer, device=device, num_titles=5)
97
  interface = create_gradio_ui(predictor)
98
  interface.launch()
 
1
  import string
2
  import re
3
+ import json
4
+ import requests
5
+ from typing import List, Optional
6
 
7
  import torch
8
  from transformers import EncoderDecoderModel, BertTokenizerFast
 
62
  return titles
63
 
64
 
65
+ class HostedInference:
66
+
67
+ def __init__(self, model: str, num_titles: int, api_key: Optional[str] = None) -> None:
68
+ super().__init__()
69
+ self.model = model
70
+ self.num_titles = num_titles
71
+ self.api_key = api_key
72
+
73
+ def __call__(self, abstract: str, temperature: float) -> List[str]:
74
+ temperature = max(1.0, float(temperature))
75
+ data = json.dumps({
76
+ 'inputs' : abstract,
77
+ 'parameters' : {
78
+ 'do_sample': (temperature > 1),
79
+ 'num_beams': 10,
80
+ 'temperature': temperature,
81
+ 'top_k': 50,
82
+ 'no_repeat_ngram_size': 2,
83
+ 'num_return_sequences': self.num_titles,
84
+ },
85
+ 'options' : { 'use_cache' : False, 'wait_for_model' : True }
86
+ })
87
+ api_url = "https://api-inference.huggingface.co/models/" + self.model
88
+ headers = { "Authorization": f"Bearer {self.api_key}" } if self.api_key is not None else {}
89
+ response = requests.request("POST", api_url, headers=headers, data=data)
90
+ response = json.loads(response.content.decode("utf-8"))
91
+
92
+ if isinstance(response, dict) and ('error' in response):
93
+ raise RuntimeError(response['error'])
94
+
95
+ titles = [str2title(title['summary_text']) for title in response]
96
+ return titles
97
+
98
+
99
  def create_gradio_ui(predictor):
100
 
101
  inputs = [
102
  gr.Textbox(label="Paper Abstract", lines=10),
103
+ gr.Slider(label="Creativity", minimum=1.0, maximum=2.5, step=0.1, value=1.5),
104
  ]
105
  outputs = ["text"] * predictor.num_titles
106
 
 
122
 
123
 
124
  if __name__ == '__main__':
 
125
  model_path = "Callidior/bert2bert-base-arxiv-titlegen"
126
+
127
+ if torch.cuda.is_available():
128
+ print('Loading model...')
129
+ tokenizer = BertTokenizerFast.from_pretrained(model_path)
130
+ model = EncoderDecoderModel.from_pretrained(model_path).cuda()
131
+ predictor = Predictor(model, tokenizer, device="cuda", num_titles=5)
132
+ print(f'Ready - running on GPU.')
133
+ else:
134
+ print(f'No GPU available - using hosted inference API.')
135
+ predictor = HostedInference(model_path, num_titles=5)
136
 
 
137
  interface = create_gradio_ui(predictor)
138
  interface.launch()