lihuigu commited on
Commit
8a27036
·
1 Parent(s): f00c2f9

speed up calculate score

Browse files
src/utils/paper_client.py CHANGED
@@ -79,6 +79,21 @@ class PaperClient:
79
  logger.error(f"paper id {paper_id} get {attribute_name} failed.")
80
  return None
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  def get_paper_by_attribute(self, attribute_name, anttribute_value):
83
  query = f"""
84
  MATCH (p:Paper {{{attribute_name}: '{anttribute_value}'}})
 
79
  logger.error(f"paper id {paper_id} get {attribute_name} failed.")
80
  return None
81
 
82
+ def get_papers_attribute(self, paper_id_list, attribute_name):
83
+ query = """
84
+ UNWIND $paper_ids AS paper_id
85
+ MATCH (p:Paper {hash_id: paper_id})
86
+ RETURN p.hash_id AS hash_id, p[$attribute_name] AS attributeValue
87
+ """
88
+ with self.driver.session() as session:
89
+ result = session.execute_read(
90
+ lambda tx: tx.run(
91
+ query, paper_ids=paper_id_list, attribute_name=attribute_name
92
+ ).data()
93
+ )
94
+ paper_attributes = [record["attributeValue"] for record in result]
95
+ return paper_attributes
96
+
97
  def get_paper_by_attribute(self, attribute_name, anttribute_value):
98
  query = f"""
99
  MATCH (p:Paper {{{attribute_name}: '{anttribute_value}'}})
src/utils/paper_retriever.py CHANGED
@@ -184,12 +184,11 @@ class Retriever(object):
184
  self, embedding, related_paper_id_list, type_name="embedding"
185
  ):
186
  score_1 = np.zeros((len(related_paper_id_list)))
187
- score_2 = np.zeros((len(related_paper_id_list)))
188
  origin_vector = torch.tensor(embedding).to(self.device).unsqueeze(0)
189
- context_embeddings = [
190
- self.paper_client.get_paper_attribute(paper_id, type_name)
191
- for paper_id in related_paper_id_list
192
- ]
193
  if len(context_embeddings) > 0:
194
  context_embeddings = torch.tensor(context_embeddings).to(self.device)
195
  score_1 = torch.nn.functional.cosine_similarity(
@@ -198,8 +197,9 @@ class Retriever(object):
198
  score_1 = score_1.cpu().numpy()
199
  if self.config.RETRIEVE.need_normalize:
200
  score_1 = score_1 / np.max(score_1)
201
- score_sn_dict = dict(zip(related_paper_id_list, score_1))
202
- score_en_dict = dict(zip(related_paper_id_list, score_2))
 
203
  score_all_dict = dict(
204
  zip(
205
  related_paper_id_list,
@@ -207,7 +207,8 @@ class Retriever(object):
207
  + score_2 * self.config.RETRIEVE.beta,
208
  )
209
  )
210
- return score_sn_dict, score_en_dict, score_all_dict
 
211
 
212
  def filter_related_paper(self, score_dict, top_k):
213
  if len(score_dict) <= top_k:
 
184
  self, embedding, related_paper_id_list, type_name="embedding"
185
  ):
186
  score_1 = np.zeros((len(related_paper_id_list)))
187
+ # score_2 = np.zeros((len(related_paper_id_list)))
188
  origin_vector = torch.tensor(embedding).to(self.device).unsqueeze(0)
189
+ context_embeddings = self.paper_client.get_papers_attribute(
190
+ related_paper_id_list, type_name
191
+ )
 
192
  if len(context_embeddings) > 0:
193
  context_embeddings = torch.tensor(context_embeddings).to(self.device)
194
  score_1 = torch.nn.functional.cosine_similarity(
 
197
  score_1 = score_1.cpu().numpy()
198
  if self.config.RETRIEVE.need_normalize:
199
  score_1 = score_1 / np.max(score_1)
200
+ score_all_dict = dict(zip(related_paper_id_list, score_1))
201
+ # score_en_dict = dict(zip(related_paper_id_list, score_2))
202
+ """
203
  score_all_dict = dict(
204
  zip(
205
  related_paper_id_list,
 
207
  + score_2 * self.config.RETRIEVE.beta,
208
  )
209
  )
210
+ """
211
+ return {}, {}, score_all_dict
212
 
213
  def filter_related_paper(self, score_dict, top_k):
214
  if len(score_dict) <= top_k: