Spaces:
Runtime error
Runtime error
#import os | |
#os.system("bash setup.sh") | |
import streamlit as st | |
# import fitz # PyMuPDF for extracting text from PDFs | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.vectorstores import Chroma | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.docstore.document import Document | |
from langchain.llms import HuggingFacePipeline | |
from langchain.chains import RetrievalQA | |
from transformers import AutoConfig, AutoTokenizer, pipeline, AutoModelForCausalLM | |
import torch | |
import re | |
import transformers | |
from torch import bfloat16 | |
from langchain_community.document_loaders import DirectoryLoader | |
# Initialize embeddings and ChromaDB | |
model_name = "sentence-transformers/all-mpnet-base-v2" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model_kwargs = {"device": device} | |
embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs) | |
# loader = DirectoryLoader('./pdf', glob="**/*.pdf", use_multithreading=True) | |
loader = DirectoryLoader('./pdf', glob="**/*.pdf", recursive=True, use_multithreading=True) | |
docs = loader.load() | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | |
all_splits = text_splitter.split_documents(docs) | |
vectordb = Chroma.from_documents(documents=all_splits, embedding=embeddings, persist_directory="pdf_db") | |
books_db = Chroma(persist_directory="./pdf_db", embedding_function=embeddings) | |
books_db_client = books_db.as_retriever() | |
# Initialize the model and tokenizer | |
model_name = "unsloth/Llama-3.2-3B-Instruct" | |
# bnb_config = transformers.BitsAndBytesConfig( | |
# load_in_4bit=True, | |
# bnb_4bit_quant_type='nf4', | |
# bnb_4bit_use_double_quant=True, | |
# bnb_4bit_compute_dtype=torch.bfloat16 | |
# ) | |
model_config = transformers.AutoConfig.from_pretrained(model_name, max_new_tokens=1024) | |
model = transformers.AutoModelForCausalLM.from_pretrained( | |
model_name, | |
trust_remote_code=True, | |
config=model_config, | |
device_map="auto" if device == "cuda" else None, | |
) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
query_pipeline = transformers.pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
return_full_text=True, | |
torch_dtype=torch.float16 if device == "cuda" else torch.float32, | |
device_map="auto" if device == "cuda" else None, | |
temperature=0.7, | |
top_p=0.9, | |
top_k=50, | |
max_new_tokens=128 # Reduce this from 256 | |
) | |
llm = HuggingFacePipeline(pipeline=query_pipeline) | |
books_db_client_retriever = RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=books_db_client, | |
verbose=True | |
) | |
st.title("RAG System with ChromaDB") | |
if 'messages' not in st.session_state: | |
st.session_state.messages = [{'role': 'assistant', "content": 'Hello! Upload PDF files and ask me anything about their content.'}] | |
# Function to retrieve answer using the RAG system | |
def test_rag(qa, query): | |
return qa.run(query) | |
user_prompt = st.text_input("Ask me anything about the content of the PDF(s):") | |
print("user input:", user_prompt) | |
# if st.button("Submit"): | |
# print("user input after submit button: ", user_prompt) | |
if user_prompt: | |
print("user input after if user prompt condition: ", user_prompt) | |
st.session_state.messages.append({'role': 'user', "content": user_prompt}) | |
books_retriever = test_rag(books_db_client_retriever, user_prompt) | |
print("books retriver:",books_retriever) | |
# Extracting the relevant answer using regex | |
corrected_text_match = re.search(r"Helpful Answer:(.*)", books_retriever, re.DOTALL) | |
print("corrected text match:", corrected_text_match) | |
if corrected_text_match: | |
corrected_text_books = corrected_text_match.group(1).strip() | |
else: | |
corrected_text_books = "No helpful answer found." | |
print("corrected text books: ",corrected_text_books) | |
st.session_state.messages.append({'role': 'assistant', "content": corrected_text_books}) | |
for message in st.session_state.messages: | |
with st.chat_message(message['role']): | |
st.write(message['content']) |