frt03 commited on
Commit
55204fd
·
1 Parent(s): 200affd

update app.py

Browse files
Files changed (4) hide show
  1. QuALITY.v1.0.1.htmlstripped.dev +0 -0
  2. app.py +441 -47
  3. example.py +63 -0
  4. requirements.txt +3 -1
QuALITY.v1.0.1.htmlstripped.dev ADDED
The diff for this file is too large to render. See raw diff
 
app.py CHANGED
@@ -1,63 +1,457 @@
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
 
 
 
 
 
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
 
 
 
 
 
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
 
26
- messages.append({"role": "user", "content": message})
27
 
28
- response = ""
 
 
 
 
 
 
 
 
 
 
29
 
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
 
39
- response += token
40
- yield response
 
41
 
 
42
  """
43
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
 
 
 
 
 
 
 
 
 
 
 
44
  """
45
- demo = gr.ChatInterface(
46
- respond,
47
- additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
- gr.Slider(
52
- minimum=0.1,
53
- maximum=1.0,
54
- value=0.95,
55
- step=0.05,
56
- label="Top-p (nucleus sampling)",
57
- ),
58
- ],
59
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
 
62
- if __name__ == "__main__":
63
- demo.launch()
 
1
+ import copy
2
+ import datetime
3
+ import json
4
+ import os
5
+ import re
6
+ import string
7
+ import time
8
+
9
  import gradio as gr
