reduce neo4j query time in retrieve
Browse files- app.py +19 -12
- src/app_pages/button_interface.py +21 -10
- src/generator.py +99 -93
- src/paper_manager.py +7 -6
- src/retriever.py +2 -2
- src/utils/api/__init__.py +2 -0
- src/utils/api/base_helper.py +70 -18
- src/utils/api/local_helper.py +39 -0
- src/utils/hash.py +35 -10
- src/utils/llms_api.py +31 -26
- src/utils/paper_client.py +480 -200
- src/utils/paper_retriever.py +12 -21
app.py
CHANGED
@@ -1,25 +1,32 @@
|
|
1 |
import sys
|
|
|
2 |
sys.path.append("./src")
|
3 |
import streamlit as st
|
4 |
-
from app_pages import
|
|
|
|
|
|
|
|
|
|
|
5 |
from app_pages.locale import _
|
6 |
-
from utils.hash import check_env, check_embedding
|
7 |
|
8 |
if __name__ == "__main__":
|
9 |
-
check_env()
|
10 |
-
check_embedding()
|
11 |
backend = button_interface.Backend()
|
12 |
# backend = None
|
13 |
st.set_page_config(layout="wide")
|
14 |
-
|
15 |
-
|
16 |
def fn1():
|
17 |
one_click_generation.one_click_generation(backend)
|
|
|
18 |
def fn2():
|
19 |
step_by_step_generation.step_by_step_generation(backend)
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
|
1 |
import sys
|
2 |
+
|
3 |
sys.path.append("./src")
|
4 |
import streamlit as st
|
5 |
+
from app_pages import (
|
6 |
+
button_interface,
|
7 |
+
step_by_step_generation,
|
8 |
+
one_click_generation,
|
9 |
+
homepage,
|
10 |
+
)
|
11 |
from app_pages.locale import _
|
|
|
12 |
|
13 |
if __name__ == "__main__":
|
|
|
|
|
14 |
backend = button_interface.Backend()
|
15 |
# backend = None
|
16 |
st.set_page_config(layout="wide")
|
17 |
+
|
18 |
+
# st.logo("./assets/pic/logo.jpg", size="large")
|
19 |
def fn1():
|
20 |
one_click_generation.one_click_generation(backend)
|
21 |
+
|
22 |
def fn2():
|
23 |
step_by_step_generation.step_by_step_generation(backend)
|
24 |
+
|
25 |
+
pg = st.navigation(
|
26 |
+
[
|
27 |
+
st.Page(homepage.home_page, title=_("🏠️ Homepage")),
|
28 |
+
st.Page(fn1, title=_("💧 One-click Generation")),
|
29 |
+
st.Page(fn2, title=_("💦 Step-by-step Generation")),
|
30 |
+
]
|
31 |
+
)
|
32 |
+
pg.run()
|
src/app_pages/button_interface.py
CHANGED
@@ -2,8 +2,10 @@ import json
|
|
2 |
from utils.paper_retriever import RetrieverFactory
|
3 |
from utils.llms_api import APIHelper
|
4 |
from utils.header import ConfigReader
|
|
|
5 |
from generator import IdeaGenerator
|
6 |
|
|
|
7 |
class Backend(object):
|
8 |
def __init__(self) -> None:
|
9 |
CONFIG_PATH = "./configs/datasets.yaml"
|
@@ -12,11 +14,14 @@ class Backend(object):
|
|
12 |
BRAINSTORM_MODE = "mode_c"
|
13 |
|
14 |
self.config = ConfigReader.load(CONFIG_PATH)
|
|
|
|
|
15 |
RETRIEVER_NAME = self.config.RETRIEVE.retriever_name
|
16 |
self.api_helper = APIHelper(self.config)
|
17 |
-
self.retriever_factory =
|
18 |
-
|
19 |
-
|
|
|
20 |
)
|
21 |
self.idea_generator = IdeaGenerator(self.config, None)
|
22 |
self.use_inspiration = USE_INSPIRATION
|
@@ -33,14 +38,14 @@ class Backend(object):
|
|
33 |
return []
|
34 |
|
35 |
def background2brainstorm_callback(self, background, json_strs=None):
|
36 |
-
if json_strs is not None:
|
37 |
json_contents = json.loads(json_strs)
|
38 |
return json_contents["brainstorm"]
|
39 |
else:
|
40 |
return self.api_helper.generate_brainstorm(background)
|
41 |
|
42 |
def brainstorm2entities_callback(self, background, brainstorm, json_strs=None):
|
43 |
-
if json_strs is not None:
|
44 |
json_contents = json.loads(json_strs)
|
45 |
entities_bg = json_contents["entities_bg"]
|
46 |
entities_bs = json_contents["entities_bs"]
|
@@ -71,13 +76,17 @@ class Backend(object):
|
|
71 |
for i, p in enumerate(result["related_paper"]):
|
72 |
res.append(str(p))
|
73 |
else:
|
74 |
-
result = self.retriever_factory.retrieve(
|
|
|
|
|
75 |
res = []
|
76 |
for i, p in enumerate(result["related_paper"]):
|
77 |
res.append(f'{p["title"]}. {p["venue_name"].upper()} {p["year"]}.')
|
78 |
return res, result["related_paper"]
|
79 |
|
80 |
-
def literature2initial_ideas_callback(
|
|
|
|
|
81 |
if json_strs is not None:
|
82 |
json_contents = json.loads(json_strs)
|
83 |
return json_contents["median"]["initial_idea"]
|
@@ -86,15 +95,16 @@ class Backend(object):
|
|
86 |
self.idea_generator.brainstorm = brainstorms
|
87 |
if self.use_inspiration:
|
88 |
message_input, idea_modified, median = (
|
89 |
-
|
90 |
-
|
|
|
91 |
)
|
92 |
else:
|
93 |
message_input, idea_modified, median = self.idea_generator.generate(
|
94 |
background, "new_idea", self.brainstorm_mode, False
|
95 |
)
|
96 |
return median["initial_idea"], idea_modified
|
97 |
-
|
98 |
def initial2final_callback(self, initial_ideas, final_ideas, json_strs=None):
|
99 |
if json_strs is not None:
|
100 |
json_contents = json.loads(json_strs)
|
@@ -107,6 +117,7 @@ class Backend(object):
|
|
107 |
return self.examples[i].get("background", "Background not found.")
|
108 |
else:
|
109 |
return "Example not found. Please select a valid index."
|
|
|
110 |
# return ("The application scope of large-scale language models such as GPT-4 and LLaMA "
|
111 |
# "has rapidly expanded, demonstrating powerful capabilities in natural language processing "
|
112 |
# "and multimodal tasks. However, as the size and complexity of the models increase, understanding "
|
|
|
2 |
from utils.paper_retriever import RetrieverFactory
|
3 |
from utils.llms_api import APIHelper
|
4 |
from utils.header import ConfigReader
|
5 |
+
from utils.hash import check_env, check_embedding
|
6 |
from generator import IdeaGenerator
|
7 |
|
8 |
+
|
9 |
class Backend(object):
|
10 |
def __init__(self) -> None:
|
11 |
CONFIG_PATH = "./configs/datasets.yaml"
|
|
|
14 |
BRAINSTORM_MODE = "mode_c"
|
15 |
|
16 |
self.config = ConfigReader.load(CONFIG_PATH)
|
17 |
+
check_env()
|
18 |
+
check_embedding(self.config.DEFAULT.embedding)
|
19 |
RETRIEVER_NAME = self.config.RETRIEVE.retriever_name
|
20 |
self.api_helper = APIHelper(self.config)
|
21 |
+
self.retriever_factory = (
|
22 |
+
RetrieverFactory.get_retriever_factory().create_retriever(
|
23 |
+
RETRIEVER_NAME, self.config
|
24 |
+
)
|
25 |
)
|
26 |
self.idea_generator = IdeaGenerator(self.config, None)
|
27 |
self.use_inspiration = USE_INSPIRATION
|
|
|
38 |
return []
|
39 |
|
40 |
def background2brainstorm_callback(self, background, json_strs=None):
|
41 |
+
if json_strs is not None: # only for DEBUG_MODE
|
42 |
json_contents = json.loads(json_strs)
|
43 |
return json_contents["brainstorm"]
|
44 |
else:
|
45 |
return self.api_helper.generate_brainstorm(background)
|
46 |
|
47 |
def brainstorm2entities_callback(self, background, brainstorm, json_strs=None):
|
48 |
+
if json_strs is not None: # only for DEBUG_MODE
|
49 |
json_contents = json.loads(json_strs)
|
50 |
entities_bg = json_contents["entities_bg"]
|
51 |
entities_bs = json_contents["entities_bs"]
|
|
|
76 |
for i, p in enumerate(result["related_paper"]):
|
77 |
res.append(str(p))
|
78 |
else:
|
79 |
+
result = self.retriever_factory.retrieve(
|
80 |
+
background, entities, need_evaluate=False, target_paper_id_list=[]
|
81 |
+
)
|
82 |
res = []
|
83 |
for i, p in enumerate(result["related_paper"]):
|
84 |
res.append(f'{p["title"]}. {p["venue_name"].upper()} {p["year"]}.')
|
85 |
return res, result["related_paper"]
|
86 |
|
87 |
+
def literature2initial_ideas_callback(
|
88 |
+
self, background, brainstorms, retrieved_literature, json_strs=None
|
89 |
+
):
|
90 |
if json_strs is not None:
|
91 |
json_contents = json.loads(json_strs)
|
92 |
return json_contents["median"]["initial_idea"]
|
|
|
95 |
self.idea_generator.brainstorm = brainstorms
|
96 |
if self.use_inspiration:
|
97 |
message_input, idea_modified, median = (
|
98 |
+
self.idea_generator.generate_by_inspiration(
|
99 |
+
background, "new_idea", self.brainstorm_mode, False
|
100 |
+
)
|
101 |
)
|
102 |
else:
|
103 |
message_input, idea_modified, median = self.idea_generator.generate(
|
104 |
background, "new_idea", self.brainstorm_mode, False
|
105 |
)
|
106 |
return median["initial_idea"], idea_modified
|
107 |
+
|
108 |
def initial2final_callback(self, initial_ideas, final_ideas, json_strs=None):
|
109 |
if json_strs is not None:
|
110 |
json_contents = json.loads(json_strs)
|
|
|
117 |
return self.examples[i].get("background", "Background not found.")
|
118 |
else:
|
119 |
return "Example not found. Please select a valid index."
|
120 |
+
|
121 |
# return ("The application scope of large-scale language models such as GPT-4 and LLaMA "
|
122 |
# "has rapidly expanded, demonstrating powerful capabilities in natural language processing "
|
123 |
# "and multimodal tasks. However, as the size and complexity of the models increase, understanding "
|
src/generator.py
CHANGED
@@ -10,6 +10,7 @@ import warnings
|
|
10 |
import time
|
11 |
import os
|
12 |
from utils.hash import check_env, check_embedding
|
|
|
13 |
warnings.filterwarnings("ignore")
|
14 |
|
15 |
|
@@ -24,9 +25,14 @@ def extract_problem(problem, background):
|
|
24 |
research_problem = background
|
25 |
return research_problem
|
26 |
|
|
|
27 |
class IdeaGenerator:
|
28 |
def __init__(
|
29 |
-
self,
|
|
|
|
|
|
|
|
|
30 |
) -> None:
|
31 |
self.api_helper = APIHelper(config)
|
32 |
self.paper_list = paper_list
|
@@ -58,7 +64,9 @@ class IdeaGenerator:
|
|
58 |
idea = self.api_helper.generate_idea_with_cue_words(
|
59 |
problem, self.paper_list, self.cue_words
|
60 |
)
|
61 |
-
idea_filtered = self.api_helper.integrate_idea(
|
|
|
|
|
62 |
return message_input, problem, idea, idea_filtered
|
63 |
|
64 |
def generate_without_cue_words_bs(self, background: str):
|
@@ -66,7 +74,9 @@ class IdeaGenerator:
|
|
66 |
background, self.paper_list
|
67 |
)
|
68 |
idea = self.api_helper.generate_idea(problem, self.paper_list)
|
69 |
-
idea_filtered = self.api_helper.integrate_idea(
|
|
|
|
|
70 |
return message_input, problem, idea, idea_filtered
|
71 |
|
72 |
def generate_with_cue_words_ins(self, background: str):
|
@@ -93,16 +103,12 @@ class IdeaGenerator:
|
|
93 |
research_problem = extract_problem(problem, background)
|
94 |
inspirations = []
|
95 |
for paper in self.paper_list:
|
96 |
-
inspiration = self.api_helper.generate_inspiration(
|
97 |
-
research_problem, paper
|
98 |
-
)
|
99 |
inspirations.append(inspiration)
|
100 |
-
idea = self.api_helper.generate_idea_by_inspiration(
|
101 |
-
problem, inspirations
|
102 |
-
)
|
103 |
idea_filtered = self.api_helper.filter_idea(idea, background)
|
104 |
return message_input, problem, inspirations, idea, idea_filtered
|
105 |
-
|
106 |
def generate_with_cue_words_ins_bs(self, background: str):
|
107 |
problem, message_input = self.api_helper.generate_problem_with_cue_words(
|
108 |
background, self.paper_list, self.cue_words
|
@@ -117,7 +123,9 @@ class IdeaGenerator:
|
|
117 |
idea = self.api_helper.generate_idea_by_inspiration_with_cue_words(
|
118 |
problem, inspirations, self.cue_words
|
119 |
)
|
120 |
-
idea_filtered = self.api_helper.integrate_idea(
|
|
|
|
|
121 |
return message_input, problem, inspirations, idea, idea_filtered
|
122 |
|
123 |
def generate_without_cue_words_ins_bs(self, background: str):
|
@@ -127,14 +135,12 @@ class IdeaGenerator:
|
|
127 |
research_problem = extract_problem(problem, background)
|
128 |
inspirations = []
|
129 |
for paper in self.paper_list:
|
130 |
-
inspiration = self.api_helper.generate_inspiration(
|
131 |
-
research_problem, paper
|
132 |
-
)
|
133 |
inspirations.append(inspiration)
|
134 |
-
idea = self.api_helper.generate_idea_by_inspiration(
|
135 |
-
|
|
|
136 |
)
|
137 |
-
idea_filtered = self.api_helper.integrate_idea(background, self.brainstorm, idea)
|
138 |
return message_input, problem, inspirations, idea, idea_filtered
|
139 |
|
140 |
def generate(
|
@@ -151,44 +157,34 @@ class IdeaGenerator:
|
|
151 |
mode_name = "Generate new idea"
|
152 |
if bs_mode == "mode_a":
|
153 |
if use_cue_words:
|
154 |
-
logger.info(
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
idea,
|
159 |
-
idea_filtered
|
160 |
-
) = (
|
161 |
self.generate_with_cue_words(background)
|
162 |
)
|
163 |
else:
|
164 |
-
logger.info(
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
idea,
|
169 |
-
idea_filtered
|
170 |
-
) = (
|
171 |
self.generate_without_cue_words(background)
|
172 |
)
|
173 |
elif bs_mode == "mode_b" or bs_mode == "mode_c":
|
174 |
if use_cue_words:
|
175 |
-
logger.info(
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
idea,
|
180 |
-
idea_filtered
|
181 |
-
) = (
|
182 |
self.generate_with_cue_words_bs(background)
|
183 |
)
|
184 |
else:
|
185 |
-
logger.info(
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
) = (
|
192 |
self.generate_without_cue_words_bs(background)
|
193 |
)
|
194 |
|
@@ -214,48 +210,34 @@ class IdeaGenerator:
|
|
214 |
mode_name = "Generate new idea"
|
215 |
if bs_mode == "mode_a":
|
216 |
if use_cue_words:
|
217 |
-
logger.info(
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
inspirations,
|
222 |
-
idea,
|
223 |
-
idea_filtered
|
224 |
-
) = (
|
225 |
self.generate_with_cue_words_ins(background)
|
226 |
)
|
227 |
else:
|
228 |
-
logger.info(
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
inspirations,
|
233 |
-
idea,
|
234 |
-
idea_filtered
|
235 |
-
) = (
|
236 |
self.generate_without_cue_words_ins(background)
|
237 |
)
|
238 |
elif bs_mode == "mode_b" or bs_mode == "mode_c":
|
239 |
if use_cue_words:
|
240 |
-
logger.info(
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
inspirations,
|
245 |
-
idea,
|
246 |
-
idea_filtered
|
247 |
-
) = (
|
248 |
self.generate_with_cue_words_ins_bs(background)
|
249 |
)
|
250 |
else:
|
251 |
-
logger.info(
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
idea_filtered
|
258 |
-
) = (
|
259 |
self.generate_without_cue_words_ins_bs(background)
|
260 |
)
|
261 |
|
@@ -330,9 +312,18 @@ def main(ctx):
|
|
330 |
required=False,
|
331 |
help="The number of papers you want to process",
|
332 |
)
|
333 |
-
def backtracking(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
334 |
check_env()
|
335 |
-
check_embedding()
|
336 |
# Configuration
|
337 |
config = ConfigReader.load(config_path, **kwargs)
|
338 |
logger.add(
|
@@ -349,7 +340,10 @@ def backtracking(config_path, ids_path, retriever_name, brainstorm_mode, use_cue
|
|
349 |
batch_size = 2
|
350 |
output_dir = "./assets/output_idea/"
|
351 |
os.makedirs(output_dir, exist_ok=True)
|
352 |
-
output_file = os.path.join(
|
|
|
|
|
|
|
353 |
if os.path.exists(output_file):
|
354 |
with open(output_file, "r", encoding="utf-8") as f:
|
355 |
try:
|
@@ -388,7 +382,7 @@ def backtracking(config_path, ids_path, retriever_name, brainstorm_mode, use_cue
|
|
388 |
if brainstorm_mode == "mode_c":
|
389 |
entities_bs = api_helper.generate_entity_list(brainstorm, 10)
|
390 |
logger.debug("Original entities from brainstorm: {}".format(entities_bs))
|
391 |
-
entities_all = list(set(entities)|set(entities_bs))
|
392 |
else:
|
393 |
entities_bs = None
|
394 |
entities_all = entities
|
@@ -404,8 +398,7 @@ def backtracking(config_path, ids_path, retriever_name, brainstorm_mode, use_cue
|
|
404 |
continue
|
405 |
# 3. 检索相关论文
|
406 |
rt = RetrieverFactory.get_retriever_factory().create_retriever(
|
407 |
-
retriever_name,
|
408 |
-
config
|
409 |
)
|
410 |
result = rt.retrieve(
|
411 |
bg, entities_all, need_evaluate=False, target_paper_id_list=[]
|
@@ -438,7 +431,7 @@ def backtracking(config_path, ids_path, retriever_name, brainstorm_mode, use_cue
|
|
438 |
"hash_id": paper["hash_id"],
|
439 |
"background": bg,
|
440 |
"entities_bg": entities,
|
441 |
-
"brainstorm"
|
442 |
"entities_bs": entities_bs,
|
443 |
"entities_rt": entities_rt,
|
444 |
"related_paper": [p["hash_id"] for p in related_paper],
|
@@ -467,6 +460,7 @@ def backtracking(config_path, ids_path, retriever_name, brainstorm_mode, use_cue
|
|
467 |
) as f:
|
468 |
json.dump(eval_data, f, ensure_ascii=False, indent=4)
|
469 |
|
|
|
470 |
@main.command()
|
471 |
@click.option(
|
472 |
"-c",
|
@@ -512,9 +506,16 @@ def backtracking(config_path, ids_path, retriever_name, brainstorm_mode, use_cue
|
|
512 |
required=False,
|
513 |
help="The number of data you want to process",
|
514 |
)
|
515 |
-
def new_idea(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
516 |
check_env()
|
517 |
-
check_embedding()
|
518 |
logger.add(
|
519 |
"log/generate_{}_{}.log".format(time.time(), retriever_name), level="DEBUG"
|
520 |
) # 添加文件输出
|
@@ -522,6 +523,7 @@ def new_idea(config_path, ids_path, retriever_name, brainstorm_mode, use_inspira
|
|
522 |
# Configuration
|
523 |
config = ConfigReader.load(config_path, **kwargs)
|
524 |
api_helper = APIHelper(config)
|
|
|
525 |
eval_data = []
|
526 |
cur_num = 0
|
527 |
data_num = 0
|
@@ -529,7 +531,9 @@ def new_idea(config_path, ids_path, retriever_name, brainstorm_mode, use_inspira
|
|
529 |
bg_ids = set()
|
530 |
output_dir = "./assets/output_idea/"
|
531 |
os.makedirs(output_dir, exist_ok=True)
|
532 |
-
output_file = os.path.join(
|
|
|
|
|
533 |
if os.path.exists(output_file):
|
534 |
with open(output_file, "r", encoding="utf-8") as f:
|
535 |
try:
|
@@ -538,7 +542,7 @@ def new_idea(config_path, ids_path, retriever_name, brainstorm_mode, use_inspira
|
|
538 |
cur_num = len(eval_data)
|
539 |
except json.JSONDecodeError:
|
540 |
eval_data = []
|
541 |
-
|
542 |
for line in ids_path:
|
543 |
# 解析每行的JSON数据
|
544 |
data = json.loads(line)
|
@@ -568,16 +572,17 @@ def new_idea(config_path, ids_path, retriever_name, brainstorm_mode, use_inspira
|
|
568 |
if brainstorm_mode == "mode_c":
|
569 |
entities_bs = api_helper.generate_entity_list(brainstorm, 10)
|
570 |
logger.debug("Original entities from brainstorm: {}".format(entities_bs))
|
571 |
-
entities_all = list(set(entities)|set(entities_bs))
|
572 |
else:
|
573 |
entities_bs = None
|
574 |
entities_all = entities
|
575 |
# 2. 检索相关论文
|
576 |
rt = RetrieverFactory.get_retriever_factory().create_retriever(
|
577 |
-
retriever_name,
|
578 |
-
|
|
|
|
|
579 |
)
|
580 |
-
result = rt.retrieve(bg, entities_all, need_evaluate=False, target_paper_id_list=[])
|
581 |
related_paper = result["related_paper"]
|
582 |
logger.info("Find {} related papers...".format(len(related_paper)))
|
583 |
entities_rt = result["entities"]
|
@@ -597,7 +602,7 @@ def new_idea(config_path, ids_path, retriever_name, brainstorm_mode, use_inspira
|
|
597 |
{
|
598 |
"background": bg,
|
599 |
"entities_bg": entities,
|
600 |
-
"brainstorm"
|
601 |
"entities_bs": entities_bs,
|
602 |
"entities_rt": entities_rt,
|
603 |
"related_paper": [p["hash_id"] for p in related_paper],
|
@@ -621,5 +626,6 @@ def new_idea(config_path, ids_path, retriever_name, brainstorm_mode, use_inspira
|
|
621 |
with open(output_file, "w", encoding="utf-8") as f:
|
622 |
json.dump(eval_data, f, ensure_ascii=False, indent=4)
|
623 |
|
|
|
624 |
if __name__ == "__main__":
|
625 |
main()
|
|
|
10 |
import time
|
11 |
import os
|
12 |
from utils.hash import check_env, check_embedding
|
13 |
+
|
14 |
warnings.filterwarnings("ignore")
|
15 |
|
16 |
|
|
|
25 |
research_problem = background
|
26 |
return research_problem
|
27 |
|
28 |
+
|
29 |
class IdeaGenerator:
|
30 |
def __init__(
|
31 |
+
self,
|
32 |
+
config,
|
33 |
+
paper_list: list[dict] = [],
|
34 |
+
cue_words: list = None,
|
35 |
+
brainstorm: str = None,
|
36 |
) -> None:
|
37 |
self.api_helper = APIHelper(config)
|
38 |
self.paper_list = paper_list
|
|
|
64 |
idea = self.api_helper.generate_idea_with_cue_words(
|
65 |
problem, self.paper_list, self.cue_words
|
66 |
)
|
67 |
+
idea_filtered = self.api_helper.integrate_idea(
|
68 |
+
background, self.brainstorm, idea
|
69 |
+
)
|
70 |
return message_input, problem, idea, idea_filtered
|
71 |
|
72 |
def generate_without_cue_words_bs(self, background: str):
|
|
|
74 |
background, self.paper_list
|
75 |
)
|
76 |
idea = self.api_helper.generate_idea(problem, self.paper_list)
|
77 |
+
idea_filtered = self.api_helper.integrate_idea(
|
78 |
+
background, self.brainstorm, idea
|
79 |
+
)
|
80 |
return message_input, problem, idea, idea_filtered
|
81 |
|
82 |
def generate_with_cue_words_ins(self, background: str):
|
|
|
103 |
research_problem = extract_problem(problem, background)
|
104 |
inspirations = []
|
105 |
for paper in self.paper_list:
|
106 |
+
inspiration = self.api_helper.generate_inspiration(research_problem, paper)
|
|
|
|
|
107 |
inspirations.append(inspiration)
|
108 |
+
idea = self.api_helper.generate_idea_by_inspiration(problem, inspirations)
|
|
|
|
|
109 |
idea_filtered = self.api_helper.filter_idea(idea, background)
|
110 |
return message_input, problem, inspirations, idea, idea_filtered
|
111 |
+
|
112 |
def generate_with_cue_words_ins_bs(self, background: str):
|
113 |
problem, message_input = self.api_helper.generate_problem_with_cue_words(
|
114 |
background, self.paper_list, self.cue_words
|
|
|
123 |
idea = self.api_helper.generate_idea_by_inspiration_with_cue_words(
|
124 |
problem, inspirations, self.cue_words
|
125 |
)
|
126 |
+
idea_filtered = self.api_helper.integrate_idea(
|
127 |
+
background, self.brainstorm, idea
|
128 |
+
)
|
129 |
return message_input, problem, inspirations, idea, idea_filtered
|
130 |
|
131 |
def generate_without_cue_words_ins_bs(self, background: str):
|
|
|
135 |
research_problem = extract_problem(problem, background)
|
136 |
inspirations = []
|
137 |
for paper in self.paper_list:
|
138 |
+
inspiration = self.api_helper.generate_inspiration(research_problem, paper)
|
|
|
|
|
139 |
inspirations.append(inspiration)
|
140 |
+
idea = self.api_helper.generate_idea_by_inspiration(problem, inspirations)
|
141 |
+
idea_filtered = self.api_helper.integrate_idea(
|
142 |
+
background, self.brainstorm, idea
|
143 |
)
|
|
|
144 |
return message_input, problem, inspirations, idea, idea_filtered
|
145 |
|
146 |
def generate(
|
|
|
157 |
mode_name = "Generate new idea"
|
158 |
if bs_mode == "mode_a":
|
159 |
if use_cue_words:
|
160 |
+
logger.info(
|
161 |
+
"{} using brainstorm_mode_a with cue words.".format(mode_name)
|
162 |
+
)
|
163 |
+
(message_input, problem, idea, idea_filtered) = (
|
|
|
|
|
|
|
164 |
self.generate_with_cue_words(background)
|
165 |
)
|
166 |
else:
|
167 |
+
logger.info(
|
168 |
+
"{} using brainstorm_mode_a without cue words.".format(mode_name)
|
169 |
+
)
|
170 |
+
(message_input, problem, idea, idea_filtered) = (
|
|
|
|
|
|
|
171 |
self.generate_without_cue_words(background)
|
172 |
)
|
173 |
elif bs_mode == "mode_b" or bs_mode == "mode_c":
|
174 |
if use_cue_words:
|
175 |
+
logger.info(
|
176 |
+
"{} using brainstorm_{} with cue words.".format(mode_name, bs_mode)
|
177 |
+
)
|
178 |
+
(message_input, problem, idea, idea_filtered) = (
|
|
|
|
|
|
|
179 |
self.generate_with_cue_words_bs(background)
|
180 |
)
|
181 |
else:
|
182 |
+
logger.info(
|
183 |
+
"{} using brainstorm_{} without cue words.".format(
|
184 |
+
mode_name, bs_mode
|
185 |
+
)
|
186 |
+
)
|
187 |
+
(message_input, problem, idea, idea_filtered) = (
|
|
|
188 |
self.generate_without_cue_words_bs(background)
|
189 |
)
|
190 |
|
|
|
210 |
mode_name = "Generate new idea"
|
211 |
if bs_mode == "mode_a":
|
212 |
if use_cue_words:
|
213 |
+
logger.info(
|
214 |
+
"{} using brainstorm_mode_a with cue words.".format(mode_name)
|
215 |
+
)
|
216 |
+
(message_input, problem, inspirations, idea, idea_filtered) = (
|
|
|
|
|
|
|
|
|
217 |
self.generate_with_cue_words_ins(background)
|
218 |
)
|
219 |
else:
|
220 |
+
logger.info(
|
221 |
+
"{} using brainstorm_mode_a without cue words.".format(mode_name)
|
222 |
+
)
|
223 |
+
(message_input, problem, inspirations, idea, idea_filtered) = (
|
|
|
|
|
|
|
|
|
224 |
self.generate_without_cue_words_ins(background)
|
225 |
)
|
226 |
elif bs_mode == "mode_b" or bs_mode == "mode_c":
|
227 |
if use_cue_words:
|
228 |
+
logger.info(
|
229 |
+
"{} using brainstorm_{} with cue words.".format(mode_name, bs_mode)
|
230 |
+
)
|
231 |
+
(message_input, problem, inspirations, idea, idea_filtered) = (
|
|
|
|
|
|
|
|
|
232 |
self.generate_with_cue_words_ins_bs(background)
|
233 |
)
|
234 |
else:
|
235 |
+
logger.info(
|
236 |
+
"{} using brainstorm_{} without cue words.".format(
|
237 |
+
mode_name, bs_mode
|
238 |
+
)
|
239 |
+
)
|
240 |
+
(message_input, problem, inspirations, idea, idea_filtered) = (
|
|
|
|
|
241 |
self.generate_without_cue_words_ins_bs(background)
|
242 |
)
|
243 |
|
|
|
312 |
required=False,
|
313 |
help="The number of papers you want to process",
|
314 |
)
|
315 |
+
def backtracking(
|
316 |
+
config_path,
|
317 |
+
ids_path,
|
318 |
+
retriever_name,
|
319 |
+
brainstorm_mode,
|
320 |
+
use_cue_words,
|
321 |
+
use_inspiration,
|
322 |
+
num,
|
323 |
+
**kwargs,
|
324 |
+
):
|
325 |
check_env()
|
326 |
+
check_embedding()
|
327 |
# Configuration
|
328 |
config = ConfigReader.load(config_path, **kwargs)
|
329 |
logger.add(
|
|
|
340 |
batch_size = 2
|
341 |
output_dir = "./assets/output_idea/"
|
342 |
os.makedirs(output_dir, exist_ok=True)
|
343 |
+
output_file = os.path.join(
|
344 |
+
output_dir,
|
345 |
+
f"output_backtracking_{brainstorm_mode}_cue_{use_cue_words}_ins_{use_inspiration}.json",
|
346 |
+
)
|
347 |
if os.path.exists(output_file):
|
348 |
with open(output_file, "r", encoding="utf-8") as f:
|
349 |
try:
|
|
|
382 |
if brainstorm_mode == "mode_c":
|
383 |
entities_bs = api_helper.generate_entity_list(brainstorm, 10)
|
384 |
logger.debug("Original entities from brainstorm: {}".format(entities_bs))
|
385 |
+
entities_all = list(set(entities) | set(entities_bs))
|
386 |
else:
|
387 |
entities_bs = None
|
388 |
entities_all = entities
|
|
|
398 |
continue
|
399 |
# 3. 检索相关论文
|
400 |
rt = RetrieverFactory.get_retriever_factory().create_retriever(
|
401 |
+
retriever_name, config
|
|
|
402 |
)
|
403 |
result = rt.retrieve(
|
404 |
bg, entities_all, need_evaluate=False, target_paper_id_list=[]
|
|
|
431 |
"hash_id": paper["hash_id"],
|
432 |
"background": bg,
|
433 |
"entities_bg": entities,
|
434 |
+
"brainstorm": brainstorm,
|
435 |
"entities_bs": entities_bs,
|
436 |
"entities_rt": entities_rt,
|
437 |
"related_paper": [p["hash_id"] for p in related_paper],
|
|
|
460 |
) as f:
|
461 |
json.dump(eval_data, f, ensure_ascii=False, indent=4)
|
462 |
|
463 |
+
|
464 |
@main.command()
|
465 |
@click.option(
|
466 |
"-c",
|
|
|
506 |
required=False,
|
507 |
help="The number of data you want to process",
|
508 |
)
|
509 |
+
def new_idea(
|
510 |
+
config_path,
|
511 |
+
ids_path,
|
512 |
+
retriever_name,
|
513 |
+
brainstorm_mode,
|
514 |
+
use_inspiration,
|
515 |
+
num,
|
516 |
+
**kwargs,
|
517 |
+
):
|
518 |
check_env()
|
|
|
519 |
logger.add(
|
520 |
"log/generate_{}_{}.log".format(time.time(), retriever_name), level="DEBUG"
|
521 |
) # 添加文件输出
|
|
|
523 |
# Configuration
|
524 |
config = ConfigReader.load(config_path, **kwargs)
|
525 |
api_helper = APIHelper(config)
|
526 |
+
check_embedding(config.DEFAULT.embedding)
|
527 |
eval_data = []
|
528 |
cur_num = 0
|
529 |
data_num = 0
|
|
|
531 |
bg_ids = set()
|
532 |
output_dir = "./assets/output_idea/"
|
533 |
os.makedirs(output_dir, exist_ok=True)
|
534 |
+
output_file = os.path.join(
|
535 |
+
output_dir, f"output_new_idea_{brainstorm_mode}_ins_{use_inspiration}.json"
|
536 |
+
)
|
537 |
if os.path.exists(output_file):
|
538 |
with open(output_file, "r", encoding="utf-8") as f:
|
539 |
try:
|
|
|
542 |
cur_num = len(eval_data)
|
543 |
except json.JSONDecodeError:
|
544 |
eval_data = []
|
545 |
+
logger.debug(f"{cur_num} datas have been processed.")
|
546 |
for line in ids_path:
|
547 |
# 解析每行的JSON数据
|
548 |
data = json.loads(line)
|
|
|
572 |
if brainstorm_mode == "mode_c":
|
573 |
entities_bs = api_helper.generate_entity_list(brainstorm, 10)
|
574 |
logger.debug("Original entities from brainstorm: {}".format(entities_bs))
|
575 |
+
entities_all = list(set(entities) | set(entities_bs))
|
576 |
else:
|
577 |
entities_bs = None
|
578 |
entities_all = entities
|
579 |
# 2. 检索相关论文
|
580 |
rt = RetrieverFactory.get_retriever_factory().create_retriever(
|
581 |
+
retriever_name, config
|
582 |
+
)
|
583 |
+
result = rt.retrieve(
|
584 |
+
bg, entities_all, need_evaluate=False, target_paper_id_list=[]
|
585 |
)
|
|
|
586 |
related_paper = result["related_paper"]
|
587 |
logger.info("Find {} related papers...".format(len(related_paper)))
|
588 |
entities_rt = result["entities"]
|
|
|
602 |
{
|
603 |
"background": bg,
|
604 |
"entities_bg": entities,
|
605 |
+
"brainstorm": brainstorm,
|
606 |
"entities_bs": entities_bs,
|
607 |
"entities_rt": entities_rt,
|
608 |
"related_paper": [p["hash_id"] for p in related_paper],
|
|
|
626 |
with open(output_file, "w", encoding="utf-8") as f:
|
627 |
json.dump(eval_data, f, ensure_ascii=False, indent=4)
|
628 |
|
629 |
+
|
630 |
if __name__ == "__main__":
|
631 |
main()
|
src/paper_manager.py
CHANGED
@@ -389,10 +389,8 @@ class PaperManager:
|
|
389 |
)
|
390 |
|
391 |
if need_summary:
|
392 |
-
print(paper.keys())
|
393 |
if not self.check_parse(paper):
|
394 |
logger.error(f"paper {paper['hash_id']} need parse first...")
|
395 |
-
|
396 |
result = self.api_helper(
|
397 |
paper["title"], paper["abstract"], paper["introduction"]
|
398 |
)
|
@@ -628,9 +626,11 @@ class PaperManager:
|
|
628 |
|
629 |
def insert_embedding(self, hash_id=None):
|
630 |
self.paper_client.add_paper_abstract_embedding(self.embedding_model, hash_id)
|
631 |
-
# self.
|
632 |
-
# self.
|
633 |
-
#
|
|
|
|
|
634 |
|
635 |
def cosine_similarity_search(self, data_type, context, k=1):
|
636 |
"""
|
@@ -837,8 +837,9 @@ def local(config_path, year, venue_name, output, **kwargs):
|
|
837 |
os.makedirs(os.path.dirname(output_path))
|
838 |
config = ConfigReader.load(config_path, output_path=output_path, **kwargs)
|
839 |
pm = PaperManager(config, venue_name, year)
|
|
|
840 |
pm.update_paper_from_json_to_json(
|
841 |
-
need_download=True, need_parse=True, need_summary=True
|
842 |
)
|
843 |
|
844 |
|
|
|
389 |
)
|
390 |
|
391 |
if need_summary:
|
|
|
392 |
if not self.check_parse(paper):
|
393 |
logger.error(f"paper {paper['hash_id']} need parse first...")
|
|
|
394 |
result = self.api_helper(
|
395 |
paper["title"], paper["abstract"], paper["introduction"]
|
396 |
)
|
|
|
626 |
|
627 |
def insert_embedding(self, hash_id=None):
|
628 |
self.paper_client.add_paper_abstract_embedding(self.embedding_model, hash_id)
|
629 |
+
# self.paper_client.add_paper_bg_embedding(self.embedding_model, hash_id)
|
630 |
+
# self.paper_client.add_paper_contribution_embedding(
|
631 |
+
# self.embedding_model, hash_id
|
632 |
+
# )
|
633 |
+
# self.paper_client.add_paper_summary_embedding(self.embedding_model, hash_id)
|
634 |
|
635 |
def cosine_similarity_search(self, data_type, context, k=1):
|
636 |
"""
|
|
|
837 |
os.makedirs(os.path.dirname(output_path))
|
838 |
config = ConfigReader.load(config_path, output_path=output_path, **kwargs)
|
839 |
pm = PaperManager(config, venue_name, year)
|
840 |
+
print("###")
|
841 |
pm.update_paper_from_json_to_json(
|
842 |
+
need_download=True, need_parse=True, need_summary=True
|
843 |
)
|
844 |
|
845 |
|
src/retriever.py
CHANGED
@@ -41,9 +41,9 @@ def main(ctx):
|
|
41 |
def retrieve(
|
42 |
config_path, ids_path, **kwargs
|
43 |
):
|
44 |
-
check_env()
|
45 |
-
check_embedding()
|
46 |
config = ConfigReader.load(config_path, **kwargs)
|
|
|
|
|
47 |
log_dir = config.DEFAULT.log_dir
|
48 |
retriever_name = config.RETRIEVE.retriever_name
|
49 |
cluster_to_filter = config.RETRIEVE.use_cluster_to_filter
|
|
|
41 |
def retrieve(
|
42 |
config_path, ids_path, **kwargs
|
43 |
):
|
|
|
|
|
44 |
config = ConfigReader.load(config_path, **kwargs)
|
45 |
+
check_embedding(config.DEFAULT.embedding)
|
46 |
+
check_env()
|
47 |
log_dir = config.DEFAULT.log_dir
|
48 |
retriever_name = config.RETRIEVE.retriever_name
|
49 |
cluster_to_filter = config.RETRIEVE.use_cluster_to_filter
|
src/utils/api/__init__.py
CHANGED
@@ -22,8 +22,10 @@ Creation Date : 2024-10-29
|
|
22 |
|
23 |
Author : Frank Kang([email protected])
|
24 |
"""
|
|
|
25 |
from .base_helper import HelperCompany
|
26 |
from .openai_helper import OpenAIHelper # noqa: F401, ensure autoregister
|
27 |
from .zhipuai_helper import ZhipuAIHelper # noqa: F401, ensure autoregister
|
|
|
28 |
|
29 |
__all__ = ["HelperCompany"]
|
|
|
22 |
|
23 |
Author : Frank Kang([email protected])
|
24 |
"""
|
25 |
+
|
26 |
from .base_helper import HelperCompany
|
27 |
from .openai_helper import OpenAIHelper # noqa: F401, ensure autoregister
|
28 |
from .zhipuai_helper import ZhipuAIHelper # noqa: F401, ensure autoregister
|
29 |
+
from .local_helper import LocalHelper # noqa: F401, ensure autoregister
|
30 |
|
31 |
__all__ = ["HelperCompany"]
|
src/utils/api/base_helper.py
CHANGED
@@ -17,6 +17,9 @@ from abc import ABCMeta
|
|
17 |
from typing_extensions import Literal, override
|
18 |
from ..base_company import BaseCompany
|
19 |
from typing import Union
|
|
|
|
|
|
|
20 |
|
21 |
|
22 |
class NotGiven:
|
@@ -109,6 +112,31 @@ class BaseHelper:
|
|
109 |
self.base_url = base_url
|
110 |
self.client = None
|
111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
def create(
|
113 |
self,
|
114 |
*args,
|
@@ -124,7 +152,7 @@ class BaseHelper:
|
|
124 |
extra_headers: None | NotGiven = None,
|
125 |
extra_body: None | NotGiven = None,
|
126 |
timeout: float | None | NotGiven = None,
|
127 |
-
**kwargs
|
128 |
):
|
129 |
"""
|
130 |
Creates a model response for the given chat conversation.
|
@@ -187,20 +215,44 @@ class BaseHelper:
|
|
187 |
|
188 |
timeout: Override the client-level default timeout for this request, in seconds
|
189 |
"""
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
from typing_extensions import Literal, override
|
18 |
from ..base_company import BaseCompany
|
19 |
from typing import Union
|
20 |
+
import requests
|
21 |
+
import json
|
22 |
+
from requests.exceptions import RequestException
|
23 |
|
24 |
|
25 |
class NotGiven:
|
|
|
112 |
self.base_url = base_url
|
113 |
self.client = None
|
114 |
|
115 |
+
def apply_for_service(self, data_param, max_attempts=4):
|
116 |
+
attempt = 0
|
117 |
+
while attempt < max_attempts:
|
118 |
+
try:
|
119 |
+
# print(f"尝试 #{attempt + 1}")
|
120 |
+
r = requests.post(
|
121 |
+
self.base_url + "/llm/generate",
|
122 |
+
headers={"Content-Type": "application/json"},
|
123 |
+
data=json.dumps(data_param),
|
124 |
+
)
|
125 |
+
# 检查请求是否成功
|
126 |
+
if r.status_code == 200:
|
127 |
+
# print("服务请求成功。")
|
128 |
+
response = r.json()["data"]["output"]
|
129 |
+
return response # 或者根据需要返回其他内容
|
130 |
+
else:
|
131 |
+
print("服务请求失败,响应状态码:", response.status_code)
|
132 |
+
except RequestException as e:
|
133 |
+
print("请求发生错误:", e)
|
134 |
+
|
135 |
+
attempt += 1
|
136 |
+
if attempt == max_attempts:
|
137 |
+
print("达到最大尝试次数,服务请求失败。")
|
138 |
+
return None # 或者根据你的情况抛出异常
|
139 |
+
|
140 |
def create(
|
141 |
self,
|
142 |
*args,
|
|
|
152 |
extra_headers: None | NotGiven = None,
|
153 |
extra_body: None | NotGiven = None,
|
154 |
timeout: float | None | NotGiven = None,
|
155 |
+
**kwargs,
|
156 |
):
|
157 |
"""
|
158 |
Creates a model response for the given chat conversation.
|
|
|
215 |
|
216 |
timeout: Override the client-level default timeout for this request, in seconds
|
217 |
"""
|
218 |
+
if self.model != "local":
|
219 |
+
return (
|
220 |
+
self.client.chat.completions.create(
|
221 |
+
*args,
|
222 |
+
model=self.model,
|
223 |
+
messages=messages,
|
224 |
+
stream=stream,
|
225 |
+
temperature=temperature,
|
226 |
+
top_p=top_p,
|
227 |
+
max_tokens=max_tokens,
|
228 |
+
seed=seed,
|
229 |
+
stop=stop,
|
230 |
+
tools=tools,
|
231 |
+
tool_choice=tool_choice,
|
232 |
+
extra_headers=extra_headers,
|
233 |
+
extra_body=extra_body,
|
234 |
+
timeout=timeout,
|
235 |
+
**kwargs,
|
236 |
+
)
|
237 |
+
.choices[0]
|
238 |
+
.message.content
|
239 |
+
)
|
240 |
+
else:
|
241 |
+
default_system = "You are a helpful assistant."
|
242 |
+
input_content = ""
|
243 |
+
for message in messages:
|
244 |
+
if message["role"] == "system":
|
245 |
+
default_system = message["content"]
|
246 |
+
else:
|
247 |
+
input_content += message["content"]
|
248 |
+
data_param = {}
|
249 |
+
data_param["input"] = input_content
|
250 |
+
data_param["serviceParams"] = {"stream": False, "system": default_system}
|
251 |
+
data_param["ModelParams"] = {
|
252 |
+
"temperature": 0.8,
|
253 |
+
"presence_penalty": 2.0,
|
254 |
+
"frequency_penalty": 0.0,
|
255 |
+
"top_p": 0.8,
|
256 |
+
}
|
257 |
+
response = self.apply_for_service(data_param)
|
258 |
+
return response
|
src/utils/api/local_helper.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r"""_summary_
|
2 |
+
-*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
Module : data.utils.api.zhipuai_helper
|
5 |
+
|
6 |
+
File Name : zhipuai_helper.py
|
7 |
+
|
8 |
+
Description : Helper class for ZhipuAI interface, generally not used directly.
|
9 |
+
For example:
|
10 |
+
```
|
11 |
+
from data.utils.api import HelperCompany
|
12 |
+
helper = HelperCompany.get()['ZhipuAI']
|
13 |
+
...
|
14 |
+
```
|
15 |
+
|
16 |
+
Creation Date : 2024-11-28
|
17 |
+
|
18 |
+
Author : lihuigu([email protected])
|
19 |
+
"""
|
20 |
+
|
21 |
+
from .base_helper import register_helper, BaseHelper
|
22 |
+
|
23 |
+
|
24 |
+
@register_helper("Local")
|
25 |
+
class LocalHelper(BaseHelper):
|
26 |
+
"""_summary_
|
27 |
+
|
28 |
+
Helper class for ZhipuAI interface, generally not used directly.
|
29 |
+
|
30 |
+
For example:
|
31 |
+
```
|
32 |
+
from data.utils.api import HelperCompany
|
33 |
+
helper = HelperCompany.get()['Local']
|
34 |
+
...
|
35 |
+
```
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(self, api_key, model, base_url=None, timeout=None):
|
39 |
+
super().__init__(api_key, model, base_url)
|
src/utils/hash.py
CHANGED
@@ -12,18 +12,35 @@ ENV_CHECKED = False
|
|
12 |
EMBEDDING_CHECKED = False
|
13 |
|
14 |
|
15 |
-
def check_embedding():
|
|
|
16 |
global EMBEDDING_CHECKED
|
17 |
if not EMBEDDING_CHECKED:
|
18 |
# Define the repository and files to download
|
19 |
-
repo_id = "sentence-transformers/all-MiniLM-L6-v2" # "BAAI/bge-small-en-v1.5"
|
20 |
local_dir = f"./assets/model/{repo_id}"
|
21 |
-
|
22 |
-
"
|
23 |
-
"
|
24 |
-
"
|
25 |
-
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
# Download each file and save it to the /model/bge directory
|
28 |
for file_name in files_to_download:
|
29 |
if not os.path.exists(os.path.join(local_dir, file_name)):
|
@@ -47,12 +64,15 @@ def check_env():
|
|
47 |
"NEO4J_PASSWD",
|
48 |
"MODEL_NAME",
|
49 |
"MODEL_TYPE",
|
50 |
-
"MODEL_API_KEY",
|
51 |
"BASE_URL",
|
52 |
]
|
53 |
for env_name in env_name_list:
|
54 |
if env_name not in os.environ or os.environ[env_name] == "":
|
55 |
raise ValueError(f"{env_name} is not set...")
|
|
|
|
|
|
|
|
|
56 |
ENV_CHECKED = True
|
57 |
|
58 |
|
@@ -61,16 +81,21 @@ class EmbeddingModel:
|
|
61 |
|
62 |
def __new__(cls, config):
|
63 |
if cls._instance is None:
|
|
|
64 |
cls._instance = super(EmbeddingModel, cls).__new__(cls)
|
65 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
66 |
cls._instance.embedding_model = SentenceTransformer(
|
67 |
-
model_name_or_path=get_dir(
|
68 |
device=device,
|
|
|
69 |
)
|
70 |
print(f"==== using device {device} ====")
|
71 |
return cls._instance
|
72 |
|
|
|
73 |
def get_embedding_model(config):
|
|
|
|
|
74 |
return EmbeddingModel(config).embedding_model
|
75 |
|
76 |
|
|
|
12 |
EMBEDDING_CHECKED = False
|
13 |
|
14 |
|
15 |
+
def check_embedding(repo_id):
|
16 |
+
print("=== check embedding model ===")
|
17 |
global EMBEDDING_CHECKED
|
18 |
if not EMBEDDING_CHECKED:
|
19 |
# Define the repository and files to download
|
|
|
20 |
local_dir = f"./assets/model/{repo_id}"
|
21 |
+
if repo_id in [
|
22 |
+
"sentence-transformers/all-MiniLM-L6-v2",
|
23 |
+
"BAAI/bge-small-en-v1.5",
|
24 |
+
"BAAAI/llm_embedder",
|
25 |
+
]:
|
26 |
+
# repo_id = "sentence-transformers/all-MiniLM-L6-v2"
|
27 |
+
# repo_id = "BAAI/bge-small-en-v1.5"
|
28 |
+
files_to_download = [
|
29 |
+
"config.json",
|
30 |
+
"pytorch_model.bin",
|
31 |
+
"tokenizer_config.json",
|
32 |
+
"vocab.txt",
|
33 |
+
]
|
34 |
+
elif repo_id in ["Alibaba-NLP/gte-base-en-v1.5"]:
|
35 |
+
files_to_download = [
|
36 |
+
"config.json",
|
37 |
+
"model.safetensors",
|
38 |
+
"modules.json",
|
39 |
+
"tokenizer.json",
|
40 |
+
"sentence_bert_config.json",
|
41 |
+
"tokenizer_config.json",
|
42 |
+
"vocab.txt",
|
43 |
+
]
|
44 |
# Download each file and save it to the /model/bge directory
|
45 |
for file_name in files_to_download:
|
46 |
if not os.path.exists(os.path.join(local_dir, file_name)):
|
|
|
64 |
"NEO4J_PASSWD",
|
65 |
"MODEL_NAME",
|
66 |
"MODEL_TYPE",
|
|
|
67 |
"BASE_URL",
|
68 |
]
|
69 |
for env_name in env_name_list:
|
70 |
if env_name not in os.environ or os.environ[env_name] == "":
|
71 |
raise ValueError(f"{env_name} is not set...")
|
72 |
+
if os.environ["MODEL_TYPE"] != "Local":
|
73 |
+
env_name = "MODEL_API_KEY"
|
74 |
+
if env_name not in os.environ or os.environ[env_name] == "":
|
75 |
+
raise ValueError(f"{env_name} is not set...")
|
76 |
ENV_CHECKED = True
|
77 |
|
78 |
|
|
|
81 |
|
82 |
def __new__(cls, config):
|
83 |
if cls._instance is None:
|
84 |
+
local_dir = f"./assets/model/{config.DEFAULT.embedding}"
|
85 |
cls._instance = super(EmbeddingModel, cls).__new__(cls)
|
86 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
87 |
cls._instance.embedding_model = SentenceTransformer(
|
88 |
+
model_name_or_path=get_dir(local_dir),
|
89 |
device=device,
|
90 |
+
trust_remote_code=True,
|
91 |
)
|
92 |
print(f"==== using device {device} ====")
|
93 |
return cls._instance
|
94 |
|
95 |
+
|
96 |
def get_embedding_model(config):
|
97 |
+
print("=== get embedding model ===")
|
98 |
+
check_embedding(config.DEFAULT.embedding)
|
99 |
return EmbeddingModel(config).embedding_model
|
100 |
|
101 |
|
src/utils/llms_api.py
CHANGED
@@ -49,7 +49,10 @@ class APIHelper(object):
|
|
49 |
def get_helper(self):
|
50 |
MODEL_TYPE = os.environ["MODEL_TYPE"]
|
51 |
MODEL_NAME = os.environ["MODEL_NAME"]
|
52 |
-
|
|
|
|
|
|
|
53 |
BASE_URL = os.environ["BASE_URL"]
|
54 |
return HelperCompany.get()[MODEL_TYPE](
|
55 |
MODEL_API_KEY, MODEL_NAME, BASE_URL, timeout=None
|
@@ -64,6 +67,8 @@ class APIHelper(object):
|
|
64 |
"glm4-air",
|
65 |
"qwen-max",
|
66 |
"qwen-plus",
|
|
|
|
|
67 |
]:
|
68 |
raise ValueError(f"Check model name...")
|
69 |
|
@@ -78,13 +83,13 @@ class APIHelper(object):
|
|
78 |
response1 = self.generator.create(
|
79 |
messages=message,
|
80 |
)
|
81 |
-
summary = clean_text(response1
|
82 |
message.append({"role": "assistant", "content": summary})
|
83 |
message.append(self.prompt.queries[1][0]())
|
84 |
response2 = self.generator.create(
|
85 |
messages=message,
|
86 |
)
|
87 |
-
detail = response2
|
88 |
motivation = clean_text(detail.split(TAG_moti)[1].split(TAG_contr)[0])
|
89 |
contribution = clean_text(detail.split(TAG_contr)[1])
|
90 |
result = {
|
@@ -116,7 +121,7 @@ class APIHelper(object):
|
|
116 |
response = self.generator.create(
|
117 |
messages=message,
|
118 |
)
|
119 |
-
entities = response
|
120 |
entity_list = entities.strip().split(", ")
|
121 |
clean_entity_list = []
|
122 |
for entity in entity_list:
|
@@ -151,7 +156,7 @@ class APIHelper(object):
|
|
151 |
response_brainstorming = self.generator.create(
|
152 |
messages=message,
|
153 |
)
|
154 |
-
brainstorming_ideas = response_brainstorming
|
155 |
|
156 |
except Exception:
|
157 |
traceback.print_exc()
|
@@ -178,7 +183,7 @@ class APIHelper(object):
|
|
178 |
response = self.generator.create(
|
179 |
messages=message,
|
180 |
)
|
181 |
-
problem = response
|
182 |
except Exception:
|
183 |
traceback.print_exc()
|
184 |
return None
|
@@ -207,7 +212,7 @@ class APIHelper(object):
|
|
207 |
response = self.generator.create(
|
208 |
messages=message,
|
209 |
)
|
210 |
-
problem = response
|
211 |
except Exception:
|
212 |
traceback.print_exc()
|
213 |
return None
|
@@ -228,7 +233,7 @@ class APIHelper(object):
|
|
228 |
response = self.generator.create(
|
229 |
messages=message,
|
230 |
)
|
231 |
-
inspiration = response
|
232 |
except Exception:
|
233 |
traceback.print_exc()
|
234 |
return None
|
@@ -254,7 +259,7 @@ class APIHelper(object):
|
|
254 |
response = self.generator.create(
|
255 |
messages=message,
|
256 |
)
|
257 |
-
inspiration = response
|
258 |
except Exception:
|
259 |
traceback.print_exc()
|
260 |
return None
|
@@ -282,7 +287,7 @@ class APIHelper(object):
|
|
282 |
response = self.generator.create(
|
283 |
messages=message,
|
284 |
)
|
285 |
-
idea = response
|
286 |
except Exception:
|
287 |
traceback.print_exc()
|
288 |
return None
|
@@ -314,7 +319,7 @@ class APIHelper(object):
|
|
314 |
response = self.generator.create(
|
315 |
messages=message,
|
316 |
)
|
317 |
-
idea = response
|
318 |
except Exception:
|
319 |
traceback.print_exc()
|
320 |
return None
|
@@ -340,7 +345,7 @@ class APIHelper(object):
|
|
340 |
response = self.generator.create(
|
341 |
messages=message,
|
342 |
)
|
343 |
-
idea = response
|
344 |
except Exception:
|
345 |
traceback.print_exc()
|
346 |
return None
|
@@ -372,7 +377,7 @@ class APIHelper(object):
|
|
372 |
response = self.generator.create(
|
373 |
messages=message,
|
374 |
)
|
375 |
-
idea = response
|
376 |
except Exception:
|
377 |
traceback.print_exc()
|
378 |
return None
|
@@ -391,7 +396,7 @@ class APIHelper(object):
|
|
391 |
response = self.generator.create(
|
392 |
messages=message,
|
393 |
)
|
394 |
-
idea = response
|
395 |
except Exception:
|
396 |
traceback.print_exc()
|
397 |
return None
|
@@ -413,7 +418,7 @@ class APIHelper(object):
|
|
413 |
response = self.generator.create(
|
414 |
messages=message,
|
415 |
)
|
416 |
-
idea_filtered = response
|
417 |
except Exception:
|
418 |
traceback.print_exc()
|
419 |
return None
|
@@ -435,7 +440,7 @@ class APIHelper(object):
|
|
435 |
response = self.generator.create(
|
436 |
messages=message,
|
437 |
)
|
438 |
-
idea_modified = response
|
439 |
except Exception:
|
440 |
traceback.print_exc()
|
441 |
return None
|
@@ -454,7 +459,7 @@ class APIHelper(object):
|
|
454 |
response = self.generator.create(
|
455 |
messages=message,
|
456 |
)
|
457 |
-
ground_truth = response
|
458 |
except Exception:
|
459 |
traceback.print_exc()
|
460 |
return ground_truth
|
@@ -469,7 +474,7 @@ class APIHelper(object):
|
|
469 |
response = self.generator.create(
|
470 |
messages=message,
|
471 |
)
|
472 |
-
idea_norm = response
|
473 |
except Exception:
|
474 |
traceback.print_exc()
|
475 |
return None
|
@@ -492,7 +497,7 @@ class APIHelper(object):
|
|
492 |
messages=message,
|
493 |
max_tokens=10,
|
494 |
)
|
495 |
-
index = response
|
496 |
except Exception:
|
497 |
traceback.print_exc()
|
498 |
return None
|
@@ -509,7 +514,7 @@ class APIHelper(object):
|
|
509 |
messages=message,
|
510 |
max_tokens=10,
|
511 |
)
|
512 |
-
score = response
|
513 |
except Exception:
|
514 |
traceback.print_exc()
|
515 |
return None
|
@@ -548,7 +553,7 @@ class APIHelper(object):
|
|
548 |
stop=None,
|
549 |
seed=0,
|
550 |
)
|
551 |
-
content = response
|
552 |
new_msg_history = new_msg_history + [
|
553 |
{"role": "assistant", "content": content}
|
554 |
]
|
@@ -577,7 +582,7 @@ class APIHelper(object):
|
|
577 |
response = self.generator.create(
|
578 |
messages=message,
|
579 |
)
|
580 |
-
result = response
|
581 |
except Exception:
|
582 |
traceback.print_exc()
|
583 |
return None
|
@@ -601,7 +606,7 @@ class APIHelper(object):
|
|
601 |
response = self.generator.create(
|
602 |
messages=message,
|
603 |
)
|
604 |
-
result = response
|
605 |
except Exception:
|
606 |
traceback.print_exc()
|
607 |
return None
|
@@ -625,7 +630,7 @@ class APIHelper(object):
|
|
625 |
response = self.generator.create(
|
626 |
messages=message,
|
627 |
)
|
628 |
-
result = response
|
629 |
except Exception:
|
630 |
traceback.print_exc()
|
631 |
return None
|
@@ -649,7 +654,7 @@ class APIHelper(object):
|
|
649 |
response = self.generator.create(
|
650 |
messages=message,
|
651 |
)
|
652 |
-
result = response
|
653 |
except Exception:
|
654 |
traceback.print_exc()
|
655 |
return None
|
@@ -673,7 +678,7 @@ class APIHelper(object):
|
|
673 |
response = self.generator.create(
|
674 |
messages=message,
|
675 |
)
|
676 |
-
result = response
|
677 |
except Exception:
|
678 |
traceback.print_exc()
|
679 |
return None
|
|
|
49 |
def get_helper(self):
|
50 |
MODEL_TYPE = os.environ["MODEL_TYPE"]
|
51 |
MODEL_NAME = os.environ["MODEL_NAME"]
|
52 |
+
if MODEL_NAME != "local":
|
53 |
+
MODEL_API_KEY = os.environ["MODEL_API_KEY"]
|
54 |
+
else:
|
55 |
+
MODEL_API_KEY = ""
|
56 |
BASE_URL = os.environ["BASE_URL"]
|
57 |
return HelperCompany.get()[MODEL_TYPE](
|
58 |
MODEL_API_KEY, MODEL_NAME, BASE_URL, timeout=None
|
|
|
67 |
"glm4-air",
|
68 |
"qwen-max",
|
69 |
"qwen-plus",
|
70 |
+
"gpt-4o-mini",
|
71 |
+
"local",
|
72 |
]:
|
73 |
raise ValueError(f"Check model name...")
|
74 |
|
|
|
83 |
response1 = self.generator.create(
|
84 |
messages=message,
|
85 |
)
|
86 |
+
summary = clean_text(response1)
|
87 |
message.append({"role": "assistant", "content": summary})
|
88 |
message.append(self.prompt.queries[1][0]())
|
89 |
response2 = self.generator.create(
|
90 |
messages=message,
|
91 |
)
|
92 |
+
detail = response2
|
93 |
motivation = clean_text(detail.split(TAG_moti)[1].split(TAG_contr)[0])
|
94 |
contribution = clean_text(detail.split(TAG_contr)[1])
|
95 |
result = {
|
|
|
121 |
response = self.generator.create(
|
122 |
messages=message,
|
123 |
)
|
124 |
+
entities = response
|
125 |
entity_list = entities.strip().split(", ")
|
126 |
clean_entity_list = []
|
127 |
for entity in entity_list:
|
|
|
156 |
response_brainstorming = self.generator.create(
|
157 |
messages=message,
|
158 |
)
|
159 |
+
brainstorming_ideas = response_brainstorming
|
160 |
|
161 |
except Exception:
|
162 |
traceback.print_exc()
|
|
|
183 |
response = self.generator.create(
|
184 |
messages=message,
|
185 |
)
|
186 |
+
problem = response
|
187 |
except Exception:
|
188 |
traceback.print_exc()
|
189 |
return None
|
|
|
212 |
response = self.generator.create(
|
213 |
messages=message,
|
214 |
)
|
215 |
+
problem = response
|
216 |
except Exception:
|
217 |
traceback.print_exc()
|
218 |
return None
|
|
|
233 |
response = self.generator.create(
|
234 |
messages=message,
|
235 |
)
|
236 |
+
inspiration = response
|
237 |
except Exception:
|
238 |
traceback.print_exc()
|
239 |
return None
|
|
|
259 |
response = self.generator.create(
|
260 |
messages=message,
|
261 |
)
|
262 |
+
inspiration = response
|
263 |
except Exception:
|
264 |
traceback.print_exc()
|
265 |
return None
|
|
|
287 |
response = self.generator.create(
|
288 |
messages=message,
|
289 |
)
|
290 |
+
idea = response
|
291 |
except Exception:
|
292 |
traceback.print_exc()
|
293 |
return None
|
|
|
319 |
response = self.generator.create(
|
320 |
messages=message,
|
321 |
)
|
322 |
+
idea = response
|
323 |
except Exception:
|
324 |
traceback.print_exc()
|
325 |
return None
|
|
|
345 |
response = self.generator.create(
|
346 |
messages=message,
|
347 |
)
|
348 |
+
idea = response
|
349 |
except Exception:
|
350 |
traceback.print_exc()
|
351 |
return None
|
|
|
377 |
response = self.generator.create(
|
378 |
messages=message,
|
379 |
)
|
380 |
+
idea = response
|
381 |
except Exception:
|
382 |
traceback.print_exc()
|
383 |
return None
|
|
|
396 |
response = self.generator.create(
|
397 |
messages=message,
|
398 |
)
|
399 |
+
idea = response
|
400 |
except Exception:
|
401 |
traceback.print_exc()
|
402 |
return None
|
|
|
418 |
response = self.generator.create(
|
419 |
messages=message,
|
420 |
)
|
421 |
+
idea_filtered = response
|
422 |
except Exception:
|
423 |
traceback.print_exc()
|
424 |
return None
|
|
|
440 |
response = self.generator.create(
|
441 |
messages=message,
|
442 |
)
|
443 |
+
idea_modified = response
|
444 |
except Exception:
|
445 |
traceback.print_exc()
|
446 |
return None
|
|
|
459 |
response = self.generator.create(
|
460 |
messages=message,
|
461 |
)
|
462 |
+
ground_truth = response
|
463 |
except Exception:
|
464 |
traceback.print_exc()
|
465 |
return ground_truth
|
|
|
474 |
response = self.generator.create(
|
475 |
messages=message,
|
476 |
)
|
477 |
+
idea_norm = response
|
478 |
except Exception:
|
479 |
traceback.print_exc()
|
480 |
return None
|
|
|
497 |
messages=message,
|
498 |
max_tokens=10,
|
499 |
)
|
500 |
+
index = response
|
501 |
except Exception:
|
502 |
traceback.print_exc()
|
503 |
return None
|
|
|
514 |
messages=message,
|
515 |
max_tokens=10,
|
516 |
)
|
517 |
+
score = response
|
518 |
except Exception:
|
519 |
traceback.print_exc()
|
520 |
return None
|
|
|
553 |
stop=None,
|
554 |
seed=0,
|
555 |
)
|
556 |
+
content = response
|
557 |
new_msg_history = new_msg_history + [
|
558 |
{"role": "assistant", "content": content}
|
559 |
]
|
|
|
582 |
response = self.generator.create(
|
583 |
messages=message,
|
584 |
)
|
585 |
+
result = response
|
586 |
except Exception:
|
587 |
traceback.print_exc()
|
588 |
return None
|
|
|
606 |
response = self.generator.create(
|
607 |
messages=message,
|
608 |
)
|
609 |
+
result = response
|
610 |
except Exception:
|
611 |
traceback.print_exc()
|
612 |
return None
|
|
|
630 |
response = self.generator.create(
|
631 |
messages=message,
|
632 |
)
|
633 |
+
result = response
|
634 |
except Exception:
|
635 |
traceback.print_exc()
|
636 |
return None
|
|
|
654 |
response = self.generator.create(
|
655 |
messages=message,
|
656 |
)
|
657 |
+
result = response
|
658 |
except Exception:
|
659 |
traceback.print_exc()
|
660 |
return None
|
|
|
678 |
response = self.generator.create(
|
679 |
messages=message,
|
680 |
)
|
681 |
+
result = response
|
682 |
except Exception:
|
683 |
traceback.print_exc()
|
684 |
return None
|
src/utils/paper_client.py
CHANGED
@@ -8,6 +8,7 @@ from collections import defaultdict, deque
|
|
8 |
from py2neo import Graph, Node, Relationship
|
9 |
from loguru import logger
|
10 |
|
|
|
11 |
class PaperClient:
|
12 |
_instance = None
|
13 |
_initialized = False
|
@@ -43,10 +44,28 @@ class PaperClient:
|
|
43 |
with self.driver.session() as session:
|
44 |
result = session.execute_read(lambda tx: tx.run(query).data())
|
45 |
if result:
|
46 |
-
paper_from_client = result[0][
|
47 |
if paper_from_client is not None:
|
48 |
paper.update(paper_from_client)
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
def get_paper_attribute(self, paper_id, attribute_name):
|
51 |
query = f"""
|
52 |
MATCH (p:Paper {{hash_id: {paper_id}}})
|
@@ -55,11 +74,11 @@ class PaperClient:
|
|
55 |
with self.driver.session() as session:
|
56 |
result = session.execute_read(lambda tx: tx.run(query).data())
|
57 |
if result:
|
58 |
-
return result[0][
|
59 |
else:
|
60 |
logger.error(f"paper id {paper_id} get {attribute_name} failed.")
|
61 |
return None
|
62 |
-
|
63 |
def get_paper_by_attribute(self, attribute_name, anttribute_value):
|
64 |
query = f"""
|
65 |
MATCH (p:Paper {{{attribute_name}: '{anttribute_value}'}})
|
@@ -68,7 +87,7 @@ class PaperClient:
|
|
68 |
with self.driver.session() as session:
|
69 |
result = session.execute_read(lambda tx: tx.run(query).data())
|
70 |
if result:
|
71 |
-
return result[0][
|
72 |
else:
|
73 |
return None
|
74 |
|
@@ -81,71 +100,50 @@ class PaperClient:
|
|
81 |
RETURN p.hash_id as hash_id
|
82 |
"""
|
83 |
with self.driver.session() as session:
|
84 |
-
result = session.execute_read(
|
|
|
|
|
85 |
if result:
|
86 |
-
return [record[
|
87 |
else:
|
88 |
return []
|
89 |
-
|
90 |
-
def
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
visited = set([entity_name])
|
95 |
-
related_entities = set()
|
96 |
-
|
97 |
-
while queue:
|
98 |
-
batch_queue = [queue.popleft() for _ in range(len(queue))]
|
99 |
-
batch_entities = [item[0] for item in batch_queue]
|
100 |
-
batch_depths = [item[1] for item in batch_queue]
|
101 |
-
|
102 |
-
if all(depth >= n for depth in batch_depths):
|
103 |
-
continue
|
104 |
-
if relation_name == "related":
|
105 |
-
query = """
|
106 |
-
UNWIND $batch_entities AS entity_name
|
107 |
-
MATCH (e1:Entity {name: entity_name})-[:RELATED_TO]->(p:Paper)<-[:RELATED_TO]-(e2:Entity)
|
108 |
-
WHERE e1 <> e2
|
109 |
-
WITH e1, e2, COUNT(p) AS common_papers, entity_name
|
110 |
-
WHERE common_papers > $k
|
111 |
-
RETURN e2.name AS entities, entity_name AS source_entity, common_papers
|
112 |
-
"""
|
113 |
-
elif relation_name == "connect":
|
114 |
-
query = """
|
115 |
-
UNWIND $batch_entities AS entity_name
|
116 |
-
MATCH (e1:Entity {name: entity_name})-[r:CONNECT]-(e2:Entity)
|
117 |
-
WHERE e1 <> e2 and r.strength >= $k
|
118 |
-
WITH e1, e2, entity_name
|
119 |
-
RETURN e2.name AS entities, entity_name AS source_entity
|
120 |
-
"""
|
121 |
-
with self.driver.session() as session:
|
122 |
-
result = session.execute_read(lambda tx: tx.run(query, batch_entities=batch_entities, k=k).data())
|
123 |
-
|
124 |
-
for record in result:
|
125 |
-
entity = record['entities']
|
126 |
-
source_entity = record['source_entity']
|
127 |
-
if entity not in visited:
|
128 |
-
visited.add(entity)
|
129 |
-
queue.append((entity, batch_depths[batch_entities.index(source_entity)] + 1))
|
130 |
-
related_entities.add(entity)
|
131 |
-
|
132 |
-
return list(related_entities)
|
133 |
-
related_entities = bfs_query(entity_name, n, k)
|
134 |
-
if entity_name in related_entities:
|
135 |
-
related_entities.remove(entity_name)
|
136 |
-
return related_entities
|
137 |
-
|
138 |
-
def find_entities_by_paper(self, hash_id: int):
|
139 |
query = """
|
140 |
-
|
141 |
-
|
|
|
|
|
|
|
|
|
142 |
"""
|
143 |
with self.driver.session() as session:
|
144 |
-
result = session.execute_read(
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
|
150 |
def find_paper_by_entity(self, entity_name):
|
151 |
query = """
|
@@ -153,18 +151,19 @@ class PaperClient:
|
|
153 |
RETURN p.hash_id AS hash_id
|
154 |
"""
|
155 |
with self.driver.session() as session:
|
156 |
-
result = session.execute_read(
|
|
|
|
|
157 |
if result:
|
158 |
-
return [record[
|
159 |
else:
|
160 |
return []
|
161 |
-
|
162 |
# TODO: @云翔
|
163 |
# 增加通过entity返回包含entity语句的功能
|
164 |
def find_sentence_by_entity(self, entity_name):
|
165 |
# Return: list(str)
|
166 |
return []
|
167 |
-
|
168 |
|
169 |
def find_sentences_by_entity(self, entity_name):
|
170 |
query = """
|
@@ -178,14 +177,25 @@ class PaperClient:
|
|
178 |
p.hash_id AS hash_id
|
179 |
"""
|
180 |
sentences = []
|
181 |
-
|
182 |
with self.driver.session() as session:
|
183 |
-
result = session.execute_read(
|
|
|
|
|
184 |
for record in result:
|
185 |
-
for key in [
|
186 |
if record[key]:
|
187 |
-
filtered_sentences = [
|
188 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
|
190 |
return sentences
|
191 |
|
@@ -194,9 +204,11 @@ class PaperClient:
|
|
194 |
MATCH (n:Paper) where n.year=$year and n.venue_name=$venue_name return n
|
195 |
"""
|
196 |
with self.driver.session() as session:
|
197 |
-
result = session.execute_read(
|
|
|
|
|
198 |
if result:
|
199 |
-
return [record[
|
200 |
else:
|
201 |
return []
|
202 |
|
@@ -230,7 +242,26 @@ class PaperClient:
|
|
230 |
RETURN p
|
231 |
"""
|
232 |
with self.driver.session() as session:
|
233 |
-
result = session.execute_write(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
234 |
|
235 |
def check_entity_node_count(self, hash_id: int):
|
236 |
query_check_count = """
|
@@ -239,7 +270,9 @@ class PaperClient:
|
|
239 |
"""
|
240 |
with self.driver.session() as session:
|
241 |
# Check the number of related entities
|
242 |
-
result = session.execute_read(
|
|
|
|
|
243 |
if result[0]["entity_count"] > 3:
|
244 |
return False
|
245 |
return True
|
@@ -254,16 +287,30 @@ class PaperClient:
|
|
254 |
"""
|
255 |
with self.driver.session() as session:
|
256 |
for entity_name in entities:
|
257 |
-
result = session.execute_write(
|
258 |
-
|
|
|
|
|
|
|
|
|
259 |
def add_paper_citation(self, paper: dict):
|
260 |
query = """
|
261 |
MERGE (p:Paper {hash_id: $hash_id}) ON MATCH SET p.cite_id_list = $cite_id_list, p.entities = $entities, p.all_cite_id_list = $all_cite_id_list
|
262 |
"""
|
263 |
with self.driver.session() as session:
|
264 |
-
result = session.execute_write(
|
265 |
-
|
266 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
267 |
if hash_id is not None:
|
268 |
query = """
|
269 |
MATCH (p:Paper {hash_id: $hash_id})
|
@@ -271,119 +318,302 @@ class PaperClient:
|
|
271 |
RETURN p.abstract AS context, p.hash_id AS hash_id, p.title AS title
|
272 |
"""
|
273 |
with self.driver.session() as session:
|
274 |
-
results = session.execute_write(
|
275 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
276 |
query = """
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
277 |
MATCH (p:Paper)
|
278 |
WHERE p.abstract IS NOT NULL
|
279 |
RETURN p.abstract AS context, p.hash_id AS hash_id, p.title AS title
|
|
|
280 |
"""
|
281 |
with self.driver.session() as session:
|
282 |
-
results = session.execute_write(
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
293 |
with self.driver.session() as session:
|
294 |
-
|
|
|
|
|
|
|
|
|
295 |
|
296 |
-
def add_paper_bg_embedding(self, embedding_model, hash_id=None):
|
297 |
if hash_id is not None:
|
298 |
query = """
|
299 |
MATCH (p:Paper {hash_id: $hash_id})
|
300 |
WHERE p.motivation IS NOT NULL
|
301 |
-
RETURN p.motivation AS context, p.hash_id AS hash_id
|
302 |
"""
|
303 |
with self.driver.session() as session:
|
304 |
-
results = session.execute_write(
|
305 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
306 |
query = """
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
307 |
MATCH (p:Paper)
|
308 |
WHERE p.motivation IS NOT NULL
|
309 |
-
RETURN p.motivation AS context, p.hash_id AS hash_id
|
|
|
310 |
"""
|
311 |
with self.driver.session() as session:
|
312 |
-
results = session.execute_write(
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
323 |
with self.driver.session() as session:
|
324 |
-
|
325 |
-
|
326 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
327 |
if hash_id is not None:
|
328 |
query = """
|
329 |
MATCH (p:Paper {hash_id: $hash_id})
|
330 |
WHERE p.contribution IS NOT NULL
|
331 |
-
RETURN p.contribution AS context, p.hash_id AS hash_id
|
332 |
"""
|
333 |
with self.driver.session() as session:
|
334 |
-
results = session.execute_write(
|
335 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
336 |
query = """
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
337 |
MATCH (p:Paper)
|
338 |
WHERE p.contribution IS NOT NULL
|
339 |
-
RETURN p.contribution AS context, p.hash_id AS hash_id
|
|
|
340 |
"""
|
341 |
with self.driver.session() as session:
|
342 |
-
results = session.execute_write(
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
353 |
with self.driver.session() as session:
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
|
|
|
|
|
|
|
|
|
|
358 |
if hash_id is not None:
|
359 |
query = """
|
360 |
MATCH (p:Paper {hash_id: $hash_id})
|
361 |
WHERE p.summary IS NOT NULL
|
362 |
-
RETURN p.summary AS context, p.hash_id AS hash_id
|
363 |
"""
|
364 |
with self.driver.session() as session:
|
365 |
-
results = session.execute_write(
|
366 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
367 |
query = """
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
368 |
MATCH (p:Paper)
|
369 |
WHERE p.summary IS NOT NULL
|
370 |
-
RETURN p.summary AS context, p.hash_id AS hash_id
|
|
|
371 |
"""
|
372 |
with self.driver.session() as session:
|
373 |
-
results = session.execute_write(
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
384 |
with self.driver.session() as session:
|
385 |
-
|
386 |
-
|
|
|
|
|
|
|
|
|
387 |
def cosine_similarity_search(self, embedding, k=1, type_name="embedding"):
|
388 |
query = f"""
|
389 |
MATCH (paper:Paper)
|
@@ -394,8 +624,10 @@ class PaperClient:
|
|
394 |
ORDER BY score DESC LIMIT {k}
|
395 |
"""
|
396 |
with self.driver.session() as session:
|
397 |
-
results = session.execute_read(
|
398 |
-
|
|
|
|
|
399 |
for result in results:
|
400 |
related_paper.append(result["paper"]["hash_id"])
|
401 |
return related_paper
|
@@ -417,7 +649,7 @@ class PaperClient:
|
|
417 |
"""
|
418 |
with self.driver.session() as session:
|
419 |
session.execute_write(lambda tx: tx.run(query).data())
|
420 |
-
|
421 |
def filter_paper_id_list(self, paper_id_list, year="2024"):
|
422 |
if not paper_id_list:
|
423 |
return []
|
@@ -429,12 +661,14 @@ class PaperClient:
|
|
429 |
RETURN p.hash_id AS hash_id
|
430 |
"""
|
431 |
with self.driver.session() as session:
|
432 |
-
result = session.execute_read(
|
433 |
-
|
434 |
-
|
|
|
|
|
435 |
existing_paper_ids = list(set(existing_paper_ids))
|
436 |
return existing_paper_ids
|
437 |
-
|
438 |
def check_index_exists(self):
|
439 |
query = "SHOW INDEXES"
|
440 |
with self.driver.session() as session:
|
@@ -451,7 +685,7 @@ class PaperClient:
|
|
451 |
"""
|
452 |
with self.driver.session() as session:
|
453 |
session.execute_write(lambda tx: tx.run(query).data())
|
454 |
-
|
455 |
def get_entity_related_paper_num(self, entity_name):
|
456 |
query = """
|
457 |
MATCH (e:Entity {name: $entity_name})-[:RELATED_TO]->(p:Paper)
|
@@ -459,10 +693,30 @@ class PaperClient:
|
|
459 |
RETURN PaperCount
|
460 |
"""
|
461 |
with self.driver.session() as session:
|
462 |
-
result = session.execute_read(
|
463 |
-
|
|
|
|
|
464 |
return paper_num
|
465 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
466 |
def get_entity_text(self):
|
467 |
query = """
|
468 |
MATCH (e:Entity)-[:RELATED_TO]->(p:Paper)
|
@@ -472,11 +726,13 @@ class PaperClient:
|
|
472 |
"""
|
473 |
with self.driver.session() as session:
|
474 |
result = session.execute_read(lambda tx: tx.run(query).data())
|
475 |
-
text_list = [record[
|
476 |
return text_list
|
477 |
-
|
478 |
def get_entity_combinations(self, venue_name, year):
|
479 |
-
def process_paper_relationships(
|
|
|
|
|
480 |
if entity_name_2 < entity_name_1:
|
481 |
entity_name_1, entity_name_2 = entity_name_2, entity_name_1
|
482 |
query = """
|
@@ -486,13 +742,17 @@ class PaperClient:
|
|
486 |
ON CREATE SET r.strength = 1
|
487 |
ON MATCH SET r.strength = r.strength + 1
|
488 |
"""
|
489 |
-
sentences = re.split(r
|
490 |
for sentence in sentences:
|
491 |
sentence = sentence.lower()
|
492 |
if entity_name_1 in sentence and entity_name_2 in sentence:
|
493 |
# 如果两个实体在同一句话中出现过,则创建或更新 CONNECT 关系
|
494 |
session.execute_write(
|
495 |
-
lambda tx: tx.run(
|
|
|
|
|
|
|
|
|
496 |
)
|
497 |
# logger.debug(f"CONNECT relation created or updated between {entity_name_1} and {entity_name_2} for Paper ID {paper_id}")
|
498 |
break # 如果找到一次出现就可以退出循环
|
@@ -506,13 +766,17 @@ class PaperClient:
|
|
506 |
RETURN p.hash_id AS hash_id, entities[i].name AS entity_name_1, entities[j].name AS entity_name_2
|
507 |
"""
|
508 |
with self.driver.session() as session:
|
509 |
-
result = session.execute_read(
|
|
|
|
|
510 |
for record in tqdm(result):
|
511 |
paper_id = record["hash_id"]
|
512 |
-
entity_name_1 = record[
|
513 |
-
entity_name_2 = record[
|
514 |
abstract = self.get_paper_attribute(paper_id, "abstract")
|
515 |
-
process_paper_relationships(
|
|
|
|
|
516 |
|
517 |
def build_citemap(self):
|
518 |
citemap = defaultdict(set)
|
@@ -523,8 +787,8 @@ class PaperClient:
|
|
523 |
with self.driver.session() as session:
|
524 |
results = session.execute_read(lambda tx: tx.run(query).data())
|
525 |
for result in results:
|
526 |
-
hash_id = result[
|
527 |
-
cite_id_list = result[
|
528 |
if cite_id_list:
|
529 |
for cited_id in cite_id_list:
|
530 |
citemap[hash_id].add(cited_id)
|
@@ -537,12 +801,17 @@ class PaperClient:
|
|
537 |
AUTH = (NEO4J_USERNAME, NEO4J_PASSWD)
|
538 |
graph = Graph(URI, auth=AUTH)
|
539 |
# 创建一个字典来保存数据
|
|
|
540 |
data = {"nodes": [], "relationships": []}
|
541 |
-
|
|
|
|
|
|
|
|
|
542 |
MATCH (e:Entity)-[r:RELATED_TO]->(p:Paper)
|
543 |
-
WHERE p.venue_name='iclr' and p.year='2024'
|
544 |
RETURN p, e, r
|
545 |
"""
|
|
|
546 |
results = graph.run(query)
|
547 |
# 处理查询结果
|
548 |
for record in tqdm(results):
|
@@ -550,39 +819,46 @@ class PaperClient:
|
|
550 |
entity_node = record["e"]
|
551 |
relationship = record["r"]
|
552 |
# 将节点数据加入字典
|
553 |
-
data["nodes"].append(
|
554 |
-
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
-
|
561 |
-
|
562 |
-
|
|
|
|
|
|
|
|
|
563 |
# 将关系数据加入字典
|
564 |
-
data["relationships"].append(
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
|
|
|
|
|
|
|
570 |
query = """
|
571 |
MATCH (p:Paper)
|
572 |
WHERE p.venue_name='acl' and p.year='2024'
|
573 |
RETURN p
|
574 |
"""
|
575 |
-
"""
|
576 |
results = graph.run(query)
|
577 |
for record in tqdm(results):
|
578 |
paper_node = record["p"]
|
579 |
# 将节点数据加入字典
|
580 |
-
data["nodes"].append(
|
581 |
-
|
582 |
-
|
583 |
-
|
584 |
-
|
585 |
-
|
|
|
586 |
# 去除重复节点
|
587 |
# data["nodes"] = [dict(t) for t in {tuple(d.items()) for d in data["nodes"]}]
|
588 |
unique_nodes = []
|
@@ -595,9 +871,11 @@ class PaperClient:
|
|
595 |
unique_nodes.append(node)
|
596 |
data["nodes"] = unique_nodes
|
597 |
# 将数据保存为 JSON 文件
|
598 |
-
with open(
|
|
|
|
|
599 |
json.dump(data, f, ensure_ascii=False, indent=4)
|
600 |
-
|
601 |
def neo4j_import_data(self):
|
602 |
# clear_database() # 清空数据库,谨慎执行
|
603 |
URI = os.environ["NEO4J_URL"]
|
@@ -606,7 +884,9 @@ class PaperClient:
|
|
606 |
AUTH = (NEO4J_USERNAME, NEO4J_PASSWD)
|
607 |
graph = Graph(URI, auth=AUTH)
|
608 |
# 从 JSON 文件中读取数据
|
609 |
-
with open(
|
|
|
|
|
610 |
data = json.load(f)
|
611 |
# 创建节点
|
612 |
nodes = {}
|
|
|
8 |
from py2neo import Graph, Node, Relationship
|
9 |
from loguru import logger
|
10 |
|
11 |
+
|
12 |
class PaperClient:
|
13 |
_instance = None
|
14 |
_initialized = False
|
|
|
44 |
with self.driver.session() as session:
|
45 |
result = session.execute_read(lambda tx: tx.run(query).data())
|
46 |
if result:
|
47 |
+
paper_from_client = result[0]["p"]
|
48 |
if paper_from_client is not None:
|
49 |
paper.update(paper_from_client)
|
50 |
+
|
51 |
+
def update_papers_from_client(self, paper_id_list):
|
52 |
+
query = """
|
53 |
+
UNWIND $papers AS paper
|
54 |
+
MATCH (p:Paper {hash_id: paper.hash_id})
|
55 |
+
RETURN p as result
|
56 |
+
"""
|
57 |
+
paper_data = [
|
58 |
+
{
|
59 |
+
"hash_id": hash_id,
|
60 |
+
}
|
61 |
+
for hash_id in paper_id_list
|
62 |
+
]
|
63 |
+
with self.driver.session() as session:
|
64 |
+
result = session.execute_read(
|
65 |
+
lambda tx: tx.run(query, papers=paper_data).data()
|
66 |
+
)
|
67 |
+
return [r["result"] for r in result]
|
68 |
+
|
69 |
def get_paper_attribute(self, paper_id, attribute_name):
|
70 |
query = f"""
|
71 |
MATCH (p:Paper {{hash_id: {paper_id}}})
|
|
|
74 |
with self.driver.session() as session:
|
75 |
result = session.execute_read(lambda tx: tx.run(query).data())
|
76 |
if result:
|
77 |
+
return result[0]["attributeValue"]
|
78 |
else:
|
79 |
logger.error(f"paper id {paper_id} get {attribute_name} failed.")
|
80 |
return None
|
81 |
+
|
82 |
def get_paper_by_attribute(self, attribute_name, anttribute_value):
|
83 |
query = f"""
|
84 |
MATCH (p:Paper {{{attribute_name}: '{anttribute_value}'}})
|
|
|
87 |
with self.driver.session() as session:
|
88 |
result = session.execute_read(lambda tx: tx.run(query).data())
|
89 |
if result:
|
90 |
+
return result[0]["p"]
|
91 |
else:
|
92 |
return None
|
93 |
|
|
|
100 |
RETURN p.hash_id as hash_id
|
101 |
"""
|
102 |
with self.driver.session() as session:
|
103 |
+
result = session.execute_read(
|
104 |
+
lambda tx: tx.run(query, entity=entity).data()
|
105 |
+
)
|
106 |
if result:
|
107 |
+
return [record["hash_id"] for record in result]
|
108 |
else:
|
109 |
return []
|
110 |
+
|
111 |
+
def find_related_entities_by_entity_list(
|
112 |
+
self, entity_names, n=1, k=3, relation_name="related"
|
113 |
+
):
|
114 |
+
related_entities = set()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
query = """
|
116 |
+
UNWIND $batch_entities AS entity_name
|
117 |
+
MATCH (e1:Entity {name: entity_name})-[:RELATED_TO]->(p:Paper)<-[:RELATED_TO]-(e2:Entity)
|
118 |
+
WHERE e1 <> e2
|
119 |
+
WITH e1, e2, COUNT(p) AS common_papers, entity_name
|
120 |
+
WHERE common_papers > $k
|
121 |
+
RETURN e2.name AS entities, entity_name AS source_entity, common_papers
|
122 |
"""
|
123 |
with self.driver.session() as session:
|
124 |
+
result = session.execute_read(
|
125 |
+
lambda tx: tx.run(query, batch_entities=entity_names, k=k).data()
|
126 |
+
)
|
127 |
+
for record in result:
|
128 |
+
entity = record["entities"]
|
129 |
+
related_entities.add(entity)
|
130 |
+
return list(related_entities)
|
131 |
+
|
132 |
+
def find_entities_by_paper_list(self, hash_ids: list):
|
133 |
+
query = """
|
134 |
+
UNWIND $hash_ids AS hash_id
|
135 |
+
MATCH (e:Entity)-[:RELATED_TO]->(p:Paper {hash_id: hash_id})
|
136 |
+
RETURN hash_id, e.name AS entity_name
|
137 |
+
"""
|
138 |
+
with self.driver.session() as session:
|
139 |
+
result = session.execute_read(
|
140 |
+
lambda tx: tx.run(query, hash_ids=hash_ids).data()
|
141 |
+
)
|
142 |
+
# 按照每个 hash_id 分组实体
|
143 |
+
entity_list = []
|
144 |
+
for record in result:
|
145 |
+
entity_list.append(record["entity_name"])
|
146 |
+
return entity_list
|
147 |
|
148 |
def find_paper_by_entity(self, entity_name):
|
149 |
query = """
|
|
|
151 |
RETURN p.hash_id AS hash_id
|
152 |
"""
|
153 |
with self.driver.session() as session:
|
154 |
+
result = session.execute_read(
|
155 |
+
lambda tx: tx.run(query, entity_name=entity_name).data()
|
156 |
+
)
|
157 |
if result:
|
158 |
+
return [record["hash_id"] for record in result]
|
159 |
else:
|
160 |
return []
|
161 |
+
|
162 |
# TODO: @云翔
|
163 |
# 增加通过entity返回包含entity语句的功能
|
164 |
def find_sentence_by_entity(self, entity_name):
|
165 |
# Return: list(str)
|
166 |
return []
|
|
|
167 |
|
168 |
def find_sentences_by_entity(self, entity_name):
|
169 |
query = """
|
|
|
177 |
p.hash_id AS hash_id
|
178 |
"""
|
179 |
sentences = []
|
180 |
+
|
181 |
with self.driver.session() as session:
|
182 |
+
result = session.execute_read(
|
183 |
+
lambda tx: tx.run(query, entity_name=entity_name).data()
|
184 |
+
)
|
185 |
for record in result:
|
186 |
+
for key in ["abstract", "introduction", "methodology"]:
|
187 |
if record[key]:
|
188 |
+
filtered_sentences = [
|
189 |
+
sentence.strip() + "."
|
190 |
+
for sentence in record[key].split(".")
|
191 |
+
if entity_name in sentence
|
192 |
+
]
|
193 |
+
sentences.extend(
|
194 |
+
[
|
195 |
+
f"{record['hash_id']}: {sentence}"
|
196 |
+
for sentence in filtered_sentences
|
197 |
+
]
|
198 |
+
)
|
199 |
|
200 |
return sentences
|
201 |
|
|
|
204 |
MATCH (n:Paper) where n.year=$year and n.venue_name=$venue_name return n
|
205 |
"""
|
206 |
with self.driver.session() as session:
|
207 |
+
result = session.execute_read(
|
208 |
+
lambda tx: tx.run(query, year=year, venue_name=venue_name).data()
|
209 |
+
)
|
210 |
if result:
|
211 |
+
return [record["n"] for record in result]
|
212 |
else:
|
213 |
return []
|
214 |
|
|
|
242 |
RETURN p
|
243 |
"""
|
244 |
with self.driver.session() as session:
|
245 |
+
result = session.execute_write(
|
246 |
+
lambda tx: tx.run(
|
247 |
+
query,
|
248 |
+
hash_id=paper["hash_id"],
|
249 |
+
venue_name=paper["venue_name"],
|
250 |
+
year=paper["year"],
|
251 |
+
title=paper["title"],
|
252 |
+
pdf_url=paper["pdf_url"],
|
253 |
+
abstract=paper["abstract"],
|
254 |
+
introduction=paper["introduction"],
|
255 |
+
reference=paper["reference"],
|
256 |
+
summary=paper["summary"],
|
257 |
+
motivation=paper["motivation"],
|
258 |
+
contribution=paper["contribution"],
|
259 |
+
methodology=paper["methodology"],
|
260 |
+
ground_truth=paper["ground_truth"],
|
261 |
+
reference_filter=paper["reference_filter"],
|
262 |
+
conclusions=paper["conclusions"],
|
263 |
+
).data()
|
264 |
+
)
|
265 |
|
266 |
def check_entity_node_count(self, hash_id: int):
|
267 |
query_check_count = """
|
|
|
270 |
"""
|
271 |
with self.driver.session() as session:
|
272 |
# Check the number of related entities
|
273 |
+
result = session.execute_read(
|
274 |
+
lambda tx: tx.run(query_check_count, hash_id=hash_id).data()
|
275 |
+
)
|
276 |
if result[0]["entity_count"] > 3:
|
277 |
return False
|
278 |
return True
|
|
|
287 |
"""
|
288 |
with self.driver.session() as session:
|
289 |
for entity_name in entities:
|
290 |
+
result = session.execute_write(
|
291 |
+
lambda tx: tx.run(
|
292 |
+
query, entity_name=entity_name, hash_id=hash_id
|
293 |
+
).data()
|
294 |
+
)
|
295 |
+
|
296 |
def add_paper_citation(self, paper: dict):
|
297 |
query = """
|
298 |
MERGE (p:Paper {hash_id: $hash_id}) ON MATCH SET p.cite_id_list = $cite_id_list, p.entities = $entities, p.all_cite_id_list = $all_cite_id_list
|
299 |
"""
|
300 |
with self.driver.session() as session:
|
301 |
+
result = session.execute_write(
|
302 |
+
lambda tx: tx.run(
|
303 |
+
query,
|
304 |
+
hash_id=paper["hash_id"],
|
305 |
+
cite_id_list=paper["cite_id_list"],
|
306 |
+
entities=paper["entities"],
|
307 |
+
all_cite_id_list=paper["all_cite_id_list"],
|
308 |
+
).data()
|
309 |
+
)
|
310 |
+
|
311 |
+
def add_paper_abstract_embedding(
|
312 |
+
self, embedding_model, hash_id=None, batch_size=512
|
313 |
+
):
|
314 |
if hash_id is not None:
|
315 |
query = """
|
316 |
MATCH (p:Paper {hash_id: $hash_id})
|
|
|
318 |
RETURN p.abstract AS context, p.hash_id AS hash_id, p.title AS title
|
319 |
"""
|
320 |
with self.driver.session() as session:
|
321 |
+
results = session.execute_write(
|
322 |
+
lambda tx: tx.run(query, hash_id=hash_id).data()
|
323 |
+
)
|
324 |
+
contexts = [result["title"] + result["context"] for result in results]
|
325 |
+
paper_ids = [result["hash_id"] for result in results]
|
326 |
+
context_embeddings = embedding_model.encode(
|
327 |
+
contexts, convert_to_tensor=True, device=self.device
|
328 |
+
)
|
329 |
query = """
|
330 |
+
MERGE (p:Paper {hash_id: $hash_id})
|
331 |
+
ON CREATE SET p.abstract_embedding = $embedding
|
332 |
+
ON MATCH SET p.abstract_embedding = $embedding
|
333 |
+
"""
|
334 |
+
for idx, hash_id in tqdm(enumerate(paper_ids)):
|
335 |
+
embedding = (
|
336 |
+
context_embeddings[idx].detach().cpu().numpy().flatten().tolist()
|
337 |
+
)
|
338 |
+
with self.driver.session() as session:
|
339 |
+
results = session.execute_write(
|
340 |
+
lambda tx: tx.run(
|
341 |
+
query, hash_id=hash_id, embedding=embedding
|
342 |
+
).data()
|
343 |
+
)
|
344 |
+
return
|
345 |
+
offset = 0
|
346 |
+
while True:
|
347 |
+
query = f"""
|
348 |
MATCH (p:Paper)
|
349 |
WHERE p.abstract IS NOT NULL
|
350 |
RETURN p.abstract AS context, p.hash_id AS hash_id, p.title AS title
|
351 |
+
SKIP $offset LIMIT $batch_size
|
352 |
"""
|
353 |
with self.driver.session() as session:
|
354 |
+
results = session.execute_write(
|
355 |
+
lambda tx: tx.run(
|
356 |
+
query, offset=offset, batch_size=batch_size
|
357 |
+
).data()
|
358 |
+
)
|
359 |
+
if not results:
|
360 |
+
break
|
361 |
+
contexts = [result["title"] + result["context"] for result in results]
|
362 |
+
paper_ids = [result["hash_id"] for result in results]
|
363 |
+
context_embeddings = embedding_model.encode(
|
364 |
+
contexts,
|
365 |
+
batch_size=batch_size,
|
366 |
+
convert_to_tensor=True,
|
367 |
+
device=self.device,
|
368 |
+
)
|
369 |
+
write_query = """
|
370 |
+
UNWIND $data AS row
|
371 |
+
MERGE (p:Paper {hash_id: row.hash_id})
|
372 |
+
ON CREATE SET p.abstract_embedding = row.embedding
|
373 |
+
ON MATCH SET p.abstract_embedding = row.embedding
|
374 |
+
"""
|
375 |
+
data_to_write = []
|
376 |
+
for idx, hash_id in enumerate(paper_ids):
|
377 |
+
embedding = (
|
378 |
+
context_embeddings[idx].detach().cpu().numpy().flatten().tolist()
|
379 |
+
)
|
380 |
+
data_to_write.append({"hash_id": hash_id, "embedding": embedding})
|
381 |
with self.driver.session() as session:
|
382 |
+
session.execute_write(
|
383 |
+
lambda tx: tx.run(write_query, data=data_to_write)
|
384 |
+
)
|
385 |
+
offset += batch_size
|
386 |
+
logger.info(f"== Processed batch starting at offset {offset} ==")
|
387 |
|
388 |
+
def add_paper_bg_embedding(self, embedding_model, hash_id=None, batch_size=512):
|
389 |
if hash_id is not None:
|
390 |
query = """
|
391 |
MATCH (p:Paper {hash_id: $hash_id})
|
392 |
WHERE p.motivation IS NOT NULL
|
393 |
+
RETURN p.motivation AS context, p.hash_id AS hash_id, p.title AS title
|
394 |
"""
|
395 |
with self.driver.session() as session:
|
396 |
+
results = session.execute_write(
|
397 |
+
lambda tx: tx.run(query, hash_id=hash_id).data()
|
398 |
+
)
|
399 |
+
contexts = [result["context"] for result in results]
|
400 |
+
paper_ids = [result["hash_id"] for result in results]
|
401 |
+
context_embeddings = embedding_model.encode(
|
402 |
+
contexts, convert_to_tensor=True, device=self.device
|
403 |
+
)
|
404 |
query = """
|
405 |
+
MERGE (p:Paper {hash_id: $hash_id})
|
406 |
+
ON CREATE SET p.motivation_embedding = $embedding
|
407 |
+
ON MATCH SET p.motivation_embedding = $embedding
|
408 |
+
"""
|
409 |
+
for idx, hash_id in tqdm(enumerate(paper_ids)):
|
410 |
+
embedding = (
|
411 |
+
context_embeddings[idx].detach().cpu().numpy().flatten().tolist()
|
412 |
+
)
|
413 |
+
with self.driver.session() as session:
|
414 |
+
results = session.execute_write(
|
415 |
+
lambda tx: tx.run(
|
416 |
+
query, hash_id=hash_id, embedding=embedding
|
417 |
+
).data()
|
418 |
+
)
|
419 |
+
return
|
420 |
+
offset = 0
|
421 |
+
while True:
|
422 |
+
query = f"""
|
423 |
MATCH (p:Paper)
|
424 |
WHERE p.motivation IS NOT NULL
|
425 |
+
RETURN p.motivation AS context, p.hash_id AS hash_id, p.title AS title
|
426 |
+
SKIP $offset LIMIT $batch_size
|
427 |
"""
|
428 |
with self.driver.session() as session:
|
429 |
+
results = session.execute_write(
|
430 |
+
lambda tx: tx.run(
|
431 |
+
query, offset=offset, batch_size=batch_size
|
432 |
+
).data()
|
433 |
+
)
|
434 |
+
if not results:
|
435 |
+
break
|
436 |
+
contexts = [result["title"] + result["context"] for result in results]
|
437 |
+
paper_ids = [result["hash_id"] for result in results]
|
438 |
+
context_embeddings = embedding_model.encode(
|
439 |
+
contexts,
|
440 |
+
batch_size=batch_size,
|
441 |
+
convert_to_tensor=True,
|
442 |
+
device=self.device,
|
443 |
+
)
|
444 |
+
write_query = """
|
445 |
+
UNWIND $data AS row
|
446 |
+
MERGE (p:Paper {hash_id: row.hash_id})
|
447 |
+
ON CREATE SET p.motivation_embedding = row.embedding
|
448 |
+
ON MATCH SET p.motivation_embedding = row.embedding
|
449 |
+
"""
|
450 |
+
data_to_write = []
|
451 |
+
for idx, hash_id in enumerate(paper_ids):
|
452 |
+
embedding = (
|
453 |
+
context_embeddings[idx].detach().cpu().numpy().flatten().tolist()
|
454 |
+
)
|
455 |
+
data_to_write.append({"hash_id": hash_id, "embedding": embedding})
|
456 |
with self.driver.session() as session:
|
457 |
+
session.execute_write(
|
458 |
+
lambda tx: tx.run(write_query, data=data_to_write)
|
459 |
+
)
|
460 |
+
offset += batch_size
|
461 |
+
logger.info(f"== Processed batch starting at offset {offset} ==")
|
462 |
+
|
463 |
+
def add_paper_contribution_embedding(
|
464 |
+
self, embedding_model, hash_id=None, batch_size=512
|
465 |
+
):
|
466 |
if hash_id is not None:
|
467 |
query = """
|
468 |
MATCH (p:Paper {hash_id: $hash_id})
|
469 |
WHERE p.contribution IS NOT NULL
|
470 |
+
RETURN p.contribution AS context, p.hash_id AS hash_id, p.title AS title
|
471 |
"""
|
472 |
with self.driver.session() as session:
|
473 |
+
results = session.execute_write(
|
474 |
+
lambda tx: tx.run(query, hash_id=hash_id).data()
|
475 |
+
)
|
476 |
+
contexts = [result["context"] for result in results]
|
477 |
+
paper_ids = [result["hash_id"] for result in results]
|
478 |
+
context_embeddings = embedding_model.encode(
|
479 |
+
contexts, convert_to_tensor=True, device=self.device
|
480 |
+
)
|
481 |
query = """
|
482 |
+
MERGE (p:Paper {hash_id: $hash_id})
|
483 |
+
ON CREATE SET p.contribution_embedding = $embedding
|
484 |
+
ON MATCH SET p.contribution_embedding = $embedding
|
485 |
+
"""
|
486 |
+
for idx, hash_id in tqdm(enumerate(paper_ids)):
|
487 |
+
embedding = (
|
488 |
+
context_embeddings[idx].detach().cpu().numpy().flatten().tolist()
|
489 |
+
)
|
490 |
+
with self.driver.session() as session:
|
491 |
+
results = session.execute_write(
|
492 |
+
lambda tx: tx.run(
|
493 |
+
query, hash_id=hash_id, embedding=embedding
|
494 |
+
).data()
|
495 |
+
)
|
496 |
+
return
|
497 |
+
offset = 0
|
498 |
+
while True:
|
499 |
+
query = f"""
|
500 |
MATCH (p:Paper)
|
501 |
WHERE p.contribution IS NOT NULL
|
502 |
+
RETURN p.contribution AS context, p.hash_id AS hash_id, p.title AS title
|
503 |
+
SKIP $offset LIMIT $batch_size
|
504 |
"""
|
505 |
with self.driver.session() as session:
|
506 |
+
results = session.execute_write(
|
507 |
+
lambda tx: tx.run(
|
508 |
+
query, offset=offset, batch_size=batch_size
|
509 |
+
).data()
|
510 |
+
)
|
511 |
+
if not results:
|
512 |
+
break
|
513 |
+
contexts = [result["context"] for result in results]
|
514 |
+
paper_ids = [result["hash_id"] for result in results]
|
515 |
+
context_embeddings = embedding_model.encode(
|
516 |
+
contexts,
|
517 |
+
batch_size=batch_size,
|
518 |
+
convert_to_tensor=True,
|
519 |
+
device=self.device,
|
520 |
+
)
|
521 |
+
write_query = """
|
522 |
+
UNWIND $data AS row
|
523 |
+
MERGE (p:Paper {hash_id: row.hash_id})
|
524 |
+
ON CREATE SET p.contribution_embedding = row.embedding
|
525 |
+
ON MATCH SET p.contribution_embedding = row.embedding
|
526 |
+
"""
|
527 |
+
data_to_write = []
|
528 |
+
for idx, hash_id in enumerate(paper_ids):
|
529 |
+
embedding = (
|
530 |
+
context_embeddings[idx].detach().cpu().numpy().flatten().tolist()
|
531 |
+
)
|
532 |
+
data_to_write.append({"hash_id": hash_id, "embedding": embedding})
|
533 |
with self.driver.session() as session:
|
534 |
+
session.execute_write(
|
535 |
+
lambda tx: tx.run(write_query, data=data_to_write)
|
536 |
+
)
|
537 |
+
offset += batch_size
|
538 |
+
logger.info(f"== Processed batch starting at offset {offset} ==")
|
539 |
+
|
540 |
+
def add_paper_summary_embedding(
|
541 |
+
self, embedding_model, hash_id=None, batch_size=512
|
542 |
+
):
|
543 |
if hash_id is not None:
|
544 |
query = """
|
545 |
MATCH (p:Paper {hash_id: $hash_id})
|
546 |
WHERE p.summary IS NOT NULL
|
547 |
+
RETURN p.summary AS context, p.hash_id AS hash_id, p.title AS title
|
548 |
"""
|
549 |
with self.driver.session() as session:
|
550 |
+
results = session.execute_write(
|
551 |
+
lambda tx: tx.run(query, hash_id=hash_id).data()
|
552 |
+
)
|
553 |
+
contexts = [result["context"] for result in results]
|
554 |
+
paper_ids = [result["hash_id"] for result in results]
|
555 |
+
context_embeddings = embedding_model.encode(
|
556 |
+
contexts, convert_to_tensor=True, device=self.device
|
557 |
+
)
|
558 |
query = """
|
559 |
+
MERGE (p:Paper {hash_id: $hash_id})
|
560 |
+
ON CREATE SET p.summary_embedding = $embedding
|
561 |
+
ON MATCH SET p.summary_embedding = $embedding
|
562 |
+
"""
|
563 |
+
for idx, hash_id in tqdm(enumerate(paper_ids)):
|
564 |
+
embedding = (
|
565 |
+
context_embeddings[idx].detach().cpu().numpy().flatten().tolist()
|
566 |
+
)
|
567 |
+
with self.driver.session() as session:
|
568 |
+
results = session.execute_write(
|
569 |
+
lambda tx: tx.run(
|
570 |
+
query, hash_id=hash_id, embedding=embedding
|
571 |
+
).data()
|
572 |
+
)
|
573 |
+
return
|
574 |
+
offset = 0
|
575 |
+
while True:
|
576 |
+
query = f"""
|
577 |
MATCH (p:Paper)
|
578 |
WHERE p.summary IS NOT NULL
|
579 |
+
RETURN p.summary AS context, p.hash_id AS hash_id, p.title AS title
|
580 |
+
SKIP $offset LIMIT $batch_size
|
581 |
"""
|
582 |
with self.driver.session() as session:
|
583 |
+
results = session.execute_write(
|
584 |
+
lambda tx: tx.run(
|
585 |
+
query, offset=offset, batch_size=batch_size
|
586 |
+
).data()
|
587 |
+
)
|
588 |
+
if not results:
|
589 |
+
break
|
590 |
+
contexts = [result["context"] for result in results]
|
591 |
+
paper_ids = [result["hash_id"] for result in results]
|
592 |
+
context_embeddings = embedding_model.encode(
|
593 |
+
contexts,
|
594 |
+
batch_size=batch_size,
|
595 |
+
convert_to_tensor=True,
|
596 |
+
device=self.device,
|
597 |
+
)
|
598 |
+
write_query = """
|
599 |
+
UNWIND $data AS row
|
600 |
+
MERGE (p:Paper {hash_id: row.hash_id})
|
601 |
+
ON CREATE SET p.summary_embedding = row.embedding
|
602 |
+
ON MATCH SET p.summary_embedding = row.embedding
|
603 |
+
"""
|
604 |
+
data_to_write = []
|
605 |
+
for idx, hash_id in enumerate(paper_ids):
|
606 |
+
embedding = (
|
607 |
+
context_embeddings[idx].detach().cpu().numpy().flatten().tolist()
|
608 |
+
)
|
609 |
+
data_to_write.append({"hash_id": hash_id, "embedding": embedding})
|
610 |
with self.driver.session() as session:
|
611 |
+
session.execute_write(
|
612 |
+
lambda tx: tx.run(write_query, data=data_to_write)
|
613 |
+
)
|
614 |
+
offset += batch_size
|
615 |
+
logger.info(f"== Processed batch starting at offset {offset} ==")
|
616 |
+
|
617 |
def cosine_similarity_search(self, embedding, k=1, type_name="embedding"):
|
618 |
query = f"""
|
619 |
MATCH (paper:Paper)
|
|
|
624 |
ORDER BY score DESC LIMIT {k}
|
625 |
"""
|
626 |
with self.driver.session() as session:
|
627 |
+
results = session.execute_read(
|
628 |
+
lambda tx: tx.run(query, embedding=embedding).data()
|
629 |
+
)
|
630 |
+
related_paper = []
|
631 |
for result in results:
|
632 |
related_paper.append(result["paper"]["hash_id"])
|
633 |
return related_paper
|
|
|
649 |
"""
|
650 |
with self.driver.session() as session:
|
651 |
session.execute_write(lambda tx: tx.run(query).data())
|
652 |
+
|
653 |
def filter_paper_id_list(self, paper_id_list, year="2024"):
|
654 |
if not paper_id_list:
|
655 |
return []
|
|
|
661 |
RETURN p.hash_id AS hash_id
|
662 |
"""
|
663 |
with self.driver.session() as session:
|
664 |
+
result = session.execute_read(
|
665 |
+
lambda tx: tx.run(query, paper_id_list=paper_id_list, year=year).data()
|
666 |
+
)
|
667 |
+
|
668 |
+
existing_paper_ids = [record["hash_id"] for record in result]
|
669 |
existing_paper_ids = list(set(existing_paper_ids))
|
670 |
return existing_paper_ids
|
671 |
+
|
672 |
def check_index_exists(self):
|
673 |
query = "SHOW INDEXES"
|
674 |
with self.driver.session() as session:
|
|
|
685 |
"""
|
686 |
with self.driver.session() as session:
|
687 |
session.execute_write(lambda tx: tx.run(query).data())
|
688 |
+
|
689 |
def get_entity_related_paper_num(self, entity_name):
|
690 |
query = """
|
691 |
MATCH (e:Entity {name: $entity_name})-[:RELATED_TO]->(p:Paper)
|
|
|
693 |
RETURN PaperCount
|
694 |
"""
|
695 |
with self.driver.session() as session:
|
696 |
+
result = session.execute_read(
|
697 |
+
lambda tx: tx.run(query, entity_name=entity_name).data()
|
698 |
+
)
|
699 |
+
paper_num = result[0]["PaperCount"]
|
700 |
return paper_num
|
701 |
|
702 |
+
def get_entities_related_paper_num(self, entity_names):
|
703 |
+
query = """
|
704 |
+
UNWIND $entity_names AS entity_name
|
705 |
+
MATCH (e:Entity {name: entity_name})-[:RELATED_TO]->(p:Paper)
|
706 |
+
WITH entity_name, COUNT(p) AS PaperCount
|
707 |
+
RETURN entity_name, PaperCount
|
708 |
+
"""
|
709 |
+
|
710 |
+
with self.driver.session() as session:
|
711 |
+
result = session.execute_read(
|
712 |
+
lambda tx: tx.run(query, entity_names=entity_names).data()
|
713 |
+
)
|
714 |
+
# 将查询结果转化为字典形式:实体名称 -> 论文数量
|
715 |
+
entity_paper_count = {
|
716 |
+
record["entity_name"]: record["PaperCount"] for record in result
|
717 |
+
}
|
718 |
+
return entity_paper_count
|
719 |
+
|
720 |
def get_entity_text(self):
|
721 |
query = """
|
722 |
MATCH (e:Entity)-[:RELATED_TO]->(p:Paper)
|
|
|
726 |
"""
|
727 |
with self.driver.session() as session:
|
728 |
result = session.execute_read(lambda tx: tx.run(query).data())
|
729 |
+
text_list = [record["entity_text"] for record in result]
|
730 |
return text_list
|
731 |
+
|
732 |
def get_entity_combinations(self, venue_name, year):
|
733 |
+
def process_paper_relationships(
|
734 |
+
session, entity_name_1, entity_name_2, abstract
|
735 |
+
):
|
736 |
if entity_name_2 < entity_name_1:
|
737 |
entity_name_1, entity_name_2 = entity_name_2, entity_name_1
|
738 |
query = """
|
|
|
742 |
ON CREATE SET r.strength = 1
|
743 |
ON MATCH SET r.strength = r.strength + 1
|
744 |
"""
|
745 |
+
sentences = re.split(r"(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s", abstract)
|
746 |
for sentence in sentences:
|
747 |
sentence = sentence.lower()
|
748 |
if entity_name_1 in sentence and entity_name_2 in sentence:
|
749 |
# 如果两个实体在同一句话中出现过,则创建或更新 CONNECT 关系
|
750 |
session.execute_write(
|
751 |
+
lambda tx: tx.run(
|
752 |
+
query,
|
753 |
+
entity_name_1=entity_name_1,
|
754 |
+
entity_name_2=entity_name_2,
|
755 |
+
).data()
|
756 |
)
|
757 |
# logger.debug(f"CONNECT relation created or updated between {entity_name_1} and {entity_name_2} for Paper ID {paper_id}")
|
758 |
break # 如果找到一次出现就可以退出循环
|
|
|
766 |
RETURN p.hash_id AS hash_id, entities[i].name AS entity_name_1, entities[j].name AS entity_name_2
|
767 |
"""
|
768 |
with self.driver.session() as session:
|
769 |
+
result = session.execute_read(
|
770 |
+
lambda tx: tx.run(query, venue_name=venue_name, year=year).data()
|
771 |
+
)
|
772 |
for record in tqdm(result):
|
773 |
paper_id = record["hash_id"]
|
774 |
+
entity_name_1 = record["entity_name_1"]
|
775 |
+
entity_name_2 = record["entity_name_2"]
|
776 |
abstract = self.get_paper_attribute(paper_id, "abstract")
|
777 |
+
process_paper_relationships(
|
778 |
+
session, entity_name_1, entity_name_2, abstract
|
779 |
+
)
|
780 |
|
781 |
def build_citemap(self):
|
782 |
citemap = defaultdict(set)
|
|
|
787 |
with self.driver.session() as session:
|
788 |
results = session.execute_read(lambda tx: tx.run(query).data())
|
789 |
for result in results:
|
790 |
+
hash_id = result["hash_id"]
|
791 |
+
cite_id_list = result["cite_id_list"]
|
792 |
if cite_id_list:
|
793 |
for cited_id in cite_id_list:
|
794 |
citemap[hash_id].add(cited_id)
|
|
|
801 |
AUTH = (NEO4J_USERNAME, NEO4J_PASSWD)
|
802 |
graph = Graph(URI, auth=AUTH)
|
803 |
# 创建一个字典来保存数据
|
804 |
+
# 定义批次大小
|
805 |
data = {"nodes": [], "relationships": []}
|
806 |
+
# 计算数据的总数(例如查询节点总数)
|
807 |
+
total_papers_query = "MATCH (e:Entity)-[:RELATED_TO]->(p:Paper) RETURN COUNT(DISTINCT p) AS count"
|
808 |
+
total_papers = graph.run(total_papers_query).evaluate()
|
809 |
+
print(f"total paper: {total_papers}")
|
810 |
+
query = f"""
|
811 |
MATCH (e:Entity)-[r:RELATED_TO]->(p:Paper)
|
|
|
812 |
RETURN p, e, r
|
813 |
"""
|
814 |
+
"""
|
815 |
results = graph.run(query)
|
816 |
# 处理查询结果
|
817 |
for record in tqdm(results):
|
|
|
819 |
entity_node = record["e"]
|
820 |
relationship = record["r"]
|
821 |
# 将节点数据加入字典
|
822 |
+
data["nodes"].append(
|
823 |
+
{
|
824 |
+
"id": paper_node.identity,
|
825 |
+
"label": "Paper",
|
826 |
+
"properties": dict(paper_node),
|
827 |
+
}
|
828 |
+
)
|
829 |
+
data["nodes"].append(
|
830 |
+
{
|
831 |
+
"id": entity_node.identity,
|
832 |
+
"label": "Entity",
|
833 |
+
"properties": dict(entity_node),
|
834 |
+
}
|
835 |
+
)
|
836 |
# 将关系数据加入字典
|
837 |
+
data["relationships"].append(
|
838 |
+
{
|
839 |
+
"start_node": entity_node.identity,
|
840 |
+
"end_node": paper_node.identity,
|
841 |
+
"type": "RELATED_TO",
|
842 |
+
"properties": dict(relationship),
|
843 |
+
}
|
844 |
+
)
|
845 |
+
"""
|
846 |
query = """
|
847 |
MATCH (p:Paper)
|
848 |
WHERE p.venue_name='acl' and p.year='2024'
|
849 |
RETURN p
|
850 |
"""
|
|
|
851 |
results = graph.run(query)
|
852 |
for record in tqdm(results):
|
853 |
paper_node = record["p"]
|
854 |
# 将节点数据加入字典
|
855 |
+
data["nodes"].append(
|
856 |
+
{
|
857 |
+
"id": paper_node.identity,
|
858 |
+
"label": "Paper",
|
859 |
+
"properties": dict(paper_node),
|
860 |
+
}
|
861 |
+
)
|
862 |
# 去除重复节点
|
863 |
# data["nodes"] = [dict(t) for t in {tuple(d.items()) for d in data["nodes"]}]
|
864 |
unique_nodes = []
|
|
|
871 |
unique_nodes.append(node)
|
872 |
data["nodes"] = unique_nodes
|
873 |
# 将数据保存为 JSON 文件
|
874 |
+
with open(
|
875 |
+
"./assets/data/scipip_neo4j_clean_backup.json", "w", encoding="utf-8"
|
876 |
+
) as f:
|
877 |
json.dump(data, f, ensure_ascii=False, indent=4)
|
878 |
+
|
879 |
def neo4j_import_data(self):
|
880 |
# clear_database() # 清空数据库,谨慎执行
|
881 |
URI = os.environ["NEO4J_URL"]
|
|
|
884 |
AUTH = (NEO4J_USERNAME, NEO4J_PASSWD)
|
885 |
graph = Graph(URI, auth=AUTH)
|
886 |
# 从 JSON 文件中读取数据
|
887 |
+
with open(
|
888 |
+
"./assets/data/scipip_neo4j_clean_backup.json", "r", encoding="utf-8"
|
889 |
+
) as f:
|
890 |
data = json.load(f)
|
891 |
# 创建节点
|
892 |
nodes = {}
|
src/utils/paper_retriever.py
CHANGED
@@ -59,6 +59,7 @@ class CoCite:
|
|
59 |
|
60 |
def __init__(self) -> None:
|
61 |
if not self._initialized:
|
|
|
62 |
self.paper_client = PaperClient()
|
63 |
citemap = self.paper_client.build_citemap()
|
64 |
self.comap = defaultdict(lambda: defaultdict(int))
|
@@ -101,20 +102,16 @@ class Retriever(object):
|
|
101 |
|
102 |
def retrieve_entities_by_enties(self, entities):
|
103 |
# TODO: KG
|
104 |
-
expand_entities =
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
relation_name=self.config.RETRIEVE.relation_name,
|
111 |
-
)
|
112 |
expand_entities = list(set(entities + expand_entities))
|
113 |
-
entity_paper_num_dict =
|
114 |
-
|
115 |
-
|
116 |
-
self.paper_client.get_entity_related_paper_num(entity)
|
117 |
-
)
|
118 |
new_entities = []
|
119 |
entity_paper_num_dict = {
|
120 |
k: v for k, v in entity_paper_num_dict.items() if v != 0
|
@@ -142,11 +139,7 @@ class Retriever(object):
|
|
142 |
Return:
|
143 |
related_paper: list(dict)
|
144 |
"""
|
145 |
-
related_paper =
|
146 |
-
for paper_id in paper_id_list:
|
147 |
-
paper = {"hash_id": paper_id}
|
148 |
-
self.paper_client.update_paper_from_client(paper)
|
149 |
-
related_paper.append(paper)
|
150 |
return related_paper
|
151 |
|
152 |
def calculate_similarity(self, entities, related_entities_list, use_weight=False):
|
@@ -333,7 +326,6 @@ class Retriever(object):
|
|
333 |
similarity_threshold = self.config.RETRIEVE.similarity_threshold
|
334 |
similarity_matrix = np.dot(target_paper_embedding, target_paper_embedding.T)
|
335 |
target_labels = self.cluster_algorithm(target_paper_id_list, similarity_matrix)
|
336 |
-
# target_labels = list(range(0, len(target_paper_id_list)))
|
337 |
target_paper_label_dict = dict(zip(target_paper_id_list, target_labels))
|
338 |
logger.debug("Target paper cluster result: {}".format(target_paper_label_dict))
|
339 |
logger.debug(
|
@@ -672,8 +664,7 @@ class SNKGRetriever(Retriever):
|
|
672 |
)
|
673 |
related_paper = set()
|
674 |
related_paper.update(sn_paper_id_list)
|
675 |
-
|
676 |
-
sn_entities += self.paper_client.find_entities_by_paper(paper_id)
|
677 |
logger.debug("SN entities for retriever: {}".format(sn_entities))
|
678 |
entities = list(set(entities + sn_entities))
|
679 |
new_entities = self.retrieve_entities_by_enties(entities)
|
|
|
59 |
|
60 |
def __init__(self) -> None:
|
61 |
if not self._initialized:
|
62 |
+
logger.debug("init co-cite map begin...")
|
63 |
self.paper_client = PaperClient()
|
64 |
citemap = self.paper_client.build_citemap()
|
65 |
self.comap = defaultdict(lambda: defaultdict(int))
|
|
|
102 |
|
103 |
def retrieve_entities_by_enties(self, entities):
|
104 |
# TODO: KG
|
105 |
+
expand_entities = self.paper_client.find_related_entities_by_entity_list(
|
106 |
+
entities,
|
107 |
+
n=self.config.RETRIEVE.kg_jump_num,
|
108 |
+
k=self.config.RETRIEVE.kg_cover_num,
|
109 |
+
relation_name=self.config.RETRIEVE.relation_name,
|
110 |
+
)
|
|
|
|
|
111 |
expand_entities = list(set(entities + expand_entities))
|
112 |
+
entity_paper_num_dict = self.paper_client.get_entities_related_paper_num(
|
113 |
+
expand_entities
|
114 |
+
)
|
|
|
|
|
115 |
new_entities = []
|
116 |
entity_paper_num_dict = {
|
117 |
k: v for k, v in entity_paper_num_dict.items() if v != 0
|
|
|
139 |
Return:
|
140 |
related_paper: list(dict)
|
141 |
"""
|
142 |
+
related_paper = self.paper_client.update_papers_from_client(paper_id_list)
|
|
|
|
|
|
|
|
|
143 |
return related_paper
|
144 |
|
145 |
def calculate_similarity(self, entities, related_entities_list, use_weight=False):
|
|
|
326 |
similarity_threshold = self.config.RETRIEVE.similarity_threshold
|
327 |
similarity_matrix = np.dot(target_paper_embedding, target_paper_embedding.T)
|
328 |
target_labels = self.cluster_algorithm(target_paper_id_list, similarity_matrix)
|
|
|
329 |
target_paper_label_dict = dict(zip(target_paper_id_list, target_labels))
|
330 |
logger.debug("Target paper cluster result: {}".format(target_paper_label_dict))
|
331 |
logger.debug(
|
|
|
664 |
)
|
665 |
related_paper = set()
|
666 |
related_paper.update(sn_paper_id_list)
|
667 |
+
sn_entities += self.paper_client.find_entities_by_paper_list(sn_paper_id_list)
|
|
|
668 |
logger.debug("SN entities for retriever: {}".format(sn_entities))
|
669 |
entities = list(set(entities + sn_entities))
|
670 |
new_entities = self.retrieve_entities_by_enties(entities)
|