t2 / app.py
batlahiya's picture
Update app.py
9b278a0 verified
raw
history blame
6.06 kB
import spaces
import gradio as gr
import os
import re
from pathlib import Path
from unidecode import unidecode
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import chromadb
import torch
from concurrent.futures import ThreadPoolExecutor
# Environment configuration
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Predefined values
predefined_pdf = "t6.pdf"
predefined_llm = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # Use a smaller model for faster responses
def load_doc(list_file_path, chunk_size, chunk_overlap):
loaders = [PyPDFLoader(x) for x in list_file_path]
pages = []
for loader in loaders:
pages.extend(loader.load())
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap)
doc_splits = text_splitter.split_documents(pages)
return doc_splits
def create_db(splits, collection_name):
embedding = HuggingFaceEmbeddings()
new_client = chromadb.EphemeralClient()
vectordb = Chroma.from_documents(
documents=splits,
embedding=embedding,
client=new_client,
collection_name=collection_name,
)
return vectordb
def load_db():
embedding = HuggingFaceEmbeddings()
vectordb = Chroma(
embedding_function=embedding)
return vectordb
def create_collection_name(filepath):
collection_name = Path(filepath).stem
collection_name = collection_name.replace(" ", "-")
collection_name = unidecode(collection_name)
collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)
collection_name = collection_name[:50]
if len(collection_name) < 3:
collection_name = collection_name + 'xyz'
if not collection_name[0].isalnum():
collection_name = 'A' + collection_name[1:]
if not collection_name[-1].isalnum():
collection_name = collection_name[:-1] + 'Z'
print('Filepath: ', filepath)
print('Collection name: ', collection_name)
return collection_name
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
if not torch.cuda.is_available():
print("CUDA is not available. This demo does not work on CPU.")
return None
def init_llm():
print("Initializing HF model and tokenizer...")
model = AutoModelForCausalLM.from_pretrained(llm_model, device_map="auto", load_in_4bit=True)
tokenizer = AutoTokenizer.from_pretrained(llm_model)
tokenizer.use_default_system_prompt = False
print("Initializing HF pipeline...")
hf_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device_map='auto',
max_new_tokens=max_tokens,
do_sample=True,
top_k=top_k,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id
)
llm = HuggingFacePipeline(pipeline=hf_pipeline, model_kwargs={'temperature': temperature})
print("Defining buffer memory...")
memory = ConversationBufferMemory(
memory_key="chat_history",
output_key='answer',
return_messages=True
)
retriever = vector_db.as_retriever()
print("Defining retrieval chain...")
qa_chain = ConversationalRetrievalChain.from_llm(
llm,
retriever=retriever,
chain_type="stuff",
memory=memory,
return_source_documents=True,
verbose=False,
)
return qa_chain
with ThreadPoolExecutor() as executor:
future = executor.submit(init_llm)
qa_chain = future.result()
print("Initialization complete!")
return qa_chain
# Define the conversation function
@spaces.GPU(duration=60)
def chat(message):
global qa_chain
prompt_template = "Instruction: You are an expert landlside assistant. Please provide a well written very well detailed helpful answer to the following user query as an expert only from the given references here. User Query:\n"
full_input = prompt_template + message
response = qa_chain({"question": full_input})
full_answer = response["answer"]
answer_parts = full_answer.split("Helpful Answer:")
qa_chain.memory.clear()
if len(answer_parts) > 1:
main_answer = answer_parts[-1].strip() # Extracting the main answer
references = answer_parts[0].strip() # Keeping the references
answer = f"Helpful Answer: {main_answer}\n\nReferences:\n{references}"
else:
answer = full_answer # In case there is no "Helpful Answer" part
return answer, full_answer
interface = gr.Interface(
fn=chat,
inputs="textbox", # Use a single input textbox
outputs=["textbox", "textbox"], # Two output fields: one for the main answer, one for other outputs
title="LANDSLIDE AWARENESS CHATBOT",
description="Ask me anything related to landlsides!",
elem_id="my-interface",
)
# Load the PDF document and create the vector database (replace with your logic)
pdf_filepath = predefined_pdf
doc_splits = load_doc([pdf_filepath], chunk_size=400, chunk_overlap=40)
collection_name = create_collection_name(pdf_filepath)
vector_db = create_db(doc_splits, collection_name)
# Initialize the LLM chain with threading
qa_chain = initialize_llmchain(predefined_llm, temperature=0.6, max_tokens=512, top_k=7, vector_db=vector_db)
# Check if qa_chain is properly initialized
if qa_chain is None:
print("Failed to initialize the QA chain. Please check the CUDA availability and model paths.")
else:
# Launch the Gradio interface with share option
interface.launch(share=True)