Spaces:
Running
Running
# Disclamer: This code is not written by me. Its taken from https://github.com/imartinez/privateGPT/pull/91. | |
# All credit goes to `vnk8071` as I mentioned in the video. | |
# As this code was still in the pull request while I was creating the video, did some modifications so that it works for me locally. | |
import os | |
os.system('pip install ./langchain') | |
import gradio as gr | |
from dotenv import load_dotenv | |
# from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler | |
from langchain.chains import RetrievalQA | |
from langchain.embeddings import LlamaCppEmbeddings | |
# from langchain.llms import GPT4All, LlamaCpp | |
from langchain.vectorstores import Chroma | |
from langchain.embeddings.huggingface import HuggingFaceEmbeddings | |
from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings#, SentenceTransformerEmbeddings | |
from langchain.prompts.prompt import PromptTemplate | |
from langchain import PromptTemplate, LLMChain | |
from langchain.llms import HuggingFacePipeline | |
from training.generate import InstructionTextGenerationPipeline, load_model_tokenizer_for_generate | |
# from googletrans import Translator | |
# translator = Translator() | |
load_dotenv() | |
embeddings_model_name = os.environ.get("EMBEDDINGS_MODEL_NAME") | |
persist_directory = os.environ.get('PERSIST_DIRECTORY') | |
model_type = os.environ.get('MODEL_TYPE') | |
model_path = os.environ.get('MODEL_PATH') | |
model_n_ctx = int(os.environ.get('MODEL_N_CTX')) | |
target_source_chunks = int(os.environ.get('TARGET_SOURCE_CHUNKS',4)) | |
# PERSIST_DIRECTORY=db | |
# MODEL_TYPE=dolly-v2-3b | |
# MODEL_PATH=/media/siiva/DataStore/LLMs/cache/dolly-v2-3b | |
# EMBEDDINGS_MODEL_NAME=sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2 | |
# MODEL_N_CTX=1000 | |
# TARGET_SOURCE_CHUNKS=4 | |
from constants import CHROMA_SETTINGS | |
# embeddings_model_name = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" | |
# persist_directory = "db" | |
# model_type = "dolly-v2-3b" | |
# model_path = "/media/siiva/DataStore/LLMs/cache/dolly-v2-3b" | |
# target_source_chunks = 3 | |
# model_n_ctx = 1000 | |
embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name) | |
db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS) | |
retriever = db.as_retriever(search_kwargs={"k": target_source_chunks}) | |
# Prepare the LLM | |
# callbacks = [StreamingStdOutCallbackHandler()] | |
match model_type: | |
case "dolly-v2-3b": | |
model, tokenizer = load_model_tokenizer_for_generate(model_path) | |
llm = HuggingFacePipeline( | |
pipeline=InstructionTextGenerationPipeline( | |
# Return the full text, because this is what the HuggingFacePipeline expects. | |
model=model, tokenizer=tokenizer, return_full_text=True, task="text-generation", max_new_tokens=model_n_ctx))#, max_new_tokens=model_n_ctx | |
#)) | |
# case "GPT4All": | |
# llm = GPT4All(model=model_path, n_ctx=model_n_ctx, backend='gptj', callbacks=callbacks, verbose=False) | |
case _default: | |
print(f"Model {model_type} not supported!") | |
exit; | |
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True) | |
server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" | |
def clear_history(request: gr.Request): | |
state = None | |
return ([], state, "") | |
def post_process_code(code): | |
sep = "\n```" | |
if sep in code: | |
blocks = code.split(sep) | |
if len(blocks) % 2 == 1: | |
for i in range(1, len(blocks), 2): | |
blocks[i] = blocks[i].replace("\\_", "_") | |
code = sep.join(blocks) | |
return code | |
def post_process_answer(answer, source): | |
answer += f"<br><br>Source: {source}" | |
answer = answer.replace("\n", "<br>") | |
return answer | |
def predict( | |
question: str, | |
# system_content: str, | |
# embeddings_model_name: str, | |
# persist_directory: str, | |
# model_type: str, | |
# model_path: str, | |
# model_n_ctx: int, | |
# target_source_chunks: int, | |
chatbot: list = [], | |
history: list = [], | |
): | |
# try: | |
# embeddings_model_name = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" | |
# persist_directory = "db" | |
# model_type = "dolly-v2-3b" | |
# model_path = "/media/siiva/DataStore/LLMs/cache/dolly-v2-3b" | |
# target_source_chunks = 3 | |
# model_n_ctx = 1000 | |
# embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name) | |
# db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS) | |
# retriever = db.as_retriever(search_kwargs={"k": target_source_chunks}) | |
# # Prepare the LLM | |
# callbacks = [StreamingStdOutCallbackHandler()] | |
# match model_type: | |
# case "dolly-v2-3b": | |
# model, tokenizer = load_model_tokenizer_for_generate(model_path) | |
# llm = HuggingFacePipeline( | |
# pipeline=InstructionTextGenerationPipeline( | |
# # Return the full text, because this is what the HuggingFacePipeline expects. | |
# model=model, tokenizer=tokenizer, return_full_text=True, task="text-generation", max_new_tokens=model_n_ctx | |
# )) | |
# case "GPT4All": | |
# llm = GPT4All(model=model_path, n_ctx=model_n_ctx, backend='gptj', callbacks=callbacks, verbose=False) | |
# case _default: | |
# print(f"Model {model_type} not supported!") | |
# exit; | |
# qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True) | |
# Get the answer from the chain | |
# prompt = system_content + f"\n Question: {question}" | |
prompt = f"{question}" | |
# res = qa(prompt) | |
no_input_prompt = PromptTemplate(input_variables=[], template=prompt, dest_language='en')#src_language='id', | |
no_input_prompt.format() | |
query = no_input_prompt.translate() | |
# prompt_trans = translator.translate(prompt, src='en', dest='id') | |
# print(prompt_trans.text) | |
# result = qa({"question": query, "chat_history": chat_history}) | |
llm_response = qa(query) | |
answer, docs = llm_response['result'], llm_response['source_documents'] | |
no_input_prompt = PromptTemplate(input_variables=[], template=answer, dest_language='id') | |
no_input_prompt.format() | |
answer = no_input_prompt.translate() | |
# answer = post_process_answer(answer, docs) | |
history.append(question) | |
history.append(answer) | |
chatbot = [(history[i], history[i + 1]) for i in range(0, len(history), 2)] | |
return chatbot, history | |
# except Exception as e: | |
# history.append("") | |
# answer = server_error_msg + f" (error_code: 503)" | |
# history.append(answer) | |
# chatbot = [(history[i], history[i + 1]) for i in range(0, len(history), 2)] | |
# return chatbot, history | |
def reset_textbox(): | |
return gr.update(value="") | |
title = """<h1 align="center">Chat with QuGPT π€</h1>""" | |
# def add_text(history, text): | |
# history = history + [(text, None)] | |
# return history, "" | |
def bot(history): | |
response = "**That's cool!**" | |
history[-1][1] = response | |
return history | |
with gr.Blocks( | |
css=""" | |
footer .svelte-1lyswbr {display: none !important;} | |
#col_container {margin-left: auto; margin-right: auto;} | |
#chatbot .wrap.svelte-13f7djk {height: 70vh; max-height: 70vh} | |
#chatbot .message.user.svelte-13f7djk.svelte-13f7djk {width:fit-content; background:orange; border-bottom-right-radius:0} | |
#chatbot .message.bot.svelte-13f7djk.svelte-13f7djk {width:fit-content; padding-left: 16px; border-bottom-left-radius:0} | |
#chatbot .pre {border:2px solid white;} | |
pre { | |
white-space: pre-wrap; /* Since CSS 2.1 */ | |
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */ | |
white-space: -pre-wrap; /* Opera 4-6 */ | |
white-space: -o-pre-wrap; /* Opera 7 */ | |
word-wrap: break-word; /* Internet Explorer 5.5+ */ | |
} | |
""" | |
) as demo: | |
gr.HTML(title) | |
with gr.Row(): | |
# with gr.Column(elem_id="col_container", scale=0.3): | |
# with gr.Accordion("Prompt", open=True): | |
# system_content = gr.Textbox(value="You are QuGPT which built with LangChain and dolly-v2 and sentence-transformer.", show_label=False) | |
# with gr.Accordion("Config", open=True): | |
# embeddings_model_name = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"#gr.Textbox(value="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", label="embeddings_model_name") | |
# persist_directory = "db" #gr.Textbox(value="db", label="persist_directory") | |
# model_type = "dolly-v2-3b" #gr.Textbox(value="dolly-v2-3b", label="model_type") | |
# model_path = "/media/siiva/DataStore/LLMs/cache/dolly-v2-3b" #gr.Textbox(value="/media/siiva/DataStore/LLMs/cache/dolly-v2-3b", label="model_path") | |
# target_source_chunks = 3 | |
# # gr.Slider( | |
# # minimum=1, | |
# # maximum=5, | |
# # value=2, | |
# # step=1, | |
# # interactive=True, | |
# # label="target_source_chunks", | |
# # ) | |
# model_n_ctx = 1000 | |
# gr.Slider( | |
# minimum=32, | |
# maximum=4096, | |
# value=1000, | |
# step=32, | |
# interactive=True, | |
# label="model_n_ctx", | |
# ) | |
with gr.Column(elem_id="col_container"): | |
chatbot = gr.Chatbot(elem_id="chatbot", label="QuGPT") | |
question = gr.Textbox(placeholder="Ask something", show_label=False, value="") | |
state = gr.State([]) | |
with gr.Row(): | |
with gr.Column(): | |
submit_btn = gr.Button(value="π Send") | |
with gr.Column(): | |
clear_btn = gr.Button(value="ποΈ Clear history") | |
question.submit( | |
predict, | |
# [question, system_content, embeddings_model_name, persist_directory, model_type, model_path, model_n_ctx, target_source_chunks, chatbot, state], | |
[question, chatbot, state], | |
[chatbot, state], | |
) | |
submit_btn.click( | |
predict, | |
# [question, system_content, embeddings_model_name, persist_directory, model_type, model_path, model_n_ctx, target_source_chunks, chatbot, state], | |
[question, chatbot, state], | |
[chatbot, state], | |
) | |
submit_btn.click(reset_textbox, [], [question]) | |
clear_btn.click(clear_history, None, [chatbot, state, question]) | |
question.submit(reset_textbox, [], [question]) | |
# demo.queue(concurrency_count=10, status_update_rate="auto") | |
# question.submit(predict, [question, system_content, embeddings_model_name, persist_directory, model_type, model_path, model_n_ctx, target_source_chunks, chatbot, state], [chatbot, state]).then( | |
# predict, chatbot | |
# ) | |
#demo.launch(server_name=args.server_name, server_port=args.server_port, share=args.share, debug=args.debug) | |
demo.launch() | |