MohamedAAK commited on
Commit
51c7ae7
1 Parent(s): 521d025

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +264 -17
app.py CHANGED
@@ -1,25 +1,272 @@
1
  import gradio as gr
2
- import openai
3
- import os
4
-
 
 
 
 
 
 
 
 
 
5
  import torch
 
 
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  def ask(text):
9
 
10
- tokenizer = AutoTokenizer.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v1")
11
- model = AutoModelForCausalLM.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v1",torch_dtype=torch.bfloat16)
12
-
13
- prompt = f"<human>: {text}\n<bot>:"
14
- inputs = tokenizer(prompt, return_tensors='pt').to(model.device)
15
-
16
- input_length = inputs.input_ids.shape[1]
17
- outputs = model.generate(**inputs, max_new_tokens=48, temperature=0.7,
18
- return_dict_in_generate=True)
19
-
20
- tokens = outputs.sequences[0, input_length:]
21
- return tokenizer.decode(tokens)
22
-
23
  with gr.Blocks() as server:
24
  with gr.Tab("LLM Inferencing"):
25
 
 
1
  import gradio as gr
2
+ from langchain.document_loaders import PDFPlumberLoader
3
+ from langchain.text_splitter import CharacterTextSplitter, TokenTextSplitter
4
+ from transformers import pipeline
5
+ from langchain.prompts import PromptTemplate
6
+ from langchain.chat_models import ChatOpenAI
7
+ from langchain.vectorstores import Chroma
8
+ from langchain.chains import RetrievalQA
9
+ from langchain import HuggingFacePipeline
10
+ from langchain.embeddings import HuggingFaceInstructEmbeddings, HuggingFaceEmbeddings
11
+ from langchain.embeddings.openai import OpenAIEmbeddings
12
+ from langchain.llms import OpenAI
13
+ from constants import *
14
  import torch
15
+ import os
16
+ import re
17
  from transformers import AutoTokenizer, AutoModelForCausalLM
