Spaces:
Paused
Paused
Carlosito16
commited on
Commit
Β·
8f6cc9b
1
Parent(s):
55dc1fb
Create conversational-app-with.py
Browse filesChange the UI form from one-query-based into chat-like flow
- conversational-app-with.py +223 -0
conversational-app-with.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This version is the same model with only different UI, to be a chat-like experience
|
2 |
+
|
3 |
+
import streamlit as st
|
4 |
+
from streamlit_chat import message as st_message
|
5 |
+
import pandas as pd
|
6 |
+
import numpy as np
|
7 |
+
import datetime
|
8 |
+
import gspread
|
9 |
+
import pickle
|
10 |
+
import os
|
11 |
+
import csv
|
12 |
+
import json
|
13 |
+
import torch
|
14 |
+
from tqdm.auto import tqdm
|
15 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
16 |
+
|
17 |
+
|
18 |
+
# from langchain.vectorstores import Chroma
|
19 |
+
from langchain.vectorstores import FAISS
|
20 |
+
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
21 |
+
|
22 |
+
|
23 |
+
from langchain import HuggingFacePipeline
|
24 |
+
from langchain.chains import RetrievalQA
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
st.set_page_config(
|
30 |
+
page_title = 'aitGPT',
|
31 |
+
page_icon = 'β
')
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
@st.cache_data
|
37 |
+
def load_scraped_web_info():
|
38 |
+
with open("ait-web-document", "rb") as fp:
|
39 |
+
ait_web_documents = pickle.load(fp)
|
40 |
+
|
41 |
+
|
42 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
43 |
+
# Set a really small chunk size, just to show.
|
44 |
+
chunk_size = 500,
|
45 |
+
chunk_overlap = 100,
|
46 |
+
length_function = len,
|
47 |
+
)
|
48 |
+
|
49 |
+
chunked_text = text_splitter.create_documents([doc for doc in tqdm(ait_web_documents)])
|
50 |
+
|
51 |
+
|
52 |
+
@st.cache_resource
|
53 |
+
def load_embedding_model():
|
54 |
+
embedding_model = HuggingFaceInstructEmbeddings(model_name='hkunlp/instructor-base',
|
55 |
+
model_kwargs = {'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu')})
|
56 |
+
return embedding_model
|
57 |
+
|
58 |
+
@st.cache_data
|
59 |
+
def load_faiss_index():
|
60 |
+
vector_database = FAISS.load_local("faiss_index_web_and_curri_new", embedding_model) #CHANGE THIS FAISS EMBEDDED KNOWLEDGE
|
61 |
+
return vector_database
|
62 |
+
|
63 |
+
@st.cache_resource
|
64 |
+
def load_llm_model():
|
65 |
+
# llm = HuggingFacePipeline.from_model_id(model_id= 'lmsys/fastchat-t5-3b-v1.0',
|
66 |
+
# task= 'text2text-generation',
|
67 |
+
# model_kwargs={ "device_map": "auto",
|
68 |
+
# "load_in_8bit": True,"max_length": 256, "temperature": 0,
|
69 |
+
# "repetition_penalty": 1.5})
|
70 |
+
|
71 |
+
|
72 |
+
llm = HuggingFacePipeline.from_model_id(model_id= 'lmsys/fastchat-t5-3b-v1.0',
|
73 |
+
task= 'text2text-generation',
|
74 |
+
|
75 |
+
model_kwargs={ "max_length": 256, "temperature": 0,
|
76 |
+
"torch_dtype":torch.float32,
|
77 |
+
"repetition_penalty": 1.3})
|
78 |
+
return llm
|
79 |
+
|
80 |
+
|
81 |
+
def load_retriever(llm, db):
|
82 |
+
qa_retriever = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff",
|
83 |
+
retriever=db.as_retriever())
|
84 |
+
|
85 |
+
return qa_retriever
|
86 |
+
|
87 |
+
def retrieve_document(query_input):
|
88 |
+
related_doc = vector_database.similarity_search(query_input)
|
89 |
+
return related_doc
|
90 |
+
|
91 |
+
def retrieve_answer(query_input):
|
92 |
+
prompt_answer= query_input + " " + "Try to elaborate as much as you can."
|
93 |
+
answer = qa_retriever.run(prompt_answer)
|
94 |
+
output = st.text_area(label="Retrieved documents", value=answer[6:]) #this positional slicing helps remove "<pad> " at the beginning
|
95 |
+
|
96 |
+
st.markdown('---')
|
97 |
+
# score = st.radio(label = 'please select the rating score for overall satifaction and helpfullness of the bot answer', options=[0, 1,2,3,4,5], horizontal=True,
|
98 |
+
# on_change=update_worksheet_qa, key='rating')
|
99 |
+
|
100 |
+
return answer[6:] #this positional slicing helps remove "<pad> " at the beginning
|
101 |
+
|
102 |
+
# def update_score():
|
103 |
+
# st.session_state.session_rating = st.session_state.rating
|
104 |
+
|
105 |
+
|
106 |
+
def update_worksheet_qa():
|
107 |
+
# st.session_state.session_rating = st.session_state.rating
|
108 |
+
#This if helps validate the initiated rating, if 0, then the google sheet would not be updated
|
109 |
+
#(edited) now even with the score of 0, we still want to store the log because some users do not give the score to complete the logging
|
110 |
+
# if st.session_state.session_rating == 0:
|
111 |
+
worksheet_qa.append_row([st.session_state.history[-1]['timestamp'].strftime(datetime_format),
|
112 |
+
st.session_state.history[-1]['question'],
|
113 |
+
st.session_state.history[-1]['generated_answer'],
|
114 |
+
0])
|
115 |
+
# else:
|
116 |
+
# worksheet_qa.append_row([st.session_state.history[-1]['timestamp'].strftime(datetime_format),
|
117 |
+
# st.session_state.history[-1]['question'],
|
118 |
+
# st.session_state.history[-1]['generated_answer'],
|
119 |
+
# st.session_state.session_rating
|
120 |
+
# ])
|
121 |
+
|
122 |
+
def update_worksheet_comment():
|
123 |
+
worksheet_comment.append_row([datetime.datetime.now().strftime(datetime_format),
|
124 |
+
feedback_input])
|
125 |
+
success_message = st.success('Feedback successfully submitted, thank you', icon="β
",
|
126 |
+
)
|
127 |
+
time.sleep(3)
|
128 |
+
success_message.empty()
|
129 |
+
|
130 |
+
|
131 |
+
def clean_chat_history():
|
132 |
+
st.session_state.chat_history = []
|
133 |
+
|
134 |
+
#--------------
|
135 |
+
|
136 |
+
|
137 |
+
if "history" not in st.session_state: #this one is for the google sheet logging
|
138 |
+
st.session_state.history = []
|
139 |
+
|
140 |
+
|
141 |
+
if "chat_history" not in st.session_state: #this one is to pass previous messages into chat flow
|
142 |
+
st.session_state.chat_history = []
|
143 |
+
# if "session_rating" not in st.session_state:
|
144 |
+
# st.session_state.session_rating = 0
|
145 |
+
|
146 |
+
|
147 |
+
credentials= json.loads(st.secrets['google_sheet_credential'])
|
148 |
+
|
149 |
+
service_account = gspread.service_account_from_dict(credentials)
|
150 |
+
workbook= service_account.open("aitGPT-qa-log")
|
151 |
+
worksheet_qa = workbook.worksheet("Sheet1")
|
152 |
+
worksheet_comment = workbook.worksheet("Sheet2")
|
153 |
+
datetime_format= "%Y-%m-%d %H:%M:%S"
|
154 |
+
|
155 |
+
|
156 |
+
|
157 |
+
load_scraped_web_info()
|
158 |
+
embedding_model = load_embedding_model()
|
159 |
+
vector_database = load_faiss_index()
|
160 |
+
llm_model = load_llm_model()
|
161 |
+
qa_retriever = load_retriever(llm= llm_model, db= vector_database)
|
162 |
+
|
163 |
+
|
164 |
+
print("all load done")
|
165 |
+
|
166 |
+
|
167 |
+
|
168 |
+
|
169 |
+
|
170 |
+
|
171 |
+
|
172 |
+
|
173 |
+
st.write("# aitGPT π€ ")
|
174 |
+
st.markdown("""
|
175 |
+
#### The aitGPT project is a virtual assistant developed by the :green[Asian Institute of Technology] that contains a vast amount of information gathered from 205 AIT-related websites.
|
176 |
+
The goal of this chatbot is to provide an alternative way for applicants and current students to access information about the institute, including admission procedures, campus facilities, and more.
|
177 |
+
""")
|
178 |
+
st.write(' β οΈ Please expect to wait **~ 10 - 20 seconds per question** as thi app is running on CPU against 3-billion-parameter LLM')
|
179 |
+
|
180 |
+
st.markdown("---")
|
181 |
+
st.write(" ")
|
182 |
+
st.write("""
|
183 |
+
### β Ask a question
|
184 |
+
""")
|
185 |
+
|
186 |
+
|
187 |
+
for chat in st.session_state.chat_history:
|
188 |
+
st_message(**chat)
|
189 |
+
|
190 |
+
query_input = st.text_area(label= 'What would you like to know about AIT?' , key = 'my_text_input')
|
191 |
+
generate_button = st.button(label = 'Ask question!')
|
192 |
+
|
193 |
+
if generate_button:
|
194 |
+
answer = retrieve_answer(query_input)
|
195 |
+
log = {"timestamp": datetime.datetime.now(),
|
196 |
+
"question":query_input,
|
197 |
+
"generated_answer": answer,
|
198 |
+
"rating":0 }
|
199 |
+
|
200 |
+
st.session_state.history.append(log)
|
201 |
+
update_worksheet_qa()
|
202 |
+
st.session_state.chat_history.append({"message": query_input, "is_user": True})
|
203 |
+
st.session_state.chat_history.append({"message": answer, "is_user": False})
|
204 |
+
|
205 |
+
|
206 |
+
clear_button = st.button("Start new convo",
|
207 |
+
on_click=clean_chat_history)
|
208 |
+
|
209 |
+
|
210 |
+
st.write(" ")
|
211 |
+
st.write(" ")
|
212 |
+
|
213 |
+
st.markdown("---")
|
214 |
+
st.write("""
|
215 |
+
### π Your voice matters
|
216 |
+
""")
|
217 |
+
|
218 |
+
feedback_input = st.text_area(label= 'please leave your feedback or any ideas to make this bot more knowledgeable and fun')
|
219 |
+
feedback_button = st.button(label = 'Submit feedback!')
|
220 |
+
|
221 |
+
if feedback_button:
|
222 |
+
update_worksheet_comment()
|
223 |
+
|