# Disclamer: This code is not written by me. Its taken from https://github.com/imartinez/privateGPT/pull/91.
# All credit goes to `vnk8071` as I mentioned in the video.
# As this code was still in the pull request while I was creating the video, did some modifications so that it works for me locally.
import os
os.system('pip install ./langchain')
import gradio as gr
from dotenv import load_dotenv
# from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.chains import RetrievalQA
from langchain.embeddings import LlamaCppEmbeddings
# from langchain.llms import GPT4All, LlamaCpp
from langchain.vectorstores import Chroma
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings#, SentenceTransformerEmbeddings
from langchain.prompts.prompt import PromptTemplate
from langchain import PromptTemplate, LLMChain
from langchain.llms import HuggingFacePipeline
from training.generate import InstructionTextGenerationPipeline, load_model_tokenizer_for_generate
# from googletrans import Translator
# translator = Translator()
load_dotenv()
embeddings_model_name = os.environ.get("EMBEDDINGS_MODEL_NAME")
persist_directory = os.environ.get('PERSIST_DIRECTORY')
model_type = os.environ.get('MODEL_TYPE')
model_path = os.environ.get('MODEL_PATH')
model_n_ctx = int(os.environ.get('MODEL_N_CTX'))
target_source_chunks = int(os.environ.get('TARGET_SOURCE_CHUNKS',4))
# PERSIST_DIRECTORY=db
# MODEL_TYPE=dolly-v2-3b
# MODEL_PATH=/media/siiva/DataStore/LLMs/cache/dolly-v2-3b
# EMBEDDINGS_MODEL_NAME=sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2
# MODEL_N_CTX=1000
# TARGET_SOURCE_CHUNKS=4
from constants import CHROMA_SETTINGS
# embeddings_model_name = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
# persist_directory = "db"
# model_type = "dolly-v2-3b"
# model_path = "/media/siiva/DataStore/LLMs/cache/dolly-v2-3b"
# target_source_chunks = 3
# model_n_ctx = 1000
embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)
db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS)
retriever = db.as_retriever(search_kwargs={"k": target_source_chunks})
# Prepare the LLM
# callbacks = [StreamingStdOutCallbackHandler()]
match model_type:
case "dolly-v2-3b":
model, tokenizer = load_model_tokenizer_for_generate(model_path)
llm = HuggingFacePipeline(
pipeline=InstructionTextGenerationPipeline(
# Return the full text, because this is what the HuggingFacePipeline expects.
model=model, tokenizer=tokenizer, return_full_text=True, task="text-generation", max_new_tokens=model_n_ctx))#, max_new_tokens=model_n_ctx
#))
# case "GPT4All":
# llm = GPT4All(model=model_path, n_ctx=model_n_ctx, backend='gptj', callbacks=callbacks, verbose=False)
case _default:
print(f"Model {model_type} not supported!")
exit;
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True)
server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
def clear_history(request: gr.Request):
state = None
return ([], state, "")
def post_process_code(code):
sep = "\n```"
if sep in code:
blocks = code.split(sep)
if len(blocks) % 2 == 1:
for i in range(1, len(blocks), 2):
blocks[i] = blocks[i].replace("\\_", "_")
code = sep.join(blocks)
return code
def post_process_answer(answer, source):
answer += f"
Source: {source}"
answer = answer.replace("\n", "
")
return answer
def predict(
question: str,
# system_content: str,
# embeddings_model_name: str,
# persist_directory: str,
# model_type: str,
# model_path: str,
# model_n_ctx: int,
# target_source_chunks: int,
chatbot: list = [],
history: list = [],
):
# try:
# embeddings_model_name = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
# persist_directory = "db"
# model_type = "dolly-v2-3b"
# model_path = "/media/siiva/DataStore/LLMs/cache/dolly-v2-3b"
# target_source_chunks = 3
# model_n_ctx = 1000
# embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)
# db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS)
# retriever = db.as_retriever(search_kwargs={"k": target_source_chunks})
# # Prepare the LLM
# callbacks = [StreamingStdOutCallbackHandler()]
# match model_type:
# case "dolly-v2-3b":
# model, tokenizer = load_model_tokenizer_for_generate(model_path)
# llm = HuggingFacePipeline(
# pipeline=InstructionTextGenerationPipeline(
# # Return the full text, because this is what the HuggingFacePipeline expects.
# model=model, tokenizer=tokenizer, return_full_text=True, task="text-generation", max_new_tokens=model_n_ctx
# ))
# case "GPT4All":
# llm = GPT4All(model=model_path, n_ctx=model_n_ctx, backend='gptj', callbacks=callbacks, verbose=False)
# case _default:
# print(f"Model {model_type} not supported!")
# exit;
# qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True)
# Get the answer from the chain
# prompt = system_content + f"\n Question: {question}"
prompt = f"{question}"
# res = qa(prompt)
no_input_prompt = PromptTemplate(input_variables=[], template=prompt, dest_language='en')#src_language='id',
no_input_prompt.format()
query = no_input_prompt.translate()
# prompt_trans = translator.translate(prompt, src='en', dest='id')
# print(prompt_trans.text)
# result = qa({"question": query, "chat_history": chat_history})
llm_response = qa(query)
answer, docs = llm_response['result'], llm_response['source_documents']
no_input_prompt = PromptTemplate(input_variables=[], template=answer, dest_language='id')
no_input_prompt.format()
answer = no_input_prompt.translate()
# answer = post_process_answer(answer, docs)
history.append(question)
history.append(answer)
chatbot = [(history[i], history[i + 1]) for i in range(0, len(history), 2)]
return chatbot, history
# except Exception as e:
# history.append("")
# answer = server_error_msg + f" (error_code: 503)"
# history.append(answer)
# chatbot = [(history[i], history[i + 1]) for i in range(0, len(history), 2)]
# return chatbot, history
def reset_textbox():
return gr.update(value="")
title = """