Darshan-BugendaiTech commited on
Commit
c1bf31e
Β·
1 Parent(s): 864dad3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +275 -0
app.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
+ from langchain.llms import HuggingFacePipeline
3
+ import torch
4
+ import bitsandbytes as bnb
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, pipeline, BitsAndBytesConfig
6
+
7
+
8
+ from langchain.vectorstores import Chroma
9
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
10
+ from langchain.chains import RetrievalQA
11
+ from langchain.document_loaders import TextLoader
12
+ from langchain.document_loaders import UnstructuredExcelLoader
13
+ from langchain.embeddings import HuggingFaceInstructEmbeddings
14
+ from langchain.memory import ConversationBufferWindowMemory
15
+ from langchain.prompts import ChatPromptTemplate
16
+ from langchain.memory import ConversationBufferWindowMemory
17
+ import gradio as gr
18
+ from controller import Controller
19
+
20
+ # Loading Model
21
+ bnb_config = BitsAndBytesConfig(
22
+ load_in_4bit=True, # Load model weights in 4-bit format
23
+ bnb_4bit_compute_type=torch.float16 # To avoid slow inference as input type into Linear4bit is torch.float16
24
+ )
25
+
26
+ MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
27
+
28
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
29
+ model = AutoModelForCausalLM.from_pretrained(
30
+ MODEL_NAME, device_map="auto", torch_dtype=torch.float16, quantization_config=bnb_config
31
+ )
32
+
33
+ generation_config = GenerationConfig.from_pretrained(MODEL_NAME)
34
+ generation_config.max_new_tokens = 2000
35
+ generation_config.temperature = 0.7
36
+ generation_config.do_sample = True
37
+
38
+ pipe = pipeline(
39
+ "text-generation",
40
+ model=model,
41
+ tokenizer=tokenizer,
42
+ return_full_text=True,
43
+ generation_config=generation_config,
44
+ num_return_sequences=1,
45
+ eos_token_id=tokenizer.eos_token_id,
46
+ pad_token_id=tokenizer.eos_token_id,
47
+ )
48
+ zephyr_llm = HuggingFacePipeline(pipeline=pipe)
49
+
50
+ """--------------------------------------------Starting UI part--------------------------------------------"""
51
+ # Configurations
52
+ persist_directory = "db"
53
+ chunk_size = 150
54
+ chunk_overlap = 0
55
+
56
+ class Retriever:
57
+ def __init__(self):
58
+ self.text_retriever = None
59
+ self.vectordb = None
60
+ self.embeddings = None
61
+ self.memory = ConversationBufferWindowMemory(k=2, return_messages=True)
62
+
63
+ def create_and_add_embeddings(self, file):
64
+ os.makedirs("db", exist_ok=True) # Recheck this and understand reason of above
65
+
66
+ self.embeddings = HuggingFaceInstructEmbeddings(model_name="BAAI/bge-base-en-v1.5",
67
+ model_kwargs={"device": "cuda"})
68
+
69
+ loader = UnstructuredExcelLoader(file)
70
+ documents = loader.load()
71
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
72
+ texts = text_splitter.split_documents(documents)
73
+
74
+ self.vectordb = Chroma.from_documents(documents=texts,
75
+ embedding=self.embeddings,
76
+ persist_directory=persist_directory)
77
+
78
+ self.text_retriever = self.vectordb.as_retriever(search_kwargs={"k": 3})
79
+
80
+
81
+ def retrieve_text(self, query):
82
+ prompt_zephyr = ChatPromptTemplate.from_messages([
83
+ ("system", "You are an helpful and harmless AI Assistant who is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user."),
84
+ ("human", "Context: {context}\n <|user|>\n {question}\n<|assistant|>\n"),
85
+ ])
86
+
87
+ qa = RetrievalQA.from_chain_type(
88
+ llm=zephyr_llm,
89
+ chain_type="stuff",
90
+ retriever=self.text_retriever,
91
+ return_source_documents=False,
92
+ verbose=False,
93
+ chain_type_kwargs={"prompt": prompt_zephyr},
94
+ memory=self.memory,
95
+ )
96
+
97
+ response = qa.run(query)
98
+ return response
99
+
100
+ class Controller:
101
+ def __init__(self):
102
+ self.retriever = None
103
+ self.query = ""
104
+
105
+ def embed_document(self, file):
106
+ if file is not None:
107
+ self.retriever = Retriever()
108
+ self.retriever.create_and_add_embeddings(file.name)
109
+
110
+ def retrieve(self, query):
111
+ texts = self.retriever.retrieve_text(query)
112
+ return texts
113
+
114
+
115
+ # Gradio Demo for trying out the Application
116
+ import os
117
+ from controller import Controller
118
+ import gradio as gr
119
+
120
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
121
+ colors = ["#64A087", "green", "black"]
122
+
123
+ CSS = """
124
+ #question input {
125
+ font-size: 16px;
126
+ }
127
+ #app-title {
128
+ width: 100%;
129
+ margin: auto;
130
+ }
131
+ #url-textbox {
132
+ padding: 0 !important;
133
+ }
134
+ #short-upload-box .w-full {
135
+ min-height: 10rem !important;
136
+ }
137
+
138
+ #select-a-file {
139
+ display: block;
140
+ width: 100%;
141
+ }
142
+ #file-clear {
143
+ padding-top: 2px !important;
144
+ padding-bottom: 2px !important;
145
+ padding-left: 8px !important;
146
+ padding-right: 8px !important;
147
+ margin-top: 10px;
148
+ }
149
+ .gradio-container .gr-button-primary {
150
+ background: linear-gradient(180deg, #CDF9BE 0%, #AFF497 100%);
151
+ border: 1px solid #B0DCCC;
152
+ border-radius: 8px;
153
+ color: #1B8700;
154
+ }
155
+ .gradio-container.dark button#submit-button {
156
+ background: linear-gradient(180deg, #CDF9BE 0%, #AFF497 100%);
157
+ border: 1px solid #B0DCCC;
158
+ border-radius: 8px;
159
+ color: #1B8700
160
+ }
161
+ table.gr-samples-table tr td {
162
+ border: none;
163
+ outline: none;
164
+ }
165
+ table.gr-samples-table tr td:first-of-type {
166
+ width: 0%;
167
+ }
168
+ div#short-upload-box div.absolute {
169
+ display: none !important;
170
+ }
171
+ gradio-app > div > div > div > div.w-full > div, .gradio-app > div > div > div > div.w-full > div {
172
+ gap: 0px 2%;
173
+ }
174
+ gradio-app div div div div.w-full, .gradio-app div div div div.w-full {
175
+ gap: 0px;
176
+ }
177
+ gradio-app h2, .gradio-app h2 {
178
+ padding-top: 10px;
179
+ }
180
+ #answer {
181
+ overflow-y: scroll;
182
+ color: white;
183
+ background: #666;
184
+ border-color: #666;
185
+ font-size: 20px;
186
+ font-weight: bold;
187
+ }
188
+ #answer span {
189
+ color: white;
190
+ }
191
+ #answer textarea {
192
+ color:white;
193
+ background: #777;
194
+ border-color: #777;
195
+ font-size: 18px;
196
+ }
197
+ #url-error input {
198
+ color: red;
199
+ }
200
+ """
201
+
202
+ controller = Controller()
203
+
204
+
205
+ def process_pdf(file):
206
+ if file is not None:
207
+ controller.embed_document(file)
208
+ return (
209
+ gr.update(visible=True),
210
+ gr.update(visible=True),
211
+ gr.update(visible=True),
212
+ gr.update(visible=True),
213
+ )
214
+
215
+
216
+ def respond(message, history):
217
+ botmessage = controller.retrieve(message)
218
+ history.append((message, botmessage))
219
+ return "", history
220
+
221
+
222
+ def clear_everything():
223
+ return (None, None, None)
224
+
225
+
226
+ with gr.Blocks(css=CSS, title="") as demo:
227
+ gr.Markdown("# Marketing Email Generator ", elem_id="app-title")
228
+ gr.Markdown("## Upload a CSV and ask your query!", elem_id="select-a-file")
229
+ gr.Markdown(
230
+ "Drop your file here πŸ‘‡",
231
+ elem_id="select-a-file",
232
+ )
233
+ with gr.Row():
234
+ with gr.Column(scale=3):
235
+ upload = gr.File(label="Upload PDF", type="file")
236
+ with gr.Row():
237
+ clear_button = gr.Button("Clear", variant="secondary")
238
+
239
+ with gr.Column(scale=6):
240
+ chatbot = gr.Chatbot()
241
+ with gr.Row().style(equal_height=True):
242
+ with gr.Column(scale=8):
243
+ question = gr.Textbox(
244
+ show_label=False,
245
+ placeholder="e.g. What is the document about?",
246
+ lines=1,
247
+ max_lines=1,
248
+ ).style(container=False)
249
+ with gr.Column(scale=1, min_width=60):
250
+ submit_button = gr.Button(
251
+ "Send your Request πŸ€–", variant="primary", elem_id="submit-button"
252
+ )
253
+
254
+ upload.change(
255
+ fn=process_pdf,
256
+ inputs=[upload],
257
+ outputs=[
258
+ question,
259
+ clear_button,
260
+ submit_button,
261
+ chatbot,
262
+ ],
263
+ api_name="upload",
264
+ )
265
+ question.submit(respond, [question, chatbot], [question, chatbot])
266
+ submit_button.click(respond, [question, chatbot], [question, chatbot])
267
+ clear_button.click(
268
+ fn=clear_everything,
269
+ inputs=[],
270
+ outputs=[upload, question, chatbot],
271
+ api_name="clear",
272
+ )
273
+
274
+ if __name__ == "__main__":
275
+ demo.launch(enable_queue=False, debug=True, share=False)