18
+ from transformers import pipeline
19
+
20
+ EMB_INSTRUCTOR_XL = "hkunlp/instructor-xl"
21
+ EMB_SBERT_MPNET_BASE = "sentence-transformers/all-mpnet-base-v2"
22
+ LLM_FLAN_T5_XXL = "google/flan-t5-xxl"
23
+ LLM_FLAN_T5_XL = "google/flan-t5-xl"
24
+ LLM_FASTCHAT_T5_XL = "lmsys/fastchat-t5-3b-v1.0"
25
+ LLM_FLAN_T5_SMALL = "google/flan-t5-small"
26
+ LLM_FLAN_T5_BASE = "google/flan-t5-base"
27
+ LLM_FLAN_T5_LARGE = "google/flan-t5-large"
28
+ LLM_FALCON_SMALL = "tiiuae/falcon-7b-instruct"
29
+
30
+
31
+
32
+ class PdfQA:
33
+ def __init__(self,config:dict = {}):
34
+ self.config = config
35
+ self.embedding = None
36
+ self.vectordb = None
37
+ self.llm = None
38
+ self.qa = None
39
+ self.retriever = None
40
+
41
+ # The following class methods are useful to create global GPU model instances
42
+ # This way we don't need to reload models in an interactive app,
43
+ # and the same model instance can be used across multiple user sessions
44
+ @classmethod
45
+ def create_instructor_xl(cls):
46
+ device = "cuda" if torch.cuda.is_available() else "cpu"
47
+ return HuggingFaceInstructEmbeddings(model_name=EMB_INSTRUCTOR_XL, model_kwargs={"device": device})
48
+
49
+ @classmethod
50
+ def create_sbert_mpnet(cls):
51
+ device = "cuda" if torch.cuda.is_available() else "cpu"
52
+ return HuggingFaceEmbeddings(model_name=EMB_SBERT_MPNET_BASE, model_kwargs={"device": device})
53
+
54
+ @classmethod
55
+ def create_flan_t5_xxl(cls, load_in_8bit=False):
56
+ # Local flan-t5-xxl with 8-bit quantization for inference
57
+ # Wrap it in HF pipeline for use with LangChain
58
+ return pipeline(
59
+ task="text2text-generation",
60
+ model="google/flan-t5-xxl",
61
+ max_new_tokens=200,
62
+ model_kwargs={"device_map": "auto", "load_in_8bit": load_in_8bit, "max_length": 512, "temperature": 0.}
63
+ )
64
+ @classmethod
65
+ def create_flan_t5_xl(cls, load_in_8bit=False):
66
+ return pipeline(
67
+ task="text2text-generation",
68
+ model="google/flan-t5-xl",
69
+ max_new_tokens=200,
70
+ model_kwargs={"device_map": "auto", "load_in_8bit": load_in_8bit, "max_length": 512, "temperature": 0.}
71
+ )
72
+
73
+ @classmethod
74
+ def create_flan_t5_small(cls, load_in_8bit=False):
75
+ # Local flan-t5-small for inference
76
+ # Wrap it in HF pipeline for use with LangChain
77
+ model="google/flan-t5-small"
78
+ tokenizer = AutoTokenizer.from_pretrained(model)
79
+ return pipeline(
80
+ task="text2text-generation",
81
+ model=model,
82
+ tokenizer = tokenizer,
83
+ max_new_tokens=100,
84
+ model_kwargs={"device_map": "auto", "load_in_8bit": load_in_8bit, "max_length": 512, "temperature": 0.}
85
+ )
86
+ @classmethod
87
+ def create_flan_t5_base(cls, load_in_8bit=False):
88
+ # Wrap it in HF pipeline for use with LangChain
89
+ model="google/flan-t5-base"
90
+ tokenizer = AutoTokenizer.from_pretrained(model)
91
+ return pipeline(
92
+ task="text2text-generation",
93
+ model=model,
94
+ tokenizer = tokenizer,
95
+ max_new_tokens=100,
96
+ model_kwargs={"device_map": "auto", "load_in_8bit": load_in_8bit, "max_length": 512, "temperature": 0.}
97
+ )
98
+ @classmethod
99
+ def create_flan_t5_large(cls, load_in_8bit=False):
100
+ # Wrap it in HF pipeline for use with LangChain
101
+ model="google/flan-t5-large"
102
+ tokenizer = AutoTokenizer.from_pretrained(model)
103
+ return pipeline(
104
+ task="text2text-generation",
105
+ model=model,
106
+ tokenizer = tokenizer,
107
+ max_new_tokens=100,
108
+ model_kwargs={"device_map": "auto", "load_in_8bit": load_in_8bit, "max_length": 512, "temperature": 0.}
109
+ )
110
+ @classmethod
111
+ def create_fastchat_t5_xl(cls, load_in_8bit=False):
112
+ return pipeline(
113
+ task="text2text-generation",
114
+ model = "lmsys/fastchat-t5-3b-v1.0",
115
+ max_new_tokens=100,
116
+ model_kwargs={"device_map": "auto", "load_in_8bit": load_in_8bit, "max_length": 512, "temperature": 0.}
117
+ )
118
+
119
+ @classmethod
120
+ def create_falcon_instruct_small(cls, load_in_8bit=False):
121
+ model = "tiiuae/falcon-7b-instruct"
122
+
123
+ tokenizer = AutoTokenizer.from_pretrained(model)
124
+ hf_pipeline = pipeline(
125
+ task="text-generation",
126
+ model = model,
127
+ tokenizer = tokenizer,
128
+ trust_remote_code = True,
129
+ max_new_tokens=100,
130
+ model_kwargs={
131
+ "device_map": "auto",
132
+ "load_in_8bit": load_in_8bit,
133
+ "max_length": 512,
134
+ "temperature": 0.01,
135
+ "torch_dtype":torch.bfloat16,
136
+ }
137
+ )
138
+ return hf_pipeline
139
+
140
+ def init_embeddings(self) -> None:
141
+ if self.config["embedding"] == EMB_INSTRUCTOR_XL:
142
+ # Local INSTRUCTOR-XL embeddings
143
+ if self.embedding is None:
144
+ self.embedding = PdfQA.create_instructor_xl()
145
+ elif self.config["embedding"] == EMB_SBERT_MPNET_BASE:
146
+ ## this is for SBERT
147
+ if self.embedding is None:
148
+ self.embedding = PdfQA.create_sbert_mpnet()
149
+ else:
150
+ self.embedding = None ## DuckDb uses sbert embeddings
151
+ # raise ValueError("Invalid config")
152
+
153
+ def init_models(self) -> None:
154
+ """ Initialize LLM models based on config """
155
+ load_in_8bit = self.config.get("load_in_8bit",False)
156
+ # OpenAI GPT 3.5 API
157
+ if self.config["llm"] == LLM_FLAN_T5_SMALL:
158
+ if self.llm is None:
159
+ self.llm = PdfQA.create_flan_t5_small(load_in_8bit=load_in_8bit)
160
+ elif self.config["llm"] == LLM_FLAN_T5_BASE:
161
+ if self.llm is None:
162
+ self.llm = PdfQA.create_flan_t5_base(load_in_8bit=load_in_8bit)
163
+ elif self.config["llm"] == LLM_FLAN_T5_LARGE:
164
+ if self.llm is None:
165
+ self.llm = PdfQA.create_flan_t5_large(load_in_8bit=load_in_8bit)
166
+ elif self.config["llm"] == LLM_FLAN_T5_XL:
167
+ if self.llm is None:
168
+ self.llm = PdfQA.create_flan_t5_xl(load_in_8bit=load_in_8bit)
169
+ elif self.config["llm"] == LLM_FLAN_T5_XXL:
170
+ if self.llm is None:
171
+ self.llm = PdfQA.create_flan_t5_xxl(load_in_8bit=load_in_8bit)
172
+ elif self.config["llm"] == LLM_FASTCHAT_T5_XL:
173
+ if self.llm is None:
174
+ self.llm = PdfQA.create_fastchat_t5_xl(load_in_8bit=load_in_8bit)
175
+ elif self.config["llm"] == LLM_FALCON_SMALL:
176
+ if self.llm is None:
177
+ self.llm = PdfQA.create_falcon_instruct_small(load_in_8bit=load_in_8bit)
178
+
179
+ else:
180
+ raise ValueError("Invalid config")
181
+ def vector_db_pdf(self) -> None:
182
+ """
183
+ creates vector db for the embeddings and persists them or loads a vector db from the persist directory
184
+ """
185
+ pdf_path = self.config.get("pdf_path",None)
186
+ persist_directory = self.config.get("persist_directory",None)
187
+ if persist_directory and os.path.exists(persist_directory):
188
+ ## Load from the persist db
189
+ self.vectordb = Chroma(persist_directory=persist_directory, embedding_function=self.embedding)
190
+ elif pdf_path and os.path.exists(pdf_path):
191
+ ## 1. Extract the documents
192
+ loader = PDFPlumberLoader(pdf_path)
193
+ documents = loader.load()
194
+ ## 2. Split the texts
195
+ text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=0)
196
+ texts = text_splitter.split_documents(documents)
197
+ # text_splitter = TokenTextSplitter(chunk_size=100, chunk_overlap=10, encoding_name="cl100k_base") # This the encoding for text-embedding-ada-002
198
+ text_splitter = TokenTextSplitter(chunk_size=100, chunk_overlap=10) # This the encoding for text-embedding-ada-002
199
+ texts = text_splitter.split_documents(texts)
200
+
201
+ ## 3. Create Embeddings and add to chroma store
202
+ ##TODO: Validate if self.embedding is not None
203
+ self.vectordb = Chroma.from_documents(documents=texts, embedding=self.embedding, persist_directory=persist_directory)
204
+ else:
205
+ raise ValueError("NO PDF found")
206
+
207
+ def retreival_qa_chain(self):
208
+ """
209
+ Creates retrieval qa chain using vectordb as retrivar and LLM to complete the prompt
210
+ """
211
+ ##TODO: Use custom prompt
212
+ self.retriever = self.vectordb.as_retriever(search_kwargs={"k":3})
213
+
214
+ hf_llm = HuggingFacePipeline(pipeline=self.llm,model_id=self.config["llm"])
215
+
216
+ self.qa = RetrievalQA.from_chain_type(llm=hf_llm, chain_type="stuff",retriever=self.retriever)
217
+ if self.config["llm"] == LLM_FLAN_T5_SMALL or self.config["llm"] == LLM_FLAN_T5_BASE or self.config["llm"] == LLM_FLAN_T5_LARGE:
218
+ question_t5_template = """
219
+ context: {context}
220
+ question: {question}
221
+ answer:
222
+ """
223
+ QUESTION_T5_PROMPT = PromptTemplate(
224
+ template=question_t5_template, input_variables=["context", "question"]
225
+ )
226
+ self.qa.combine_documents_chain.llm_chain.prompt = QUESTION_T5_PROMPT
227
+ self.qa.combine_documents_chain.verbose = True
228
+ self.qa.return_source_documents = True
229
+ def answer_query(self,question:str) ->str:
230
+ """
231
+ Answer the question
232
+ """
233
+
234
+ answer_dict = self.qa({"query":question,})
235
+ print(answer_dict)
236
+ answer = answer_dict["result"]
237
+ if self.config["llm"] == LLM_FASTCHAT_T5_XL:
238
+ answer = self._clean_fastchat_t5_output(answer)
239
+ return answer
240
+ def _clean_fastchat_t5_output(self, answer: str) -> str:
241
+ # Remove <pad> tags, double spaces, trailing newline
242
+ answer = re.sub(r"<pad>\s+", "", answer)
243
+ answer = re.sub(r" ", " ", answer)
244
+ answer = re.sub(r"\n$", "", answer)
245
+ return answer
246
+
247
+ # Configuration for PdfQA
248
+ config = {"persist_directory":None,
249
+ "load_in_8bit":False,
250
+ "embedding" : EMB_SBERT_MPNET_BASE,
251
+ "llm":LLM_FLAN_T5_BASE,
252
+ "pdf_path":"48lawsofpower.pdf"
253
+ }
254
+ pdfqa = PdfQA(config=config)
255
+ pdfqa.init_embeddings()
256
+ pdfqa.init_models()
257
+
258
+ # Create Vector DB
259
+ pdfqa.vector_db_pdf()
260
+
261
+ # Set up Retrieval QA Chain
262
+ pdfqa.retreival_qa_chain()
263
+
264
  def ask(text):
265
 
266
+ question = text+", tell me in details"
267
+ answer = pdfqa.answer_query(question)
268
+ return answer
269
+
 
 
 
 
 
 
 
 
 
270
  with gr.Blocks() as server:
271
  with gr.Tab("LLM Inferencing"):
272