secilozksen commited on
Commit
14df537
·
1 Parent(s): f98bd2b

demo_dpr update

Browse files
Files changed (1) hide show
  1. demo_dpr.py +30 -6
demo_dpr.py CHANGED
@@ -20,11 +20,13 @@ DATAFRAME_FILE_BSBS = 'basecamp.csv'
20
  selectbox_selections = {
21
  'Retrieve - Rerank (with fine-tuned cross-encoder)': 1,
22
  'Dense Passage Retrieval':2,
 
23
  'Retrieve - Rerank':4
24
  }
25
  imagebox_selections = {
26
  'Retrieve - Rerank (with fine-tuned cross-encoder)': 'Retrieve-rerank-trained-cross-encoder.png',
27
  'Dense Passage Retrieval': 'DPR_pipeline.png',
 
28
  'Retrieve - Rerank': 'retrieve-rerank.png'
29
  }
30
 
@@ -71,7 +73,7 @@ def load_paragraphs(path):
71
  def load_dataframes():
72
  # data_original = pd.read_csv(DATAFRAME_FILE_ORIGINAL, index_col=0, sep='|')
73
  data_bsbs = pd.read_csv(DATAFRAME_FILE_BSBS, index_col=0, sep='|')
74
- data_bsbs.drop('context_id', axis=1, inplace=True)
75
  # data_original = data_original.sample(frac=1).reset_index(drop=True)
76
  data_bsbs = data_bsbs.sample(frac=1).reset_index(drop=True)
77
  return data_bsbs
@@ -82,11 +84,31 @@ def dot_product(question_output, context_output):
82
  result = torch.dot(mat1, mat2)
83
  return result
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  def search_pipeline(question, search_method):
86
  if search_method == 1: #Retrieve - rerank with fine-tuned cross encoder
87
  return retrieve_rerank_with_trained_cross_encoder(question)
88
  if search_method == 2:
89
  return custom_dpr_pipeline(question) # DPR only
 
 
90
  if search_method == 4:
91
  return retrieve_rerank(question)
92
 
@@ -213,8 +235,8 @@ def qa_main_widgetsv2():
213
  st.write(selection['context'])
214
  st.markdown("### Question:")
215
  st.write(selection['question'])
216
- st.markdown("### Answer:")
217
- st.write(selection['answer'])
218
  st.session_state.grid_click_2 = False
219
 
220
  @st.cache(show_spinner=False, allow_output_mutation = True)
@@ -226,15 +248,17 @@ def load_models(dpr_model_path, auth_token, cross_encoder_model_path):
226
  cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
227
  bi_encoder.max_seq_length = 500
228
  trained_cross_encoder = CrossEncoder(cross_encoder_model_path)
 
229
  question_tokenizer = AutoTokenizer.from_pretrained('facebook/contriever-msmarco')
230
- return dpr_trained, bi_encoder, cross_encoder, trained_cross_encoder, question_tokenizer
231
 
232
  context_embeddings, contexes = load_paragraphs('st-context-embeddings.pkl')
233
  dpr_context_embeddings, dpr_contexes = load_paragraphs('basecamp-dpr-contriever-embeddings.pkl')
 
234
  dataframe_bsbs = load_dataframes()
235
- dpr_trained, bi_encoder, cross_encoder, trained_cross_encoder, question_tokenizer = copy.deepcopy(load_models(st.secrets["DPR_MODEL_PATH"], st.secrets["AUTH_TOKEN"], st.secrets["CROSS_ENCODER_MODEL_PATH"]))
236
  qa_main_widgetsv2()
237
 
238
  #if __name__ == '__main__':
239
- # top_5_contexes, top_5_scores = search_pipeline('What are the benefits of 37Signals Visa Card?', 1)
240
 
 
20
  selectbox_selections = {
21
  'Retrieve - Rerank (with fine-tuned cross-encoder)': 1,
22
  'Dense Passage Retrieval':2,
23
+ # 'Base Dense Passage Retrieval': 3,
24
  'Retrieve - Rerank':4
25
  }
