import time import gradio as gr from langchain.docstore.document import Document from langchain.text_splitter import RecursiveCharacterTextSplitter, Language import vector_db as vdb from llm_model import LLMModel chunk_size = 2000 chunk_overlap = 200 uploaded_docs = [] uploaded_df = gr.Dataframe(headers=["file_name", "content_length"]) upload_files_section = gr.Files( file_types=[".md", ".mdx", ".rst", ".txt"], ) chatbot_stream = gr.Chatbot(bubble_full_width=False, show_copy_button=True) def load_docs(files): all_docs = [] all_qa = [] for file in files: if file.name is not None: with open(file.name, "r") as f: file_content = f.read() file_name = file.name.split("/")[-1] # Create document with metadata doc = Document(page_content=file_content, metadata={"source": file_name}) # Create an instance of the RecursiveCharacterTextSplitter class with specific parameters. # It splits text into chunks of 1000 characters each with a 150-character overlap. language = get_language(file_name) text_splitter = RecursiveCharacterTextSplitter.from_language( chunk_size=chunk_size, chunk_overlap=chunk_overlap, language=language ) # Split the text into chunks using the text splitter. doc_chunks = text_splitter.split_documents([doc]) print(f"Number of chunks: {len(doc_chunks)}") # Foreach chunk, send to LLM to get potential questions and answers for doc_chunk in doc_chunks: gr.Info("Analysing document...") potential_qa_from_doc = llm_model.get_potential_question_answer(doc_chunk.page_content) all_qa += [Document(page_content=potential_qa_from_doc, metadata=doc_chunk.metadata)] all_docs += doc_chunks uploaded_docs.append(file.name) vector_db.load_docs_into_vector_db(all_qa) gr.Info("Loaded document(s) into vector db.") return uploaded_docs def get_language(file_name: str): if file_name.endswith(".md") or file_name.endswith(".mdx"): return Language.MARKDOWN elif file_name.endswith(".rst"): return Language.RST else: return Language.MARKDOWN def get_vector_db(): return vdb.VectorDB() def get_llm_model(_db: vdb.VectorDB): retriever = _db.docs_db.as_retriever(search_kwargs={"k": 2}) # return LLMModel(retriever=retriever).create_qa_chain() return LLMModel(retriever=retriever) def predict(message, history): # resp = llm_model.answer_question_inference(message) # return resp.get("answer") resp = llm_model.answer_question_inference_text_gen(message) for i in range(len(resp)): time.sleep(0.005) yield resp[:i + 1] # final_resp = "" # for c in resp: # final_resp += str(c) # # + "▌" # yield final_resp # start_time = time.time() # res = llm_model({"query": message}) # sources = [] # for source_docs in res['source_documents']: # if 'source' in source_docs.metadata: # sources.append(source_docs.metadata['source']) # # Display assistant response in chat message container # end_time = time.time() # time_taken = "{:.2f}".format(end_time - start_time) # format_answer = f"## Result\n\n{res['result']}\n\n### Sources\n\n{sources}\n\nTime taken: {time_taken}s" # format_source = None # for source_docs in res['source_documents']: # if 'source' in source_docs.metadata: # format_source = f"## File: {source_docs.metadata['source']}\n\n{source_docs.page_content}" # # return format_answer def vote(data: gr.LikeData): if data.liked: gr.Info("You upvoted this response 😊", ) else: gr.Warning("You downvoted this response 👀") vector_db = get_vector_db() llm_model = get_llm_model(vector_db) chat_interface_stream = gr.ChatInterface( predict, title="👀 Document answering bot", description="📚🔦 Upload some documents on the side and ask questions!", textbox=gr.Textbox(container=False, scale=7), chatbot=chatbot_stream, examples=["What is Data Caterer?"], ).queue(default_concurrency_limit=1) with gr.Blocks() as blocks: with gr.Row(): with gr.Column(scale=1, min_width=100) as upload_col: gr.Interface( load_docs, title="📖 Upload documents", inputs=upload_files_section, outputs=gr.Files(), allow_flagging="never" ) # upload_files_section.upload(load_docs, inputs=upload_files_section) with gr.Column(scale=4, min_width=600) as chat_col: chatbot_stream.like(vote, None, None) chat_interface_stream.render() blocks.queue().launch()