dkdaniz commited on
Commit
d233be3
β€’
1 Parent(s): 20fbd22

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +164 -102
app.py CHANGED
@@ -1,49 +1,22 @@
1
- import torch
 
 
2
  import subprocess
3
- import streamlit as st
4
- from run_localGPT import load_model
5
- from langchain.vectorstores import Chroma
6
- from constants import CHROMA_SETTINGS, EMBEDDING_MODEL_NAME, PERSIST_DIRECTORY, MODEL_ID, MODEL_BASENAME
7
- from langchain.embeddings import HuggingFaceInstructEmbeddings
8
- from langchain.chains import RetrievalQA
9
- from streamlit_extras.add_vertical_space import add_vertical_space
10
- from langchain.prompts import PromptTemplate
11
- from langchain.memory import ConversationBufferMemory
12
-
13
 
14
- def model_memory():
15
- # Adding history to the model.
16
- template = """Use the following pieces of context to answer the question at the end. If you don't know the answer,\
17
- just say that you don't know, don't try to make up an answer.
18
-
19
- {context}
20
-
21
- {history}
22
- Question: {question}
23
- Helpful Answer:"""
24
-
25
- prompt = PromptTemplate(input_variables=["history", "context", "question"], template=template)
26
- memory = ConversationBufferMemory(input_key="question", memory_key="history")
27
-
28
- return prompt, memory
29
 
 
 
 
30
 
31
- # Sidebar contents
32
- with st.sidebar:
33
- st.title("πŸ€—πŸ’¬ Converse with your Data")
34
- st.markdown(
35
- """
36
- ## About
37
- This app is an LLM-powered chatbot built using:
38
- - [Streamlit](https://streamlit.io/)
39
- - [LangChain](https://python.langchain.com/)
40
- - [LocalGPT](https://github.com/PromtEngineer/localGPT)
41
-
42
- """
43
- )
44
- add_vertical_space(5)
45
- st.write("Made with ❀️ by [Prompt Engineer](https://youtube.com/@engineerprompt)")
46
 
 
47
 
48
  if torch.backends.mps.is_available():
49
  DEVICE_TYPE = "mps"
@@ -52,71 +25,160 @@ elif torch.cuda.is_available():
52
  else:
53
  DEVICE_TYPE = "cpu"
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- # if "result" not in st.session_state:
57
- # # Run the document ingestion process.
58
- # run_langest_commands = ["python", "ingest.py"]
59
  # run_langest_commands.append("--device_type")
60
  # run_langest_commands.append(DEVICE_TYPE)
61
 
62
- # result = subprocess.run(run_langest_commands, capture_output=True)
63
- # st.session_state.result = result
 
 
 
64
 
65
- # Define the retreiver
66
  # load the vectorstore
67
- if "EMBEDDINGS" not in st.session_state:
68
- EMBEDDINGS = HuggingFaceInstructEmbeddings(model_name=EMBEDDING_MODEL_NAME, model_kwargs={"device": DEVICE_TYPE})
69
- st.session_state.EMBEDDINGS = EMBEDDINGS
70
-
71
- if "DB" not in st.session_state:
72
- DB = Chroma(
73
- persist_directory=PERSIST_DIRECTORY,
74
- embedding_function=st.session_state.EMBEDDINGS,
75
- client_settings=CHROMA_SETTINGS,
76
- )
77
- st.session_state.DB = DB
78
-
79
- if "RETRIEVER" not in st.session_state:
80
- RETRIEVER = DB.as_retriever()
81
- st.session_state.RETRIEVER = RETRIEVER
82
-
83
- if "LLM" not in st.session_state:
84
- LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME)
85
- st.session_state["LLM"] = LLM
86
-
87
-
88
- if "QA" not in st.session_state:
89
- prompt, memory = model_memory()
90
-
91
- QA = RetrievalQA.from_chain_type(
92
- llm=LLM,
93
- chain_type="stuff",
94
- retriever=RETRIEVER,
95
- return_source_documents=True,
96
- chain_type_kwargs={"prompt": prompt, "memory": memory},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  )
98
- st.session_state["QA"] = QA
99
-
100
- st.title("LocalGPT App πŸ’¬")
101
- # Create a text input box for the user
102
- prompt = st.text_input("Input your prompt here")
103
- # while True:
104
-
105
- # If the user hits enter
106
- if prompt:
107
- # Then pass the prompt to the LLM
108
- response = st.session_state["QA"](prompt)
109
- answer, docs = response["result"], response["source_documents"]
110
- # ...and write it out to the screen
111
- st.write(answer)
112
-
113
- # With a streamlit expander
114
- with st.expander("Document Similarity Search"):
115
- # Find the relevant pages
116
- search = st.session_state.DB.similarity_search_with_score(prompt)
117
- # Write out the first
118
- for i, doc in enumerate(search):
119
- # print(doc)
120
- st.write(f"Source Document # {i+1} : {doc[0].metadata['source'].split('/')[-1]}")
121
- st.write(doc[0].page_content)
122
- st.write("--------------------------------")
 
1
+ import logging
2
+ import os
3
+ import shutil
4
  import subprocess
 
 
 
 
 
 
 
 
 
 
5
 
6
+ import torch
7
+ from flask import Flask, jsonify, request
8
+ from langchain.chains import RetrievalQA
9
+ from langchain.embeddings import HuggingFaceInstructEmbeddings
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ # from langchain.embeddings import HuggingFaceEmbeddings
12
+ from run_localGPT import load_model
13
+ from prompt_template_utils import get_prompt_template
14
 
15
+ # from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
16
+ from langchain.vectorstores import Chroma
17
+ from werkzeug.utils import secure_filename
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ from constants import CHROMA_SETTINGS, EMBEDDING_MODEL_NAME, PERSIST_DIRECTORY, MODEL_ID, MODEL_BASENAME
20
 
21
  if torch.backends.mps.is_available():
22
  DEVICE_TYPE = "mps"
 
25
  else:
26
  DEVICE_TYPE = "cpu"
27
 
28
+ SHOW_SOURCES = True
29
+ logging.info(f"Running on: {DEVICE_TYPE}")
30
+ logging.info(f"Display Source Documents set to: {SHOW_SOURCES}")
31
+
32
+ EMBEDDINGS = HuggingFaceInstructEmbeddings(model_name=EMBEDDING_MODEL_NAME, model_kwargs={"device": DEVICE_TYPE})
33
+
34
+ # uncomment the following line if you used HuggingFaceEmbeddings in the ingest.py
35
+ # EMBEDDINGS = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
36
+ # if os.path.exists(PERSIST_DIRECTORY):
37
+ # try:
38
+ # shutil.rmtree(PERSIST_DIRECTORY)
39
+ # except OSError as e:
40
+ # print(f"Error: {e.filename} - {e.strerror}.")
41
+ # else:
42
+ # print("The directory does not exist")
43
 
44
+ # run_langest_commands = ["python", "ingest.py"]
45
+ # if DEVICE_TYPE == "cpu":
 
46
  # run_langest_commands.append("--device_type")
47
  # run_langest_commands.append(DEVICE_TYPE)
48
 
49
+ # result = subprocess.run(run_langest_commands, capture_output=True)
50
+ # if result.returncode != 0:
51
+ # raise FileNotFoundError(
52
+ # "No files were found inside SOURCE_DOCUMENTS, please put a starter file inside before starting the API!"
53
+ # )
54
 
 
55
  # load the vectorstore
