Carlosito16 commited on
Commit
8f6cc9b
Β·
1 Parent(s): 55dc1fb

Create conversational-app-with.py

Browse files

Change the UI form from one-query-based into chat-like flow

Files changed (1) hide show
  1. conversational-app-with.py +223 -0
conversational-app-with.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This version is the same model with only different UI, to be a chat-like experience
2
+
3
+ import streamlit as st
4
+ from streamlit_chat import message as st_message
5
+ import pandas as pd
6
+ import numpy as np
7
+ import datetime
8
+ import gspread
9
+ import pickle
10
+ import os
11
+ import csv
12
+ import json
13
+ import torch
14
+ from tqdm.auto import tqdm
15
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
16
+
17
+
18
+ # from langchain.vectorstores import Chroma
19
+ from langchain.vectorstores import FAISS
20
+ from langchain.embeddings import HuggingFaceInstructEmbeddings
21
+
22
+
23
+ from langchain import HuggingFacePipeline
24
+ from langchain.chains import RetrievalQA
25
+
26
+
27
+
28
+
29
+ st.set_page_config(
30
+ page_title = 'aitGPT',
31
+ page_icon = 'βœ…')
32
+
33
+
34
+
35
+
36
+ @st.cache_data
37
+ def load_scraped_web_info():
38
+ with open("ait-web-document", "rb") as fp:
39
+ ait_web_documents = pickle.load(fp)
40
+
41
+
42
+ text_splitter = RecursiveCharacterTextSplitter(
43
+ # Set a really small chunk size, just to show.
44
+ chunk_size = 500,
45
+ chunk_overlap = 100,
46
+ length_function = len,
47
+ )
48
+
49
+ chunked_text = text_splitter.create_documents([doc for doc in tqdm(ait_web_documents)])
50
+
51
+
52
+ @st.cache_resource
53
+ def load_embedding_model():
54
+ embedding_model = HuggingFaceInstructEmbeddings(model_name='hkunlp/instructor-base',
55
+ model_kwargs = {'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu')})
56
+ return embedding_model
57
+
58
+ @st.cache_data
59
+ def load_faiss_index():
60
+ vector_database = FAISS.load_local("faiss_index_web_and_curri_new", embedding_model) #CHANGE THIS FAISS EMBEDDED KNOWLEDGE
61
+ return vector_database
62
+
63
+ @st.cache_resource
64
+ def load_llm_model():
65
+ # llm = HuggingFacePipeline.from_model_id(model_id= 'lmsys/fastchat-t5-3b-v1.0',
66
+ # task= 'text2text-generation',
67
+ # model_kwargs={ "device_map": "auto",
68
+ # "load_in_8bit": True,"max_length": 256, "temperature": 0,
69
+ # "repetition_penalty": 1.5})
70
+
71
+
72
+ llm = HuggingFacePipeline.from_model_id(model_id= 'lmsys/fastchat-t5-3b-v1.0',
73
+ task= 'text2text-generation',
74
+
75
+ model_kwargs={ "max_length": 256, "temperature": 0,
76
+ "torch_dtype":torch.float32,
77
+ "repetition_penalty": 1.3})
78
+ return llm
79
+
80
+
81
+ def load_retriever(llm, db):
82
+ qa_retriever = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff",
83
+ retriever=db.as_retriever())
84
+
85
+ return qa_retriever
86
+
87
+ def retrieve_document(query_input):
88
+ related_doc = vector_database.similarity_search(query_input)
89
+ return related_doc
90
+
91
+ def retrieve_answer(query_input):
92
+ prompt_answer= query_input + " " + "Try to elaborate as much as you can."
93
+ answer = qa_retriever.run(prompt_answer)
94
+ output = st.text_area(label="Retrieved documents", value=answer[6:]) #this positional slicing helps remove "<pad> " at the beginning
95
+
96
+ st.markdown('---')
97
+ # score = st.radio(label = 'please select the rating score for overall satifaction and helpfullness of the bot answer', options=[0, 1,2,3,4,5], horizontal=True,
98
+ # on_change=update_worksheet_qa, key='rating')
99
+
100
+ return answer[6:] #this positional slicing helps remove "<pad> " at the beginning
101
+
102
+ # def update_score():
103
+ # st.session_state.session_rating = st.session_state.rating
104
+
105
+
106
+ def update_worksheet_qa():
107
+ # st.session_state.session_rating = st.session_state.rating
108
+ #This if helps validate the initiated rating, if 0, then the google sheet would not be updated
109
+ #(edited) now even with the score of 0, we still want to store the log because some users do not give the score to complete the logging
110
+ # if st.session_state.session_rating == 0:
111
+ worksheet_qa.append_row([st.session_state.history[-1]['timestamp'].strftime(datetime_format),
112
+ st.session_state.history[-1]['question'],
113
+ st.session_state.history[-1]['generated_answer'],
114
+ 0])
115
+ # else:
116
+ # worksheet_qa.append_row([st.session_state.history[-1]['timestamp'].strftime(datetime_format),
117
+ # st.session_state.history[-1]['question'],
118
+ # st.session_state.history[-1]['generated_answer'],
119
+ # st.session_state.session_rating
120
+ # ])
121
+
122
+ def update_worksheet_comment():
123
+ worksheet_comment.append_row([datetime.datetime.now().strftime(datetime_format),
124
+ feedback_input])
125
+ success_message = st.success('Feedback successfully submitted, thank you', icon="βœ…",
126
+ )
127
+ time.sleep(3)
128
+ success_message.empty()
129
+
130
+
131
+ def clean_chat_history():
132
+ st.session_state.chat_history = []
133
+
134
+ #--------------
135
+
136
+
137
+ if "history" not in st.session_state: #this one is for the google sheet logging
138
+ st.session_state.history = []
139
+
140
+
141
+ if "chat_history" not in st.session_state: #this one is to pass previous messages into chat flow
142
+ st.session_state.chat_history = []
143
+ # if "session_rating" not in st.session_state:
144
+ # st.session_state.session_rating = 0
145
+
146
+
147
+ credentials= json.loads(st.secrets['google_sheet_credential'])
148
+
149
+ service_account = gspread.service_account_from_dict(credentials)
150
+ workbook= service_account.open("aitGPT-qa-log")
151
+ worksheet_qa = workbook.worksheet("Sheet1")
152
+ worksheet_comment = workbook.worksheet("Sheet2")
153
+ datetime_format= "%Y-%m-%d %H:%M:%S"
154
+
155
+
156
+
157
+ load_scraped_web_info()
158
+ embedding_model = load_embedding_model()
159
+ vector_database = load_faiss_index()
160
+ llm_model = load_llm_model()
161
+ qa_retriever = load_retriever(llm= llm_model, db= vector_database)
162
+
163
+
164
+ print("all load done")
165
+
166
+
167
+
168
+
169
+
170
+
171
+
172
+
173
+ st.write("# aitGPT πŸ€– ")
174
+ st.markdown("""
175
+ #### The aitGPT project is a virtual assistant developed by the :green[Asian Institute of Technology] that contains a vast amount of information gathered from 205 AIT-related websites.
176
+ The goal of this chatbot is to provide an alternative way for applicants and current students to access information about the institute, including admission procedures, campus facilities, and more.
177
+ """)
178
+ st.write(' ⚠️ Please expect to wait **~ 10 - 20 seconds per question** as thi app is running on CPU against 3-billion-parameter LLM')
179
+
180
+ st.markdown("---")
181
+ st.write(" ")
182
+ st.write("""
183
+ ### ❔ Ask a question
184
+ """)
185
+
186
+
187
+ for chat in st.session_state.chat_history:
188
+ st_message(**chat)
189
+
190
+ query_input = st.text_area(label= 'What would you like to know about AIT?' , key = 'my_text_input')
191
+ generate_button = st.button(label = 'Ask question!')
192
+
193
+ if generate_button:
194
+ answer = retrieve_answer(query_input)
195
+ log = {"timestamp": datetime.datetime.now(),
196
+ "question":query_input,
197
+ "generated_answer": answer,
198
+ "rating":0 }
199
+
200
+ st.session_state.history.append(log)
201
+ update_worksheet_qa()
202
+ st.session_state.chat_history.append({"message": query_input, "is_user": True})
203
+ st.session_state.chat_history.append({"message": answer, "is_user": False})
204
+
205
+
206
+ clear_button = st.button("Start new convo",
207
+ on_click=clean_chat_history)
208
+
209
+
210
+ st.write(" ")
211
+ st.write(" ")
212
+
213
+ st.markdown("---")
214
+ st.write("""
215
+ ### πŸ’Œ Your voice matters
216
+ """)
217
+
218
+ feedback_input = st.text_area(label= 'please leave your feedback or any ideas to make this bot more knowledgeable and fun')
219
+ feedback_button = st.button(label = 'Submit feedback!')
220
+
221
+ if feedback_button:
222
+ update_worksheet_comment()
223
+