kamau1 commited on
Commit
061584a
1 Parent(s): a9997cf

Create semapdf1.2.py

Browse files
Files changed (1) hide show
  1. version/semapdf1.2.py +177 -0
version/semapdf1.2.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.agents import AgentType, Tool, initialize_agent
2
+ from langchain.callbacks import StreamlitCallbackHandler
3
+ from langchain.chains import RetrievalQA
4
+ from langchain.chains.conversation.memory import ConversationBufferMemory
5
+ from utils.ask_human import CustomAskHumanTool
6
+ from utils.model_params import get_model_params
7
+ from utils.prompts import create_agent_prompt, create_qa_prompt
8
+ from PyPDF2 import PdfReader
9
+ from langchain.vectorstores import FAISS
10
+ from langchain.embeddings import HuggingFaceEmbeddings
11
+ from langchain.embeddings import HuggingFaceHubEmbeddings
12
+ from langchain import HuggingFaceHub
13
+ import torch
14
+ import streamlit as st
15
+ from langchain.utilities import SerpAPIWrapper
16
+ from langchain.tools import DuckDuckGoSearchRun
17
+ import os
18
+ hf_token = os.environ['HF_TOKEN']
19
+ serp_token = os.environ['SERP_TOKEN']
20
+ repo_id = "sentence-transformers/all-mpnet-base-v2"
21
+
22
+
23
+
24
+
25
+ HUGGINGFACEHUB_API_TOKEN= hf_token
26
+ hf = HuggingFaceHubEmbeddings(
27
+ repo_id=repo_id,
28
+ task="feature-extraction",
29
+ huggingfacehub_api_token= HUGGINGFACEHUB_API_TOKEN,
30
+ )
31
+
32
+
33
+
34
+
35
+
36
+ llm = HuggingFaceHub(
37
+ repo_id='mistralai/Mistral-7B-Instruct-v0.2',
38
+ huggingfacehub_api_token = HUGGINGFACEHUB_API_TOKEN,
39
+
40
+
41
+ )
42
+
43
+
44
+
45
+ from langchain.text_splitter import CharacterTextSplitter, TokenTextSplitter
46
+ from langchain.vectorstores import Chroma
47
+ from langchain.chains import RetrievalQA
48
+ from langchain import PromptTemplate
49
+
50
+ ### PAGE ELEMENTS
51
+
52
+ # st.set_page_config(
53
+ # page_title="RAG Agent Demo",
54
+ # page_icon="🦜",
55
+ # layout="centered",
56
+ # initial_sidebar_state="collapsed",
57
+ # )
58
+ # st.markdown("### Leveraging the User to Improve Agents in RAG Use Cases")
59
+
60
+
61
+ def main():
62
+
63
+ st.set_page_config(page_title="Ask your PDF powered by Search Agents")
64
+ st.header("Ask your PDF powered by Search Agents 💬")
65
+
66
+ # upload file
67
+ pdf = st.file_uploader("Upload your PDF and chat with Agent", type="pdf")
68
+
69
+ # extract the text
70
+ if pdf is not None:
71
+ pdf_reader = PdfReader(pdf)
72
+ text = ""
73
+ for page in pdf_reader.pages:
74
+ text += page.extract_text()
75
+
76
+ # Split documents and create text snippets
77
+
78
+ text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=0)
79
+ texts = text_splitter.split_text(text)
80
+
81
+ embeddings = hf
82
+ knowledge_base = FAISS.from_texts(texts, embeddings)
83
+
84
+ retriever = knowledge_base.as_retriever(search_kwargs={"k":3})
85
+
86
+
87
+
88
+
89
+ qa_chain = RetrievalQA.from_chain_type(
90
+ llm=llm,
91
+ chain_type="stuff",
92
+ retriever=retriever,
93
+ return_source_documents=False,
94
+ chain_type_kwargs={
95
+ "prompt": create_qa_prompt(),
96
+ },
97
+ )
98
+
99
+ conversational_memory = ConversationBufferMemory(
100
+ memory_key="chat_history", k=3, return_messages=True
101
+ )
102
+
103
+ # tool for db search
104
+ db_search_tool = Tool(
105
+ name="dbRetrievalTool",
106
+ func=qa_chain,
107
+ description="""Use this tool to answer document related questions. The input to this tool should be the question.""",
108
+ )
109
+
110
+ # search = SerpAPIWrapper(serpapi_api_key=serp_token)
111
+
112
+ # google_searchtool= Tool(
113
+ # name="Current Search",
114
+ # func=search.run,
115
+ # description="use this tool to answer real time or current search related questions.",
116
+ # )
117
+ search = DuckDuckGoSearchRun()
118
+ search_tool = Tool(
119
+ name="search",
120
+ func=search,
121
+ description="use this tool to answer real time or current search related questions."
122
+ )
123
+ # tool for asking human
124
+ human_ask_tool = CustomAskHumanTool()
125
+ # agent prompt
126
+ prefix, format_instructions, suffix = create_agent_prompt()
127
+ mode = "Agent with AskHuman tool"
128
+
129
+ # initialize agent
130
+ agent = initialize_agent(
131
+ agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
132
+ tools=[db_search_tool,search_tool],
133
+ llm=llm,
134
+ verbose=True,
135
+ max_iterations=5,
136
+ early_stopping_method="generate",
137
+ memory=conversational_memory,
138
+ agent_kwargs={
139
+ "prefix": prefix,
140
+ "format_instructions": format_instructions,
141
+ "suffix": suffix,
142
+ },
143
+ handle_parsing_errors=True,
144
+
145
+ )
146
+
147
+ # question form
148
+ with st.form(key="form"):
149
+ user_input = st.text_input("Ask your question")
150
+ submit_clicked = st.form_submit_button("Submit Question")
151
+
152
+ # output container
153
+ output_container = st.empty()
154
+ if submit_clicked:
155
+ # st_callback = StreamlitCallbackHandler(st.container())
156
+ # response = agent.run(user_input,callbacks = [st_callback])
157
+ response = agent.run(user_input)
158
+ st.write(response)
159
+ # output_container = output_container.container()
160
+ # output_container.chat_message("user").write(user_input)
161
+ # with st.chat_message("assistant"):
162
+ # st_callback = StreamlitCallbackHandler(st.container())
163
+ # response = agent.run(user_input, callbacks=[st_callback])
164
+ # st.write(response)
165
+
166
+ # answer_container = output_container.chat_message("assistant", avatar="🦜")
167
+ # st_callback = StreamlitCallbackHandler(answer_container,)
168
+
169
+ # answer = agent.run(user_input, callbacks=[st_callback])
170
+
171
+ # answer_container = output_container.container()
172
+ # answer_container.chat_message("assistant").write(answer)
173
+
174
+
175
+
176
+ if __name__ == '__main__':
177
+ main()