frt03 commited on
Commit
a0656ca
·
1 Parent(s): 55204fd

fix api key

Browse files
Files changed (1) hide show
  1. app.py +37 -21
app.py CHANGED
@@ -5,22 +5,17 @@ import os
5
  import re
6
  import string
7
  import time
 
8
 
9
  import gradio as gr
10
  import openai
11
  import google.generativeai as genai
12
 
13
 
14
- openai_key = os.environ.get('OPEN_AI_KEY')
15
- gpt_client = openai.OpenAI(api_key=openai_key)
16
-
17
- gemini_key = os.environ.get('GEMINI_API_KEY')
18
- genai.configure(api_key=gemini_key)
19
-
20
-
21
  def query_gpt_model(
22
  prompt: str,
23
  llm: str = 'gpt-3.5-turbo-1106',
 
24
  temperature: float = 0.0,
25
  max_decode_steps: int = 512,
26
  seconds_to_reset_tokens: float = 30.0,
@@ -28,7 +23,7 @@ def query_gpt_model(
28
 
29
  while True:
30
  try:
31
- raw_response = gpt_client.chat.completions.with_raw_response.create(
32
  model=llm,
33
  max_tokens=max_decode_steps,
34
  temperature=temperature,
@@ -56,8 +51,10 @@ safety_settings=[
56
  def query_gemini_model(
57
  prompt: str,
58
  llm: str = 'gemini-pro',
 
59
  retries: int = 10,
60
  ) -> str:
 
61
  model = genai.GenerativeModel(llm)
62
  while True and retries > 0:
63
  try:
@@ -74,12 +71,13 @@ def query_gemini_model(
74
  def query_model(
75
  prompt: str,
76
  model_name: str = 'gemini-pro',
 
77
  ) -> str:
78
  model_type = model_name.split('-')[0]
79
  if model_type == "gpt":
80
- return query_gpt_model(prompt, llm=model_name)
81
  elif model_type == "gemini":
82
- return query_gemini_model(prompt, llm=model_name)
83
  else:
84
  raise ValueError('Unexpected model_name: ', model_name)
85
 
@@ -201,6 +199,7 @@ def parse_pause_point(text):
201
 
202
  def quality_pagination(example,
203
  model_name='gemini-pro',
 
204
  word_limit=600,
205
  start_threshold=280,
206
  max_retires=10,
@@ -232,7 +231,7 @@ def quality_pagination(example,
232
  pause_point = len(paragraphs)
233
  else:
234
  prompt = prompt_pagination_template.format(preceding, '\n'.join(passage), end_tag)
235
- response = query_model(prompt=prompt, model_name=model_name).strip()
236
  pause_point = parse_pause_point(response)
237
  if pause_point and (pause_point <= i or pause_point > j):
238
  # process += f"prompt:\n{prompt},\nresponse:\n{response}\n"
@@ -264,7 +263,7 @@ Passage:
264
 
265
  """
266
 
267
- def quality_gisting(example, pages, model_name, word_limit=600, start_threshold=280, verbose=True):
268
  article = example['article']
269
  title = example['title']
270
  word_count = count_words(article)
@@ -273,7 +272,7 @@ def quality_gisting(example, pages, model_name, word_limit=600, start_threshold=
273
  shortened_pages = []
274
  for i, page in enumerate(pages):
275
  prompt = prompt_shorten_template.format('\n'.join(page))
276
- response = query_model(prompt, model_name)
277
  shortened_text = response.strip()
278
  shortened_pages.append(shortened_text)
279
  text_output += "[gist] page {}: {}\n\n".format(i, shortened_text)
@@ -323,7 +322,7 @@ Question:
323
 
324
  """
325
 
326
- def quality_parallel_lookup(example, verbose=True):
327
  preprocessed_pages = example['pages']
328
  article = example['article']
329
  title = example['title']
@@ -360,7 +359,7 @@ def quality_parallel_lookup(example, verbose=True):
360
 
361
  page_ids = []
362
 
363
- response = query_model(prompt=prompt_lookup).strip()
364
 
365
  try: start = response.index('[')
366
  except ValueError: start = len(response)
@@ -391,7 +390,7 @@ def quality_parallel_lookup(example, verbose=True):
391
  prompt_answer = prompt_answer_template.format(expanded_shortened_article, q, '\n'.join(options_i))
392
 
393
  model_choice = None
394
- response = query_model(prompt=prompt_answer)
395
  response = response.strip()
396
  for j, choice in enumerate(choices):
397
  if response.startswith(f"Answer: {choice}") or response.startswith(f"Answer: {choice[1]}"):
@@ -408,14 +407,25 @@ def quality_parallel_lookup(example, verbose=True):
408
 
409
  def query_model_with_quality(
410
  index: int,
411
- model_name: str = 'gemini-pro'
 
412
  ):
 
 
 
 
 
 
 
 
 
 
413
  example = quality_dev[index]
414
- pages, pagination = quality_pagination(example, model_name)
415
  print('Finish Pagination.')
416
- example_with_gists, gisting = quality_gisting(example, pages, model_name)
417
  print('Finish Gisting.')
418
- answers = quality_parallel_lookup(example_with_gists)
419
  return prompt_pagination_template, pagination, prompt_shorten_template, gisting, prompt_lookup_template, '\n\n'.join(answers)
420
 
421
 
@@ -428,6 +438,11 @@ with gr.Blocks() as demo:
428
  """)
429
  with gr.Tab('ReadAgent (QuALITY)'):
430
  llm_options = gr.Radio(llm_api_options, label="Backend LLM API", value='gemini-pro')
 
 
 
 
 
431
  with gr.Row():
432
  with gr.Column():
433
  index = gr.Dropdown(list(range(len(quality_dev))), value=13, label="QuALITY Index",)
@@ -443,7 +458,8 @@ with gr.Blocks() as demo:
443
  fn=query_model_with_quality,
444
  inputs=[
445
  index,
446
- llm_options
 
447
  ],
448
  outputs=[
449
  prompt_pagination, pagination_results,
 
5
  import re
6
  import string
7
  import time
8
+ from typing import Optional, Any
9
 
10
  import gradio as gr
11
  import openai
12
  import google.generativeai as genai
13
 
14
 
 
 
 
 
 
 
 
15
  def query_gpt_model(
16
  prompt: str,
17
  llm: str = 'gpt-3.5-turbo-1106',
18
+ client: Optional[Any] = None,
19
  temperature: float = 0.0,
20
  max_decode_steps: int = 512,
21
  seconds_to_reset_tokens: float = 30.0,
 
23
 
24
  while True:
25
  try:
26
+ raw_response = client.chat.completions.with_raw_response.create(
27
  model=llm,
28
  max_tokens=max_decode_steps,
29
  temperature=temperature,
 
51
  def query_gemini_model(
52
  prompt: str,
53
  llm: str = 'gemini-pro',
54
+ client: Optional[Any] = None,
55
  retries: int = 10,
56
  ) -> str:
57
+ del client
58
  model = genai.GenerativeModel(llm)
59
  while True and retries > 0:
60
  try:
 
71
  def query_model(
72
  prompt: str,
73
  model_name: str = 'gemini-pro',
74
+ client: Optional[Any] = None,
75
  ) -> str:
76
  model_type = model_name.split('-')[0]
77
  if model_type == "gpt":
78
+ return query_gpt_model(prompt, llm=model_name, client=client)
79
  elif model_type == "gemini":
80
+ return query_gemini_model(prompt, llm=model_name, client=client)
81
  else:
82
  raise ValueError('Unexpected model_name: ', model_name)
83
 
 
199
 
200
  def quality_pagination(example,
201
  model_name='gemini-pro',
202
+ client=None,
203
  word_limit=600,
204
  start_threshold=280,
205
  max_retires=10,
 
231
  pause_point = len(paragraphs)
232
  else:
233
  prompt = prompt_pagination_template.format(preceding, '\n'.join(passage), end_tag)
234
+ response = query_model(prompt=prompt, model_name=model_name, client=client).strip()
235
  pause_point = parse_pause_point(response)
236
  if pause_point and (pause_point <= i or pause_point > j):
237
  # process += f"prompt:\n{prompt},\nresponse:\n{response}\n"
 
263
 
264
  """
265
 
266
+ def quality_gisting(example, pages, model_name, client=None, word_limit=600, start_threshold=280, verbose=True):
267
  article = example['article']
268
  title = example['title']
269
  word_count = count_words(article)
 
272
  shortened_pages = []
273
  for i, page in enumerate(pages):
274
  prompt = prompt_shorten_template.format('\n'.join(page))
275
+ response = query_model(prompt, model_name, client)
276
  shortened_text = response.strip()
277
  shortened_pages.append(shortened_text)
278
  text_output += "[gist] page {}: {}\n\n".format(i, shortened_text)
 
322
 
323
  """
324
 
325
+ def quality_parallel_lookup(example, model_name, client, verbose=True):
326
  preprocessed_pages = example['pages']
327
  article = example['article']
328
  title = example['title']
 
359
 
360
  page_ids = []
361
 
362
+ response = query_model(prompt=prompt_lookup, model_name=model_name, client=client).strip()
363
 
364
  try: start = response.index('[')
365
  except ValueError: start = len(response)
 
390
  prompt_answer = prompt_answer_template.format(expanded_shortened_article, q, '\n'.join(options_i))
391
 
392
  model_choice = None
393
+ response = query_model(prompt=prompt_answer, model_name=model_name, client=client)
394
  response = response.strip()
395
  for j, choice in enumerate(choices):
396
  if response.startswith(f"Answer: {choice}") or response.startswith(f"Answer: {choice[1]}"):
 
407
 
408
  def query_model_with_quality(
409
  index: int,
410
+ model_name: str = 'gemini-pro',
411
+ api_key: Optional[str] = None,
412
  ):
413
+ # setup api key first
414
+ client = None
415
+ model_type = model_name.split('-')[0]
416
+ if model_type == "gpt":
417
+ # api_key = os.environ.get('OPEN_AI_KEY')
418
+ client = openai.OpenAI(api_key=api_key)
419
+ elif model_type == "gemini":
420
+ # api_key = os.environ.get('GEMINI_API_KEY')
421
+ genai.configure(api_key=api_key)
422
+
423
  example = quality_dev[index]
424
+ pages, pagination = quality_pagination(example, model_name, client)
425
  print('Finish Pagination.')
426
+ example_with_gists, gisting = quality_gisting(example, pages, model_name, client)
427
  print('Finish Gisting.')
428
+ answers = quality_parallel_lookup(example_with_gists, model_name, client)
429
  return prompt_pagination_template, pagination, prompt_shorten_template, gisting, prompt_lookup_template, '\n\n'.join(answers)
430
 
431
 
 
438
  """)
439
  with gr.Tab('ReadAgent (QuALITY)'):
440
  llm_options = gr.Radio(llm_api_options, label="Backend LLM API", value='gemini-pro')
441
+ llm_api_key = gr.Textbox(
442
+ label="Paste your OpenAI API key (sk-...) or Gemini API key",
443
+ lines=1,
444
+ type="password",
445
+ )
446
  with gr.Row():
447
  with gr.Column():
448
  index = gr.Dropdown(list(range(len(quality_dev))), value=13, label="QuALITY Index",)
 
458
  fn=query_model_with_quality,
459
  inputs=[
460
  index,
461
+ llm_options,
462
+ llm_api_key,
463
  ],
464
  outputs=[
465
  prompt_pagination, pagination_results,