56
+ DB = Chroma(
57
+ persist_directory=PERSIST_DIRECTORY,
58
+ embedding_function=EMBEDDINGS,
59
+ client_settings=CHROMA_SETTINGS,
60
+ )
61
+
62
+ RETRIEVER = DB.as_retriever()
63
+
64
+ LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME)
65
+ prompt, memory = get_prompt_template(promptTemplate_type="llama", history=False)
66
+
67
+ QA = RetrievalQA.from_chain_type(
68
+ llm=LLM,
69
+ chain_type="stuff",
70
+ retriever=RETRIEVER,
71
+ return_source_documents=SHOW_SOURCES,
72
+ chain_type_kwargs={
73
+ "prompt": prompt,
74
+ },
75
+ )
76
+
77
+ app = Flask(__name__)
78
+
79
+
80
+ @app.route("/api/delete_source", methods=["GET"])
81
+ def delete_source_route():
82
+ folder_name = "SOURCE_DOCUMENTS"
83
+
84
+ if os.path.exists(folder_name):
85
+ shutil.rmtree(folder_name)
86
+
87
+ os.makedirs(folder_name)
88
+
89
+ return jsonify({"message": f"Folder '{folder_name}' successfully deleted and recreated."})
90
+
91
+
92
+ @app.route("/api/save_document", methods=["GET", "POST"])
93
+ def save_document_route():
94
+ if "document" not in request.files:
95
+ return "No document part", 400
96
+ file = request.files["document"]
97
+ if file.filename == "":
98
+ return "No selected file", 400
99
+ if file:
100
+ filename = secure_filename(file.filename)
101
+ folder_path = "SOURCE_DOCUMENTS"
102
+ if not os.path.exists(folder_path):
103
+ os.makedirs(folder_path)
104
+ file_path = os.path.join(folder_path, filename)
105
+ file.save(file_path)
106
+ return "File saved successfully", 200
107
+
108
+
109
+ @app.route("/api/run_ingest", methods=["GET"])
110
+ def run_ingest_route():
111
+ global DB
112
+ global RETRIEVER
113
+ global QA
114
+ try:
115
+ if os.path.exists(PERSIST_DIRECTORY):
116
+ try:
117
+ shutil.rmtree(PERSIST_DIRECTORY)
118
+ except OSError as e:
119
+ print(f"Error: {e.filename} - {e.strerror}.")
120
+ else:
121
+ print("The directory does not exist")
122
+
123
+ run_langest_commands = ["python", "ingest.py"]
124
+ if DEVICE_TYPE == "cpu":
125
+ run_langest_commands.append("--device_type")
126
+ run_langest_commands.append(DEVICE_TYPE)
127
+
128
+ result = subprocess.run(run_langest_commands, capture_output=True)
129
+ if result.returncode != 0:
130
+ return "Script execution failed: {}".format(result.stderr.decode("utf-8")), 500
131
+ # load the vectorstore
132
+ DB = Chroma(
133
+ persist_directory=PERSIST_DIRECTORY,
134
+ embedding_function=EMBEDDINGS,
135
+ client_settings=CHROMA_SETTINGS,
136
+ )
137
+ RETRIEVER = DB.as_retriever()
138
+ prompt, memory = get_prompt_template(promptTemplate_type="llama", history=False)
139
+
140
+ QA = RetrievalQA.from_chain_type(
141
+ llm=LLM,
142
+ chain_type="stuff",
143
+ retriever=RETRIEVER,
144
+ return_source_documents=SHOW_SOURCES,
145
+ chain_type_kwargs={
146
+ "prompt": prompt,
147
+ },
148
+ )
149
+ return "Script executed successfully: {}".format(result.stdout.decode("utf-8")), 200
150
+ except Exception as e:
151
+ return f"Error occurred: {str(e)}", 500
152
+
153
+
154
+ @app.route("/api/prompt_route", methods=["GET", "POST"])
155
+ def prompt_route():
156
+ global QA
157
+ user_prompt = request.form.get("user_prompt")
158
+ if user_prompt:
159
+ # print(f'User Prompt: {user_prompt}')
160
+ # Get the answer from the chain
161
+ res = QA(user_prompt)
162
+ answer, docs = res["result"], res["source_documents"]
163
+
164
+ prompt_response_dict = {
165
+ "Prompt": user_prompt,
166
+ "Answer": answer,
167
+ }
168
+
169
+ prompt_response_dict["Sources"] = []
170
+ for document in docs:
171
+ prompt_response_dict["Sources"].append(
172
+ (os.path.basename(str(document.metadata["source"])), str(document.page_content))
173
+ )
174
+
175
+ return jsonify(prompt_response_dict), 200
176
+ else:
177
+ return "No user prompt received", 400
178
+
179
+
180
+ if __name__ == "__main__":
181
+ logging.basicConfig(
182
+ format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)s - %(message)s", level=logging.INFO
183
  )
184
+ app.run(debug=False, port=5110)