sayakpaul HF staff commited on
Commit
417eb3c
·
verified ·
1 Parent(s): b50a68d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -21
app.py CHANGED
@@ -2,8 +2,6 @@ import os
2
  import json
3
  import re
4
  from sentence_transformers import SentenceTransformer, CrossEncoder
5
- from huggingface_hub import hf_hub_download
6
- from openai import OpenAI
7
  import hnswlib
8
  import numpy as np
9
  from typing import Iterator
@@ -12,8 +10,11 @@ import gradio as gr
12
  import pandas as pd
13
  import torch
14
 
 
15
  from transformers import AutoTokenizer
16
 
 
 
17
  MAX_MAX_NEW_TOKENS = 2048
18
  DEFAULT_MAX_NEW_TOKENS = 1024
19
  MAX_INPUT_TOKEN_LENGTH = 4000
@@ -30,15 +31,11 @@ torch_device = "cuda" if torch.cuda.is_available() else "cpu"
30
  print("Running on device:", torch_device)
31
  print("CPU threads:", torch.get_num_threads())
32
 
 
33
  biencoder = SentenceTransformer("intfloat/e5-large-v2", device=torch_device)
34
  cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-12-v2", max_length=512, device=torch_device)
35
 
36
- model_id = "HuggingFaceH4/zephyr-7b-beta"
37
- client = OpenAI(
38
- base_url=f"https://api-inference.huggingface.co/models/{model_id}/v1",
39
- api_key=os.environ["HUGGINGFACE_TOKEN"],
40
- )
41
- tokenizer = AutoTokenizer.from_pretrained(model_id)
42
 
43
 
44
  def create_qa_prompt(query, relevant_chunks):
@@ -104,17 +101,17 @@ def get_completion(
104
  if system_prompt is not None:
105
  messages.append({"role": "system", "content": system_prompt})
106
  messages.append({"role": "user", "content": prompt})
107
- response = client.chat.completions.create(
108
  model=model,
109
  messages=messages,
110
- # temperature=temperature, # this is the degree of randomness of the model's output
111
  max_tokens=max_new_tokens, # this is the number of new tokens being generated
112
- # top_p=top_p,
113
- # top_k=top_k,
114
  stream=stream,
115
- # debug=debug,
116
  )
117
- return response.choices[0].message.content if not stream else response
118
 
119
 
120
  # load the index for the Diffusers docs
@@ -190,6 +187,8 @@ DESCRIPTION = """
190
  # 🧨 Diffusers Docs QA Chatbot 🤗
191
  """
192
 
 
 
193
  LICENSE = """
194
  <p/>
195
 
@@ -198,10 +197,6 @@ As a derivate work of [Llama-2-70b-chat](https://huggingface.co/meta-llama/Llama
198
  this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/llama-2-70b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/llama-2-70b-chat/blob/main/USE_POLICY.md).
199
  """
200
 
201
- if not torch.cuda.is_available():
202
- DESCRIPTION += "This application is almost exactly copied from [smangrul/PEFT-Docs-QA-Chatbot](https://huggingface.co/spaces/smangrul/PEFT-Docs-QA-Chatbot).\n Related code: [pacman100/DHS-LLM-Workshop](https://github.com/pacman100/DHS-LLM-Workshop/blob/main/6_Module/)."
203
-
204
-
205
  def clear_and_save_textbox(message: str) -> tuple[str, str]:
206
  return "", message
207
 
@@ -262,7 +257,7 @@ def generate(
262
 
263
  output = ""
264
  for idx, response in enumerate(generator):
265
- token = response.choices[0].delta.content or ""
266
  output += token
267
  if idx == 0:
268
  history.append((message, output))
@@ -293,7 +288,7 @@ def check_input_token_length(message: str, chat_history: list[tuple[str, str]],
293
  )
294
 
295
 
296
- search_index = load_hnsw_index(SEARCH_INDEX) # create_hnsw_index(EMBEDDINGS_FILE)
297
  data_df = pd.read_parquet(DOCUMENT_DATASET).reset_index()
298
  with gr.Blocks(css="style.css") as demo:
299
  gr.Markdown(DESCRIPTION)
@@ -467,4 +462,4 @@ with gr.Blocks(css="style.css") as demo:
467
  api_name=False,
468
  )
469
 
470
- demo.queue(max_size=20).launch(debug=True, share=False)
 
2
  import json
3
  import re
4
  from sentence_transformers import SentenceTransformer, CrossEncoder
 
 
5
  import hnswlib
6
  import numpy as np
7
  from typing import Iterator
 
10
  import pandas as pd
11
  import torch
12
 
13
+ from easyllm.clients import huggingface
14
  from transformers import AutoTokenizer
15
 
16
+ huggingface.prompt_builder = "llama2"
17
+ huggingface.api_key = os.environ["HUGGINGFACE_TOKEN"]
18
  MAX_MAX_NEW_TOKENS = 2048
19
  DEFAULT_MAX_NEW_TOKENS = 1024
20
  MAX_INPUT_TOKEN_LENGTH = 4000
 
31
  print("Running on device:", torch_device)
32
  print("CPU threads:", torch.get_num_threads())
33
 
34
+ model_id = "meta-llama/Llama-2-70b-chat-hf"
35
  biencoder = SentenceTransformer("intfloat/e5-large-v2", device=torch_device)
36
  cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-12-v2", max_length=512, device=torch_device)
37
 
38
+ tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=os.environ["HUGGINGFACE_TOKEN"])
 
 
 
 
 
