|
|
|
|
|
|
|
import os |
|
|
|
|
|
|
|
os.system("pip install gradio==3.42.0") |
|
|
|
from typing import TypeVar |
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
from langchain.vectorstores import FAISS |
|
import gradio as gr |
|
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
from ctransformers import AutoModelForCausalLM |
|
|
|
PandasDataFrame = TypeVar('pd.core.frame.DataFrame') |
|
|
|
|
|
|
|
|
|
|
|
import chatfuncs.ingest as ing |
|
|
|
|
|
|
|
embeddings_name = "BAAI/bge-base-en-v1.5" |
|
|
|
def load_embeddings(embeddings_name = "BAAI/bge-base-en-v1.5"): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
embeddings_func = HuggingFaceEmbeddings(model_name=embeddings_name) |
|
|
|
global embeddings |
|
|
|
embeddings = embeddings_func |
|
|
|
return embeddings |
|
|
|
def get_faiss_store(faiss_vstore_folder,embeddings): |
|
import zipfile |
|
with zipfile.ZipFile(faiss_vstore_folder + '/' + faiss_vstore_folder + '.zip', 'r') as zip_ref: |
|
zip_ref.extractall(faiss_vstore_folder) |
|
|
|
faiss_vstore = FAISS.load_local(folder_path=faiss_vstore_folder, embeddings=embeddings) |
|
os.remove(faiss_vstore_folder + "/index.faiss") |
|
os.remove(faiss_vstore_folder + "/index.pkl") |
|
|
|
global vectorstore |
|
|
|
vectorstore = faiss_vstore |
|
|
|
return vectorstore |
|
|
|
import chatfuncs.chatfuncs as chatf |
|
|
|
chatf.embeddings = load_embeddings(embeddings_name) |
|
chatf.vectorstore = get_faiss_store(faiss_vstore_folder="faiss_embedding",embeddings=globals()["embeddings"]) |
|
|
|
def load_model(model_type, gpu_layers, gpu_config=None, cpu_config=None, torch_device=None): |
|
print("Loading model") |
|
|
|
|
|
if gpu_config is None: |
|
gpu_config = chatf.gpu_config |
|
if cpu_config is None: |
|
cpu_config = chatf.cpu_config |
|
if torch_device is None: |
|
torch_device = chatf.torch_device |
|
|
|
if model_type == "Mistral Open Orca (larger, slow)": |
|
if torch_device == "cuda": |
|
gpu_config.update_gpu(gpu_layers) |
|
else: |
|
gpu_config.update_gpu(gpu_layers) |
|
cpu_config.update_gpu(gpu_layers) |
|
|
|
print("Loading with", cpu_config.gpu_layers, "model layers sent to GPU.") |
|
|
|
print(vars(gpu_config)) |
|
print(vars(cpu_config)) |
|
|
|
try: |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained('TheBloke/Mistral-7B-OpenOrca-GGUF', model_type='mistral', model_file='mistral-7b-openorca.Q4_K_M.gguf', **vars(gpu_config)) |
|
|
|
|
|
except: |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained('TheBloke/Mistral-7B-OpenOrca-GGUF', model_type='mistral', model_file='mistral-7b-openorca.Q4_K_M.gguf', **vars(cpu_config)) |
|
|
|
|
|
tokenizer = [] |
|
|
|
if model_type == "Flan Alpaca (small, fast)": |
|
|
|
hf_checkpoint = 'declare-lab/flan-alpaca-large' |
|
|
|
def create_hf_model(model_name): |
|
|
|
from transformers import AutoModelForSeq2SeqLM, AutoModelForCausalLM |
|
|
|
if torch_device == "cuda": |
|
if "flan" in model_name: |
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, device_map="auto") |
|
else: |
|
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto") |
|
else: |
|
if "flan" in model_name: |
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
else: |
|
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length = chatf.context_length) |
|
|
|
return model, tokenizer, model_type |
|
|
|
model, tokenizer, model_type = create_hf_model(model_name = hf_checkpoint) |
|
|
|
chatf.model = model |
|
chatf.tokenizer = tokenizer |
|
chatf.model_type = model_type |
|
|
|
load_confirmation = "Finished loading model: " + model_type |
|
|
|
print(load_confirmation) |
|
return model_type, load_confirmation, model_type |
|
|
|
|
|
|
|
|
|
|
|
model_type = "Flan Alpaca (small, fast)" |
|
load_model(model_type, 0, chatf.gpu_config, chatf.cpu_config, chatf.torch_device) |
|
|
|
def docs_to_faiss_save(docs_out:PandasDataFrame, embeddings=embeddings): |
|
|
|
print(f"> Total split documents: {len(docs_out)}") |
|
|
|
print(docs_out) |
|
|
|
vectorstore_func = FAISS.from_documents(documents=docs_out, embedding=embeddings) |
|
|
|
|
|
chatf.vectorstore = vectorstore_func |
|
|
|
out_message = "Document processing complete" |
|
|
|
return out_message, vectorstore_func |
|
|
|
|
|
|
|
block = gr.Blocks(theme = gr.themes.Base()) |
|
|
|
with block: |
|
ingest_text = gr.State() |
|
ingest_metadata = gr.State() |
|
ingest_docs = gr.State() |
|
|
|
model_type_state = gr.State(model_type) |
|
embeddings_state = gr.State(globals()["embeddings"]) |
|
vectorstore_state = gr.State(globals()["vectorstore"]) |
|
|
|
model_state = gr.State() |
|
tokenizer_state = gr.State() |
|
|
|
chat_history_state = gr.State() |
|
instruction_prompt_out = gr.State() |
|
|
|
gr.Markdown("<h1><center>Chat with Misbahuddin Classroom</center></h1>") |
|
|
|
gr.Markdown("Ask any question of pharmacology of 10 drugs. However, there are some limitations.") |
|
|
|
with gr.Row(): |
|
current_source = gr.Textbox(label="Current data source(s)", value="Lambeth_2030-Our_Future_Our_Lambeth.pdf", scale = 10) |
|
current_model = gr.Textbox(label="Current model", value=model_type, scale = 3) |
|
|
|
with gr.Tab("Chatbot"): |
|
|
|
with gr.Row(): |
|
|
|
chatbot = gr.Chatbot(avatar_images=('user.jfif', 'bot.jpg'),bubble_full_width = False, scale = 1) |
|
with gr.Accordion("Open this tab to see the source paragraphs used to generate the answer", open = False): |
|
sources = gr.HTML(value = "Source paragraphs with the most relevant text will appear here", scale = 1) |
|
|
|
with gr.Row(): |
|
message = gr.Textbox( |
|
label="Enter your question here", |
|
lines=1, |
|
) |
|
with gr.Row(): |
|
submit = gr.Button(value="Send message", variant="secondary", scale = 1) |
|
clear = gr.Button(value="Clear chat", variant="secondary", scale=0) |
|
stop = gr.Button(value="Stop generating", variant="secondary", scale=0) |
|
|
|
examples_set = gr.Radio(label="Examples for the Lambeth Borough Plan", |
|
|
|
choices=["What were the five pillars of the previous borough plan?", |
|
"What is the vision statement for Lambeth?", |
|
"What are the commitments for Lambeth?", |
|
"What are the 2030 outcomes for Lambeth?"]) |
|
|
|
|
|
current_topic = gr.Textbox(label="Feature currently disabled - Keywords related to current conversation topic.", placeholder="Keywords related to the conversation topic will appear here") |
|
|
|
|
|
|
|
with gr.Tab("Load in a different file to chat with"): |
|
with gr.Accordion("PDF file", open = False): |
|
in_pdf = gr.File(label="Upload pdf", file_count="multiple", file_types=['.pdf']) |
|
load_pdf = gr.Button(value="Load in file", variant="secondary", scale=0) |
|
|
|
with gr.Accordion("Web page", open = False): |
|
with gr.Row(): |
|
in_web = gr.Textbox(label="Enter web page url") |
|
in_div = gr.Textbox(label="(Advanced) Web page div for text extraction", value="p", placeholder="p") |
|
load_web = gr.Button(value="Load in webpage", variant="secondary", scale=0) |
|
|
|
with gr.Accordion("CSV/Excel file", open = False): |
|
in_csv = gr.File(label="Upload CSV/Excel file", file_count="multiple", file_types=['.csv', '.xlsx']) |
|
in_text_column = gr.Textbox(label="Enter column name where text is stored") |
|
load_csv = gr.Button(value="Load in CSV/Excel file", variant="secondary", scale=0) |
|
|
|
ingest_embed_out = gr.Textbox(label="File/web page preparation progress") |
|
|
|
with gr.Tab("Advanced features"): |
|
out_passages = gr.Slider(minimum=1, value = 2, maximum=10, step=1, label="Choose number of passages to retrieve from the document. Numbers greater than 2 may lead to increased hallucinations or input text being truncated.") |
|
temp_slide = gr.Slider(minimum=0.1, value = 0.1, maximum=1, step=0.1, label="Choose temperature setting for response generation.") |
|
with gr.Row(): |
|
model_choice = gr.Radio(label="Choose a chat model", value="Flan Alpaca (small, fast)", choices = ["Flan Alpaca (small, fast)", "Mistral Open Orca (larger, slow)"]) |
|
change_model_button = gr.Button(value="Load model", scale=0) |
|
with gr.Accordion("Choose number of model layers to send to GPU (WARNING: please don't modify unless you are sure you have a GPU).", open = False): |
|
gpu_layer_choice = gr.Slider(label="Choose number of model layers to send to GPU.", value=0, minimum=0, maximum=5, step = 1, visible=True) |
|
|
|
load_text = gr.Text(label="Load status") |
|
|
|
|
|
gr.HTML( |
|
"<center>This app is based on the models Flan Alpaca and Mistral Open Orca. It powered by Gradio, Transformers, Ctransformers, and Langchain.</a></center>" |
|
) |
|
|
|
examples_set.change(fn=chatf.update_message, inputs=[examples_set], outputs=[message]) |
|
|
|
change_model_button.click(fn=chatf.turn_off_interactivity, inputs=[message, chatbot], outputs=[message, chatbot], queue=False).\ |
|
then(fn=load_model, inputs=[model_choice, gpu_layer_choice], outputs = [model_type_state, load_text, current_model]).\ |
|
then(lambda: chatf.restore_interactivity(), None, [message], queue=False).\ |
|
then(chatf.clear_chat, inputs=[chat_history_state, sources, message, current_topic], outputs=[chat_history_state, sources, message, current_topic]).\ |
|
then(lambda: None, None, chatbot, queue=False) |
|
|
|
|
|
load_pdf_click = load_pdf.click(ing.parse_file, inputs=[in_pdf], outputs=[ingest_text, current_source]).\ |
|
then(ing.text_to_docs, inputs=[ingest_text], outputs=[ingest_docs]).\ |
|
then(docs_to_faiss_save, inputs=[ingest_docs], outputs=[ingest_embed_out, vectorstore_state]).\ |
|
then(chatf.hide_block, outputs = [examples_set]) |
|
|
|
|
|
load_web_click = load_web.click(ing.parse_html, inputs=[in_web, in_div], outputs=[ingest_text, ingest_metadata, current_source]).\ |
|
then(ing.html_text_to_docs, inputs=[ingest_text, ingest_metadata], outputs=[ingest_docs]).\ |
|
then(docs_to_faiss_save, inputs=[ingest_docs], outputs=[ingest_embed_out, vectorstore_state]).\ |
|
then(chatf.hide_block, outputs = [examples_set]) |
|
|
|
|
|
load_csv_click = load_csv.click(ing.parse_csv_or_excel, inputs=[in_csv, in_text_column], outputs=[ingest_text, current_source]).\ |
|
then(ing.csv_excel_text_to_docs, inputs=[ingest_text, in_text_column], outputs=[ingest_docs]).\ |
|
then(docs_to_faiss_save, inputs=[ingest_docs], outputs=[ingest_embed_out, vectorstore_state]).\ |
|
then(chatf.hide_block, outputs = [examples_set]) |
|
|
|
|
|
|
|
|
|
response_click = submit.click(chatf.create_full_prompt, inputs=[message, chat_history_state, current_topic, vectorstore_state, embeddings_state, model_type_state, out_passages], outputs=[chat_history_state, sources, instruction_prompt_out], queue=False, api_name="retrieval").\ |
|
then(chatf.turn_off_interactivity, inputs=[message, chatbot], outputs=[message, chatbot], queue=False).\ |
|
then(chatf.produce_streaming_answer_chatbot, inputs=[chatbot, instruction_prompt_out, model_type_state, temp_slide], outputs=chatbot) |
|
response_click.then(chatf.highlight_found_text, [chatbot, sources], [sources]).\ |
|
then(chatf.add_inputs_answer_to_history,[message, chatbot, current_topic], [chat_history_state, current_topic]).\ |
|
then(lambda: chatf.restore_interactivity(), None, [message], queue=False) |
|
|
|
response_enter = message.submit(chatf.create_full_prompt, inputs=[message, chat_history_state, current_topic, vectorstore_state, embeddings_state, model_type_state, out_passages], outputs=[chat_history_state, sources, instruction_prompt_out], queue=False).\ |
|
then(chatf.turn_off_interactivity, inputs=[message, chatbot], outputs=[message, chatbot], queue=False).\ |
|
then(chatf.produce_streaming_answer_chatbot, [chatbot, instruction_prompt_out, model_type_state, temp_slide], chatbot) |
|
response_enter.then(chatf.highlight_found_text, [chatbot, sources], [sources]).\ |
|
then(chatf.add_inputs_answer_to_history,[message, chatbot, current_topic], [chat_history_state, current_topic]).\ |
|
then(lambda: chatf.restore_interactivity(), None, [message], queue=False) |
|
|
|
|
|
stop.click(fn=None, inputs=None, outputs=None, cancels=[response_click, response_enter]) |
|
|
|
|
|
clear.click(chatf.clear_chat, inputs=[chat_history_state, sources, message, current_topic], outputs=[chat_history_state, sources, message, current_topic]) |
|
clear.click(lambda: None, None, chatbot, queue=False) |
|
|
|
|
|
chatbot.like(chatf.vote, [chat_history_state, instruction_prompt_out, model_type_state], None) |
|
|
|
block.queue(concurrency_count=1).launch(debug=True) |
|
|
|
|
|
|