speed up calculate score
Browse files- src/utils/paper_client.py +15 -0
- src/utils/paper_retriever.py +9 -8
src/utils/paper_client.py
CHANGED
@@ -79,6 +79,21 @@ class PaperClient:
|
|
79 |
logger.error(f"paper id {paper_id} get {attribute_name} failed.")
|
80 |
return None
|
81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
def get_paper_by_attribute(self, attribute_name, anttribute_value):
|
83 |
query = f"""
|
84 |
MATCH (p:Paper {{{attribute_name}: '{anttribute_value}'}})
|
|
|
79 |
logger.error(f"paper id {paper_id} get {attribute_name} failed.")
|
80 |
return None
|
81 |
|
82 |
+
def get_papers_attribute(self, paper_id_list, attribute_name):
|
83 |
+
query = """
|
84 |
+
UNWIND $paper_ids AS paper_id
|
85 |
+
MATCH (p:Paper {hash_id: paper_id})
|
86 |
+
RETURN p.hash_id AS hash_id, p[$attribute_name] AS attributeValue
|
87 |
+
"""
|
88 |
+
with self.driver.session() as session:
|
89 |
+
result = session.execute_read(
|
90 |
+
lambda tx: tx.run(
|
91 |
+
query, paper_ids=paper_id_list, attribute_name=attribute_name
|
92 |
+
).data()
|
93 |
+
)
|
94 |
+
paper_attributes = [record["attributeValue"] for record in result]
|
95 |
+
return paper_attributes
|
96 |
+
|
97 |
def get_paper_by_attribute(self, attribute_name, anttribute_value):
|
98 |
query = f"""
|
99 |
MATCH (p:Paper {{{attribute_name}: '{anttribute_value}'}})
|
src/utils/paper_retriever.py
CHANGED
@@ -184,12 +184,11 @@ class Retriever(object):
|
|
184 |
self, embedding, related_paper_id_list, type_name="embedding"
|
185 |
):
|
186 |
score_1 = np.zeros((len(related_paper_id_list)))
|
187 |
-
score_2 = np.zeros((len(related_paper_id_list)))
|
188 |
origin_vector = torch.tensor(embedding).to(self.device).unsqueeze(0)
|
189 |
-
context_embeddings =
|
190 |
-
|
191 |
-
|
192 |
-
]
|
193 |
if len(context_embeddings) > 0:
|
194 |
context_embeddings = torch.tensor(context_embeddings).to(self.device)
|
195 |
score_1 = torch.nn.functional.cosine_similarity(
|
@@ -198,8 +197,9 @@ class Retriever(object):
|
|
198 |
score_1 = score_1.cpu().numpy()
|
199 |
if self.config.RETRIEVE.need_normalize:
|
200 |
score_1 = score_1 / np.max(score_1)
|
201 |
-
|
202 |
-
score_en_dict = dict(zip(related_paper_id_list, score_2))
|
|
|
203 |
score_all_dict = dict(
|
204 |
zip(
|
205 |
related_paper_id_list,
|
@@ -207,7 +207,8 @@ class Retriever(object):
|
|
207 |
+ score_2 * self.config.RETRIEVE.beta,
|
208 |
)
|
209 |
)
|
210 |
-
|
|
|
211 |
|
212 |
def filter_related_paper(self, score_dict, top_k):
|
213 |
if len(score_dict) <= top_k:
|
|
|
184 |
self, embedding, related_paper_id_list, type_name="embedding"
|
185 |
):
|
186 |
score_1 = np.zeros((len(related_paper_id_list)))
|
187 |
+
# score_2 = np.zeros((len(related_paper_id_list)))
|
188 |
origin_vector = torch.tensor(embedding).to(self.device).unsqueeze(0)
|
189 |
+
context_embeddings = self.paper_client.get_papers_attribute(
|
190 |
+
related_paper_id_list, type_name
|
191 |
+
)
|
|
|
192 |
if len(context_embeddings) > 0:
|
193 |
context_embeddings = torch.tensor(context_embeddings).to(self.device)
|
194 |
score_1 = torch.nn.functional.cosine_similarity(
|
|
|
197 |
score_1 = score_1.cpu().numpy()
|
198 |
if self.config.RETRIEVE.need_normalize:
|
199 |
score_1 = score_1 / np.max(score_1)
|
200 |
+
score_all_dict = dict(zip(related_paper_id_list, score_1))
|
201 |
+
# score_en_dict = dict(zip(related_paper_id_list, score_2))
|
202 |
+
"""
|
203 |
score_all_dict = dict(
|
204 |
zip(
|
205 |
related_paper_id_list,
|
|
|
207 |
+ score_2 * self.config.RETRIEVE.beta,
|
208 |
)
|
209 |
)
|
210 |
+
"""
|
211 |
+
return {}, {}, score_all_dict
|
212 |
|
213 |
def filter_related_paper(self, score_dict, top_k):
|
214 |
if len(score_dict) <= top_k:
|