ShaoXia commited on
Commit
069157b
·
1 Parent(s): 7abd762

初始化代码

Browse files
README.md CHANGED
@@ -1,12 +1,33 @@
1
  ---
2
- title: Semi-Annual Security Exam
3
- emoji: 👁
4
- colorFrom: blue
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 3.35.2
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Faiss Chat
3
+ emoji: 🐠
4
+ colorFrom: indigo
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 3.32.0
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
+ # FAISS Chat: Chat with FAISS database
14
+
15
+ Webui版本的Langchain-Chat. 目前支持两个功能:
16
+ * 将本地PDF和TXT文件打包上传, 构建FAISS向量数据库.
17
+ * 直接上传本地的FAISS向量数据库.
18
+
19
+
20
+ ## 更新日志
21
+
22
+
23
+ * 2023-06-04:
24
+ * 支持读取图片格式的图表数据(目前支持JPG, PNG)
25
+
26
+ * 2023-06-04:
27
+ * 支持更多文件格式 (目前支持PDF, TXT, MD, TEX)
28
+ * 支持更多Embedding Models (目前支持[text-embedding-ada-002](https://openai.com/blog/new-and-improved-embedding-model), [text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese), 和[distilbert-dot-tas_b-b256-msmarco](https://huggingface.co/sebastian-hofstaetter/distilbert-dot-tas_b-b256-msmarco) )
29
+ * 优化本地知识库文件结构.
30
+
31
+ ## 体验地址
32
+ [Huggingface Space](https://huggingface.co/spaces/shaocongma/faiss_chat)
33
+
app.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import time
4
+ import uuid
5
+ from datetime import datetime
6
+
7
+ import gradio as gr
8
+ import openai
9
+ from huggingface_hub import HfApi
10
+ from langchain.document_loaders import PyPDFLoader, \
11
+ UnstructuredPDFLoader, PyPDFium2Loader, PyMuPDFLoader, PDFPlumberLoader
12
+
13
+ from knowledge.faiss_handler import create_faiss_index_from_zip, load_faiss_index_from_zip
14
+ from knowledge.img_handler import process_image, add_markup
15
+ from llms.chatbot import OpenAIChatBot
16
+ from llms.embeddings import EMBEDDINGS_MAPPING
17
+ from utils import make_archive
18
+
19
+ UPLOAD_REPO_ID=os.getenv("UPLOAD_REPO_ID")
20
+ HF_TOKEN=os.getenv("HF_TOKEN")
21
+ openai.api_key = os.getenv("OPENAI_API_KEY")
22
+ openai.api_base == os.getenv("OPENAI_API_BASE")
23
+ hf_api = HfApi(token=HF_TOKEN)
24
+
25
+ ALL_PDF_LOADERS = [PyPDFLoader, UnstructuredPDFLoader, PyPDFium2Loader, PyMuPDFLoader, PDFPlumberLoader]
26
+ ALL_EMBEDDINGS = EMBEDDINGS_MAPPING.keys()
27
+ PDF_LOADER_MAPPING = {loader.__name__: loader for loader in ALL_PDF_LOADERS}
28
+
29
+
30
+ #######################################################################################################################
31
+ # Host multiple vector database for use
32
+ #######################################################################################################################
33
+ # todo: add this feature in the future
34
+
35
+
36
+
37
+ INSTRUCTIONS = '''# FAISS Chat: 和本地数据库聊天!
38
+
39
+ ***2023-06-06更新:***
40
+ 1. 支持读取图片格式的图表数据(目前支持JPG, PNG).
41
+ 2. 在"总结图表(Demo)"的标签页里提供了这个模块的测试.
42
+
43
+ ***2023-06-04更新:***
44
+ 1. 支持更多的Embedding Model (目前支持[text-embedding-ada-002](https://openai.com/blog/new-and-improved-embedding-model), [text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese), 和[distilbert-dot-tas_b-b256-msmarco](https://huggingface.co/sebastian-hofstaetter/distilbert-dot-tas_b-b256-msmarco) )
45
+ 2. 支持更多的文件格式(PDF, TXT, TEX, 和MD).
46
+ 3. 所有生成的数据库都可以在[这个数据集](https://huggingface.co/datasets/shaocongma/shared-faiss-vdb)里访问了!如果不希望文件被上传,可以在高级设置里关闭.
47
+ '''
48
+
49
+
50
+ def load_zip_as_db(file_from_gradio,
51
+ pdf_loader,
52
+ embedding_model,
53
+ chunk_size=300,
54
+ chunk_overlap=20,
55
+ upload_to_cloud=True):
56
+ if chunk_size <= chunk_overlap:
57
+ return "chunk_size小于chunk_overlap. 创建失败.", None, None
58
+ if file_from_gradio is None:
59
+ return "文件为空. 创建失败.", None, None
60
+ pdf_loader = PDF_LOADER_MAPPING[pdf_loader]
61
+ zip_file_path = file_from_gradio.name
62
+ project_name = uuid.uuid4().hex
63
+ db, project_name, db_meta = create_faiss_index_from_zip(zip_file_path, embeddings=embedding_model,
64
+ pdf_loader=pdf_loader, chunk_size=chunk_size,
65
+ chunk_overlap=chunk_overlap, project_name=project_name)
66
+ index_name = project_name + ".zip"
67
+ make_archive(project_name, index_name)
68
+ date = datetime.today().strftime('%Y-%m-%d')
69
+ if upload_to_cloud:
70
+ hf_api.upload_file(path_or_fileobj=index_name,
71
+ path_in_repo=f"{date}/faiss_{index_name}.zip",
72
+ repo_id=UPLOAD_REPO_ID,
73
+ repo_type="dataset")
74
+ return "成功创建知识库. 可以开始聊天了!", index_name, db, db_meta
75
+
76
+
77
+ def load_local_db(file_from_gradio):
78
+ if file_from_gradio is None:
79
+ return "文件为空. 创建失败.", None
80
+ zip_file_path = file_from_gradio.name
81
+ db = load_faiss_index_from_zip(zip_file_path)
82
+
83
+ return "成功读取知识库. 可以开始聊天了!", db
84
+
85
+
86
+ def extract_image(image_path):
87
+ from PIL import Image
88
+ print("Image Path:", image_path)
89
+ im = Image.open(image_path)
90
+ table = process_image(im)
91
+ print(f"Success in processing the image. Table: {table}")
92
+ return table, add_markup(table)
93
+
94
+
95
+ def describe(image):
96
+ table = add_markup(process_image(image))
97
+ _INSTRUCTION = 'Read the table below to answer the following questions.'
98
+ question = "Please refer to the above table, and write a summary of no less than 200 words based on it in Chinese, ensuring that your response is detailed and precise. "
99
+ prompt_0shot = _INSTRUCTION + "\n" + add_markup(table) + "\n" + "Q: " + question + "\n" + "A:"
100
+
101
+ messages = [{"role": "assistant", "content": prompt_0shot}]
102
+ response = openai.ChatCompletion.create(
103
+ model="gpt-3.5-turbo",
104
+ messages=messages,
105
+ temperature=0.7,
106
+ top_p=1,
107
+ frequency_penalty=0,
108
+ presence_penalty=0,
109
+ )
110
+ ret = response.choices[0].message['content']
111
+ return ret
112
+
113
+
114
+ with gr.Blocks() as demo:
115
+ local_db = gr.State(None)
116
+
117
+ def get_augmented_message(message, local_db, query_count, preprocessing, meta):
118
+ print(f"Receiving message: {message}")
119
+
120
+ print("Detecting if the user need to read image from the local database...")
121
+ # read the db_meta.json from the local file
122
+ # read the images file list
123
+ files = meta["files"]
124
+ source_path = meta["source_path"]
125
+ # with open(meta.name, "r", encoding="utf-8") as f:
126
+ # files = json.load(f)["files"]
127
+ img_files = []
128
+ for file in files:
129
+ if os.path.splitext(file)[1] in [".png", ".jpg"]:
130
+ img_files.append(file)
131
+
132
+ # scan user's input to see if it contains images' name
133
+ do_extract_image = False
134
+ target_file = None
135
+ for file in img_files:
136
+ img = os.path.splitext(file)[0]
137
+ if img in message:
138
+ do_extract_image = True
139
+ target_file = file
140
+ break
141
+
142
+ # extract image to tables
143
+ image_info = ""
144
+ if do_extract_image:
145
+ print("The user needs to read image from the local database. Extract image ... ")
146
+ target_file = os.path.join(source_path, target_file)
147
+ _, image_info = extract_image(target_file)
148
+ if len(image_info)>0:
149
+ image_content = {"content": image_info, "source": os.path.basename(target_file)}
150
+ else:
151
+ image_content = None
152
+
153
+ print("Querying references from the local database...")
154
+ contents = []
155
+ try:
156
+ if query_count > 0:
157
+ docs = local_db.similarity_search(message, k=query_count)
158
+ for i in range(query_count):
159
+ # pre-processing each chunk
160
+ content = docs[i].page_content.replace('\n', ' ')
161
+ # pre-process meta data
162
+ contents.append(content)
163
+ except:
164
+ print("Failed to query from the local database. ")
165
+ # generate augmented_message
166
+ print("Success in querying references: {}".format(contents))
167
+ if image_content is not None:
168
+ augmented_message = f"{image_content}\n\n---\n\n" + "\n\n---\n\n".join(contents) + "\n\n-----\n\n"
169
+ else:
170
+ augmented_message = "\n\n---\n\n".join(contents) + "\n\n-----\n\n"
171
+ return augmented_message + "\n\n" + f"'user_input': {message}"
172
+
173
+
174
+ def respond(message, local_db, chat_history, meta, query_count=5, test_mode=False, response_delay=5, preprocessing=False):
175
+ gpt_chatbot = OpenAIChatBot()
176
+ print("Chat History: ", chat_history)
177
+ print("Local DB: ", local_db is None)
178
+ for chat in chat_history:
179
+ gpt_chatbot.load_chat(chat)
180
+ if local_db is None or query_count == 0:
181
+ bot_message = gpt_chatbot(message)
182
+ print(bot_message)
183
+ print(message)
184
+ chat_history.append((message, bot_message))
185
+ return "", chat_history
186
+ else:
187
+ augmented_message = get_augmented_message(message, local_db, query_count, preprocessing, meta)
188
+ bot_message = gpt_chatbot(augmented_message, original_message=message)
189
+ print(message)
190
+ print(augmented_message)
191
+ print(bot_message)
192
+ if test_mode:
193
+ chat_history.append((augmented_message, bot_message))
194
+ else:
195
+ chat_history.append((message, bot_message))
196
+ time.sleep(response_delay) # sleep 5 seconds to avoid freq. wall.
197
+ return "", chat_history
198
+
199
+ with gr.Row():
200
+ with gr.Column():
201
+ gr.Markdown(INSTRUCTIONS)
202
+
203
+ with gr.Row():
204
+ with gr.Tab("从本地PDF文件创建知识库"):
205
+ zip_file = gr.File(file_types=[".zip"], label="本地PDF文件(.zip)")
206
+ create_db = gr.Button("创建知识库", variant="primary")
207
+ with gr.Accordion("高级设置", open=False):
208
+ embedding_selector = gr.Dropdown(ALL_EMBEDDINGS,
209
+ value="distilbert-dot-tas_b-b256-msmarco",
210
+ label="Embedding Models")
211
+ pdf_loader_selector = gr.Dropdown([loader.__name__ for loader in ALL_PDF_LOADERS],
212
+ value=PyPDFLoader.__name__, label="PDF Loader")
213
+ chunk_size_slider = gr.Slider(minimum=50, maximum=2000, step=50, value=500,
214
+ label="Chunk size (tokens)")
215
+ chunk_overlap_slider = gr.Slider(minimum=0, maximum=500, step=1, value=50,
216
+ label="Chunk overlap (tokens)")
217
+ save_to_cloud_checkbox = gr.Checkbox(value=False, label="把数据库上传到云端")
218
+
219
+
220
+ file_dp_output = gr.File(file_types=[".zip"], label="(输出)知识库文件(.zip)")
221
+ with gr.Tab("读取本地知识库文件"):
222
+ file_local = gr.File(file_types=[".zip"], label="本地知识库文件(.zip)")
223
+ load_db = gr.Button("读取已创建知识库", variant="primary")
224
+
225
+ with gr.Tab("总结图表(Demo)"):
226
+ gr.Markdown(r"代码来源于: https://huggingface.co/spaces/fl399/deplot_plus_llm")
227
+ input_image = gr.Image(label="Input Image", type="pil", interactive=True)
228
+ extract = gr.Button("总结", variant="primary")
229
+
230
+ output_text = gr.Textbox(lines=8, label="Output")
231
+
232
+
233
+
234
+
235
+ with gr.Column():
236
+ status = gr.Textbox(label="用来显示程序运行状态的Textbox")
237
+ chatbot = gr.Chatbot()
238
+
239
+ msg = gr.Textbox()
240
+ submit = gr.Button("Submit", variant="primary")
241
+ with gr.Accordion("高级设置", open=False):
242
+ json_output = gr.JSON()
243
+ with gr.Row():
244
+ query_count_slider = gr.Slider(minimum=0, maximum=10, step=1, value=3,
245
+ label="Query counts")
246
+ test_mode_checkbox = gr.Checkbox(label="Test mode")
247
+
248
+
249
+ # def load_pdf_as_db(file_from_gradio,
250
+ # pdf_loader,
251
+ # embedding_model,
252
+ # chunk_size=300,
253
+ # chunk_overlap=20,
254
+ # upload_to_cloud=True):
255
+ msg.submit(respond, [msg, local_db, chatbot, json_output, query_count_slider, test_mode_checkbox], [msg, chatbot])
256
+ submit.click(respond, [msg, local_db, chatbot, json_output, query_count_slider, test_mode_checkbox], [msg, chatbot])
257
+
258
+ create_db.click(load_zip_as_db, [zip_file, pdf_loader_selector, embedding_selector, chunk_size_slider, chunk_overlap_slider, save_to_cloud_checkbox],
259
+ [status, file_dp_output, local_db, json_output])
260
+ load_db.click(load_local_db, [file_local], [status, local_db])
261
+
262
+ extract.click(describe, [input_image], [output_text])
263
+
264
+ demo.launch(show_api=False)
knowledge/__init__.py ADDED
File without changes
knowledge/faiss_handler.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import uuid
3
+ from langchain.vectorstores import FAISS
4
+ import os
5
+ from tqdm.auto import tqdm
6
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
7
+ from langchain.document_loaders import DirectoryLoader, TextLoader
8
+ from llms.embeddings import EMBEDDINGS_MAPPING
9
+ import tiktoken
10
+ import zipfile
11
+ import pickle
12
+
13
+ tokenizer_name = tiktoken.encoding_for_model('gpt-4')
14
+ tokenizer = tiktoken.get_encoding(tokenizer_name.name)
15
+ EMBED_MODEL = "text-embedding-ada-002"
16
+ EMBED_DIM = 1536
17
+ METRIC = 'cosine'
18
+
19
+ #######################################################################################################################
20
+ # Files handler
21
+ #######################################################################################################################
22
+ def check_existence(path):
23
+ return os.path.isfile(path) or os.path.isdir(path)
24
+
25
+
26
+ def list_files(directory, ext=".pdf"):
27
+ # List all files in the directory
28
+ files_in_directory = os.listdir(directory)
29
+ # Filter the list to only include PDF files
30
+ files_list = [file for file in files_in_directory if file.endswith(ext)]
31
+ return files_list
32
+
33
+
34
+ def list_pdf_files(directory):
35
+ # List all files in the directory
36
+ files_in_directory = os.listdir(directory)
37
+ # Filter the list to only include PDF files
38
+ pdf_files = [file for file in files_in_directory if file.endswith(".pdf")]
39
+ return pdf_files
40
+
41
+
42
+
43
+ def tiktoken_len(text):
44
+ # evaluate how many tokens for the given text
45
+ tokens = tokenizer.encode(text, disallowed_special=())
46
+ return len(tokens)
47
+
48
+
49
+ def get_chunks(docs, chunk_size=500, chunk_overlap=20, length_function=tiktoken_len):
50
+ # docs should be the output of `loader.load()`
51
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size,
52
+ chunk_overlap=chunk_overlap,
53
+ length_function=length_function,
54
+ separators=["\n\n", "\n", " ", ""])
55
+ chunks = []
56
+ for idx, page in enumerate(tqdm(docs)):
57
+ source = page.metadata.get('source')
58
+ content = page.page_content
59
+ if len(content) > 50:
60
+ texts = text_splitter.split_text(content)
61
+ chunks.extend([str({'content': texts[i], 'chunk': i, 'source': os.path.basename(source)}) for i in
62
+ range(len(texts))])
63
+ return chunks
64
+
65
+
66
+ #######################################################################################################################
67
+ # Create FAISS object
68
+ #######################################################################################################################
69
+
70
+ # ["text-embedding-ada-002", "distilbert-dot-tas_b-b256-msmarco"]
71
+
72
+ def create_faiss_index_from_zip(path_to_zip_file, embeddings=None, pdf_loader=None,
73
+ chunk_size=500, chunk_overlap=20,
74
+ project_name="Very_Cool_Project_Name"):
75
+ # initialize the file structure
76
+ # structure: project_name
77
+ # - source data
78
+ # - embeddings
79
+ # - faiss_index
80
+ if isinstance(embeddings, str):
81
+ import copy
82
+ embeddings_str = copy.deepcopy(embeddings)
83
+ else:
84
+ embeddings_str = "other-embedding-model"
85
+
86
+ if embeddings is None or embeddings == "text-embedding-ada-002":
87
+ embeddings = EMBEDDINGS_MAPPING["text-embedding-ada-002"]
88
+ elif isinstance(embeddings, str):
89
+ embeddings = EMBEDDINGS_MAPPING[embeddings]
90
+ else:
91
+ embeddings = EMBEDDINGS_MAPPING["text-embedding-ada-002"]
92
+ # STEP 1:
93
+ # Create a folder f"{project_name}" in the current directory.
94
+ current_directory = os.getcwd()
95
+ if not os.path.exists(project_name):
96
+ os.makedirs(project_name)
97
+ project_path = os.path.join(current_directory, project_name)
98
+ source_data = os.path.join(project_path, "source_data")
99
+ embeddings_data = os.path.join(project_path, "embeddings")
100
+ index_data = os.path.join(project_path, "faiss_index")
101
+ os.makedirs(source_data) #./project/source_data
102
+ os.makedirs(embeddings_data) #./project/embeddings
103
+ os.makedirs(index_data) #./project/faiss_index
104
+ else:
105
+ raise ValueError(f"The project {project_name} exists.")
106
+ with zipfile.ZipFile(path_to_zip_file, 'r') as zip_ref:
107
+ # extract everything to "source_data"
108
+ zip_ref.extractall(source_data)
109
+
110
+
111
+ db_meta = {"project_name": project_name,
112
+ "pdf_loader": pdf_loader.__name__, "chunk_size": chunk_size,
113
+ "chunk_overlap": chunk_overlap,
114
+ "embedding_model": embeddings_str,
115
+ "files": os.listdir(source_data),
116
+ "source_path": source_data}
117
+ with open(os.path.join(project_path, "db_meta.json"), "w", encoding="utf-8") as f:
118
+ # save db_meta.json to folder
119
+ json.dump(db_meta, f)
120
+
121
+
122
+ all_docs = []
123
+ for ext in [".txt", ".tex", ".md", ".pdf"]:
124
+ if ext in [".txt", ".tex", ".md"]:
125
+ loader = DirectoryLoader(source_data, glob=f"**/*{ext}", loader_cls=TextLoader,
126
+ loader_kwargs={'autodetect_encoding': True})
127
+ elif ext in [".pdf"]:
128
+ loader = DirectoryLoader(source_data, glob=f"**/*{ext}", loader_cls=pdf_loader)
129
+ else:
130
+ continue
131
+ docs = loader.load()
132
+ all_docs = all_docs + docs
133
+
134
+ # split pdf files into chunks and evaluate its embeddings; save all results into embeddings
135
+ chunks = get_chunks(all_docs, chunk_size, chunk_overlap)
136
+ text_embeddings = embeddings.embed_documents(chunks)
137
+ text_embedding_pairs = list(zip(chunks, text_embeddings))
138
+ embeddings_save_to = os.path.join(embeddings_data, 'text_embedding_pairs.pickle')
139
+ with open(embeddings_save_to, 'wb') as handle:
140
+ pickle.dump(text_embedding_pairs, handle, protocol=pickle.HIGHEST_PROTOCOL)
141
+ db = FAISS.from_embeddings(text_embedding_pairs, embeddings)
142
+
143
+ db.save_local(index_data)
144
+ print(db_meta)
145
+ print("Success!")
146
+ return db, project_name, db_meta
147
+
148
+
149
+ def find_file(file_name, directory):
150
+ for root, dirs, files in os.walk(directory):
151
+ if file_name in files:
152
+ return os.path.join(root, file_name)
153
+ return None # If the file was not found
154
+
155
+ def find_file_dir(file_name, directory):
156
+ for root, dirs, files in os.walk(directory):
157
+ if file_name in files:
158
+ return root # return the directory instead of the full path
159
+ return None # If the file was not found
160
+
161
+
162
+ def load_faiss_index_from_zip(path_to_zip_file):
163
+ # Extract the zip file. Read the db_meta
164
+ # base_name = os.path.basename(path_to_zip_file)
165
+ path_to_extract = os.path.join(os.getcwd())
166
+ with zipfile.ZipFile(path_to_zip_file, 'r') as zip_ref:
167
+ zip_ref.extractall(path_to_extract)
168
+
169
+ db_meta_json = find_file("db_meta.json" , path_to_extract)
170
+ if db_meta_json is not None:
171
+ with open(db_meta_json, "r", encoding="utf-8") as f:
172
+ db_meta_dict = json.load(f)
173
+ else:
174
+ raise ValueError("Cannot find `db_meta.json` in the .zip file. ")
175
+
176
+ try:
177
+ embeddings = EMBEDDINGS_MAPPING[db_meta_dict["embedding_model"]]
178
+ except:
179
+ from langchain.embeddings.openai import OpenAIEmbeddings
180
+ embeddings = OpenAIEmbeddings(model="text-embedding-ada-002")
181
+
182
+ # locate index.faiss
183
+ index_path = find_file_dir("index.faiss", path_to_extract)
184
+ if index_path is not None:
185
+ db = FAISS.load_local(index_path, embeddings)
186
+ return db
187
+ else:
188
+ raise ValueError("Failed to find `index.faiss` in the .zip file.")
189
+
190
+
191
+ if __name__ == "__main__":
192
+ from langchain.document_loaders import PyPDFLoader
193
+ from langchain.embeddings.openai import OpenAIEmbeddings
194
+ from langchain.embeddings import HuggingFaceEmbeddings
195
+
196
+ model_name = "sebastian-hofstaetter/distilbert-dot-tas_b-b256-msmarco"
197
+ model_kwargs = {'device': 'cpu'}
198
+ encode_kwargs = {'normalize_embeddings': False}
199
+ embeddings = HuggingFaceEmbeddings(
200
+ model_name=model_name,
201
+ model_kwargs=model_kwargs,
202
+ encode_kwargs=encode_kwargs)
203
+ create_faiss_index_from_zip(path_to_zip_file="document.zip", pdf_loader=PyPDFLoader, embeddings=embeddings)
knowledge/img_handler.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
3
+ from PIL import Image
4
+
5
+ if torch.cuda.is_available():
6
+ device = "cuda"
7
+ else:
8
+ device = "cpu"
9
+
10
+ model_deplot = Pix2StructForConditionalGeneration.from_pretrained("google/deplot", torch_dtype=torch.bfloat16)
11
+ if device == "cuda":
12
+ model_deplot = model_deplot.to(0)
13
+ processor_deplot = Pix2StructProcessor.from_pretrained("google/deplot")
14
+
15
+
16
+
17
+ def add_markup(table):
18
+ try:
19
+ parts = [p.strip() for p in table.splitlines(keepends=False)]
20
+ if parts[0].startswith('TITLE'):
21
+ result = f"Title: {parts[0].split(' | ')[1].strip()}\n"
22
+ rows = parts[1:]
23
+ else:
24
+ result = ''
25
+ rows = parts
26
+ prefixes = ['Header: '] + [f'Row {i+1}: ' for i in range(len(rows) - 1)]
27
+ return result + '\n'.join(prefix + row for prefix, row in zip(prefixes, rows))
28
+ except:
29
+ # just use the raw table if parsing fails
30
+ return table
31
+
32
+ def process_image(image):
33
+ inputs = processor_deplot(images=image, text="Generate the underlying data table for the figure below:",
34
+ return_tensors="pt").to(torch.bfloat16)
35
+ if device == "cuda":
36
+ inputs = inputs.to(0)
37
+ predictions = model_deplot.generate(**inputs, max_new_tokens=512)
38
+ table = processor_deplot.decode(predictions[0], skip_special_tokens=True).replace("<0x0A>", "\n")
39
+ return table
40
+
41
+
42
+ if __name__ == "__main__":
43
+ im = Image.open(r"meat-image.png")
44
+ process_image(im)
llms/__init__.py ADDED
File without changes
llms/chatbot.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import openai
2
+ import copy
3
+
4
+ class OpenAIChatBot:
5
+ def __init__(self, model="gpt-3.5-turbo"):
6
+ self.system = "You are Q&A bot. A highly intelligent system that answers user questions based on the information provided by the user's local database. " \
7
+ "User's question will include some references information above his question." \
8
+ "You need to answer user's question based on the provided references and inform the user what is the source of that reference. " \
9
+ "If you cannot find answer in the provided references, you still need to answer user's question but you also need to notice the user that your response is not based on the provided references."
10
+ self.model = model
11
+ self.message = [{"role": "system", "content": self.system}]
12
+ self.raw_message = [{"role": "system", "content": self.system}]
13
+
14
+ def load_message(self, message, role, original_message=None):
15
+ if original_message is None:
16
+ original_message = message
17
+ msg = {"role": role, "content": message}
18
+ self.message.append(msg)
19
+ msg = {"role": role, "content": original_message}
20
+ self.raw_message.append(msg)
21
+
22
+ def load_chat(self, chat):
23
+ msg = {"role": "user", "content": chat[0]}
24
+ self.message.append(msg)
25
+ msg = {"role": "assistant", "content": chat[1]}
26
+ self.raw_message.append(msg)
27
+
28
+
29
+ def __call__(self, message, original_message = None):
30
+ self.load_message(message, "user", original_message)
31
+ augmented_message = copy.deepcopy(self.message)
32
+ completion = openai.ChatCompletion.create(
33
+ model=self.model,
34
+ messages=augmented_message
35
+ )
36
+ assistant_message = completion.choices[0].message
37
+ return assistant_message["content"]
llms/embeddings.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.embeddings.openai import OpenAIEmbeddings
2
+ from langchain.embeddings import HuggingFaceEmbeddings
3
+
4
+ model_name = "sebastian-hofstaetter/distilbert-dot-tas_b-b256-msmarco"
5
+ model_kwargs = {'device': 'cpu'}
6
+ encode_kwargs = {'normalize_embeddings': False}
7
+ hf_embeddings_1 = HuggingFaceEmbeddings(
8
+ model_name=model_name,
9
+ model_kwargs=model_kwargs,
10
+ encode_kwargs=encode_kwargs)
11
+
12
+ openai_embedding = OpenAIEmbeddings(model="text-embedding-ada-002")
13
+
14
+
15
+ model_name = "GanymedeNil/text2vec-large-chinese"
16
+ hf_embeddings_2 = HuggingFaceEmbeddings(
17
+ model_name=model_name,
18
+ model_kwargs=model_kwargs,
19
+ encode_kwargs=encode_kwargs)
20
+
21
+
22
+ EMBEDDINGS_MAPPING = {"text-embedding-ada-002": openai_embedding,
23
+ "distilbert-dot-tas_b-b256-msmarco": hf_embeddings_1,
24
+ "text2vec-large-chinese": hf_embeddings_2}
25
+
26
+ def main():
27
+ pass
28
+
29
+ if __name__ == "__main__":
30
+ main()
llms/tools.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import openai
2
+
3
+
4
+ class BaseTool:
5
+ def __init__(self, model="gpt-3.5-turbo"):
6
+ self.system = ""
7
+ self.model = model
8
+ self.message = [{"role": "system", "content": self.system}]
9
+
10
+ def __call__(self, message):
11
+ user_message = {"role": "user", "content": message}
12
+ messages = self.message + [user_message]
13
+ completion = openai.ChatCompletion.create(
14
+ model=self.model,
15
+ messages=messages
16
+ )
17
+ assistant_message = completion.choices[0].message
18
+ return assistant_message["content"].replace("\n", " ")
19
+
20
+
21
+
22
+
23
+ class PreprocessingBot(BaseTool):
24
+ def __init__(self, model="gpt-3.5-turbo"):
25
+ super().__init__(model)
26
+ self.system = r"""You are an AI assistant for raw data pre-processing. The user will input multiple raw references which may include unicode characters or ASCII code such as '\u001e'. Your task it to make it more readable by doing:
27
+ - Change all unicode characters or ASCII code such as '\u001e' to LaTeX format and put them in formula environment $...$ or $$...$$.
28
+ - Re-write formulas or mathematical notations to LaTeX format in formula environment $...$ or $$...$$.
29
+ - Remove meaningless contents.
30
+ - Response in the following format: {pdf-name-1: main contents from pdf-name-1, pdf-name-2: main contents from pdf-name-2, ...}.
31
+ """
32
+ self.message = [{"role": "system", "content": self.system}]
33
+
34
+ class ToolBot(BaseTool):
35
+ def __init__(self, model="gpt-3.5-turbo"):
36
+ super().__init__(model)
37
+ self.system = r"""You need to pretend a Python function. You receive a string that is the user's question to a QA bot. You need to analyze the user's goal and decide if the QA bot needs to use the search engine to generate the response to the user.
38
+ Response 1 if you think the QA bot needs to use the search engine to user's input and response 0 if the QA bot doesn't need that.
39
+ """
40
+ self.message = [{"role": "system", "content": self.system}]
41
+
42
+ if __name__ == "__main__":
43
+ bot = ToolBot()
44
+ rsp = bot("Hello!")
45
+ print(rsp)
requirements.txt ADDED
Binary file (4.19 kB). View file
 
utils.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+
4
+ def make_archive(source, destination):
5
+ base = os.path.basename(destination)
6
+ name = base.split('.')[0]
7
+ format = base.split('.')[1]
8
+ archive_from = os.path.dirname(source)
9
+ archive_to = os.path.basename(source.strip(os.sep))
10
+ shutil.make_archive(name, format, archive_from, archive_to)
11
+ shutil.move('%s.%s'%(name,format), destination)
12
+ return destination