Spaces:
Runtime error
Runtime error
import time | |
import gradio as gr | |
from langchain.docstore.document import Document | |
from langchain.text_splitter import RecursiveCharacterTextSplitter, Language | |
import vector_db as vdb | |
from llm_model import LLMModel | |
chunk_size = 2000 | |
chunk_overlap = 200 | |
uploaded_docs = [] | |
uploaded_df = gr.Dataframe(headers=["file_name", "content_length"]) | |
upload_files_section = gr.Files( | |
file_types=[".md", ".mdx", ".rst", ".txt"], | |
) | |
chatbot_stream = gr.Chatbot(bubble_full_width=False, show_copy_button=True) | |
def load_docs(files): | |
all_docs = [] | |
all_qa = [] | |
for file in files: | |
if file.name is not None: | |
with open(file.name, "r") as f: | |
file_content = f.read() | |
file_name = file.name.split("/")[-1] | |
# Create document with metadata | |
doc = Document(page_content=file_content, metadata={"source": file_name}) | |
# Create an instance of the RecursiveCharacterTextSplitter class with specific parameters. | |
# It splits text into chunks of 1000 characters each with a 150-character overlap. | |
language = get_language(file_name) | |
text_splitter = RecursiveCharacterTextSplitter.from_language( | |
chunk_size=chunk_size, | |
chunk_overlap=chunk_overlap, | |
language=language | |
) | |
# Split the text into chunks using the text splitter. | |
doc_chunks = text_splitter.split_documents([doc]) | |
print(f"Number of chunks: {len(doc_chunks)}") | |
# Foreach chunk, send to LLM to get potential questions and answers | |
for doc_chunk in doc_chunks: | |
gr.Info("Analysing document...") | |
potential_qa_from_doc = llm_model.get_potential_question_answer(doc_chunk.page_content) | |
all_qa += [Document(page_content=potential_qa_from_doc, metadata=doc_chunk.metadata)] | |
all_docs += doc_chunks | |
uploaded_docs.append(file.name) | |
vector_db.load_docs_into_vector_db(all_qa) | |
gr.Info("Loaded document(s) into vector db.") | |
return uploaded_docs | |
def get_language(file_name: str): | |
if file_name.endswith(".md") or file_name.endswith(".mdx"): | |
return Language.MARKDOWN | |
elif file_name.endswith(".rst"): | |
return Language.RST | |
else: | |
return Language.MARKDOWN | |
def get_vector_db(): | |
return vdb.VectorDB() | |
def get_llm_model(_db: vdb.VectorDB): | |
retriever = _db.docs_db.as_retriever(search_kwargs={"k": 2}) | |
# return LLMModel(retriever=retriever).create_qa_chain() | |
return LLMModel(retriever=retriever) | |
def predict(message, history): | |
# resp = llm_model.answer_question_inference(message) | |
# return resp.get("answer") | |
resp = llm_model.answer_question_inference_text_gen(message) | |
for i in range(len(resp)): | |
time.sleep(0.005) | |
yield resp[:i + 1] | |
# final_resp = "" | |
# for c in resp: | |
# final_resp += str(c) | |
# # + "β" | |
# yield final_resp | |
# start_time = time.time() | |
# res = llm_model({"query": message}) | |
# sources = [] | |
# for source_docs in res['source_documents']: | |
# if 'source' in source_docs.metadata: | |
# sources.append(source_docs.metadata['source']) | |
# # Display assistant response in chat message container | |
# end_time = time.time() | |
# time_taken = "{:.2f}".format(end_time - start_time) | |
# format_answer = f"## Result\n\n{res['result']}\n\n### Sources\n\n{sources}\n\nTime taken: {time_taken}s" | |
# format_source = None | |
# for source_docs in res['source_documents']: | |
# if 'source' in source_docs.metadata: | |
# format_source = f"## File: {source_docs.metadata['source']}\n\n{source_docs.page_content}" | |
# | |
# return format_answer | |
def vote(data: gr.LikeData): | |
if data.liked: | |
gr.Info("You upvoted this response π", ) | |
else: | |
gr.Warning("You downvoted this response π") | |
vector_db = get_vector_db() | |
llm_model = get_llm_model(vector_db) | |
chat_interface_stream = gr.ChatInterface( | |
predict, | |
title="π Document answering bot", | |
description="ππ¦ Upload some documents on the side and ask questions!", | |
textbox=gr.Textbox(container=False, scale=7), | |
chatbot=chatbot_stream, | |
examples=["What is Data Caterer?"], | |
).queue(default_concurrency_limit=1) | |
with gr.Blocks() as blocks: | |
with gr.Row(): | |
with gr.Column(scale=1, min_width=100) as upload_col: | |
gr.Interface( | |
load_docs, | |
title="π Upload documents", | |
inputs=upload_files_section, | |
outputs=gr.Files(), | |
allow_flagging="never" | |
) | |
# upload_files_section.upload(load_docs, inputs=upload_files_section) | |
with gr.Column(scale=4, min_width=600) as chat_col: | |
chatbot_stream.like(vote, None, None) | |
chat_interface_stream.render() | |
blocks.queue().launch() | |