katielink commited on
Commit
bad215c
·
1 Parent(s): 6f47ef1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +400 -0
app.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import os
3
+ import random
4
+ import re
5
+ from io import StringIO
6
+
7
+ import gradio as gr
8
+ import pandas as pd
9
+ from huggingface_hub import upload_file
10
+ from text_generation import Client
11
+
12
+ from dialogues import DialogueTemplate
13
+ from share_btn import (community_icon_html, loading_icon_html, share_btn_css,
14
+ share_js)
15
+
16
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
17
+ API_TOKEN = os.environ.get("API_TOKEN", None)
18
+
19
+ model2endpoint = {
20
+ "zephyr-7b-beta": "https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta",
21
+ "llama2-13b-chat": "https://api-inference.huggingface.co/models/meta-llama/Llama-2-13b-chat-hf",
22
+ "mistral-7b-v0.1": "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-v0.1"
23
+ }
24
+ model_names = list(model2endpoint.keys())
25
+
26
+
27
+ def randomize_seed_generator():
28
+ seed = random.randint(0, 1000000)
29
+ return seed
30
+
31
+
32
+
33
+ def get_total_inputs(inputs, chatbot, preprompt, user_name, assistant_name, sep):
34
+ past = []
35
+ for data in chatbot:
36
+ user_data, model_data = data
37
+
38
+ if not user_data.startswith(user_name):
39
+ user_data = user_name + user_data
40
+ if not model_data.startswith(sep + assistant_name):
41
+ model_data = sep + assistant_name + model_data
42
+
43
+ past.append(user_data + model_data.rstrip() + sep)
44
+
45
+ if not inputs.startswith(user_name):
46
+ inputs = user_name + inputs
47
+
48
+ total_inputs = preprompt + "".join(past) + inputs + sep + assistant_name.rstrip()
49
+
50
+ return total_inputs
51
+
52
+
53
+ def wrap_html_code(text):
54
+ pattern = r"<.*?>"
55
+ matches = re.findall(pattern, text)
56
+ if len(matches) > 0:
57
+ return f"```{text}```"
58
+ else:
59
+ return text
60
+
61
+
62
+ def has_no_history(chatbot, history):
63
+ return not chatbot and not history
64
+
65
+
66
+ def generate(
67
+ RETRY_FLAG,
68
+ model_name,
69
+ system_message,
70
+ user_message,
71
+ chatbot,
72
+ history,
73
+ temperature,
74
+ top_k,
75
+ top_p,
76
+ max_new_tokens,
77
+ repetition_penalty,
78
+ # do_save=True,
79
+ ):
80
+ client = Client(
81
+ model2endpoint[model_name],
82
+ headers={"Authorization": f"Bearer {API_TOKEN}"},
83
+ timeout=60,
84
+ )
85
+ # Don't return meaningless message when the input is empty
86
+ if not user_message:
87
+ print("Empty input")
88
+
89
+ if not RETRY_FLAG:
90
+ history.append(user_message)
91
+ seed = 42
92
+ else:
93
+ seed = randomize_seed_generator()
94
+
95
+ past_messages = []
96
+ for data in chatbot:
97
+ user_data, model_data = data
98
+
99
+ past_messages.extend(
100
+ [{"role": "user", "content": user_data}, {"role": "assistant", "content": model_data.rstrip()}]
101
+ )
102
+
103
+ if len(past_messages) < 1:
104
+ dialogue_template = DialogueTemplate(
105
+ system=system_message, messages=[{"role": "user", "content": user_message}]
106
+ )
107
+ prompt = dialogue_template.get_inference_prompt()
108
+ else:
109
+ dialogue_template = DialogueTemplate(
110
+ system=system_message, messages=past_messages + [{"role": "user", "content": user_message}]
111
+ )
112
+ prompt = dialogue_template.get_inference_prompt()
113
+
114
+ generate_kwargs = {
115
+ "temperature": temperature,
116
+ "top_k": top_k,
117
+ "top_p": top_p,
118
+ "max_new_tokens": max_new_tokens,
119
+ }
120
+
121
+ temperature = float(temperature)
122
+ if temperature < 1e-2:
123
+ temperature = 1e-2
124
+ top_p = float(top_p)
125
+
126
+ generate_kwargs = dict(
127
+ temperature=temperature,
128
+ max_new_tokens=max_new_tokens,
129
+ top_p=top_p,
130
+ repetition_penalty=repetition_penalty,
131
+ do_sample=True,
132
+ truncate=4096,
133
+ seed=seed,
134
+ stop_sequences=["<|end|>"],
135
+ )
136
+
137
+ stream = client.generate_stream(
138
+ prompt,
139
+ **generate_kwargs,
140
+ )
141
+
142
+ output = ""
143
+ for idx, response in enumerate(stream):
144
+ if response.token.special:
145
+ continue
146
+ output += response.token.text
147
+ if idx == 0:
148
+ history.append(" " + output)
149
+ else:
150
+ history[-1] = output
151
+
152
+ chat = [
153
+ (wrap_html_code(history[i].strip()), wrap_html_code(history[i + 1].strip()))
154
+ for i in range(0, len(history) - 1, 2)
155
+ ]
156
+
157
+ # chat = [(history[i].strip(), history[i + 1].strip()) for i in range(0, len(history) - 1, 2)]
158
+
159
+ yield chat, history, user_message, ""
160
+
161
+ return chat, history, user_message, ""
162
+
163
+
164
+ examples = [
165
+ "What are the signs and symptoms of community acquired pneumonia (CAP)?", "What is the treatment for recurrent otitis media?"
166
+ ]
167
+
168
+
169
+ def clear_chat():
170
+ return [], []
171
+
172
+
173
+ def delete_last_turn(chat, history):
174
+ if chat and history:
175
+ chat.pop(-1)
176
+ history.pop(-1)
177
+ history.pop(-1)
178
+ return chat, history
179
+
180
+
181
+ def process_example(args):
182
+ for [x, y] in generate(args):
183
+ pass
184
+ return [x, y]
185
+
186
+
187
+ # Regenerate response
188
+ def retry_last_answer(
189
+ selected_model,
190
+ system_message,
191
+ user_message,
192
+ chat,
193
+ history,
194
+ temperature,
195
+ top_k,
196
+ top_p,
197
+ max_new_tokens,
198
+ repetition_penalty,
199
+ # do_save,
200
+ ):
201
+ if chat and history:
202
+ # Removing the previous conversation from chat
203
+ chat.pop(-1)
204
+ # Removing bot response from the history
205
+ history.pop(-1)
206
+ # Setting up a flag to capture a retry
207
+ RETRY_FLAG = True
208
+ # Getting last message from user
209
+ user_message = history[-1]
210
+
211
+ yield from generate(
212
+ RETRY_FLAG,
213
+ selected_model,
214
+ system_message,
215
+ user_message,
216
+ chat,
217
+ history,
218
+ temperature,
219
+ top_k,
220
+ top_p,
221
+ max_new_tokens,
222
+ repetition_penalty,
223
+ # do_save,
224
+ )
225
+
226
+
227
+ title = """<h1 align="center">LLM Playground 💬</h1>"""
228
+ custom_css = """
229
+ #banner-image {
230
+ display: block;
231
+ margin-left: auto;
232
+ margin-right: auto;
233
+ }
234
+ #chat-message {
235
+ font-size: 14px;
236
+ min-height: 300px;
237
+ }
238
+ """
239
+
240
+ with gr.Blocks(analytics_enabled=False, css=custom_css) as demo:
241
+ gr.HTML(title)
242
+
243
+ with gr.Row():
244
+ # with gr.Column():
245
+ # gr.Image("thumbnail.png", elem_id="banner-image", show_label=False)
246
+ with gr.Column():
247
+ gr.Markdown(
248
+ """
249
+ 💻 This demo showcases a few smaller open source models."""
250
+ )
251
+
252
+ with gr.Row():
253
+ selected_model = gr.Radio(choices=model_names, value=model_names[1], label="Select a model")
254
+
255
+ with gr.Accordion(label="System Prompt", open=False, elem_id="parameters-accordion"):
256
+ system_message = gr.Textbox(
257
+ elem_id="system-message",
258
+ placeholder="Below is a conversation between a medical student and a helpful AI medical assistant.",
259
+ show_label=False,
260
+ )
261
+ with gr.Row():
262
+ with gr.Box():
263
+ output = gr.Markdown()
264
+ chatbot = gr.Chatbot(elem_id="chat-message", label="Chat")
265
+
266
+ with gr.Row():
267
+ with gr.Column(scale=3):
268
+ user_message = gr.Textbox(placeholder="Enter your message here", show_label=False, elem_id="q-input")
269
+ with gr.Row():
270
+ send_button = gr.Button("Send", elem_id="send-btn", visible=True)
271
+
272
+ regenerate_button = gr.Button("Regenerate", elem_id="retry-btn", visible=True)
273
+
274
+ delete_turn_button = gr.Button("Delete last turn", elem_id="delete-btn", visible=True)
275
+
276
+ clear_chat_button = gr.Button("Clear chat", elem_id="clear-btn", visible=True)
277
+
278
+ with gr.Accordion(label="Parameters", open=False, elem_id="parameters-accordion"):
279
+ temperature = gr.Slider(
280
+ label="Temperature",
281
+ value=0.2,
282
+ minimum=0.0,
283
+ maximum=1.0,
284
+ step=0.1,
285
+ interactive=True,
286
+ info="Higher values produce more diverse outputs",
287
+ )
288
+ top_k = gr.Slider(
289
+ label="Top-k",
290
+ value=50,
291
+ minimum=0.0,
292
+ maximum=100,
293
+ step=1,
294
+ interactive=True,
295
+ info="Sample from a shortlist of top-k tokens",
296
+ )
297
+ top_p = gr.Slider(
298
+ label="Top-p (nucleus sampling)",
299
+ value=0.95,
300
+ minimum=0.0,
301
+ maximum=1,
302
+ step=0.05,
303
+ interactive=True,
304
+ info="Higher values sample more low-probability tokens",
305
+ )
306
+ max_new_tokens = gr.Slider(
307
+ label="Max new tokens",
308
+ value=512,
309
+ minimum=0,
310
+ maximum=1024,
311
+ step=4,
312
+ interactive=True,
313
+ info="The maximum numbers of new tokens",
314
+ )
315
+ repetition_penalty = gr.Slider(
316
+ label="Repetition Penalty",
317
+ value=1.2,
318
+ minimum=0.0,
319
+ maximum=10,
320
+ step=0.1,
321
+ interactive=True,
322
+ info="The parameter for repetition penalty. 1.0 means no penalty.",
323
+ )
324
+
325
+ with gr.Row():
326
+ gr.Examples(
327
+ examples=examples,
328
+ inputs=[user_message],
329
+ cache_examples=False,
330
+ fn=process_example,
331
+ outputs=[output],
332
+ )
333
+
334
+ history = gr.State([])
335
+ RETRY_FLAG = gr.Checkbox(value=False, visible=False)
336
+
337
+ # To clear out "message" input textbox and use this to regenerate message
338
+ last_user_message = gr.State("")
339
+
340
+ user_message.submit(
341
+ generate,
342
+ inputs=[
343
+ RETRY_FLAG,
344
+ selected_model,
345
+ system_message,
346
+ user_message,
347
+ chatbot,
348
+ history,
349
+ temperature,
350
+ top_k,
351
+ top_p,
352
+ max_new_tokens,
353
+ repetition_penalty,
354
+ # do_save,
355
+ ],
356
+ outputs=[chatbot, history, last_user_message, user_message],
357
+ )
358
+
359
+ send_button.click(
360
+ generate,
361
+ inputs=[
362
+ RETRY_FLAG,
363
+ selected_model,
364
+ system_message,
365
+ user_message,
366
+ chatbot,
367
+ history,
368
+ temperature,
369
+ top_k,
370
+ top_p,
371
+ max_new_tokens,
372
+ repetition_penalty,
373
+ # do_save,
374
+ ],
375
+ outputs=[chatbot, history, last_user_message, user_message],
376
+ )
377
+
378
+ regenerate_button.click(
379
+ retry_last_answer,
380
+ inputs=[
381
+ selected_model,
382
+ system_message,
383
+ user_message,
384
+ chatbot,
385
+ history,
386
+ temperature,
387
+ top_k,
388
+ top_p,
389
+ max_new_tokens,
390
+ repetition_penalty,
391
+ # do_save,
392
+ ],
393
+ outputs=[chatbot, history, last_user_message, user_message],
394
+ )
395
+
396
+ delete_turn_button.click(delete_last_turn, [chatbot, history], [chatbot, history])
397
+ clear_chat_button.click(clear_chat, outputs=[chatbot, history])
398
+ selected_model.change(clear_chat, outputs=[chatbot, history])
399
+
400
+ demo.queue(concurrency_count=16).launch(debug=True)