Spaces:
Runtime error
Runtime error
Commit
·
14df537
1
Parent(s):
f98bd2b
demo_dpr update
Browse files- demo_dpr.py +30 -6
demo_dpr.py
CHANGED
@@ -20,11 +20,13 @@ DATAFRAME_FILE_BSBS = 'basecamp.csv'
|
|
20 |
selectbox_selections = {
|
21 |
'Retrieve - Rerank (with fine-tuned cross-encoder)': 1,
|
22 |
'Dense Passage Retrieval':2,
|
|
|
23 |
'Retrieve - Rerank':4
|
24 |
}
|
25 |
imagebox_selections = {
|
26 |
'Retrieve - Rerank (with fine-tuned cross-encoder)': 'Retrieve-rerank-trained-cross-encoder.png',
|
27 |
'Dense Passage Retrieval': 'DPR_pipeline.png',
|
|
|
28 |
'Retrieve - Rerank': 'retrieve-rerank.png'
|
29 |
}
|
30 |
|
@@ -71,7 +73,7 @@ def load_paragraphs(path):
|
|
71 |
def load_dataframes():
|
72 |
# data_original = pd.read_csv(DATAFRAME_FILE_ORIGINAL, index_col=0, sep='|')
|
73 |
data_bsbs = pd.read_csv(DATAFRAME_FILE_BSBS, index_col=0, sep='|')
|
74 |
-
data_bsbs.drop('context_id', axis=1, inplace=True)
|
75 |
# data_original = data_original.sample(frac=1).reset_index(drop=True)
|
76 |
data_bsbs = data_bsbs.sample(frac=1).reset_index(drop=True)
|
77 |
return data_bsbs
|
@@ -82,11 +84,31 @@ def dot_product(question_output, context_output):
|
|
82 |
result = torch.dot(mat1, mat2)
|
83 |
return result
|
84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
def search_pipeline(question, search_method):
|
86 |
if search_method == 1: #Retrieve - rerank with fine-tuned cross encoder
|
87 |
return retrieve_rerank_with_trained_cross_encoder(question)
|
88 |
if search_method == 2:
|
89 |
return custom_dpr_pipeline(question) # DPR only
|
|
|
|
|
90 |
if search_method == 4:
|
91 |
return retrieve_rerank(question)
|
92 |
|
@@ -213,8 +235,8 @@ def qa_main_widgetsv2():
|
|
213 |
st.write(selection['context'])
|
214 |
st.markdown("### Question:")
|
215 |
st.write(selection['question'])
|
216 |
-
|
217 |
-
|
218 |
st.session_state.grid_click_2 = False
|
219 |
|
220 |
@st.cache(show_spinner=False, allow_output_mutation = True)
|
@@ -226,15 +248,17 @@ def load_models(dpr_model_path, auth_token, cross_encoder_model_path):
|
|
226 |
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
227 |
bi_encoder.max_seq_length = 500
|
228 |
trained_cross_encoder = CrossEncoder(cross_encoder_model_path)
|
|
|
229 |
question_tokenizer = AutoTokenizer.from_pretrained('facebook/contriever-msmarco')
|
230 |
-
return dpr_trained, bi_encoder, cross_encoder, trained_cross_encoder, question_tokenizer
|
231 |
|
232 |
context_embeddings, contexes = load_paragraphs('st-context-embeddings.pkl')
|
233 |
dpr_context_embeddings, dpr_contexes = load_paragraphs('basecamp-dpr-contriever-embeddings.pkl')
|
|
|
234 |
dataframe_bsbs = load_dataframes()
|
235 |
-
dpr_trained, bi_encoder, cross_encoder, trained_cross_encoder, question_tokenizer = copy.deepcopy(load_models(st.secrets["DPR_MODEL_PATH"], st.secrets["AUTH_TOKEN"], st.secrets["CROSS_ENCODER_MODEL_PATH"]))
|
236 |
qa_main_widgetsv2()
|
237 |
|
238 |
#if __name__ == '__main__':
|
239 |
-
# top_5_contexes, top_5_scores = search_pipeline('What
|
240 |
|
|
|
20 |
selectbox_selections = {
|
21 |
'Retrieve - Rerank (with fine-tuned cross-encoder)': 1,
|
22 |
'Dense Passage Retrieval':2,
|
23 |
+
# 'Base Dense Passage Retrieval': 3,
|
24 |
'Retrieve - Rerank':4
|
25 |
}
|
26 |
imagebox_selections = {
|
27 |
'Retrieve - Rerank (with fine-tuned cross-encoder)': 'Retrieve-rerank-trained-cross-encoder.png',
|
28 |
'Dense Passage Retrieval': 'DPR_pipeline.png',
|
29 |
+
'Base Dense Passage Retrieval': 'base-dpr.png',
|
30 |
'Retrieve - Rerank': 'retrieve-rerank.png'
|
31 |
}
|
32 |
|
|
|
73 |
def load_dataframes():
|
74 |
# data_original = pd.read_csv(DATAFRAME_FILE_ORIGINAL, index_col=0, sep='|')
|
75 |
data_bsbs = pd.read_csv(DATAFRAME_FILE_BSBS, index_col=0, sep='|')
|
76 |
+
data_bsbs.drop(['context_id', 'answer', 'answer_start', 'answer_end'], axis=1, inplace=True)
|
77 |
# data_original = data_original.sample(frac=1).reset_index(drop=True)
|
78 |
data_bsbs = data_bsbs.sample(frac=1).reset_index(drop=True)
|
79 |
return data_bsbs
|
|
|
84 |
result = torch.dot(mat1, mat2)
|
85 |
return result
|
86 |
|
87 |
+
def base_dpr_pipeline(question):
|
88 |
+
tokenized_question = question_tokenizer(question, padding=True, truncation=True, return_tensors="pt")
|
89 |
+
question_embedding = base_dpr_context_encoder(**tokenized_question)
|
90 |
+
question_embedding = mean_pooling(question_embedding[0], tokenized_question['attention_mask'])
|
91 |
+
# question_embedding = question_embedding['pooler_output']
|
92 |
+
results_list = []
|
93 |
+
for i, context_embedding in enumerate(base_dpr_context_embeddings):
|
94 |
+
score = dot_product(question_embedding, context_embedding)
|
95 |
+
results_list.append(score.detach().cpu())
|
96 |
+
|
97 |
+
hits = sorted(range(len(results_list)), key=lambda i: results_list[i], reverse=True)
|
98 |
+
top_5_contexes = []
|
99 |
+
top_5_scores = []
|
100 |
+
for j in hits[0:5]:
|
101 |
+
top_5_contexes.append(base_contexes[j])
|
102 |
+
top_5_scores.append(results_list[j])
|
103 |
+
return top_5_contexes, top_5_scores
|
104 |
+
|
105 |
def search_pipeline(question, search_method):
|
106 |
if search_method == 1: #Retrieve - rerank with fine-tuned cross encoder
|
107 |
return retrieve_rerank_with_trained_cross_encoder(question)
|
108 |
if search_method == 2:
|
109 |
return custom_dpr_pipeline(question) # DPR only
|
110 |
+
# if search_method == 3:
|
111 |
+
# return base_dpr_pipeline(question) # DPR only
|
112 |
if search_method == 4:
|
113 |
return retrieve_rerank(question)
|
114 |
|
|
|
235 |
st.write(selection['context'])
|
236 |
st.markdown("### Question:")
|
237 |
st.write(selection['question'])
|
238 |
+
# st.markdown("### Answer:")
|
239 |
+
# st.write(selection['answer'])
|
240 |
st.session_state.grid_click_2 = False
|
241 |
|
242 |
@st.cache(show_spinner=False, allow_output_mutation = True)
|
|
|
248 |
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
249 |
bi_encoder.max_seq_length = 500
|
250 |
trained_cross_encoder = CrossEncoder(cross_encoder_model_path)
|
251 |
+
base_dpr_context_encoder = AutoModel.from_pretrained('facebook/contriever-msmarco')
|
252 |
question_tokenizer = AutoTokenizer.from_pretrained('facebook/contriever-msmarco')
|
253 |
+
return dpr_trained, bi_encoder, cross_encoder, trained_cross_encoder, question_tokenizer, base_dpr_context_encoder
|
254 |
|
255 |
context_embeddings, contexes = load_paragraphs('st-context-embeddings.pkl')
|
256 |
dpr_context_embeddings, dpr_contexes = load_paragraphs('basecamp-dpr-contriever-embeddings.pkl')
|
257 |
+
base_dpr_context_embeddings, base_contexes = load_paragraphs('basecamp-base-dpr-contriever-embeddings.pkl')
|
258 |
dataframe_bsbs = load_dataframes()
|
259 |
+
dpr_trained, bi_encoder, cross_encoder, trained_cross_encoder, question_tokenizer, base_dpr_context_encoder = copy.deepcopy(load_models(st.secrets["DPR_MODEL_PATH"], st.secrets["AUTH_TOKEN"], st.secrets["CROSS_ENCODER_MODEL_PATH"]))
|
260 |
qa_main_widgetsv2()
|
261 |
|
262 |
#if __name__ == '__main__':
|
263 |
+
# top_5_contexes, top_5_scores = search_pipeline('What contributions does 37Signals make to open-source projects?', 3)
|
264 |
|