quGPT / app.py
IC4T
commit
c244053
# Disclamer: This code is not written by me. Its taken from https://github.com/imartinez/privateGPT/pull/91.
# All credit goes to `vnk8071` as I mentioned in the video.
# As this code was still in the pull request while I was creating the video, did some modifications so that it works for me locally.
import os
os.system('pip install ./langchain')
import gradio as gr
from dotenv import load_dotenv
# from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.chains import RetrievalQA
from langchain.embeddings import LlamaCppEmbeddings
# from langchain.llms import GPT4All, LlamaCpp
from langchain.vectorstores import Chroma
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings#, SentenceTransformerEmbeddings
from langchain.prompts.prompt import PromptTemplate
from langchain import PromptTemplate, LLMChain
from langchain.llms import HuggingFacePipeline
from instruct_pipeline import InstructionTextGenerationPipeline
from training.generate import load_model_tokenizer_for_generate
from ctransformers import AutoModelForCausalLM
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.llms import CTransformers
# from training.generate import InstructionTextGenerationPipeline, load_model_tokenizer_for_generate
# from googletrans import Translator
# translator = Translator()
load_dotenv()
embeddings_model_name = os.environ.get("EMBEDDINGS_MODEL_NAME")
persist_directory = os.environ.get('PERSIST_DIRECTORY')
model_type = os.environ.get('MODEL_TYPE')
model_path = os.environ.get('MODEL_PATH')
model_n_ctx = int(os.environ.get('MODEL_N_CTX'))
target_source_chunks = int(os.environ.get('TARGET_SOURCE_CHUNKS',4))
# PERSIST_DIRECTORY=db
# MODEL_TYPE=dolly-v2-3b
# MODEL_PATH=/media/siiva/DataStore/LLMs/cache/dolly-v2-3b
# EMBEDDINGS_MODEL_NAME=sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2
# MODEL_N_CTX=1000
# TARGET_SOURCE_CHUNKS=4
import psutil
from constants import CHROMA_SETTINGS
# embeddings_model_name = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
# persist_directory = "db"
# model_type = "dolly-v2-3b"
# model_path = "/media/siiva/DataStore/LLMs/cache/dolly-v2-3b"
# target_source_chunks = 3
# model_n_ctx = 1000
embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)
db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS)
retriever = db.as_retriever(search_kwargs={"k": target_source_chunks})
# Prepare the LLM
# callbacks = [StreamingStdOutCallbackHandler()]
match model_type:
case "dolly-v2-3b":
# model, tokenizer = load_model_tokenizer_for_generate(model_path)
# generate_text = InstructionTextGenerationPipeline(model=model, tokenizer=tokenizer)
# llm = HuggingFacePipeline(pipeline=generate_text)
# llm = AutoModelForCausalLM.from_pretrained(model_path, model_type='dolly-v2')
# llm = CTransformers(model_path, callbacks=[StreamingStdOutCallbackHandler()])
llm = CTransformers(model=model_path, model_type="dolly-v2", callbacks=[StreamingStdOutCallbackHandler()])
# llm = HuggingFacePipeline(
# pipeline=InstructionTextGenerationPipeline(
# # Return the full text, because this is what the HuggingFacePipeline expects.
# model=model, tokenizer=tokenizer, return_full_text=True, task="text-generation", max_new_tokens=model_n_ctx))#, max_new_tokens=model_n_ctx
# #))
# case "GPT4All":
# llm = GPT4All(model=model_path, n_ctx=model_n_ctx, backend='gptj', callbacks=callbacks, verbose=False)
case _default:
print(f"Model {model_type} not supported!")
exit;
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True)
server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
def clear_history(request: gr.Request):
state = None
return ([], state, "")
def post_process_code(code):
sep = "\n```"
if sep in code:
blocks = code.split(sep)
if len(blocks) % 2 == 1:
for i in range(1, len(blocks), 2):
blocks[i] = blocks[i].replace("\\_", "_")
code = sep.join(blocks)
return code
def post_process_answer(answer, source):
answer += f"<br><br>Source: {source}"
answer = answer.replace("\n", "<br>")
return answer
def predict(
question: str,
# system_content: str,
# embeddings_model_name: str,
# persist_directory: str,
# model_type: str,
# model_path: str,
# model_n_ctx: int,
# target_source_chunks: int,
chatbot: list = [],
history: list = [],
):
# try:
# embeddings_model_name = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
# persist_directory = "db"
# model_type = "dolly-v2-3b"
# model_path = "/media/siiva/DataStore/LLMs/cache/dolly-v2-3b"
# target_source_chunks = 3
# model_n_ctx = 1000
# embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)
# db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS)
# retriever = db.as_retriever(search_kwargs={"k": target_source_chunks})
# # Prepare the LLM
# callbacks = [StreamingStdOutCallbackHandler()]
# match model_type:
# case "dolly-v2-3b":
# model, tokenizer = load_model_tokenizer_for_generate(model_path)
# llm = HuggingFacePipeline(
# pipeline=InstructionTextGenerationPipeline(
# # Return the full text, because this is what the HuggingFacePipeline expects.
# model=model, tokenizer=tokenizer, return_full_text=True, task="text-generation", max_new_tokens=model_n_ctx
# ))
# case "GPT4All":
# llm = GPT4All(model=model_path, n_ctx=model_n_ctx, backend='gptj', callbacks=callbacks, verbose=False)
# case _default:
# print(f"Model {model_type} not supported!")
# exit;
# qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True)
# Get the answer from the chain
# prompt = system_content + f"\n Question: {question}"
prompt = f"{question}"
# res = qa(prompt)
no_input_prompt = PromptTemplate(input_variables=[], template=prompt, dest_language='en')#src_language='id',
no_input_prompt.format()
query = no_input_prompt.translate()
# prompt_trans = translator.translate(prompt, src='en', dest='id')
# print(prompt_trans.text)
# result = qa({"question": query, "chat_history": chat_history})
llm_response = qa(query)
answer, docs = llm_response['result'], llm_response['source_documents']
no_input_prompt = PromptTemplate(input_variables=[], template=answer, dest_language='id')
no_input_prompt.format()
answer = no_input_prompt.translate()
# answer = post_process_answer(answer, docs)
history.append(question)
history.append(answer)
chatbot = [(history[i], history[i + 1]) for i in range(0, len(history), 2)]
return chatbot, history
# except Exception as e:
# history.append("")
# answer = server_error_msg + f" (error_code: 503)"
# history.append(answer)
# chatbot = [(history[i], history[i + 1]) for i in range(0, len(history), 2)]
# return chatbot, history
def get_system_memory():
memory = psutil.virtual_memory()
memory_percent = memory.percent
memory_used = memory.used / (1024.0 ** 3)
memory_total = memory.total / (1024.0 ** 3)
return {"percent": f"{memory_percent}%", "used": f"{memory_used:.3f}GB", "total": f"{memory_total:.3f}GB"}
def reset_textbox():
return gr.update(value="")
title = """<h1 align="center">Chat with QuGPT πŸ€–</h1>"""
# def add_text(history, text):
# history = history + [(text, None)]
# return history, ""
def bot(history):
response = "**That's cool!**"
history[-1][1] = response
return history
with gr.Blocks(
css="""
footer .svelte-1lyswbr {display: none !important;}
#col_container {margin-left: auto; margin-right: auto;}
#chatbot .wrap.svelte-13f7djk {height: 70vh; max-height: 70vh}
#chatbot .message.user.svelte-13f7djk.svelte-13f7djk {width:fit-content; background:orange; border-bottom-right-radius:0}
#chatbot .message.bot.svelte-13f7djk.svelte-13f7djk {width:fit-content; padding-left: 16px; border-bottom-left-radius:0}
#chatbot .pre {border:2px solid white;}
pre {
white-space: pre-wrap; /* Since CSS 2.1 */
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
white-space: -pre-wrap; /* Opera 4-6 */
white-space: -o-pre-wrap; /* Opera 7 */
word-wrap: break-word; /* Internet Explorer 5.5+ */
}
"""
) as demo:
gr.HTML(title)
with gr.Row():
# with gr.Column(elem_id="col_container", scale=0.3):
# with gr.Accordion("Prompt", open=True):
# system_content = gr.Textbox(value="You are QuGPT which built with LangChain and dolly-v2 and sentence-transformer.", show_label=False)
# with gr.Accordion("Config", open=True):
# embeddings_model_name = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"#gr.Textbox(value="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", label="embeddings_model_name")
# persist_directory = "db" #gr.Textbox(value="db", label="persist_directory")
# model_type = "dolly-v2-3b" #gr.Textbox(value="dolly-v2-3b", label="model_type")
# model_path = "/media/siiva/DataStore/LLMs/cache/dolly-v2-3b" #gr.Textbox(value="/media/siiva/DataStore/LLMs/cache/dolly-v2-3b", label="model_path")
# target_source_chunks = 3
# # gr.Slider(
# # minimum=1,
# # maximum=5,
# # value=2,
# # step=1,
# # interactive=True,
# # label="target_source_chunks",
# # )
# model_n_ctx = 1000
# gr.Slider(
# minimum=32,
# maximum=4096,
# value=1000,
# step=32,
# interactive=True,
# label="model_n_ctx",
# )
with gr.Column(elem_id="col_container"):
chatbot = gr.Chatbot(elem_id="chatbot", label="QuGPT")
question = gr.Textbox(placeholder="Ask something", show_label=False, value="")
state = gr.State([])
with gr.Row():
with gr.Column():
submit_btn = gr.Button(value="πŸš€ Send")
with gr.Column():
clear_btn = gr.Button(value="πŸ—‘οΈ Clear history")
# with gr.Column():
# gr.JSON(get_system_memory, every=1)
question.submit(
predict,
# [question, system_content, embeddings_model_name, persist_directory, model_type, model_path, model_n_ctx, target_source_chunks, chatbot, state],
[question, chatbot, state],
[chatbot, state],
)
submit_btn.click(
predict,
# [question, system_content, embeddings_model_name, persist_directory, model_type, model_path, model_n_ctx, target_source_chunks, chatbot, state],
[question, chatbot, state],
[chatbot, state],
)
submit_btn.click(reset_textbox, [], [question])
clear_btn.click(clear_history, None, [chatbot, state, question])
question.submit(reset_textbox, [], [question])
# demo.queue(concurrency_count=10, status_update_rate="auto")
# question.submit(predict, [question, system_content, embeddings_model_name, persist_directory, model_type, model_path, model_n_ctx, target_source_chunks, chatbot, state], [chatbot, state]).then(
# predict, chatbot
# )
# if __name__ == "__main__":
# demo.queue(concurrency_count=1,max_size=100).launch(max_threads=5,debug=True)
#demo.launch(server_name=args.server_name, server_port=args.server_port, share=args.share, debug=args.debug)
# demo.queue(concurrency_count=5).launch()
demo.launch(show_api=False, enable_queue=False)