39
 
40
 
41
  def create_qa_prompt(query, relevant_chunks):
 
101
  if system_prompt is not None:
102
  messages.append({"role": "system", "content": system_prompt})
103
  messages.append({"role": "user", "content": prompt})
104
+ response = huggingface.ChatCompletion.create(
105
  model=model,
106
  messages=messages,
107
+ temperature=temperature, # this is the degree of randomness of the model's output
108
  max_tokens=max_new_tokens, # this is the number of new tokens being generated
109
+ top_p=top_p,
110
+ top_k=top_k,
111
  stream=stream,
112
+ debug=debug,
113
  )
114
+ return response["choices"][0]["message"]["content"] if not stream else response
115
 
116
 
117
  # load the index for the Diffusers docs
 
187
  # 🧨 Diffusers Docs QA Chatbot 🤗
188
  """
189
 
190
+ DESCRIPTION += "This application is almost exactly copied from [smangrul/Diffusers-Docs-QA-Chatbot](https://huggingface.co/spaces/smangrul/Diffusers-Docs-QA-Chatbot).\n Related code: [pacman100/DHS-LLM-Workshop](https://github.com/pacman100/DHS-LLM-Workshop/blob/main/6_Module/)."
191
+
192
  LICENSE = """
193
  <p/>
194
 
 
197
  this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/llama-2-70b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/llama-2-70b-chat/blob/main/USE_POLICY.md).
198
  """
199
 
 
 
 
 
200
  def clear_and_save_textbox(message: str) -> tuple[str, str]:
201
  return "", message
202
 
 
257
 
258
  output = ""
259
  for idx, response in enumerate(generator):
260
+ token = response["choices"][0]["delta"].get("content", "") or ""
261
  output += token
262
  if idx == 0:
263
  history.append((message, output))
 
288
  )
289
 
290
 
291
+ search_index = load_hnsw_index(SEARCH_INDEX) # create_hnsw_index(EMBEDDINGS_FILE)
292
  data_df = pd.read_parquet(DOCUMENT_DATASET).reset_index()
293
  with gr.Blocks(css="style.css") as demo:
294
  gr.Markdown(DESCRIPTION)
 
462
  api_name=False,
463
  )
464
 
465
+ demo.queue(max_size=20).launch(debug=True, share=False)