Manoj21k commited on
Commit
7bf90b9
·
1 Parent(s): 66fa80f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +214 -0
app.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import re
4
+ import streamlit as st
5
+ from transformers import AutoTokenizer
6
+ import pandas as pd
7
+
8
+ # Importing Hugging Face models and libraries
9
+ from sentence_transformers import SentenceTransformer, CrossEncoder
10
+ import hnswlib
11
+ import numpy as np
12
+ from typing import Iterator
13
+
14
+ from easyllm.clients import huggingface
15
+
16
+ # Set Hugging Face API key
17
+ huggingface.prompt_builder = "llama2"
18
+ huggingface.api_key = os.environ["HUGGINGFACE_TOKEN"]
19
+
20
+ # Constants
21
+ MAX_MAX_NEW_TOKENS = 2048
22
+ DEFAULT_MAX_NEW_TOKENS = 1024
23
+ MAX_INPUT_TOKEN_LENGTH = 4000
24
+ EMBED_DIM = 1024
25
+ K = 10
26
+ EF = 100
27
+ SEARCH_INDEX = "search_index.bin"
28
+ EMBEDDINGS_FILE = "embeddings.npy"
29
+ DOCUMENT_DATASET = "chunked_data.parquet"
30
+ COSINE_THRESHOLD = 0.7
31
+
32
+ torch_device = "cuda" if torch.cuda.is_available() else "cpu"
33
+ print("Running on device:", torch_device)
34
+ print("CPU threads:", torch.get_num_threads())
35
+
36
+ model_id = "meta-llama/Llama-2-70b-chat-hf"
37
+ biencoder = SentenceTransformer("intfloat/e5-large-v2", device=torch_device)
38
+ cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-12-v2", max_length=512, device=torch_device)
39
+
40
+ tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=os.environ["HUGGINGFACE_TOKEN"])
41
+
42
+ # Initialize Streamlit app
43
+ st.title("PEFT Docs QA Chatbot")
44
+
45
+ # Function to create QA prompt
46
+ def create_qa_prompt(query, relevant_chunks):
47
+ stuffed_context = " ".join(relevant_chunks)
48
+ return f"""\
49
+ Use the following pieces of context given in to answer the question at the end. \
50
+ If you don't know the answer, just say that you don't know, don't try to make up an answer. \
51
+ Keep the answer short and succinct.
52
+
53
+ Context: {stuffed_context}
54
+ Question: {query}
55
+ Helpful Answer: \
56
+ """
57
+
58
+ # Function to generate a Streamlit app response
59
+ def generate_response(message, history_with_input, system_prompt, max_new_tokens, temperature, top_p, top_k):
60
+ if max_new_tokens > MAX_MAX_NEW_TOKENS:
61
+ raise ValueError
62
+ history = history_with_input[:-1]
63
+ if len(history) > 0:
64
+ condensed_query = generate_condensed_query(message, history)
65
+ print(f"{condensed_query=}")
66
+ else:
67
+ condensed_query = message
68
+ query_embedding = create_query_embedding(condensed_query)
69
+ relevant_chunks = find_nearest_neighbors(query_embedding)
70
+ reranked_relevant_chunks = rerank_chunks_with_cross_encoder(condensed_query, relevant_chunks)
71
+ qa_prompt = create_qa_prompt(condensed_query, reranked_relevant_chunks)
72
+ print(f"{qa_prompt=}")
73
+ generator = get_completion(
74
+ qa_prompt,
75
+ system_prompt=system_prompt,
76
+ stream=True,
77
+ max_new_tokens=max_new_tokens,
78
+ temperature=temperature,
79
+ top_k=top_k,
80
+ top_p=top_p,
81
+ )
82
+
83
+ output = ""
84
+ for idx, response in generator:
85
+ token = response["choices"][0]["delta"].get("content", "") or ""
86
+ output += token
87
+ if idx == 0:
88
+ history.append((message, output))
89
+ else:
90
+ history[-1] = (message, output)
91
+
92
+ history = [
93
+ (wrap_html_code(history[i][0].strip()), wrap_html_code(history[i][1].strip()))
94
+ for i in range(0, len(history))
95
+ ]
96
+ return history
97
+
98
+ # Function to get input token length
99
+ def get_input_token_length(message, chat_history, system_prompt):
100
+ prompt = get_prompt(message, chat_history, system_prompt)
101
+ input_ids = tokenizer([prompt], return_tensors="np", add_special_tokens=False)["input_ids"]
102
+ return input_ids.shape[-1]
103
+
104
+ # Function to create a condensed query
105
+ def generate_condensed_query(query, history):
106
+ chat_history = ""
107
+ for turn in history:
108
+ chat_history += f"Human: {turn[0]}\n"
109
+ chat_history += f"Assistant: {turn[1]}\n"
110
+
111
+ condense_question_prompt = create_condense_question_prompt(query, chat_history)
112
+ condensed_question = json.loads(get_completion(condense_question_prompt, max_new_tokens=64, temperature=0))
113
+ return condensed_question["question"]
114
+
115
+ # Function to load the HNSW index
116
+ def load_hnsw_index(index_file):
117
+ index = hnswlib.Index(space="ip", dim=EMBED_DIM)
118
+ index.load_index(index_file)
119
+ return index
120
+
121
+ # Function to create the HNSW index
122
+ def create_hnsw_index(embeddings_file, M=16, efC=100):
123
+ embeddings = np.load(embeddings_file)
124
+ num_dim = embeddings.shape[1]
125
+ ids = np.arange(embeddings.shape[0]
126
+ index = hnswlib.Index(space="ip", dim=num_dim)
127
+ index.init_index(max_elements=embeddings.shape[0], ef_construction=efC, M=M)
128
+ index.add_items(embeddings, ids)
129
+ return index
130
+
131
+ # Function to create a query embedding
132
+ def create_query_embedding(query):
133
+ embedding = biencoder.encode([query], normalize_embeddings=True)[0]
134
+ return embedding
135
+
136
+ # Function to find nearest neighbors
137
+ def find_nearest_neighbors(query_embedding):
138
+ search_index.set_ef(EF)
139
+ labels, distances = search_index.knn_query(query_embedding, k=K)
140
+ labels = [label for label, distance in zip(labels[0], distances[0]) if (1 - distance) >= COSINE_THRESHOLD]
141
+ relevant_chunks = data_df.iloc[labels]["chunk_content"].tolist()
142
+ return relevant_chunks
143
+
144
+ # Function to rerank chunks with the cross encoder
145
+ def rerank_chunks_with_cross_encoder(query, chunks):
146
+ pairs = [(query, chunk) for chunk in chunks]
147
+ scores = cross_encoder.predict(pairs)
148
+ sorted_chunks = [chunk for _, chunk in sorted(zip(scores, chunks), reverse=True)]
149
+ return sorted_chunks
150
+
151
+ # Function to wrap HTML code
152
+ def wrap_html_code(text):
153
+ pattern = r"<.*?>"
154
+ matches = re.findall(pattern, text)
155
+ if len(matches) > 0:
156
+ return f"```{text}```"
157
+ else:
158
+ return text
159
+
160
+ # Load the HNSW index for the PEFT docs
161
+ search_index = create_hnsw_index(EMBEDDINGS_FILE) # load_hnsw_index(SEARCH_INDEX)
162
+ data_df = pd.read_parquet(DOCUMENT_DATASET).reset_index()
163
+
164
+ # Streamlit UI
165
+ st.markdown("Welcome to the PEFT Docs QA Chatbot.")
166
+ message = st.text_input("You:", "")
167
+ history_with_input = []
168
+ system_prompt = st.text_area("System prompt", DEFAULT_SYSTEM_PROMPT)
169
+ max_new_tokens = st.slider("Max new tokens", 1, MAX_MAX_NEW_TOKENS, DEFAULT_MAX_NEW_TOKENS)
170
+ temperature = st.slider("Temperature", 0.1, 4.0, 0.2, 0.1)
171
+ top_p = st.slider("Top-p (nucleus sampling)", 0.05 , 1.0, 0.05)
172
+ top_k = st.slider("Top-k", 1, 1000, 50)
173
+
174
+ if st.button("Submit"):
175
+ if message:
176
+ try:
177
+ history_with_input, response = generate_response(
178
+ message, history_with_input, system_prompt, max_new_tokens, temperature, top_p, top_k
179
+ )
180
+ st.write("Chatbot:", response[-1][1])
181
+ except Exception as e:
182
+ st.error(f"An error occurred: {e}")
183
+ else:
184
+ st.warning("Please enter a message.")
185
+
186
+ if st.button("Retry"):
187
+ if history_with_input:
188
+ history_with_input, _ = generate_response(
189
+ message, history_with_input, system_prompt, max_new_tokens, temperature, top_p, top_k
190
+ )
191
+ st.write("Chatbot:", history_with_input[-1][1])
192
+ else:
193
+ st.warning("No previous message to retry.")
194
+
195
+ if st.button("Undo"):
196
+ if history_with_input:
197
+ _, last_message = history_with_input.pop()
198
+ st.text_area("You:", last_message, height=50)
199
+ else:
200
+ st.warning("No previous message to undo.")
201
+
202
+ if st.button("Clear"):
203
+ message = ""
204
+ history_with_input = []
205
+ system_prompt = DEFAULT_SYSTEM_PROMPT
206
+ max_new_tokens = DEFAULT_MAX_NEW_TOKENS
207
+ temperature = 0.2
208
+ top_p = 0.95
209
+ top_k = 50
210
+
211
+ st.sidebar.markdown(
212
+ "This is a Streamlit app for the PEFT Docs QA Chatbot. Enter your message, configure advanced options, and interact with the chatbot."
213
+ )
214
+