hotchpotch commited on
Commit
2569947
·
1 Parent(s): d8b8e7e
Files changed (1) hide show
  1. app.py +304 -4
app.py CHANGED
@@ -1,7 +1,307 @@
 
 
 
 
 
 
1
  import streamlit as st
2
  import os
3
 
4
- # 環境変数全てを整形して表示
5
- all_env_dict = os.environ
6
- for key in all_env_dict.keys():
7
- st.write(f"{key}: {all_env_dict[key]}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ streamlit run app.py --server.address 0.0.0.0
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
  import streamlit as st
8
  import os
9
 
10
+ import faiss
11
+ from sentence_transformers import SentenceTransformer
12
+ import torch
13
+ from openai import OpenAI
14
+ import streamlit as st
15
+ import pandas as pd
16
+ import os
17
+ from time import time
18
+ from datasets.download import DownloadManager
19
+ from datasets import load_dataset # type: ignore
20
+
21
+
22
+ WIKIPEDIA_JA_DS = "singletongue/wikipedia-utils"
23
+ WIKIPEDIA_JS_DS_NAME = "passages-c400-jawiki-20230403"
24
+ WIKIPEDIA_JA_EMB_DS = "hotchpotch/wikipedia-passages-jawiki-embeddings"
25
+
26
+ EMB_MODEL_PQ = {
27
+ "intfloat/multilingual-e5-small": 96,
28
+ "intfloat/multilingual-e5-base": 192,
29
+ "intfloat/multilingual-e5-large": 256,
30
+ "cl-nagoya/sup-simcse-ja-base": 192,
31
+ "pkshatech/GLuCoSE-base-ja": 192,
32
+ }
33
+
34
+ EMB_MODEL_NAMES = list(EMB_MODEL_PQ.keys())
35
+
36
+ OPENAI_MODEL_NAMES = [
37
+ "gpt-3.5-turbo-1106",
38
+ "gpt-4-1106-preview",
39
+ ]
40
+
41
+ E5_QUERY_TYPES = [
42
+ "passage",
43
+ "query",
44
+ ]
45
+
46
+ DEFAULT_QA_PROMPT = """
47
+ ## Instruction
48
+
49
+ Prepare an explanatory statement for the question, including as much detailed explanation as possible.
50
+ Avoid speculations or information not contained in the contexts. Heavily favor knowledge provided in the documents before falling back to baseline knowledge or other contexts. If searching the contexts didn"t yield any answer, just say that.
51
+
52
+ Responses must be given in Japanese.
53
+
54
+ ## Contexts
55
+
56
+ {contexts}
57
+
58
+ ## Question
59
+
60
+ {question}
61
+ """.strip()
62
+
63
+
64
+ if os.getenv("SPACE_ID"):
65
+ USE_HF_SPACE = True
66
+ os.environ["HF_HOME"] = "/data/.huggingface"
67
+ else:
68
+ USE_HF_SPACE = False
69
+
70
+ # for tokenizer
71
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
72
+
73
+ OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
74
+
75
+
76
+ @st.cache_data
77
+ def get_model(name: str, max_seq_length=512):
78
+ device = "cpu"
79
+ if torch.cuda.is_available():
80
+ device = "cuda"
81
+ elif torch.backends.mps.is_available():
82
+ device = "mps"
83
+ model = SentenceTransformer(name, device=device)
84
+ model.max_seq_length = max_seq_length
85
+ return model
86
+
87
+
88
+ @st.cache_data
89
+ def get_wikija_ds(name: str = WIKIPEDIA_JS_DS_NAME):
90
+ ds = load_dataset(path=WIKIPEDIA_JA_DS, name=name, split="train")
91
+ return ds
92
+
93
+
94
+ @st.cache_data
95
+ def get_faiss_index(
96
+ index_name: str, ja_emb_ds: str = WIKIPEDIA_JA_EMB_DS, name=WIKIPEDIA_JS_DS_NAME
97
+ ):
98
+ target_path = f"faiss_indexes/{name}/{index_name}"
99
+ dm = DownloadManager()
100
+ index_local_path = dm.download(
101
+ f"https://huggingface.co/datasets/{ja_emb_ds}/resolve/main/{target_path}"
102
+ )
103
+ index = faiss.read_index(index_local_path)
104
+ index.nprobe = 128
105
+ return index
106
+
107
+
108
+ def text_to_emb(model, text: str, prefix: str):
109
+ return model.encode([prefix + text], normalize_embeddings=True)
110
+
111
+
112
+ def search(
113
+ faiss_index, emb_model, ds, question: str, search_text_prefix: str, top_k: int
114
+ ):
115
+ start_time = time()
116
+ emb = text_to_emb(emb_model, question, search_text_prefix)
117
+ emb_exec_time = time() - start_time
118
+ scores, indexes = faiss_index.search(emb, top_k)
119
+ faiss_seartch_time = time() - emb_exec_time - start_time
120
+ scores = scores[0]
121
+ indexes = indexes[0]
122
+ results = []
123
+ for idx, score in zip(indexes, scores): # type: ignore
124
+ idx = int(idx)
125
+ passage = ds[idx]
126
+ results.append((score, passage))
127
+ return results, emb_exec_time, faiss_seartch_time
128
+
129
+
130
+ def to_contexts(passages):
131
+ contexts = ""
132
+ for passage in passages:
133
+ title = passage["title"]
134
+ text = passage["text"]
135
+ # section = passage["section"]
136
+ contexts += f"- {title}: {text}\n"
137
+ return contexts
138
+
139
+
140
+ def qa(
141
+ question: str,
142
+ passages: list,
143
+ model_name: str,
144
+ temperature: int,
145
+ qa_prompt: str,
146
+ max_tokens=2000,
147
+ ):
148
+ client = OpenAI()
149
+ contexts = to_contexts(passages)
150
+ prompt = qa_prompt.format(contexts=contexts, question=question)
151
+ response = client.chat.completions.create(
152
+ model=model_name,
153
+ messages=[
154
+ {"role": "user", "content": prompt},
155
+ ],
156
+ stream=True,
157
+ temperature=temperature,
158
+ max_tokens=max_tokens,
159
+ seed=42,
160
+ )
161
+ for chunk in response:
162
+ delta = chunk.choices[0].delta
163
+ yield delta.content or ""
164
+
165
+
166
+ def generate_answer(
167
+ buf, question, passages, model_name, temperature, qa_prompt, max_tokens
168
+ ):
169
+ buf.write("⏳回答の生成中...")
170
+ texts = ""
171
+ for char in qa(
172
+ question=question,
173
+ passages=passages,
174
+ model_name=model_name,
175
+ temperature=temperature,
176
+ qa_prompt=qa_prompt,
177
+ ):
178
+ texts += char
179
+ buf.write(texts)
180
+
181
+
182
+ def to_df(scores, passages):
183
+ df = pd.DataFrame(passages)
184
+ df["text"] = df["text"]
185
+ df["score"] = scores
186
+ df_rows = ["score", "title", "text", "section"]
187
+ df = df[df_rows]
188
+ return df
189
+
190
+
191
+ def app():
192
+ st.title("Wikipedia 日本語 RAG 検索")
193
+ st.subheader("⭐️大元へのリンクを貼る")
194
+
195
+ st.text_area(
196
+ "Question",
197
+ key="question",
198
+ value="1975年に『アザミ嬢のララバイ』でデビューした女性歌手で、『わかれうた』『地上の星』などの曲を出しているのは誰?",
199
+ )
200
+ if not OPENAI_API_KEY:
201
+ st.text_input(
202
+ "OpenAI API Key",
203
+ key="openai_api_key",
204
+ type="password",
205
+ placeholder="※ API_KEY 未入力時は、回答生成せずに検索のみ",
206
+ )
207
+ else:
208
+ st.session_state.openai_api_key = OPENAI_API_KEY
209
+
210
+ with st.expander("オプション"):
211
+ option_cols_main = st.columns(2)
212
+ with option_cols_main[0]:
213
+ st.selectbox("Emb Model", EMB_MODEL_NAMES, index=0, key="emb_model_name")
214
+ with option_cols_main[1]:
215
+ st.selectbox(
216
+ "OpenAI Model", OPENAI_MODEL_NAMES, index=0, key="openai_model_name"
217
+ )
218
+ emb_model_name = st.session_state.emb_model_name
219
+ option_cols_sub = st.columns(2)
220
+ with option_cols_sub[0]:
221
+ st.number_input("Top K", value=5, key="top_k", min_value=1, max_value=20)
222
+ with option_cols_sub[1]:
223
+ if "-e5-" in emb_model_name:
224
+ st.radio(
225
+ "Passage or Query (only e5)",
226
+ E5_QUERY_TYPES,
227
+ index=0,
228
+ key="e5_query_or_passage",
229
+ horizontal=True,
230
+ )
231
+ e5_query_or_passage = st.session_state.e5_query_or_passage
232
+ index_emb_model_name = (
233
+ f"{emb_model_name.split('/')[-1]}-{e5_query_or_passage}"
234
+ )
235
+ search_text_prefix = f"{e5_query_or_passage}: "
236
+ else:
237
+ index_emb_model_name = emb_model_name.split("/")[-1]
238
+ search_text_prefix = ""
239
+ option_cols = st.columns(3)
240
+ with option_cols[0]:
241
+ st.slider("Temperature", 0.0, 1.0, value=0.8, key="temperature")
242
+ with option_cols[1]:
243
+ st.slider("nprobe", 16, 1024, value=128, key="nprobe")
244
+ with option_cols[2]:
245
+ st.number_input(
246
+ "max_tokens", value=2000, key="max_tokens", min_value=1, max_value=16000
247
+ )
248
+ st.text_area("QA Prompt", value=DEFAULT_QA_PROMPT, key="qa_prompt")
249
+
250
+ loading_placeholder = st.empty()
251
+ loading_placeholder.text("⏳ Loading - Embedding Model...")
252
+ emb_model = get_model(st.session_state.emb_model_name)
253
+ loading_placeholder.text("⏳ Loading - Faiss Index...")
254
+ emb_model_pq = EMB_MODEL_PQ[emb_model_name]
255
+ index_name = f"{index_emb_model_name}/index_IVF2048_PQ{emb_model_pq}.faiss"
256
+ faiss_index = get_faiss_index(index_name=index_name)
257
+ faiss_index.nprobe = st.session_state.nprobe
258
+ loading_placeholder.text("⏳ Loading - Huggingface Dataset...")
259
+ ds = get_wikija_ds()
260
+ loading_placeholder.empty()
261
+
262
+ if st.button("Search"):
263
+ answer_header = st.empty()
264
+ answer_text_buffer = st.empty()
265
+
266
+ question = st.session_state.question
267
+ top_k = st.session_state.top_k
268
+ scores = []
269
+ passages = []
270
+ search_results, emb_exec_time, faiss_seartch_time = search(
271
+ faiss_index,
272
+ emb_model,
273
+ ds,
274
+ question,
275
+ search_text_prefix=search_text_prefix,
276
+ top_k=top_k,
277
+ )
278
+ st.subheader("Search Results: ")
279
+ st.write(
280
+ f"⏱️ generate embedding: {emb_exec_time*1000:.2f}ms / faiss search: {faiss_seartch_time*1000:.2f}ms"
281
+ )
282
+ for score, passage in search_results:
283
+ scores.append(score)
284
+ passages.append(passage)
285
+ df = to_df(scores, passages)
286
+ st.dataframe(df, hide_index=True)
287
+
288
+ openai_api_key = st.session_state.openai_api_key
289
+ if openai_api_key:
290
+ answer_header.subheader("Answer: ")
291
+ openai_model_name = st.session_state.openai_model_name
292
+ temperature = st.session_state.temperature
293
+ qa_prompt = st.session_state.qa_prompt
294
+ max_tokens = st.session_state.max_tokens
295
+ generate_answer(
296
+ buf=answer_text_buffer,
297
+ question=question,
298
+ passages=passages,
299
+ model_name=openai_model_name,
300
+ temperature=temperature,
301
+ qa_prompt=qa_prompt,
302
+ max_tokens=max_tokens,
303
+ )
304
+
305
+
306
+ if __name__ == "__main__":
307
+ app()