import json import os import time import uuid from datetime import datetime import gradio as gr import openai from huggingface_hub import HfApi from langchain.document_loaders import PyPDFLoader, \ UnstructuredPDFLoader, PyPDFium2Loader, PyMuPDFLoader, PDFPlumberLoader from knowledge.faiss_handler import create_faiss_index_from_zip, load_faiss_index_from_zip from knowledge.img_handler import process_image, add_markup from llms.chatbot import OpenAIChatBot from llms.embeddings import EMBEDDINGS_MAPPING from utils import make_archive UPLOAD_REPO_ID=os.getenv("UPLOAD_REPO_ID") HF_TOKEN=os.getenv("HF_TOKEN") openai.api_key = os.getenv("OPENAI_API_KEY") openai.api_base == os.getenv("OPENAI_API_BASE") hf_api = HfApi(token=HF_TOKEN) ALL_PDF_LOADERS = [PyPDFLoader, UnstructuredPDFLoader, PyPDFium2Loader, PyMuPDFLoader, PDFPlumberLoader] ALL_EMBEDDINGS = EMBEDDINGS_MAPPING.keys() PDF_LOADER_MAPPING = {loader.__name__: loader for loader in ALL_PDF_LOADERS} ####################################################################################################################### # Host multiple vector database for use ####################################################################################################################### # todo: add this feature in the future INSTRUCTIONS = '''# FAISS Chat: 和本地数据库聊天! ***2023-06-06更新:*** 1. 支持读取图片格式的图表数据(目前支持JPG, PNG). 2. 在"总结图表(Demo)"的标签页里提供了这个模块的测试. ***2023-06-04更新:*** 1. 支持更多的Embedding Model (目前支持[text-embedding-ada-002](https://openai.com/blog/new-and-improved-embedding-model), [text2vec-large-chinese](https://huggingface.co./GanymedeNil/text2vec-large-chinese), 和[distilbert-dot-tas_b-b256-msmarco](https://huggingface.co./sebastian-hofstaetter/distilbert-dot-tas_b-b256-msmarco) ) 2. 支持更多的文件格式(PDF, TXT, TEX, 和MD). 3. 所有生成的数据库都可以在[这个数据集](https://huggingface.co./datasets/shaocongma/shared-faiss-vdb)里访问了!如果不希望文件被上传,可以在高级设置里关闭. ''' def load_zip_as_db(file_from_gradio, pdf_loader, embedding_model, chunk_size=300, chunk_overlap=20, upload_to_cloud=True): if chunk_size <= chunk_overlap: return "chunk_size小于chunk_overlap. 创建失败.", None, None if file_from_gradio is None: return "文件为空. 创建失败.", None, None pdf_loader = PDF_LOADER_MAPPING[pdf_loader] zip_file_path = file_from_gradio.name project_name = uuid.uuid4().hex db, project_name, db_meta = create_faiss_index_from_zip(zip_file_path, embeddings=embedding_model, pdf_loader=pdf_loader, chunk_size=chunk_size, chunk_overlap=chunk_overlap, project_name=project_name) index_name = project_name + ".zip" make_archive(project_name, index_name) date = datetime.today().strftime('%Y-%m-%d') if upload_to_cloud: hf_api.upload_file(path_or_fileobj=index_name, path_in_repo=f"{date}/faiss_{index_name}.zip", repo_id=UPLOAD_REPO_ID, repo_type="dataset") return "成功创建知识库. 可以开始聊天了!", index_name, db, db_meta def load_local_db(file_from_gradio): if file_from_gradio is None: return "文件为空. 创建失败.", None zip_file_path = file_from_gradio.name db = load_faiss_index_from_zip(zip_file_path) return "成功读取知识库. 可以开始聊天了!", db def extract_image(image_path): from PIL import Image print("Image Path:", image_path) im = Image.open(image_path) table = process_image(im) print(f"Success in processing the image. Table: {table}") return table, add_markup(table) def describe(image): table = add_markup(process_image(image)) _INSTRUCTION = 'Read the table below to answer the following questions.' question = "Please refer to the above table, and write a summary of no less than 200 words based on it in Chinese, ensuring that your response is detailed and precise. " prompt_0shot = _INSTRUCTION + "\n" + add_markup(table) + "\n" + "Q: " + question + "\n" + "A:" messages = [{"role": "assistant", "content": prompt_0shot}] response = openai.ChatCompletion.create( model="gpt-3.5-turbo", messages=messages, temperature=0.7, top_p=1, frequency_penalty=0, presence_penalty=0, ) ret = response.choices[0].message['content'] return ret with gr.Blocks() as demo: local_db = gr.State(None) def get_augmented_message(message, local_db, query_count, preprocessing, meta): print(f"Receiving message: {message}") print("Detecting if the user need to read image from the local database...") # read the db_meta.json from the local file # read the images file list files = meta["files"] source_path = meta["source_path"] # with open(meta.name, "r", encoding="utf-8") as f: # files = json.load(f)["files"] img_files = [] for file in files: if os.path.splitext(file)[1] in [".png", ".jpg"]: img_files.append(file) # scan user's input to see if it contains images' name do_extract_image = False target_file = None for file in img_files: img = os.path.splitext(file)[0] if img in message: do_extract_image = True target_file = file break # extract image to tables image_info = "" if do_extract_image: print("The user needs to read image from the local database. Extract image ... ") target_file = os.path.join(source_path, target_file) _, image_info = extract_image(target_file) if len(image_info)>0: image_content = {"content": image_info, "source": os.path.basename(target_file)} else: image_content = None print("Querying references from the local database...") contents = [] try: if query_count > 0: docs = local_db.similarity_search(message, k=query_count) for i in range(query_count): # pre-processing each chunk content = docs[i].page_content.replace('\n', ' ') # pre-process meta data contents.append(content) except: print("Failed to query from the local database. ") # generate augmented_message print("Success in querying references: {}".format(contents)) if image_content is not None: augmented_message = f"{image_content}\n\n---\n\n" + "\n\n---\n\n".join(contents) + "\n\n-----\n\n" else: augmented_message = "\n\n---\n\n".join(contents) + "\n\n-----\n\n" return augmented_message + "\n\n" + f"'user_input': {message}" def respond(message, local_db, chat_history, meta, query_count=5, test_mode=False, response_delay=5, preprocessing=False): gpt_chatbot = OpenAIChatBot() print("Chat History: ", chat_history) print("Local DB: ", local_db is None) for chat in chat_history: gpt_chatbot.load_chat(chat) if local_db is None or query_count == 0: bot_message = gpt_chatbot(message) print(bot_message) print(message) chat_history.append((message, bot_message)) return "", chat_history else: augmented_message = get_augmented_message(message, local_db, query_count, preprocessing, meta) bot_message = gpt_chatbot(augmented_message, original_message=message) print(message) print(augmented_message) print(bot_message) if test_mode: chat_history.append((augmented_message, bot_message)) else: chat_history.append((message, bot_message)) time.sleep(response_delay) # sleep 5 seconds to avoid freq. wall. return "", chat_history with gr.Row(): with gr.Column(): gr.Markdown(INSTRUCTIONS) with gr.Row(): with gr.Tab("从本地PDF文件创建知识库"): zip_file = gr.File(file_types=[".zip"], label="本地PDF文件(.zip)") create_db = gr.Button("创建知识库", variant="primary") with gr.Accordion("高级设置", open=False): embedding_selector = gr.Dropdown(ALL_EMBEDDINGS, value="distilbert-dot-tas_b-b256-msmarco", label="Embedding Models") pdf_loader_selector = gr.Dropdown([loader.__name__ for loader in ALL_PDF_LOADERS], value=PyPDFLoader.__name__, label="PDF Loader") chunk_size_slider = gr.Slider(minimum=50, maximum=2000, step=50, value=500, label="Chunk size (tokens)") chunk_overlap_slider = gr.Slider(minimum=0, maximum=500, step=1, value=50, label="Chunk overlap (tokens)") save_to_cloud_checkbox = gr.Checkbox(value=False, label="把数据库上传到云端") file_dp_output = gr.File(file_types=[".zip"], label="(输出)知识库文件(.zip)") with gr.Tab("读取本地知识库文件"): file_local = gr.File(file_types=[".zip"], label="本地知识库文件(.zip)") load_db = gr.Button("读取已创建知识库", variant="primary") with gr.Tab("总结图表(Demo)"): gr.Markdown(r"代码来源于: https://huggingface.co./spaces/fl399/deplot_plus_llm") input_image = gr.Image(label="Input Image", type="pil", interactive=True) extract = gr.Button("总结", variant="primary") output_text = gr.Textbox(lines=8, label="Output") with gr.Column(): status = gr.Textbox(label="用来显示程序运行状态的Textbox") chatbot = gr.Chatbot() msg = gr.Textbox() submit = gr.Button("Submit", variant="primary") with gr.Accordion("高级设置", open=False): json_output = gr.JSON() with gr.Row(): query_count_slider = gr.Slider(minimum=0, maximum=10, step=1, value=3, label="Query counts") test_mode_checkbox = gr.Checkbox(label="Test mode") # def load_pdf_as_db(file_from_gradio, # pdf_loader, # embedding_model, # chunk_size=300, # chunk_overlap=20, # upload_to_cloud=True): msg.submit(respond, [msg, local_db, chatbot, json_output, query_count_slider, test_mode_checkbox], [msg, chatbot]) submit.click(respond, [msg, local_db, chatbot, json_output, query_count_slider, test_mode_checkbox], [msg, chatbot]) create_db.click(load_zip_as_db, [zip_file, pdf_loader_selector, embedding_selector, chunk_size_slider, chunk_overlap_slider, save_to_cloud_checkbox], [status, file_dp_output, local_db, json_output]) load_db.click(load_local_db, [file_local], [status, local_db]) extract.click(describe, [input_image], [output_text]) demo.launch(show_api=False)