NanBags commited on
Commit
c9a97bb
·
verified ·
1 Parent(s): ae2b92b

Create app.py

Browse files

This is the app.py file

Files changed (1) hide show
  1. app.py +128 -0
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import streamlit as st
4
+ from langchain_core.prompts import ChatPromptTemplate
5
+ from langchain_community.vectorstores import FAISS
6
+ from langchain_core.output_parsers import StrOutputParser
7
+ from langchain_core.runnables import RunnablePassthrough
8
+ from langchain_community.llms import Together
9
+ from langchain_community.document_loaders import UnstructuredPDFLoader
10
+ from langchain.text_splitter import CharacterTextSplitter
11
+ from langchain.embeddings import HuggingFaceEmbeddings
12
+
13
+ os.environ["TOGETHER_API_KEY"] = os.getenv("TOGETHER_API_KEY")
14
+
15
+
16
+ def inference(chain, input_query):
17
+ """Invoke the processing chain with the input query."""
18
+ result = chain.invoke(input_query)
19
+ return result
20
+
21
+
22
+ def create_chain(retriever, prompt, model):
23
+ """Compose the processing chain with the specified components."""
24
+ chain = (
25
+ {"context": retriever, "question": RunnablePassthrough()}
26
+ | prompt
27
+ | model
28
+ | StrOutputParser()
29
+ )
30
+ return chain
31
+
32
+
33
+ def generate_prompt():
34
+ """Define the prompt template for question answering."""
35
+ template = """<s>[INST] Answer the question in a simple sentence based only on the following context:
36
+ {context}
37
+ Question: {question} [/INST]
38
+ """
39
+ return ChatPromptTemplate.from_template(template)
40
+
41
+
42
+ def configure_model():
43
+ """Configure the language model with specified parameters."""
44
+ return Together(
45
+ model="mistralai/Mixtral-8x7B-Instruct-v0.1",
46
+ temperature=0.1,
47
+ max_tokens=3000,
48
+ top_k=50,
49
+ top_p=0.7,
50
+ repetition_penalty=1.1,
51
+ )
52
+
53
+
54
+ def configure_retriever(pdf_loader):
55
+ """Configure the retriever with embeddings and a FAISS vector store."""
56
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
57
+ vector_db = FAISS.from_documents(pdf_loader, embeddings)
58
+ return vector_db.as_retriever()
59
+
60
+
61
+ def load_documents(path):
62
+ """Load and preprocess documents from PDF files located at the specified path."""
63
+ pdf_loader = []
64
+ for file in os.listdir(path):
65
+ if file.endswith('.pdf'):
66
+ filepath = os.path.join(path, file)
67
+ loader = UnstructuredPDFLoader(filepath)
68
+ documents = loader.load()
69
+ text_splitter = CharacterTextSplitter(chunk_size=18000, chunk_overlap=10)
70
+ docs = text_splitter.split_documents(documents)
71
+ pdf_loader.extend(docs)
72
+ return pdf_loader
73
+
74
+
75
+ def process_document(path, input_query):
76
+ """Process the document by setting up the chain and invoking it with the input query."""
77
+ pdf_loader = load_documents(path)
78
+ llm_model = configure_model()
79
+ prompt = generate_prompt()
80
+ retriever = configure_retriever(pdf_loader)
81
+ chain = create_chain(retriever, prompt, llm_model)
82
+ response = inference(chain, input_query)
83
+ return response
84
+
85
+
86
+ def main():
87
+ """Main function to run the Streamlit app."""
88
+ tmp_folder = '/tmp/1'
89
+ os.makedirs(tmp_folder,exist_ok=True)
90
+
91
+ st.title("Q&A PDF AI RAG Chatbot")
92
+
93
+ uploaded_files = st.sidebar.file_uploader("Choose PDF files", accept_multiple_files=True, type='pdf')
94
+ if uploaded_files:
95
+ for file in uploaded_files:
96
+ with open(os.path.join(tmp_folder, file.name), 'wb') as f:
97
+ f.write(file.getbuffer())
98
+ st.success('File successfully uploaded. Start prompting!')
99
+ if 'chat_history' not in st.session_state:
100
+ st.session_state.chat_history = []
101
+
102
+ if uploaded_files:
103
+ with st.form(key='question_form'):
104
+ user_query = st.text_input("Ask a question:", key="query_input")
105
+ if st.form_submit_button("Ask") and user_query:
106
+ response = process_document(tmp_folder, user_query)
107
+ st.session_state.chat_history.append({"question": user_query, "answer": response})
108
+
109
+ if st.button("Clear Chat History"):
110
+ st.session_state.chat_history = []
111
+ for chat in st.session_state.chat_history:
112
+ st.markdown(f"**Q:** {chat['question']}")
113
+ st.markdown(f"**A:** {chat['answer']}")
114
+ st.markdown("---")
115
+ else:
116
+ st.success('Upload Document to Start Process !')
117
+
118
+ if st.sidebar.button("REMOVE UPLOADED FILES"):
119
+ document_count = os.listdir(tmp_folder)
120
+ if len(document_count) > 0:
121
+ shutil.rmtree(tmp_folder)
122
+ st.sidebar.write("FILES DELETED SUCCESSFULLY !!!")
123
+ else:
124
+ st.sidebar.write("NO DOCUMENT FOUND TO DELETE !!! PLEASE UPLOAD DOCUMENTS TO START PROCESS !! ")
125
+
126
+
127
+ if __name__ == "__main__":
128
+ main()