26
  imagebox_selections = {
27
  'Retrieve - Rerank (with fine-tuned cross-encoder)': 'Retrieve-rerank-trained-cross-encoder.png',
28
  'Dense Passage Retrieval': 'DPR_pipeline.png',
29
+ 'Base Dense Passage Retrieval': 'base-dpr.png',
30
  'Retrieve - Rerank': 'retrieve-rerank.png'
31
  }
32
 
 
73
  def load_dataframes():
74
  # data_original = pd.read_csv(DATAFRAME_FILE_ORIGINAL, index_col=0, sep='|')
75
  data_bsbs = pd.read_csv(DATAFRAME_FILE_BSBS, index_col=0, sep='|')
76
+ data_bsbs.drop(['context_id', 'answer', 'answer_start', 'answer_end'], axis=1, inplace=True)
77
  # data_original = data_original.sample(frac=1).reset_index(drop=True)
78
  data_bsbs = data_bsbs.sample(frac=1).reset_index(drop=True)
79
  return data_bsbs
 
84
  result = torch.dot(mat1, mat2)
85
  return result
86
 
87
+ def base_dpr_pipeline(question):
88
+ tokenized_question = question_tokenizer(question, padding=True, truncation=True, return_tensors="pt")
89
+ question_embedding = base_dpr_context_encoder(**tokenized_question)
90
+ question_embedding = mean_pooling(question_embedding[0], tokenized_question['attention_mask'])
91
+ # question_embedding = question_embedding['pooler_output']
92
+ results_list = []
93
+ for i, context_embedding in enumerate(base_dpr_context_embeddings):
94
+ score = dot_product(question_embedding, context_embedding)
95
+ results_list.append(score.detach().cpu())
96
+
97
+ hits = sorted(range(len(results_list)), key=lambda i: results_list[i], reverse=True)
98
+ top_5_contexes = []
99
+ top_5_scores = []
100
+ for j in hits[0:5]:
101
+ top_5_contexes.append(base_contexes[j])
102
+ top_5_scores.append(results_list[j])
103
+ return top_5_contexes, top_5_scores
104
+
105
  def search_pipeline(question, search_method):
106
  if search_method == 1: #Retrieve - rerank with fine-tuned cross encoder
107
  return retrieve_rerank_with_trained_cross_encoder(question)
108
  if search_method == 2:
109
  return custom_dpr_pipeline(question) # DPR only
110
+ # if search_method == 3:
111
+ # return base_dpr_pipeline(question) # DPR only
112
  if search_method == 4:
113
  return retrieve_rerank(question)
114
 
 
235
  st.write(selection['context'])
236
  st.markdown("### Question:")
237
  st.write(selection['question'])
238
+ # st.markdown("### Answer:")
239
+ # st.write(selection['answer'])
240
  st.session_state.grid_click_2 = False
241
 
242
  @st.cache(show_spinner=False, allow_output_mutation = True)
 
248
  cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
249
  bi_encoder.max_seq_length = 500
250
  trained_cross_encoder = CrossEncoder(cross_encoder_model_path)
251
+ base_dpr_context_encoder = AutoModel.from_pretrained('facebook/contriever-msmarco')
252
  question_tokenizer = AutoTokenizer.from_pretrained('facebook/contriever-msmarco')
253
+ return dpr_trained, bi_encoder, cross_encoder, trained_cross_encoder, question_tokenizer, base_dpr_context_encoder
254
 
255
  context_embeddings, contexes = load_paragraphs('st-context-embeddings.pkl')
256
  dpr_context_embeddings, dpr_contexes = load_paragraphs('basecamp-dpr-contriever-embeddings.pkl')
257
+ base_dpr_context_embeddings, base_contexes = load_paragraphs('basecamp-base-dpr-contriever-embeddings.pkl')
258
  dataframe_bsbs = load_dataframes()
259
+ dpr_trained, bi_encoder, cross_encoder, trained_cross_encoder, question_tokenizer, base_dpr_context_encoder = copy.deepcopy(load_models(st.secrets["DPR_MODEL_PATH"], st.secrets["AUTH_TOKEN"], st.secrets["CROSS_ENCODER_MODEL_PATH"]))
260
  qa_main_widgetsv2()
261
 
262
  #if __name__ == '__main__':
263
+ # top_5_contexes, top_5_scores = search_pipeline('What contributions does 37Signals make to open-source projects?', 3)
264