10
+ import openai
11
+ import google.generativeai as genai
12
+
13
+
14
+ openai_key = os.environ.get('OPEN_AI_KEY')
15
+ gpt_client = openai.OpenAI(api_key=openai_key)
16
+
17
+ gemini_key = os.environ.get('GEMINI_API_KEY')
18
+ genai.configure(api_key=gemini_key)
19
+
20
+
21
+ def query_gpt_model(
22
+ prompt: str,
23
+ llm: str = 'gpt-3.5-turbo-1106',
24
+ temperature: float = 0.0,
25
+ max_decode_steps: int = 512,
26
+ seconds_to_reset_tokens: float = 30.0,
27
+ ) -> str:
28
+
29
+ while True:
30
+ try:
31
+ raw_response = gpt_client.chat.completions.with_raw_response.create(
32
+ model=llm,
33
+ max_tokens=max_decode_steps,
34
+ temperature=temperature,
35
+ messages=[
36
+ {'role': 'user', 'content': prompt},
37
+ ]
38
+ )
39
+ completion = raw_response.parse()
40
+ return completion.choices[0].message.content
41
+ except openai.RateLimitError as e:
42
+ print(f'{datetime.datetime.now()}: query_gpt_model: RateLimitError {e.message}: {e}')
43
+ time.sleep(seconds_to_reset_tokens)
44
+ except openai.APIError as e:
45
+ print(f'{datetime.datetime.now()}: query_gpt_model: APIError {e.message}: {e}')
46
+ print(f'{datetime.datetime.now()}: query_gpt_model: Retrying after 5 seconds...')
47
+ time.sleep(5)
48
+
49
+ safety_settings=[
50
+ {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_ONLY_HIGH"},
51
+ {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_ONLY_HIGH"},
52
+ {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_ONLY_HIGH"},
53
+ {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_ONLY_HIGH"}
54
+ ]
55
+
56
+ def query_gemini_model(
57
+ prompt: str,
58
+ llm: str = 'gemini-pro',
59
+ retries: int = 10,
60
+ ) -> str:
61
+ model = genai.GenerativeModel(llm)
62
+ while True and retries > 0:
63
+ try:
64
+ response = model.generate_content(prompt, safety_settings=safety_settings)
65
+ text_response = response.text.replace("**", "")
66
+ return text_response
67
+ except Exception as e:
68
+ print(f'{datetime.datetime.now()}: query_gemini_model: Error: {e}')
69
+ print(f'{datetime.datetime.now()}: query_gemini_model: Retrying after 5 seconds...')
70
+ retries -= 1
71
+ time.sleep(5)
72
+
73
+
74
+ def query_model(
75
+ prompt: str,
76
+ model_name: str = 'gemini-pro',
77
+ ) -> str:
78
+ model_type = model_name.split('-')[0]
79
+ if model_type == "gpt":
80
+ return query_gpt_model(prompt, llm=model_name)
81
+ elif model_type == "gemini":
82
+ return query_gemini_model(prompt, llm=model_name)
83
+ else:
84
+ raise ValueError('Unexpected model_name: ', model_name)
85
+
86
+ # Load QuALITY dataset
87
+
88
+ _ONE2ONE_FIELDS = (
89
+ 'article',
90
+ 'article_id',
91
+ 'set_unique_id',
92
+ 'writer_id',
93
+ 'source',
94
+ 'title',
95
+ 'topic',
96
+ 'url',
97
+ 'writer_id',
98
+ 'author',
99
+ )
100
+
101
+ quality_dev = []
102
+ with open('QuALITY.v1.0.1.htmlstripped.dev', 'r') as f:
103
+ for line in f.readlines():
104
+ j = json.loads(line)
105
+ fields = {k: j[k] for k in _ONE2ONE_FIELDS}
106
+ fields.update({
107
+ 'questions': [q['question'] for q in j['questions']],
108
+ 'question_ids': [q['question_unique_id'] for q in j['questions']],
109
+ 'difficults': [q['difficult'] for q in j['questions']],
110
+ 'options': [q['options'] for q in j['questions']],
111
+ })
112
+
113
+ fields.update({
114
+ 'gold_labels': [q['gold_label'] for q in j['questions']],
115
+ 'writer_labels': [q['writer_label'] for q in j['questions']],
116
+ })
117
+
118
+ quality_dev.append(fields)
119
+
120
+
121
+
122
+ # Helper functions
123
+ all_lowercase_letters = string.ascii_lowercase # "abcd...xyz"
124
+ bracketed_lowercase_letters_set = set(
125
+ [f"({l})" for l in all_lowercase_letters]
126
+ ) # {"(a)", ...}
127
+ bracketed_uppercase_letters_set = set(
128
+ [f"({l.upper()})" for l in all_lowercase_letters]
129
+ ) # {"(a)", ...}
130
+
131
+ choices = ['(A)', '(B)', '(C)', '(D)']
132
+
133
+ def get_index_from_symbol(answer):
134
+ """Get the index from the letter symbols A, B, C, D, to extract answer texts.
135
+
136
+ Args:
137
+ answer (str): the string of answer like "(B)".
138
+
139
+ Returns:
140
+ index (int): how far the given choice is from "a", like 1 for answer "(B)".
141
+ """
142
+ answer = str(answer).lower()
143
+ # extract the choice letter from within bracket
144
+ if answer in bracketed_lowercase_letters_set:
145
+ answer = re.findall(r".*?", answer)[0][1]
146
+ index = ord(answer) - ord("a")
147
+ return index
148
+
149
+ def count_words(text):
150
+ """Simple word counting."""
151
+ return len(text.split())
152
+
153
+ def quality_gutenberg_parser(raw_article):
154
+ """Parse Gutenberg articles in the QuALITY dataset."""
155
+ lines = []
156
+ previous_line = None
157
+ for i, line in enumerate(raw_article.split('\n')):
158
+ line = line.strip()
159
+ original_line = line
160
+ if line == '':
161
+ if previous_line == '':
162
+ line = '\n'
163
+ else:
164
+ previous_line = original_line
165
+ continue
166
+ previous_line = original_line
167
+ lines.append(line)
168
+ return ' '.join(lines)
169
+
170
+
171
+
172
+ # ReadAgent (1) Episode Pagination
173
+ prompt_pagination_template = """
174
+ You are given a passage that is taken from a larger text (article, book, ...) and some numbered labels between the paragraphs in the passage.
175
+ Numbered label are in angeled brackets. For example, if the label number is 19, it shows as <19> in text.
176
+ Please choose one label that it is natural to break reading.
177
+ Such point can be scene transition, end of a dialogue, end of an argument, narrative transition, etc.
178
+ Please answer the break point label and explain.
179
+ For example, if <57> is a good point to break, answer with \"Break point: <57>\n Because ...\"
180
+
181
+ Passage:
182
+
183
+ {0}
184
+ {1}
185
+ {2}
186
 
187
  """
188
+
189
+ def parse_pause_point(text):
190
+ text = text.strip("Break point: ")
191
+ if text[0] != '<':
192
+ return None
193
+ for i, c in enumerate(text):
194
+ if c == '>':
195
+ if text[1:i].isnumeric():
196
+ return int(text[1:i])
197
+ else:
198
+ return None
199
+ return None
200
+
201
+
202
+ def quality_pagination(example,
203
+ model_name='gemini-pro',
204
+ word_limit=600,
205
+ start_threshold=280,
206
+ max_retires=10,
207
+ verbose=True,
208
+ allow_fallback_to_last=True):
209
+ article = example['article']
210
+ title = example['title']
211
+ text_output = f"[Pagination][Article {title}]" + '\n\n'
212
+ paragraphs = quality_gutenberg_parser(article).split('\n')
213
+
214
+ i = 0
215
+ pages = []
216
+ while i < len(paragraphs):
217
+ preceding = "" if i == 0 else "...\n" + '\n'.join(pages[-1])
218
+ passage = [paragraphs[i]]
219
+ wcount = count_words(paragraphs[i])
220
+ j = i + 1
221
+ while wcount < word_limit and j < len(paragraphs):
222
+ wcount += count_words(paragraphs[j])
223
+ if wcount >= start_threshold:
224
+ passage.append(f"<{j}>")
225
+ passage.append(paragraphs[j])
226
+ j += 1
227
+ passage.append(f"<{j}>")
228
+ end_tag = "" if j == len(paragraphs) else paragraphs[j] + "\n..."
229
+
230
+ pause_point = None
231
+ if wcount < 350:
232
+ pause_point = len(paragraphs)
233
+ else:
234
+ prompt = prompt_pagination_template.format(preceding, '\n'.join(passage), end_tag)
235
+ response = query_model(prompt=prompt, model_name=model_name).strip()
236
+ pause_point = parse_pause_point(response)
237
+ if pause_point and (pause_point <= i or pause_point > j):
238
+ # process += f"prompt:\n{prompt},\nresponse:\n{response}\n"
239
+ # process += f"i:{i} j:{j} pause_point:{pause_point}" + '\n'
240
+ pause_point = None
241
+ if pause_point is None:
242
+ if allow_fallback_to_last:
243
+ pause_point = j
244
+ else:
245
+ raise ValueError(f"prompt:\n{prompt},\nresponse:\n{response}\n")
246
+
247
+ page = paragraphs[i:pause_point]
248
+ pages.append(page)
249
+ text_output += f"Paragraph {i}-{pause_point-1}: {page}\n\n"
250
+ i = pause_point
251
+ text_output += f"\n\n[Pagination] Done with {len(pages)} pages"
252
+ return pages, text_output
253
+
254
+ # pages = quality_pagination(example)
255
+
256
+
257
+ # ReadAgent (2) Memory Gisting
258
+ prompt_shorten_template = """
259
+ Please shorten the following passage.
260
+ Just give me a shortened version. DO NOT explain your reason.
261
+
262
+ Passage:
263
+ {}
264
+
265
  """
 
266
 
267
+ def quality_gisting(example, pages, model_name, word_limit=600, start_threshold=280, verbose=True):
268
+ article = example['article']
269
+ title = example['title']
270
+ word_count = count_words(article)
271
+ text_output = f"[Gisting][Article {title}], {word_count} words\n\n"
272
 
273
+ shortened_pages = []
274
+ for i, page in enumerate(pages):
275
+ prompt = prompt_shorten_template.format('\n'.join(page))
276
+ response = query_model(prompt, model_name)
277
+ shortened_text = response.strip()
278
+ shortened_pages.append(shortened_text)
279
+ text_output += "[gist] page {}: {}\n\n".format(i, shortened_text)
280
+ shortened_article = '\n'.join(shortened_pages)
281
+ gist_word_count = count_words(shortened_article)
282
+ text_output += '\n\n' + f"Shortened article:\n{shortened_article}\n\n"
283
+ output = copy.deepcopy(example)
284
+ output.update({'title': title, 'word_count': word_count, 'gist_word_count': gist_word_count, 'shortened_pages': shortened_pages, 'pages': pages})
285
+ text_output += f"\n\ncompression rate {round(100.0 - gist_word_count/word_count*100, 2)}% ({gist_word_count}/{word_count})"
286
+ return output, text_output
287
 
288
+ # example_with_gists = quality_gisting(example, pages)
 
 
 
 
289
 
 
290
 
291
+ # ReadAgent (3) Look-Up
292
+ prompt_lookup_template = """
293
+ The following text is what you remembered from reading an article and a multiple choice question related to it.
294
+ You may read 1 to 6 page(s) of the article again to refresh your memory to prepare yourselve for the question.
295
+ Please respond with which page(s) you would like to read.
296
+ For example, if your only need to read Page 8, respond with \"I want to look up Page [8] to ...\";
297
+ if your would like to read Page 7 and 12, respond with \"I want to look up Page [7, 12] to ...\";
298
+ if your would like to read Page 2, 3, 7, 15 and 18, respond with \"I want to look up Page [2, 3, 7, 15, 18] to ...\".
299
+ if your would like to read Page 3, 4, 5, 12, 13 and 16, respond with \"I want to look up Page [3, 3, 4, 12, 13, 16] to ...\".
300
+ DO NOT select more pages if you don't need to.
301
+ DO NOT answer the question yet.
302
 
303
+ Text:
304
+ {}
 
 
 
 
 
 
305
 
306
+ Question:
307
+ {}
308
+ {}
309
 
310
+ Take a deep breath and tell me: Which page(s) would you like to read again?
311
  """
312
+
313
+ prompt_answer_template = """
314
+ Read the following article and answer a multiple choice question.
315
+ For example, if (C) is correct, answer with \"Answer: (C) ...\"
316
+
317
+ Article:
318
+ {}
319
+
320
+ Question:
321
+ {}
322
+ {}
323
+
324
  """
325
+
326
+ def quality_parallel_lookup(example, verbose=True):
327
+ preprocessed_pages = example['pages']
328
+ article = example['article']
329
+ title = example['title']
330
+ word_count = example['word_count']
331
+ gist_word_count = example['gist_word_count']
332
+ pages = example['pages']
333
+ shortened_pages = example['shortened_pages']
334
+ questions = example['questions']
335
+ options = example['options']
336
+ gold_labels = example['gold_labels'] # numerical [1, 2, 3, 4]
337
+
338
+ text_outputs = [f"[Look-Up][Article {title}] {word_count} words"]
339
+
340
+ model_choices = []
341
+ lookup_page_ids = []
342
+
343
+ shortened_pages_pidx = []
344
+ for i, shortened_text in enumerate(shortened_pages):
345
+ shortened_pages_pidx.append("\n".format(i) + shortened_text)
346
+ shortened_article = '\n'.join(shortened_pages_pidx)
347
+
348
+ expanded_gist_word_counts = []
349
+
350
+ for i, label in enumerate(gold_labels):
351
+ # only test the first question for demo
352
+ if i != 1:
353
+ continue
354
+ q = questions[i]
355
+ text_output = f"question {i}: {q}" + '\n\n'
356
+ options_i = [f"{ol} {o}" for ol, o in zip(choices, options[i])]
357
+ text_output += "options: " + "\n".join(options_i)
358
+ text_output += '\n\n'
359
+ prompt_lookup = prompt_lookup_template.format(shortened_article, q, '\n'.join(options_i))
360
+
361
+ page_ids = []
362
+
363
+ response = query_model(prompt=prompt_lookup).strip()
364
+
365
+ try: start = response.index('[')
366
+ except ValueError: start = len(response)
367
+ try: end = response.index(']')
368
+ except ValueError: end = 0
369
+ if start < end:
370
+ page_ids_str = response[start+1:end].split(',')
371
+ page_ids = []
372
+ for p in page_ids_str:
373
+ if p.strip().isnumeric():
374
+ page_id = int(p)
375
+ if page_id < 0 or page_id >= len(pages):
376
+ text_output += f"Skip invalid page number: {page_id}\n\n"
377
+ else:
378
+ page_ids.append(page_id)
379
+
380
+ text_output += "Model chose to look up page {}\n\n".format(page_ids)
381
+
382
+ # Memory expansion after look-up, replacing the target shortened page with the original page
383
+ expanded_shortened_pages = shortened_pages[:]
384
+ if len(page_ids) > 0:
385
+ for page_id in page_ids:
386
+ expanded_shortened_pages[page_id] = '\n'.join(pages[page_id])
387
+
388
+ expanded_shortened_article = '\n'.join(expanded_shortened_pages)
389
+ expanded_gist_word_count = count_words(expanded_shortened_article)
390
+ text_output += "Expanded shortened article:\n" + expanded_shortened_article + '\n\n'
391
+ prompt_answer = prompt_answer_template.format(expanded_shortened_article, q, '\n'.join(options_i))
392
+
393
+ model_choice = None
394
+ response = query_model(prompt=prompt_answer)
395
+ response = response.strip()
396
+ for j, choice in enumerate(choices):
397
+ if response.startswith(f"Answer: {choice}") or response.startswith(f"Answer: {choice[1]}"):
398
+ model_choice = j+1
399
+ break
400
+ is_correct = 1 if model_choice == label else 0
401
+ text_output += f"reference answer: {choices[label]}, model prediction: {choices[model_choice]}, is_correct: {is_correct}" + '\n\n'
402
+ text_output += f"compression rate {round(100.0 - gist_word_count/word_count*100, 2)}% ({gist_word_count}/{word_count})" + '\n\n'
403
+ text_output += f"compression rate after look-up {round(100.0 - expanded_gist_word_count/word_count*100, 2)}% ({expanded_gist_word_count}/{word_count})" + '\n\n'
404
+ text_output += '\n\n'
405
+ text_outputs.append(text_output)
406
+ return text_outputs
407
+
408
+
409
+ def query_model_with_quality(
410
+ index: int,
411
+ model_name: str = 'gemini-pro'
412
+ ):
413
+ example = quality_dev[index]
414
+ pages, pagination = quality_pagination(example, model_name)
415
+ print('Finish Pagination.')
416
+ example_with_gists, gisting = quality_gisting(example, pages, model_name)
417
+ print('Finish Gisting.')
418
+ answers = quality_parallel_lookup(example_with_gists)
419
+ return prompt_pagination_template, pagination, prompt_shorten_template, gisting, prompt_lookup_template, '\n\n'.join(answers)
420
+
421
+
422
+ llm_api_options = ['gemini-pro', 'gemini-1.5-flash', 'gpt-3.5-turbo-1106']
423
+
424
+ with gr.Blocks() as demo:
425
+ gr.Markdown(
426
+ """
427
+ # A Human-Inspired Reading Agent with Gist Memory of Very Long Contexts
428
+ """)
429
+ with gr.Tab('ReadAgent (QuALITY)'):
430
+ llm_options = gr.Radio(llm_api_options, label="Backend LLM API", value='gemini-pro')
431
+ with gr.Row():
432
+ with gr.Column():
433
+ index = gr.Dropdown(list(range(len(quality_dev))), value=13, label="QuALITY Index",)
434
+ button = gr.Button("Execute")
435
+ prompt_pagination = gr.Textbox(label="Episode Pagination Prompt Template", lines=5)
436
+ pagination_results = gr.Textbox(label="Episode Pagination", lines=20)
437
+ prompt_gisting = gr.Textbox(label="Memory Gisting Prompt Template", lines=5)
438
+ gisting_results = gr.Textbox(label="Memory Gisting", lines=20)
439
+ prompt_lookup = gr.Textbox(label="Parallel Lookup Prompt Template", lines=5)
440
+ lookup_qa_results = gr.Textbox(label="Parallel Lookup and QA", lines=20)
441
+
442
+ button.click(
443
+ fn=query_model_with_quality,
444
+ inputs=[
445
+ index,
446
+ llm_options
447
+ ],
448
+ outputs=[
449
+ prompt_pagination, pagination_results,
450
+ prompt_gisting, gisting_results,
451
+ prompt_lookup, lookup_qa_results,
452
+ ]
453
+ )
454
 
455
 
456
+ if __name__ == '__main__':
457
+ demo.launch()
example.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import InferenceClient
3
+
4
+ """
5
+ For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
+ """
7
+ client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
+
9
+
10
+ def respond(
11
+ message,
12
+ history: list[tuple[str, str]],
13
+ system_message,
14
+ max_tokens,
15
+ temperature,
16
+ top_p,
17
+ ):
18
+ messages = [{"role": "system", "content": system_message}]
19
+
20
+ for val in history:
21
+ if val[0]:
22
+ messages.append({"role": "user", "content": val[0]})
23
+ if val[1]:
24
+ messages.append({"role": "assistant", "content": val[1]})
25
+
26
+ messages.append({"role": "user", "content": message})
27
+
28
+ response = ""
29
+
30
+ for message in client.chat_completion(
31
+ messages,
32
+ max_tokens=max_tokens,
33
+ stream=True,
34
+ temperature=temperature,
35
+ top_p=top_p,
36
+ ):
37
+ token = message.choices[0].delta.content
38
+
39
+ response += token
40
+ yield response
41
+
42
+ """
43
+ For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
+ """
45
+ demo = gr.ChatInterface(
46
+ respond,
47
+ additional_inputs=[
48
+ gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
+ gr.Slider(
52
+ minimum=0.1,
53
+ maximum=1.0,
54
+ value=0.95,
55
+ step=0.05,
56
+ label="Top-p (nucleus sampling)",
57
+ ),
58
+ ],
59
+ )
60
+
61
+
62
+ if __name__ == "__main__":
63
+ demo.launch()
requirements.txt CHANGED
@@ -1 +1,3 @@
1
- huggingface_hub==0.22.2
 
 
 
1
+ huggingface_hub==0.22.2
2
+ openai==1.37.0
3
+ google-generativeai==0.7.2