IC4T commited on
Commit
4c2dd0f
β€’
1 Parent(s): db57198
Files changed (3) hide show
  1. .env +6 -0
  2. app.py +261 -4
  3. requirements.txt +17 -0
.env ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ PERSIST_DIRECTORY=db
2
+ MODEL_TYPE=dolly-v2-3b
3
+ MODEL_PATH=databricks/dolly-v2-3b
4
+ EMBEDDINGS_MODEL_NAME=sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2
5
+ MODEL_N_CTX=1000
6
+ TARGET_SOURCE_CHUNKS=4
app.py CHANGED
@@ -1,7 +1,264 @@
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
1
+ # Disclamer: This code is not written by me. Its taken from https://github.com/imartinez/privateGPT/pull/91.
2
+ # All credit goes to `vnk8071` as I mentioned in the video.
3
+ # As this code was still in the pull request while I was creating the video, did some modifications so that it works for me locally.
4
+ import os
5
+
6
+ os.system('pip install -e ./langchain')
7
+
8
  import gradio as gr
9
+ from dotenv import load_dotenv
10
+ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
11
+ from langchain.chains import RetrievalQA
12
+ from langchain.embeddings import LlamaCppEmbeddings
13
+ # from langchain.llms import GPT4All, LlamaCpp
14
+ from langchain.vectorstores import Chroma
15
+ from langchain.embeddings.huggingface import HuggingFaceEmbeddings
16
+ from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings#, SentenceTransformerEmbeddings
17
+ from langchain.prompts.prompt import PromptTemplate
18
+ from langchain import PromptTemplate, LLMChain
19
+ from langchain.llms import HuggingFacePipeline
20
+
21
+ from training.generate import InstructionTextGenerationPipeline, load_model_tokenizer_for_generate
22
+ # from googletrans import Translator
23
+ # translator = Translator()
24
+
25
+ load_dotenv()
26
+
27
+ embeddings_model_name = os.environ.get("EMBEDDINGS_MODEL_NAME")
28
+ persist_directory = os.environ.get('PERSIST_DIRECTORY')
29
+
30
+ model_type = os.environ.get('MODEL_TYPE')
31
+ model_path = os.environ.get('MODEL_PATH')
32
+ model_n_ctx = int(os.environ.get('MODEL_N_CTX'))
33
+ target_source_chunks = int(os.environ.get('TARGET_SOURCE_CHUNKS',4))
34
+
35
+ # PERSIST_DIRECTORY=db
36
+ # MODEL_TYPE=dolly-v2-3b
37
+ # MODEL_PATH=/media/siiva/DataStore/LLMs/cache/dolly-v2-3b
38
+ # EMBEDDINGS_MODEL_NAME=sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2
39
+ # MODEL_N_CTX=1000
40
+ # TARGET_SOURCE_CHUNKS=4
41
+
42
+
43
+ from constants import CHROMA_SETTINGS
44
+ # embeddings_model_name = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
45
+ # persist_directory = "db"
46
+ # model_type = "dolly-v2-3b"
47
+ # model_path = "/media/siiva/DataStore/LLMs/cache/dolly-v2-3b"
48
+ # target_source_chunks = 3
49
+ # model_n_ctx = 1000
50
+
51
+ embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)
52
+ db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS)
53
+ retriever = db.as_retriever(search_kwargs={"k": target_source_chunks})
54
+ # Prepare the LLM
55
+ callbacks = [StreamingStdOutCallbackHandler()]
56
+
57
+ match model_type:
58
+ case "dolly-v2-3b":
59
+ model, tokenizer = load_model_tokenizer_for_generate(model_path)
60
+ llm = HuggingFacePipeline(
61
+ pipeline=InstructionTextGenerationPipeline(
62
+ # Return the full text, because this is what the HuggingFacePipeline expects.
63
+ model=model, tokenizer=tokenizer, return_full_text=True, task="text-generation", max_new_tokens=model_n_ctx))#, max_new_tokens=model_n_ctx
64
+ #))
65
+ case "GPT4All":
66
+ llm = GPT4All(model=model_path, n_ctx=model_n_ctx, backend='gptj', callbacks=callbacks, verbose=False)
67
+ case _default:
68
+ print(f"Model {model_type} not supported!")
69
+ exit;
70
+ qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True)
71
+
72
+ server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
73
+
74
+ def clear_history(request: gr.Request):
75
+ state = None
76
+ return ([], state, "")
77
+
78
+ def post_process_code(code):
79
+ sep = "\n```"
80
+ if sep in code:
81
+ blocks = code.split(sep)
82
+ if len(blocks) % 2 == 1:
83
+ for i in range(1, len(blocks), 2):
84
+ blocks[i] = blocks[i].replace("\\_", "_")
85
+ code = sep.join(blocks)
86
+ return code
87
+
88
+ def post_process_answer(answer, source):
89
+ answer += f"<br><br>Source: {source}"
90
+ answer = answer.replace("\n", "<br>")
91
+ return answer
92
+
93
+ def predict(
94
+ question: str,
95
+ # system_content: str,
96
+ # embeddings_model_name: str,
97
+ # persist_directory: str,
98
+ # model_type: str,
99
+ # model_path: str,
100
+ # model_n_ctx: int,
101
+ # target_source_chunks: int,
102
+ chatbot: list = [],
103
+ history: list = [],
104
+ ):
105
+ # try:
106
+ # embeddings_model_name = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
107
+ # persist_directory = "db"
108
+ # model_type = "dolly-v2-3b"
109
+ # model_path = "/media/siiva/DataStore/LLMs/cache/dolly-v2-3b"
110
+ # target_source_chunks = 3
111
+ # model_n_ctx = 1000
112
+
113
+ # embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)
114
+ # db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS)
115
+ # retriever = db.as_retriever(search_kwargs={"k": target_source_chunks})
116
+ # # Prepare the LLM
117
+ # callbacks = [StreamingStdOutCallbackHandler()]
118
+
119
+ # match model_type:
120
+ # case "dolly-v2-3b":
121
+ # model, tokenizer = load_model_tokenizer_for_generate(model_path)
122
+ # llm = HuggingFacePipeline(
123
+ # pipeline=InstructionTextGenerationPipeline(
124
+ # # Return the full text, because this is what the HuggingFacePipeline expects.
125
+ # model=model, tokenizer=tokenizer, return_full_text=True, task="text-generation", max_new_tokens=model_n_ctx
126
+ # ))
127
+ # case "GPT4All":
128
+ # llm = GPT4All(model=model_path, n_ctx=model_n_ctx, backend='gptj', callbacks=callbacks, verbose=False)
129
+ # case _default:
130
+ # print(f"Model {model_type} not supported!")
131
+ # exit;
132
+ # qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True)
133
+
134
+ # Get the answer from the chain
135
+ # prompt = system_content + f"\n Question: {question}"
136
+ prompt = f"{question}"
137
+ # res = qa(prompt)
138
+
139
+ no_input_prompt = PromptTemplate(input_variables=[], template=prompt, dest_language='en')#src_language='id',
140
+ no_input_prompt.format()
141
+
142
+ query = no_input_prompt.translate()
143
+
144
+
145
+ # prompt_trans = translator.translate(prompt, src='en', dest='id')
146
+ # print(prompt_trans.text)
147
+
148
+ # result = qa({"question": query, "chat_history": chat_history})
149
+ llm_response = qa(query)
150
+
151
+ answer, docs = llm_response['result'], llm_response['source_documents']
152
+ no_input_prompt = PromptTemplate(input_variables=[], template=answer, dest_language='id')
153
+ no_input_prompt.format()
154
+ answer = no_input_prompt.translate()
155
+ # answer = post_process_answer(answer, docs)
156
+ history.append(question)
157
+ history.append(answer)
158
+ chatbot = [(history[i], history[i + 1]) for i in range(0, len(history), 2)]
159
+ return chatbot, history
160
+
161
+ # except Exception as e:
162
+ # history.append("")
163
+ # answer = server_error_msg + f" (error_code: 503)"
164
+ # history.append(answer)
165
+ # chatbot = [(history[i], history[i + 1]) for i in range(0, len(history), 2)]
166
+ # return chatbot, history
167
+
168
+ def reset_textbox():
169
+ return gr.update(value="")
170
+
171
+ title = """<h1 align="center">Chat with QuGPT πŸ€–</h1>"""
172
+
173
+ # def add_text(history, text):
174
+ # history = history + [(text, None)]
175
+ # return history, ""
176
+
177
+ def bot(history):
178
+ response = "**That's cool!**"
179
+ history[-1][1] = response
180
+ return history
181
+
182
+ with gr.Blocks(
183
+ css="""
184
+ footer .svelte-1lyswbr {display: none !important;}
185
+ #col_container {margin-left: auto; margin-right: auto;}
186
+ #chatbot .wrap.svelte-13f7djk {height: 70vh; max-height: 70vh}
187
+ #chatbot .message.user.svelte-13f7djk.svelte-13f7djk {width:fit-content; background:orange; border-bottom-right-radius:0}
188
+ #chatbot .message.bot.svelte-13f7djk.svelte-13f7djk {width:fit-content; padding-left: 16px; border-bottom-left-radius:0}
189
+ #chatbot .pre {border:2px solid white;}
190
+ pre {
191
+ white-space: pre-wrap; /* Since CSS 2.1 */
192
+ white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
193
+ white-space: -pre-wrap; /* Opera 4-6 */
194
+ white-space: -o-pre-wrap; /* Opera 7 */
195
+ word-wrap: break-word; /* Internet Explorer 5.5+ */
196
+ }
197
+ """
198
+ ) as demo:
199
+ gr.HTML(title)
200
+ with gr.Row():
201
+ # with gr.Column(elem_id="col_container", scale=0.3):
202
+ # with gr.Accordion("Prompt", open=True):
203
+ # system_content = gr.Textbox(value="You are QuGPT which built with LangChain and dolly-v2 and sentence-transformer.", show_label=False)
204
+ # with gr.Accordion("Config", open=True):
205
+ # embeddings_model_name = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"#gr.Textbox(value="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", label="embeddings_model_name")
206
+ # persist_directory = "db" #gr.Textbox(value="db", label="persist_directory")
207
+ # model_type = "dolly-v2-3b" #gr.Textbox(value="dolly-v2-3b", label="model_type")
208
+ # model_path = "/media/siiva/DataStore/LLMs/cache/dolly-v2-3b" #gr.Textbox(value="/media/siiva/DataStore/LLMs/cache/dolly-v2-3b", label="model_path")
209
+ # target_source_chunks = 3
210
+ # # gr.Slider(
211
+ # # minimum=1,
212
+ # # maximum=5,
213
+ # # value=2,
214
+ # # step=1,
215
+ # # interactive=True,
216
+ # # label="target_source_chunks",
217
+ # # )
218
+
219
+ # model_n_ctx = 1000
220
+ # gr.Slider(
221
+ # minimum=32,
222
+ # maximum=4096,
223
+ # value=1000,
224
+ # step=32,
225
+ # interactive=True,
226
+ # label="model_n_ctx",
227
+ # )
228
+
229
+ with gr.Column(elem_id="col_container"):
230
+ chatbot = gr.Chatbot(elem_id="chatbot", label="QuGPT")
231
+ question = gr.Textbox(placeholder="Ask something", show_label=False, value="")
232
+ state = gr.State([])
233
+ with gr.Row():
234
+ with gr.Column():
235
+ submit_btn = gr.Button(value="πŸš€ Send")
236
+ with gr.Column():
237
+ clear_btn = gr.Button(value="πŸ—‘οΈ Clear history")
238
+
239
+ question.submit(
240
+ predict,
241
+ # [question, system_content, embeddings_model_name, persist_directory, model_type, model_path, model_n_ctx, target_source_chunks, chatbot, state],
242
+ [question, chatbot, state],
243
+
244
+ [chatbot, state],
245
+ )
246
+ submit_btn.click(
247
+ predict,
248
+ # [question, system_content, embeddings_model_name, persist_directory, model_type, model_path, model_n_ctx, target_source_chunks, chatbot, state],
249
+ [question, chatbot, state],
250
+ [chatbot, state],
251
+ )
252
+ submit_btn.click(reset_textbox, [], [question])
253
+ clear_btn.click(clear_history, None, [chatbot, state, question])
254
+ question.submit(reset_textbox, [], [question])
255
+ # demo.queue(concurrency_count=10, status_update_rate="auto")
256
 
257
+ # question.submit(predict, [question, system_content, embeddings_model_name, persist_directory, model_type, model_path, model_n_ctx, target_source_chunks, chatbot, state], [chatbot, state]).then(
258
+ # predict, chatbot
259
+ # )
260
+
261
+
262
+ #demo.launch(server_name=args.server_name, server_port=args.server_port, share=args.share, debug=args.debug)
263
+ demo.launch()
264
 
 
 
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ chromadb==0.3.21
2
+ duckdb==0.7.1
3
+ googletrans==3.1.0a0
4
+ gradio==3.28.3
5
+ gradio_client==0.2.0
6
+ huggingface-hub==0.13.4
7
+ pypdf==3.8.1
8
+ python-dotenv==1.0.0
9
+ sentence-transformers==2.2.2
10
+ tiktoken==0.3.3
11
+ tokenizers==0.13.3
12
+ torch==2.0.0
13
+ transformers @ git+https://github.com/huggingface/transformers@ef42c2c487260c2a0111fa9d17f2507d84ddedea
14
+ unstructured==0.6.2
15
+ xformers==0.0.19
16
+
17
+