lihuigu commited on
Commit
02069d7
·
1 Parent(s): 88253fe

reduce neo4j query time in retrieve

Browse files
app.py CHANGED
@@ -1,25 +1,32 @@
1
  import sys
 
2
  sys.path.append("./src")
3
  import streamlit as st
4
- from app_pages import button_interface, step_by_step_generation, one_click_generation, homepage
 
 
 
 
 
5
  from app_pages.locale import _
6
- from utils.hash import check_env, check_embedding
7
 
8
  if __name__ == "__main__":
9
- check_env()
10
- check_embedding()
11
  backend = button_interface.Backend()
12
  # backend = None
13
  st.set_page_config(layout="wide")
14
- if "language" not in st.session_state:
15
- st.session_state["language"] = "zh"
16
  def fn1():
17
  one_click_generation.one_click_generation(backend)
 
18
  def fn2():
19
  step_by_step_generation.step_by_step_generation(backend)
20
- pg = st.navigation([
21
- st.Page(homepage.home_page, title=_("🏠️ Homepage")),
22
- st.Page(fn1, title=_("💧 One-click Generation")),
23
- st.Page(fn2, title=_("💦 Step-by-step Generation")),
24
- ])
25
- pg.run()
 
 
 
 
1
  import sys
2
+
3
  sys.path.append("./src")
4
  import streamlit as st
5
+ from app_pages import (
6
+ button_interface,
7
+ step_by_step_generation,
8
+ one_click_generation,
9
+ homepage,
10
+ )
11
  from app_pages.locale import _
 
12
 
13
  if __name__ == "__main__":
 
 
14
  backend = button_interface.Backend()
15
  # backend = None
16
  st.set_page_config(layout="wide")
17
+
18
+ # st.logo("./assets/pic/logo.jpg", size="large")
19
  def fn1():
20
  one_click_generation.one_click_generation(backend)
21
+
22
  def fn2():
23
  step_by_step_generation.step_by_step_generation(backend)
24
+
25
+ pg = st.navigation(
26
+ [
27
+ st.Page(homepage.home_page, title=_("🏠️ Homepage")),
28
+ st.Page(fn1, title=_("💧 One-click Generation")),
29
+ st.Page(fn2, title=_("💦 Step-by-step Generation")),
30
+ ]
31
+ )
32
+ pg.run()
src/app_pages/button_interface.py CHANGED
@@ -2,8 +2,10 @@ import json
2
  from utils.paper_retriever import RetrieverFactory
3
  from utils.llms_api import APIHelper
4
  from utils.header import ConfigReader
 
5
  from generator import IdeaGenerator
6
 
 
7
  class Backend(object):
8
  def __init__(self) -> None:
9
  CONFIG_PATH = "./configs/datasets.yaml"
@@ -12,11 +14,14 @@ class Backend(object):
12
  BRAINSTORM_MODE = "mode_c"
13
 
14
  self.config = ConfigReader.load(CONFIG_PATH)
 
 
15
  RETRIEVER_NAME = self.config.RETRIEVE.retriever_name
16
  self.api_helper = APIHelper(self.config)
17
- self.retriever_factory = RetrieverFactory.get_retriever_factory().create_retriever(
18
- RETRIEVER_NAME,
19
- self.config
 
20
  )
21
  self.idea_generator = IdeaGenerator(self.config, None)
22
  self.use_inspiration = USE_INSPIRATION
@@ -33,14 +38,14 @@ class Backend(object):
33
  return []
34
 
35
  def background2brainstorm_callback(self, background, json_strs=None):
36
- if json_strs is not None: # only for DEBUG_MODE
37
  json_contents = json.loads(json_strs)
38
  return json_contents["brainstorm"]
39
  else:
40
  return self.api_helper.generate_brainstorm(background)
41
 
42
  def brainstorm2entities_callback(self, background, brainstorm, json_strs=None):
43
- if json_strs is not None: # only for DEBUG_MODE
44
  json_contents = json.loads(json_strs)
45
  entities_bg = json_contents["entities_bg"]
46
  entities_bs = json_contents["entities_bs"]
@@ -71,13 +76,17 @@ class Backend(object):
71
  for i, p in enumerate(result["related_paper"]):
72
  res.append(str(p))
73
  else:
74
- result = self.retriever_factory.retrieve(background, entities, need_evaluate=False, target_paper_id_list=[])
 
 
75
  res = []
76
  for i, p in enumerate(result["related_paper"]):
77
  res.append(f'{p["title"]}. {p["venue_name"].upper()} {p["year"]}.')
78
  return res, result["related_paper"]
79
 
80
- def literature2initial_ideas_callback(self, background, brainstorms, retrieved_literature, json_strs=None):
 
 
81
  if json_strs is not None:
82
  json_contents = json.loads(json_strs)
83
  return json_contents["median"]["initial_idea"]
@@ -86,15 +95,16 @@ class Backend(object):
86
  self.idea_generator.brainstorm = brainstorms
87
  if self.use_inspiration:
88
  message_input, idea_modified, median = (
89
- self.idea_generator.generate_by_inspiration(
90
- background, "new_idea", self.brainstorm_mode, False)
 
91
  )
92
  else:
93
  message_input, idea_modified, median = self.idea_generator.generate(
94
  background, "new_idea", self.brainstorm_mode, False
95
  )
96
  return median["initial_idea"], idea_modified
97
-
98
  def initial2final_callback(self, initial_ideas, final_ideas, json_strs=None):
99
  if json_strs is not None:
100
  json_contents = json.loads(json_strs)
@@ -107,6 +117,7 @@ class Backend(object):
107
  return self.examples[i].get("background", "Background not found.")
108
  else:
109
  return "Example not found. Please select a valid index."
 
110
  # return ("The application scope of large-scale language models such as GPT-4 and LLaMA "
111
  # "has rapidly expanded, demonstrating powerful capabilities in natural language processing "
112
  # "and multimodal tasks. However, as the size and complexity of the models increase, understanding "
 
2
  from utils.paper_retriever import RetrieverFactory
3
  from utils.llms_api import APIHelper
4
  from utils.header import ConfigReader
5
+ from utils.hash import check_env, check_embedding
6
  from generator import IdeaGenerator
7
 
8
+
9
  class Backend(object):
10
  def __init__(self) -> None:
11
  CONFIG_PATH = "./configs/datasets.yaml"
 
14
  BRAINSTORM_MODE = "mode_c"
15
 
16
  self.config = ConfigReader.load(CONFIG_PATH)
17
+ check_env()
18
+ check_embedding(self.config.DEFAULT.embedding)
19
  RETRIEVER_NAME = self.config.RETRIEVE.retriever_name
20
  self.api_helper = APIHelper(self.config)
21
+ self.retriever_factory = (
22
+ RetrieverFactory.get_retriever_factory().create_retriever(
23
+ RETRIEVER_NAME, self.config
24
+ )
25
  )
26
  self.idea_generator = IdeaGenerator(self.config, None)
27
  self.use_inspiration = USE_INSPIRATION
 
38
  return []
39
 
40
  def background2brainstorm_callback(self, background, json_strs=None):
41
+ if json_strs is not None: # only for DEBUG_MODE
42
  json_contents = json.loads(json_strs)
43
  return json_contents["brainstorm"]
44
  else:
45
  return self.api_helper.generate_brainstorm(background)
46
 
47
  def brainstorm2entities_callback(self, background, brainstorm, json_strs=None):
48
+ if json_strs is not None: # only for DEBUG_MODE
49
  json_contents = json.loads(json_strs)
50
  entities_bg = json_contents["entities_bg"]
51
  entities_bs = json_contents["entities_bs"]
 
76
  for i, p in enumerate(result["related_paper"]):
77
  res.append(str(p))
78
  else:
79
+ result = self.retriever_factory.retrieve(
80
+ background, entities, need_evaluate=False, target_paper_id_list=[]
81
+ )
82
  res = []
83
  for i, p in enumerate(result["related_paper"]):
84
  res.append(f'{p["title"]}. {p["venue_name"].upper()} {p["year"]}.')
85
  return res, result["related_paper"]
86
 
87
+ def literature2initial_ideas_callback(
88
+ self, background, brainstorms, retrieved_literature, json_strs=None
89
+ ):
90
  if json_strs is not None:
91
  json_contents = json.loads(json_strs)
92
  return json_contents["median"]["initial_idea"]
 
95
  self.idea_generator.brainstorm = brainstorms
96
  if self.use_inspiration:
97
  message_input, idea_modified, median = (
98
+ self.idea_generator.generate_by_inspiration(
99
+ background, "new_idea", self.brainstorm_mode, False
100
+ )
101
  )
102
  else:
103
  message_input, idea_modified, median = self.idea_generator.generate(
104
  background, "new_idea", self.brainstorm_mode, False
105
  )
106
  return median["initial_idea"], idea_modified
107
+
108
  def initial2final_callback(self, initial_ideas, final_ideas, json_strs=None):
109
  if json_strs is not None:
110
  json_contents = json.loads(json_strs)
 
117
  return self.examples[i].get("background", "Background not found.")
118
  else:
119
  return "Example not found. Please select a valid index."
120
+
121
  # return ("The application scope of large-scale language models such as GPT-4 and LLaMA "
122
  # "has rapidly expanded, demonstrating powerful capabilities in natural language processing "
123
  # "and multimodal tasks. However, as the size and complexity of the models increase, understanding "
src/generator.py CHANGED
@@ -10,6 +10,7 @@ import warnings
10
  import time
11
  import os
12
  from utils.hash import check_env, check_embedding
 
13
  warnings.filterwarnings("ignore")
14
 
15
 
@@ -24,9 +25,14 @@ def extract_problem(problem, background):
24
  research_problem = background
25
  return research_problem
26
 
 
27
  class IdeaGenerator:
28
  def __init__(
29
- self, config, paper_list: list[dict] = [], cue_words: list = None, brainstorm: str = None
 
 
 
 
30
  ) -> None:
31
  self.api_helper = APIHelper(config)
32
  self.paper_list = paper_list
@@ -58,7 +64,9 @@ class IdeaGenerator:
58
  idea = self.api_helper.generate_idea_with_cue_words(
59
  problem, self.paper_list, self.cue_words
60
  )
61
- idea_filtered = self.api_helper.integrate_idea(background, self.brainstorm, idea)
 
 
62
  return message_input, problem, idea, idea_filtered
63
 
64
  def generate_without_cue_words_bs(self, background: str):
@@ -66,7 +74,9 @@ class IdeaGenerator:
66
  background, self.paper_list
67
  )
68
  idea = self.api_helper.generate_idea(problem, self.paper_list)
69
- idea_filtered = self.api_helper.integrate_idea(background, self.brainstorm, idea)
 
 
70
  return message_input, problem, idea, idea_filtered
71
 
72
  def generate_with_cue_words_ins(self, background: str):
@@ -93,16 +103,12 @@ class IdeaGenerator:
93
  research_problem = extract_problem(problem, background)
94
  inspirations = []
95
  for paper in self.paper_list:
96
- inspiration = self.api_helper.generate_inspiration(
97
- research_problem, paper
98
- )
99
  inspirations.append(inspiration)
100
- idea = self.api_helper.generate_idea_by_inspiration(
101
- problem, inspirations
102
- )
103
  idea_filtered = self.api_helper.filter_idea(idea, background)
104
  return message_input, problem, inspirations, idea, idea_filtered
105
-
106
  def generate_with_cue_words_ins_bs(self, background: str):
107
  problem, message_input = self.api_helper.generate_problem_with_cue_words(
108
  background, self.paper_list, self.cue_words
@@ -117,7 +123,9 @@ class IdeaGenerator:
117
  idea = self.api_helper.generate_idea_by_inspiration_with_cue_words(
118
  problem, inspirations, self.cue_words
119
  )
120
- idea_filtered = self.api_helper.integrate_idea(background, self.brainstorm, idea)
 
 
121
  return message_input, problem, inspirations, idea, idea_filtered
122
 
123
  def generate_without_cue_words_ins_bs(self, background: str):
@@ -127,14 +135,12 @@ class IdeaGenerator:
127
  research_problem = extract_problem(problem, background)
128
  inspirations = []
129
  for paper in self.paper_list:
130
- inspiration = self.api_helper.generate_inspiration(
131
- research_problem, paper
132
- )
133
  inspirations.append(inspiration)
134
- idea = self.api_helper.generate_idea_by_inspiration(
135
- problem, inspirations
 
136
  )
137
- idea_filtered = self.api_helper.integrate_idea(background, self.brainstorm, idea)
138
  return message_input, problem, inspirations, idea, idea_filtered
139
 
140
  def generate(
@@ -151,44 +157,34 @@ class IdeaGenerator:
151
  mode_name = "Generate new idea"
152
  if bs_mode == "mode_a":
153
  if use_cue_words:
154
- logger.info("{} using brainstorm_mode_a with cue words.".format(mode_name))
155
- (
156
- message_input,
157
- problem,
158
- idea,
159
- idea_filtered
160
- ) = (
161
  self.generate_with_cue_words(background)
162
  )
163
  else:
164
- logger.info("{} using brainstorm_mode_a without cue words.".format(mode_name))
165
- (
166
- message_input,
167
- problem,
168
- idea,
169
- idea_filtered
170
- ) = (
171
  self.generate_without_cue_words(background)
172
  )
173
  elif bs_mode == "mode_b" or bs_mode == "mode_c":
174
  if use_cue_words:
175
- logger.info("{} using brainstorm_{} with cue words.".format(mode_name, bs_mode))
176
- (
177
- message_input,
178
- problem,
179
- idea,
180
- idea_filtered
181
- ) = (
182
  self.generate_with_cue_words_bs(background)
183
  )
184
  else:
185
- logger.info("{} using brainstorm_{} without cue words.".format(mode_name, bs_mode))
186
- (
187
- message_input,
188
- problem,
189
- idea,
190
- idea_filtered
191
- ) = (
192
  self.generate_without_cue_words_bs(background)
193
  )
194
 
@@ -214,48 +210,34 @@ class IdeaGenerator:
214
  mode_name = "Generate new idea"
215
  if bs_mode == "mode_a":
216
  if use_cue_words:
217
- logger.info("{} using brainstorm_mode_a with cue words.".format(mode_name))
218
- (
219
- message_input,
220
- problem,
221
- inspirations,
222
- idea,
223
- idea_filtered
224
- ) = (
225
  self.generate_with_cue_words_ins(background)
226
  )
227
  else:
228
- logger.info("{} using brainstorm_mode_a without cue words.".format(mode_name))
229
- (
230
- message_input,
231
- problem,
232
- inspirations,
233
- idea,
234
- idea_filtered
235
- ) = (
236
  self.generate_without_cue_words_ins(background)
237
  )
238
  elif bs_mode == "mode_b" or bs_mode == "mode_c":
239
  if use_cue_words:
240
- logger.info("{} using brainstorm_{} with cue words.".format(mode_name, bs_mode))
241
- (
242
- message_input,
243
- problem,
244
- inspirations,
245
- idea,
246
- idea_filtered
247
- ) = (
248
  self.generate_with_cue_words_ins_bs(background)
249
  )
250
  else:
251
- logger.info("{} using brainstorm_{} without cue words.".format(mode_name, bs_mode))
252
- (
253
- message_input,
254
- problem,
255
- inspirations,
256
- idea,
257
- idea_filtered
258
- ) = (
259
  self.generate_without_cue_words_ins_bs(background)
260
  )
261
 
@@ -330,9 +312,18 @@ def main(ctx):
330
  required=False,
331
  help="The number of papers you want to process",
332
  )
333
- def backtracking(config_path, ids_path, retriever_name, brainstorm_mode, use_cue_words, use_inspiration, num, **kwargs):
 
 
 
 
 
 
 
 
 
334
  check_env()
335
- check_embedding()
336
  # Configuration
337
  config = ConfigReader.load(config_path, **kwargs)
338
  logger.add(
@@ -349,7 +340,10 @@ def backtracking(config_path, ids_path, retriever_name, brainstorm_mode, use_cue
349
  batch_size = 2
350
  output_dir = "./assets/output_idea/"
351
  os.makedirs(output_dir, exist_ok=True)
352
- output_file = os.path.join(output_dir, f"output_backtracking_{brainstorm_mode}_cue_{use_cue_words}_ins_{use_inspiration}.json")
 
 
 
353
  if os.path.exists(output_file):
354
  with open(output_file, "r", encoding="utf-8") as f:
355
  try:
@@ -388,7 +382,7 @@ def backtracking(config_path, ids_path, retriever_name, brainstorm_mode, use_cue
388
  if brainstorm_mode == "mode_c":
389
  entities_bs = api_helper.generate_entity_list(brainstorm, 10)
390
  logger.debug("Original entities from brainstorm: {}".format(entities_bs))
391
- entities_all = list(set(entities)|set(entities_bs))
392
  else:
393
  entities_bs = None
394
  entities_all = entities
@@ -404,8 +398,7 @@ def backtracking(config_path, ids_path, retriever_name, brainstorm_mode, use_cue
404
  continue
405
  # 3. 检索相关论文
406
  rt = RetrieverFactory.get_retriever_factory().create_retriever(
407
- retriever_name,
408
- config
409
  )
410
  result = rt.retrieve(
411
  bg, entities_all, need_evaluate=False, target_paper_id_list=[]
@@ -438,7 +431,7 @@ def backtracking(config_path, ids_path, retriever_name, brainstorm_mode, use_cue
438
  "hash_id": paper["hash_id"],
439
  "background": bg,
440
  "entities_bg": entities,
441
- "brainstorm" : brainstorm,
442
  "entities_bs": entities_bs,
443
  "entities_rt": entities_rt,
444
  "related_paper": [p["hash_id"] for p in related_paper],
@@ -467,6 +460,7 @@ def backtracking(config_path, ids_path, retriever_name, brainstorm_mode, use_cue
467
  ) as f:
468
  json.dump(eval_data, f, ensure_ascii=False, indent=4)
469
 
 
470
  @main.command()
471
  @click.option(
472
  "-c",
@@ -512,9 +506,16 @@ def backtracking(config_path, ids_path, retriever_name, brainstorm_mode, use_cue
512
  required=False,
513
  help="The number of data you want to process",
514
  )
515
- def new_idea(config_path, ids_path, retriever_name, brainstorm_mode, use_inspiration, num, **kwargs):
 
 
 
 
 
 
 
 
516
  check_env()
517
- check_embedding()
518
  logger.add(
519
  "log/generate_{}_{}.log".format(time.time(), retriever_name), level="DEBUG"
520
  ) # 添加文件输出
@@ -522,6 +523,7 @@ def new_idea(config_path, ids_path, retriever_name, brainstorm_mode, use_inspira
522
  # Configuration
523
  config = ConfigReader.load(config_path, **kwargs)
524
  api_helper = APIHelper(config)
 
525
  eval_data = []
526
  cur_num = 0
527
  data_num = 0
@@ -529,7 +531,9 @@ def new_idea(config_path, ids_path, retriever_name, brainstorm_mode, use_inspira
529
  bg_ids = set()
530
  output_dir = "./assets/output_idea/"
531
  os.makedirs(output_dir, exist_ok=True)
532
- output_file = os.path.join(output_dir, f"output_new_idea_{brainstorm_mode}_ins_{use_inspiration}.json")
 
 
533
  if os.path.exists(output_file):
534
  with open(output_file, "r", encoding="utf-8") as f:
535
  try:
@@ -538,7 +542,7 @@ def new_idea(config_path, ids_path, retriever_name, brainstorm_mode, use_inspira
538
  cur_num = len(eval_data)
539
  except json.JSONDecodeError:
540
  eval_data = []
541
- print(f"{cur_num} datas have been processed.")
542
  for line in ids_path:
543
  # 解析每行的JSON数据
544
  data = json.loads(line)
@@ -568,16 +572,17 @@ def new_idea(config_path, ids_path, retriever_name, brainstorm_mode, use_inspira
568
  if brainstorm_mode == "mode_c":
569
  entities_bs = api_helper.generate_entity_list(brainstorm, 10)
570
  logger.debug("Original entities from brainstorm: {}".format(entities_bs))
571
- entities_all = list(set(entities)|set(entities_bs))
572
  else:
573
  entities_bs = None
574
  entities_all = entities
575
  # 2. 检索相关论文
576
  rt = RetrieverFactory.get_retriever_factory().create_retriever(
577
- retriever_name,
578
- config
 
 
579
  )
580
- result = rt.retrieve(bg, entities_all, need_evaluate=False, target_paper_id_list=[])
581
  related_paper = result["related_paper"]
582
  logger.info("Find {} related papers...".format(len(related_paper)))
583
  entities_rt = result["entities"]
@@ -597,7 +602,7 @@ def new_idea(config_path, ids_path, retriever_name, brainstorm_mode, use_inspira
597
  {
598
  "background": bg,
599
  "entities_bg": entities,
600
- "brainstorm" : brainstorm,
601
  "entities_bs": entities_bs,
602
  "entities_rt": entities_rt,
603
  "related_paper": [p["hash_id"] for p in related_paper],
@@ -621,5 +626,6 @@ def new_idea(config_path, ids_path, retriever_name, brainstorm_mode, use_inspira
621
  with open(output_file, "w", encoding="utf-8") as f:
622
  json.dump(eval_data, f, ensure_ascii=False, indent=4)
623
 
 
624
  if __name__ == "__main__":
625
  main()
 
10
  import time
11
  import os
12
  from utils.hash import check_env, check_embedding
13
+
14
  warnings.filterwarnings("ignore")
15
 
16
 
 
25
  research_problem = background
26
  return research_problem
27
 
28
+
29
  class IdeaGenerator:
30
  def __init__(
31
+ self,
32
+ config,
33
+ paper_list: list[dict] = [],
34
+ cue_words: list = None,
35
+ brainstorm: str = None,
36
  ) -> None:
37
  self.api_helper = APIHelper(config)
38
  self.paper_list = paper_list
 
64
  idea = self.api_helper.generate_idea_with_cue_words(
65
  problem, self.paper_list, self.cue_words
66
  )
67
+ idea_filtered = self.api_helper.integrate_idea(
68
+ background, self.brainstorm, idea
69
+ )
70
  return message_input, problem, idea, idea_filtered
71
 
72
  def generate_without_cue_words_bs(self, background: str):
 
74
  background, self.paper_list
75
  )
76
  idea = self.api_helper.generate_idea(problem, self.paper_list)
77
+ idea_filtered = self.api_helper.integrate_idea(
78
+ background, self.brainstorm, idea
79
+ )
80
  return message_input, problem, idea, idea_filtered
81
 
82
  def generate_with_cue_words_ins(self, background: str):
 
103
  research_problem = extract_problem(problem, background)
104
  inspirations = []
105
  for paper in self.paper_list:
106
+ inspiration = self.api_helper.generate_inspiration(research_problem, paper)
 
 
107
  inspirations.append(inspiration)
108
+ idea = self.api_helper.generate_idea_by_inspiration(problem, inspirations)
 
 
109
  idea_filtered = self.api_helper.filter_idea(idea, background)
110
  return message_input, problem, inspirations, idea, idea_filtered
111
+
112
  def generate_with_cue_words_ins_bs(self, background: str):
113
  problem, message_input = self.api_helper.generate_problem_with_cue_words(
114
  background, self.paper_list, self.cue_words
 
123
  idea = self.api_helper.generate_idea_by_inspiration_with_cue_words(
124
  problem, inspirations, self.cue_words
125
  )
126
+ idea_filtered = self.api_helper.integrate_idea(
127
+ background, self.brainstorm, idea
128
+ )
129
  return message_input, problem, inspirations, idea, idea_filtered
130
 
131
  def generate_without_cue_words_ins_bs(self, background: str):
 
135
  research_problem = extract_problem(problem, background)
136
  inspirations = []
137
  for paper in self.paper_list:
138
+ inspiration = self.api_helper.generate_inspiration(research_problem, paper)
 
 
139
  inspirations.append(inspiration)
140
+ idea = self.api_helper.generate_idea_by_inspiration(problem, inspirations)
141
+ idea_filtered = self.api_helper.integrate_idea(
142
+ background, self.brainstorm, idea
143
  )
 
144
  return message_input, problem, inspirations, idea, idea_filtered
145
 
146
  def generate(
 
157
  mode_name = "Generate new idea"
158
  if bs_mode == "mode_a":
159
  if use_cue_words:
160
+ logger.info(
161
+ "{} using brainstorm_mode_a with cue words.".format(mode_name)
162
+ )
163
+ (message_input, problem, idea, idea_filtered) = (
 
 
 
164
  self.generate_with_cue_words(background)
165
  )
166
  else:
167
+ logger.info(
168
+ "{} using brainstorm_mode_a without cue words.".format(mode_name)
169
+ )
170
+ (message_input, problem, idea, idea_filtered) = (
 
 
 
171
  self.generate_without_cue_words(background)
172
  )
173
  elif bs_mode == "mode_b" or bs_mode == "mode_c":
174
  if use_cue_words:
175
+ logger.info(
176
+ "{} using brainstorm_{} with cue words.".format(mode_name, bs_mode)
177
+ )
178
+ (message_input, problem, idea, idea_filtered) = (
 
 
 
179
  self.generate_with_cue_words_bs(background)
180
  )
181
  else:
182
+ logger.info(
183
+ "{} using brainstorm_{} without cue words.".format(
184
+ mode_name, bs_mode
185
+ )
186
+ )
187
+ (message_input, problem, idea, idea_filtered) = (
 
188
  self.generate_without_cue_words_bs(background)
189
  )
190
 
 
210
  mode_name = "Generate new idea"
211
  if bs_mode == "mode_a":
212
  if use_cue_words:
213
+ logger.info(
214
+ "{} using brainstorm_mode_a with cue words.".format(mode_name)
215
+ )
216
+ (message_input, problem, inspirations, idea, idea_filtered) = (
 
 
 
 
217
  self.generate_with_cue_words_ins(background)
218
  )
219
  else:
220
+ logger.info(
221
+ "{} using brainstorm_mode_a without cue words.".format(mode_name)
222
+ )
223
+ (message_input, problem, inspirations, idea, idea_filtered) = (
 
 
 
 
224
  self.generate_without_cue_words_ins(background)
225
  )
226
  elif bs_mode == "mode_b" or bs_mode == "mode_c":
227
  if use_cue_words:
228
+ logger.info(
229
+ "{} using brainstorm_{} with cue words.".format(mode_name, bs_mode)
230
+ )
231
+ (message_input, problem, inspirations, idea, idea_filtered) = (
 
 
 
 
232
  self.generate_with_cue_words_ins_bs(background)
233
  )
234
  else:
235
+ logger.info(
236
+ "{} using brainstorm_{} without cue words.".format(
237
+ mode_name, bs_mode
238
+ )
239
+ )
240
+ (message_input, problem, inspirations, idea, idea_filtered) = (
 
 
241
  self.generate_without_cue_words_ins_bs(background)
242
  )
243
 
 
312
  required=False,
313
  help="The number of papers you want to process",
314
  )
315
+ def backtracking(
316
+ config_path,
317
+ ids_path,
318
+ retriever_name,
319
+ brainstorm_mode,
320
+ use_cue_words,
321
+ use_inspiration,
322
+ num,
323
+ **kwargs,
324
+ ):
325
  check_env()
326
+ check_embedding()
327
  # Configuration
328
  config = ConfigReader.load(config_path, **kwargs)
329
  logger.add(
 
340
  batch_size = 2
341
  output_dir = "./assets/output_idea/"
342
  os.makedirs(output_dir, exist_ok=True)
343
+ output_file = os.path.join(
344
+ output_dir,
345
+ f"output_backtracking_{brainstorm_mode}_cue_{use_cue_words}_ins_{use_inspiration}.json",
346
+ )
347
  if os.path.exists(output_file):
348
  with open(output_file, "r", encoding="utf-8") as f:
349
  try:
 
382
  if brainstorm_mode == "mode_c":
383
  entities_bs = api_helper.generate_entity_list(brainstorm, 10)
384
  logger.debug("Original entities from brainstorm: {}".format(entities_bs))
385
+ entities_all = list(set(entities) | set(entities_bs))
386
  else:
387
  entities_bs = None
388
  entities_all = entities
 
398
  continue
399
  # 3. 检索相关论文
400
  rt = RetrieverFactory.get_retriever_factory().create_retriever(
401
+ retriever_name, config
 
402
  )
403
  result = rt.retrieve(
404
  bg, entities_all, need_evaluate=False, target_paper_id_list=[]
 
431
  "hash_id": paper["hash_id"],
432
  "background": bg,
433
  "entities_bg": entities,
434
+ "brainstorm": brainstorm,
435
  "entities_bs": entities_bs,
436
  "entities_rt": entities_rt,
437
  "related_paper": [p["hash_id"] for p in related_paper],
 
460
  ) as f:
461
  json.dump(eval_data, f, ensure_ascii=False, indent=4)
462
 
463
+
464
  @main.command()
465
  @click.option(
466
  "-c",
 
506
  required=False,
507
  help="The number of data you want to process",
508
  )
509
+ def new_idea(
510
+ config_path,
511
+ ids_path,
512
+ retriever_name,
513
+ brainstorm_mode,
514
+ use_inspiration,
515
+ num,
516
+ **kwargs,
517
+ ):
518
  check_env()
 
519
  logger.add(
520
  "log/generate_{}_{}.log".format(time.time(), retriever_name), level="DEBUG"
521
  ) # 添加文件输出
 
523
  # Configuration
524
  config = ConfigReader.load(config_path, **kwargs)
525
  api_helper = APIHelper(config)
526
+ check_embedding(config.DEFAULT.embedding)
527
  eval_data = []
528
  cur_num = 0
529
  data_num = 0
 
531
  bg_ids = set()
532
  output_dir = "./assets/output_idea/"
533
  os.makedirs(output_dir, exist_ok=True)
534
+ output_file = os.path.join(
535
+ output_dir, f"output_new_idea_{brainstorm_mode}_ins_{use_inspiration}.json"
536
+ )
537
  if os.path.exists(output_file):
538
  with open(output_file, "r", encoding="utf-8") as f:
539
  try:
 
542
  cur_num = len(eval_data)
543
  except json.JSONDecodeError:
544
  eval_data = []
545
+ logger.debug(f"{cur_num} datas have been processed.")
546
  for line in ids_path:
547
  # 解析每行的JSON数据
548
  data = json.loads(line)
 
572
  if brainstorm_mode == "mode_c":
573
  entities_bs = api_helper.generate_entity_list(brainstorm, 10)
574
  logger.debug("Original entities from brainstorm: {}".format(entities_bs))
575
+ entities_all = list(set(entities) | set(entities_bs))
576
  else:
577
  entities_bs = None
578
  entities_all = entities
579
  # 2. 检索相关论文
580
  rt = RetrieverFactory.get_retriever_factory().create_retriever(
581
+ retriever_name, config
582
+ )
583
+ result = rt.retrieve(
584
+ bg, entities_all, need_evaluate=False, target_paper_id_list=[]
585
  )
 
586
  related_paper = result["related_paper"]
587
  logger.info("Find {} related papers...".format(len(related_paper)))
588
  entities_rt = result["entities"]
 
602
  {
603
  "background": bg,
604
  "entities_bg": entities,
605
+ "brainstorm": brainstorm,
606
  "entities_bs": entities_bs,
607
  "entities_rt": entities_rt,
608
  "related_paper": [p["hash_id"] for p in related_paper],
 
626
  with open(output_file, "w", encoding="utf-8") as f:
627
  json.dump(eval_data, f, ensure_ascii=False, indent=4)
628
 
629
+
630
  if __name__ == "__main__":
631
  main()
src/paper_manager.py CHANGED
@@ -389,10 +389,8 @@ class PaperManager:
389
  )
390
 
391
  if need_summary:
392
- print(paper.keys())
393
  if not self.check_parse(paper):
394
  logger.error(f"paper {paper['hash_id']} need parse first...")
395
-
396
  result = self.api_helper(
397
  paper["title"], paper["abstract"], paper["introduction"]
398
  )
@@ -628,9 +626,11 @@ class PaperManager:
628
 
629
  def insert_embedding(self, hash_id=None):
630
  self.paper_client.add_paper_abstract_embedding(self.embedding_model, hash_id)
631
- # self.client.add_paper_bg_embedding(self.embedding_model, hash_id)
632
- # self.client.add_paper_contribution_embedding(self.embedding_model, hash_id)
633
- # self.client.add_paper_summary_embedding(self.embedding_model, hash_id)
 
 
634
 
635
  def cosine_similarity_search(self, data_type, context, k=1):
636
  """
@@ -837,8 +837,9 @@ def local(config_path, year, venue_name, output, **kwargs):
837
  os.makedirs(os.path.dirname(output_path))
838
  config = ConfigReader.load(config_path, output_path=output_path, **kwargs)
839
  pm = PaperManager(config, venue_name, year)
 
840
  pm.update_paper_from_json_to_json(
841
- need_download=True, need_parse=True, need_summary=True, need_ground_truth=True
842
  )
843
 
844
 
 
389
  )
390
 
391
  if need_summary:
 
392
  if not self.check_parse(paper):
393
  logger.error(f"paper {paper['hash_id']} need parse first...")
 
394
  result = self.api_helper(
395
  paper["title"], paper["abstract"], paper["introduction"]
396
  )
 
626
 
627
  def insert_embedding(self, hash_id=None):
628
  self.paper_client.add_paper_abstract_embedding(self.embedding_model, hash_id)
629
+ # self.paper_client.add_paper_bg_embedding(self.embedding_model, hash_id)
630
+ # self.paper_client.add_paper_contribution_embedding(
631
+ # self.embedding_model, hash_id
632
+ # )
633
+ # self.paper_client.add_paper_summary_embedding(self.embedding_model, hash_id)
634
 
635
  def cosine_similarity_search(self, data_type, context, k=1):
636
  """
 
837
  os.makedirs(os.path.dirname(output_path))
838
  config = ConfigReader.load(config_path, output_path=output_path, **kwargs)
839
  pm = PaperManager(config, venue_name, year)
840
+ print("###")
841
  pm.update_paper_from_json_to_json(
842
+ need_download=True, need_parse=True, need_summary=True
843
  )
844
 
845
 
src/retriever.py CHANGED
@@ -41,9 +41,9 @@ def main(ctx):
41
  def retrieve(
42
  config_path, ids_path, **kwargs
43
  ):
44
- check_env()
45
- check_embedding()
46
  config = ConfigReader.load(config_path, **kwargs)
 
 
47
  log_dir = config.DEFAULT.log_dir
48
  retriever_name = config.RETRIEVE.retriever_name
49
  cluster_to_filter = config.RETRIEVE.use_cluster_to_filter
 
41
  def retrieve(
42
  config_path, ids_path, **kwargs
43
  ):
 
 
44
  config = ConfigReader.load(config_path, **kwargs)
45
+ check_embedding(config.DEFAULT.embedding)
46
+ check_env()
47
  log_dir = config.DEFAULT.log_dir
48
  retriever_name = config.RETRIEVE.retriever_name
49
  cluster_to_filter = config.RETRIEVE.use_cluster_to_filter
src/utils/api/__init__.py CHANGED
@@ -22,8 +22,10 @@ Creation Date : 2024-10-29
22
 
23
  Author : Frank Kang([email protected])
24
  """
 
25
  from .base_helper import HelperCompany
26
  from .openai_helper import OpenAIHelper # noqa: F401, ensure autoregister
27
  from .zhipuai_helper import ZhipuAIHelper # noqa: F401, ensure autoregister
 
28
 
29
  __all__ = ["HelperCompany"]
 
22
 
23
  Author : Frank Kang([email protected])
24
  """
25
+
26
  from .base_helper import HelperCompany
27
  from .openai_helper import OpenAIHelper # noqa: F401, ensure autoregister
28
  from .zhipuai_helper import ZhipuAIHelper # noqa: F401, ensure autoregister
29
+ from .local_helper import LocalHelper # noqa: F401, ensure autoregister
30
 
31
  __all__ = ["HelperCompany"]
src/utils/api/base_helper.py CHANGED
@@ -17,6 +17,9 @@ from abc import ABCMeta
17
  from typing_extensions import Literal, override
18
  from ..base_company import BaseCompany
19
  from typing import Union
 
 
 
20
 
21
 
22
  class NotGiven:
@@ -109,6 +112,31 @@ class BaseHelper:
109
  self.base_url = base_url
110
  self.client = None
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  def create(
113
  self,
114
  *args,
@@ -124,7 +152,7 @@ class BaseHelper:
124
  extra_headers: None | NotGiven = None,
125
  extra_body: None | NotGiven = None,
126
  timeout: float | None | NotGiven = None,
127
- **kwargs
128
  ):
129
  """
130
  Creates a model response for the given chat conversation.
@@ -187,20 +215,44 @@ class BaseHelper:
187
 
188
  timeout: Override the client-level default timeout for this request, in seconds
189
  """
190
- return self.client.chat.completions.create(
191
- *args,
192
- model=self.model,
193
- messages=messages,
194
- stream=stream,
195
- temperature=temperature,
196
- top_p=top_p,
197
- max_tokens=max_tokens,
198
- seed=seed,
199
- stop=stop,
200
- tools=tools,
201
- tool_choice=tool_choice,
202
- extra_headers=extra_headers,
203
- extra_body=extra_body,
204
- timeout=timeout,
205
- **kwargs
206
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  from typing_extensions import Literal, override
18
  from ..base_company import BaseCompany
19
  from typing import Union
20
+ import requests
21
+ import json
22
+ from requests.exceptions import RequestException
23
 
24
 
25
  class NotGiven:
 
112
  self.base_url = base_url
113
  self.client = None
114
 
115
+ def apply_for_service(self, data_param, max_attempts=4):
116
+ attempt = 0
117
+ while attempt < max_attempts:
118
+ try:
119
+ # print(f"尝试 #{attempt + 1}")
120
+ r = requests.post(
121
+ self.base_url + "/llm/generate",
122
+ headers={"Content-Type": "application/json"},
123
+ data=json.dumps(data_param),
124
+ )
125
+ # 检查请求是否成功
126
+ if r.status_code == 200:
127
+ # print("服务请求成功。")
128
+ response = r.json()["data"]["output"]
129
+ return response # 或者根据需要返回其他内容
130
+ else:
131
+ print("服务请求失败,响应状态码:", response.status_code)
132
+ except RequestException as e:
133
+ print("请求发生错误:", e)
134
+
135
+ attempt += 1
136
+ if attempt == max_attempts:
137
+ print("达到最大尝试次数,服务请求失败。")
138
+ return None # 或者根据你的情况抛出异常
139
+
140
  def create(
141
  self,
142
  *args,
 
152
  extra_headers: None | NotGiven = None,
153
  extra_body: None | NotGiven = None,
154
  timeout: float | None | NotGiven = None,
155
+ **kwargs,
156
  ):
157
  """
158
  Creates a model response for the given chat conversation.
 
215
 
216
  timeout: Override the client-level default timeout for this request, in seconds
217
  """
218
+ if self.model != "local":
219
+ return (
220
+ self.client.chat.completions.create(
221
+ *args,
222
+ model=self.model,
223
+ messages=messages,
224
+ stream=stream,
225
+ temperature=temperature,
226
+ top_p=top_p,
227
+ max_tokens=max_tokens,
228
+ seed=seed,
229
+ stop=stop,
230
+ tools=tools,
231
+ tool_choice=tool_choice,
232
+ extra_headers=extra_headers,
233
+ extra_body=extra_body,
234
+ timeout=timeout,
235
+ **kwargs,
236
+ )
237
+ .choices[0]
238
+ .message.content
239
+ )
240
+ else:
241
+ default_system = "You are a helpful assistant."
242
+ input_content = ""
243
+ for message in messages:
244
+ if message["role"] == "system":
245
+ default_system = message["content"]
246
+ else:
247
+ input_content += message["content"]
248
+ data_param = {}
249
+ data_param["input"] = input_content
250
+ data_param["serviceParams"] = {"stream": False, "system": default_system}
251
+ data_param["ModelParams"] = {
252
+ "temperature": 0.8,
253
+ "presence_penalty": 2.0,
254
+ "frequency_penalty": 0.0,
255
+ "top_p": 0.8,
256
+ }
257
+ response = self.apply_for_service(data_param)
258
+ return response
src/utils/api/local_helper.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r"""_summary_
2
+ -*- coding: utf-8 -*-
3
+
4
+ Module : data.utils.api.zhipuai_helper
5
+
6
+ File Name : zhipuai_helper.py
7
+
8
+ Description : Helper class for ZhipuAI interface, generally not used directly.
9
+ For example:
10
+ ```
11
+ from data.utils.api import HelperCompany
12
+ helper = HelperCompany.get()['ZhipuAI']
13
+ ...
14
+ ```
15
+
16
+ Creation Date : 2024-11-28
17
+
18
+ Author : lihuigu([email protected])
19
+ """
20
+
21
+ from .base_helper import register_helper, BaseHelper
22
+
23
+
24
+ @register_helper("Local")
25
+ class LocalHelper(BaseHelper):
26
+ """_summary_
27
+
28
+ Helper class for ZhipuAI interface, generally not used directly.
29
+
30
+ For example:
31
+ ```
32
+ from data.utils.api import HelperCompany
33
+ helper = HelperCompany.get()['Local']
34
+ ...
35
+ ```
36
+ """
37
+
38
+ def __init__(self, api_key, model, base_url=None, timeout=None):
39
+ super().__init__(api_key, model, base_url)
src/utils/hash.py CHANGED
@@ -12,18 +12,35 @@ ENV_CHECKED = False
12
  EMBEDDING_CHECKED = False
13
 
14
 
15
- def check_embedding():
 
16
  global EMBEDDING_CHECKED
17
  if not EMBEDDING_CHECKED:
18
  # Define the repository and files to download
19
- repo_id = "sentence-transformers/all-MiniLM-L6-v2" # "BAAI/bge-small-en-v1.5"
20
  local_dir = f"./assets/model/{repo_id}"
21
- files_to_download = [
22
- "config.json",
23
- "pytorch_model.bin",
24
- "tokenizer_config.json",
25
- "vocab.txt",
26
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  # Download each file and save it to the /model/bge directory
28
  for file_name in files_to_download:
29
  if not os.path.exists(os.path.join(local_dir, file_name)):
@@ -47,12 +64,15 @@ def check_env():
47
  "NEO4J_PASSWD",
48
  "MODEL_NAME",
49
  "MODEL_TYPE",
50
- "MODEL_API_KEY",
51
  "BASE_URL",
52
  ]
53
  for env_name in env_name_list:
54
  if env_name not in os.environ or os.environ[env_name] == "":
55
  raise ValueError(f"{env_name} is not set...")
 
 
 
 
56
  ENV_CHECKED = True
57
 
58
 
@@ -61,16 +81,21 @@ class EmbeddingModel:
61
 
62
  def __new__(cls, config):
63
  if cls._instance is None:
 
64
  cls._instance = super(EmbeddingModel, cls).__new__(cls)
65
  device = "cuda" if torch.cuda.is_available() else "cpu"
66
  cls._instance.embedding_model = SentenceTransformer(
67
- model_name_or_path=get_dir(config.DEFAULT.embedding),
68
  device=device,
 
69
  )
70
  print(f"==== using device {device} ====")
71
  return cls._instance
72
 
 
73
  def get_embedding_model(config):
 
 
74
  return EmbeddingModel(config).embedding_model
75
 
76
 
 
12
  EMBEDDING_CHECKED = False
13
 
14
 
15
+ def check_embedding(repo_id):
16
+ print("=== check embedding model ===")
17
  global EMBEDDING_CHECKED
18
  if not EMBEDDING_CHECKED:
19
  # Define the repository and files to download
 
20
  local_dir = f"./assets/model/{repo_id}"
21
+ if repo_id in [
22
+ "sentence-transformers/all-MiniLM-L6-v2",
23
+ "BAAI/bge-small-en-v1.5",
24
+ "BAAAI/llm_embedder",
25
+ ]:
26
+ # repo_id = "sentence-transformers/all-MiniLM-L6-v2"
27
+ # repo_id = "BAAI/bge-small-en-v1.5"
28
+ files_to_download = [
29
+ "config.json",
30
+ "pytorch_model.bin",
31
+ "tokenizer_config.json",
32
+ "vocab.txt",
33
+ ]
34
+ elif repo_id in ["Alibaba-NLP/gte-base-en-v1.5"]:
35
+ files_to_download = [
36
+ "config.json",
37
+ "model.safetensors",
38
+ "modules.json",
39
+ "tokenizer.json",
40
+ "sentence_bert_config.json",
41
+ "tokenizer_config.json",
42
+ "vocab.txt",
43
+ ]
44
  # Download each file and save it to the /model/bge directory
45
  for file_name in files_to_download:
46
  if not os.path.exists(os.path.join(local_dir, file_name)):
 
64
  "NEO4J_PASSWD",
65
  "MODEL_NAME",
66
  "MODEL_TYPE",
 
67
  "BASE_URL",
68
  ]
69
  for env_name in env_name_list:
70
  if env_name not in os.environ or os.environ[env_name] == "":
71
  raise ValueError(f"{env_name} is not set...")
72
+ if os.environ["MODEL_TYPE"] != "Local":
73
+ env_name = "MODEL_API_KEY"
74
+ if env_name not in os.environ or os.environ[env_name] == "":
75
+ raise ValueError(f"{env_name} is not set...")
76
  ENV_CHECKED = True
77
 
78
 
 
81
 
82
  def __new__(cls, config):
83
  if cls._instance is None:
84
+ local_dir = f"./assets/model/{config.DEFAULT.embedding}"
85
  cls._instance = super(EmbeddingModel, cls).__new__(cls)
86
  device = "cuda" if torch.cuda.is_available() else "cpu"
87
  cls._instance.embedding_model = SentenceTransformer(
88
+ model_name_or_path=get_dir(local_dir),
89
  device=device,
90
+ trust_remote_code=True,
91
  )
92
  print(f"==== using device {device} ====")
93
  return cls._instance
94
 
95
+
96
  def get_embedding_model(config):
97
+ print("=== get embedding model ===")
98
+ check_embedding(config.DEFAULT.embedding)
99
  return EmbeddingModel(config).embedding_model
100
 
101
 
src/utils/llms_api.py CHANGED
@@ -49,7 +49,10 @@ class APIHelper(object):
49
  def get_helper(self):
50
  MODEL_TYPE = os.environ["MODEL_TYPE"]
51
  MODEL_NAME = os.environ["MODEL_NAME"]
52
- MODEL_API_KEY = os.environ["MODEL_API_KEY"]
 
 
 
53
  BASE_URL = os.environ["BASE_URL"]
54
  return HelperCompany.get()[MODEL_TYPE](
55
  MODEL_API_KEY, MODEL_NAME, BASE_URL, timeout=None
@@ -64,6 +67,8 @@ class APIHelper(object):
64
  "glm4-air",
65
  "qwen-max",
66
  "qwen-plus",
 
 
67
  ]:
68
  raise ValueError(f"Check model name...")
69
 
@@ -78,13 +83,13 @@ class APIHelper(object):
78
  response1 = self.generator.create(
79
  messages=message,
80
  )
81
- summary = clean_text(response1.choices[0].message.content)
82
  message.append({"role": "assistant", "content": summary})
83
  message.append(self.prompt.queries[1][0]())
84
  response2 = self.generator.create(
85
  messages=message,
86
  )
87
- detail = response2.choices[0].message.content
88
  motivation = clean_text(detail.split(TAG_moti)[1].split(TAG_contr)[0])
89
  contribution = clean_text(detail.split(TAG_contr)[1])
90
  result = {
@@ -116,7 +121,7 @@ class APIHelper(object):
116
  response = self.generator.create(
117
  messages=message,
118
  )
119
- entities = response.choices[0].message.content
120
  entity_list = entities.strip().split(", ")
121
  clean_entity_list = []
122
  for entity in entity_list:
@@ -151,7 +156,7 @@ class APIHelper(object):
151
  response_brainstorming = self.generator.create(
152
  messages=message,
153
  )
154
- brainstorming_ideas = response_brainstorming.choices[0].message.content
155
 
156
  except Exception:
157
  traceback.print_exc()
@@ -178,7 +183,7 @@ class APIHelper(object):
178
  response = self.generator.create(
179
  messages=message,
180
  )
181
- problem = response.choices[0].message.content
182
  except Exception:
183
  traceback.print_exc()
184
  return None
@@ -207,7 +212,7 @@ class APIHelper(object):
207
  response = self.generator.create(
208
  messages=message,
209
  )
210
- problem = response.choices[0].message.content
211
  except Exception:
212
  traceback.print_exc()
213
  return None
@@ -228,7 +233,7 @@ class APIHelper(object):
228
  response = self.generator.create(
229
  messages=message,
230
  )
231
- inspiration = response.choices[0].message.content
232
  except Exception:
233
  traceback.print_exc()
234
  return None
@@ -254,7 +259,7 @@ class APIHelper(object):
254
  response = self.generator.create(
255
  messages=message,
256
  )
257
- inspiration = response.choices[0].message.content
258
  except Exception:
259
  traceback.print_exc()
260
  return None
@@ -282,7 +287,7 @@ class APIHelper(object):
282
  response = self.generator.create(
283
  messages=message,
284
  )
285
- idea = response.choices[0].message.content
286
  except Exception:
287
  traceback.print_exc()
288
  return None
@@ -314,7 +319,7 @@ class APIHelper(object):
314
  response = self.generator.create(
315
  messages=message,
316
  )
317
- idea = response.choices[0].message.content
318
  except Exception:
319
  traceback.print_exc()
320
  return None
@@ -340,7 +345,7 @@ class APIHelper(object):
340
  response = self.generator.create(
341
  messages=message,
342
  )
343
- idea = response.choices[0].message.content
344
  except Exception:
345
  traceback.print_exc()
346
  return None
@@ -372,7 +377,7 @@ class APIHelper(object):
372
  response = self.generator.create(
373
  messages=message,
374
  )
375
- idea = response.choices[0].message.content
376
  except Exception:
377
  traceback.print_exc()
378
  return None
@@ -391,7 +396,7 @@ class APIHelper(object):
391
  response = self.generator.create(
392
  messages=message,
393
  )
394
- idea = response.choices[0].message.content
395
  except Exception:
396
  traceback.print_exc()
397
  return None
@@ -413,7 +418,7 @@ class APIHelper(object):
413
  response = self.generator.create(
414
  messages=message,
415
  )
416
- idea_filtered = response.choices[0].message.content
417
  except Exception:
418
  traceback.print_exc()
419
  return None
@@ -435,7 +440,7 @@ class APIHelper(object):
435
  response = self.generator.create(
436
  messages=message,
437
  )
438
- idea_modified = response.choices[0].message.content
439
  except Exception:
440
  traceback.print_exc()
441
  return None
@@ -454,7 +459,7 @@ class APIHelper(object):
454
  response = self.generator.create(
455
  messages=message,
456
  )
457
- ground_truth = response.choices[0].message.content
458
  except Exception:
459
  traceback.print_exc()
460
  return ground_truth
@@ -469,7 +474,7 @@ class APIHelper(object):
469
  response = self.generator.create(
470
  messages=message,
471
  )
472
- idea_norm = response.choices[0].message.content
473
  except Exception:
474
  traceback.print_exc()
475
  return None
@@ -492,7 +497,7 @@ class APIHelper(object):
492
  messages=message,
493
  max_tokens=10,
494
  )
495
- index = response.choices[0].message.content
496
  except Exception:
497
  traceback.print_exc()
498
  return None
@@ -509,7 +514,7 @@ class APIHelper(object):
509
  messages=message,
510
  max_tokens=10,
511
  )
512
- score = response.choices[0].message.content
513
  except Exception:
514
  traceback.print_exc()
515
  return None
@@ -548,7 +553,7 @@ class APIHelper(object):
548
  stop=None,
549
  seed=0,
550
  )
551
- content = response.choices[0].message.content
552
  new_msg_history = new_msg_history + [
553
  {"role": "assistant", "content": content}
554
  ]
@@ -577,7 +582,7 @@ class APIHelper(object):
577
  response = self.generator.create(
578
  messages=message,
579
  )
580
- result = response.choices[0].message.content
581
  except Exception:
582
  traceback.print_exc()
583
  return None
@@ -601,7 +606,7 @@ class APIHelper(object):
601
  response = self.generator.create(
602
  messages=message,
603
  )
604
- result = response.choices[0].message.content
605
  except Exception:
606
  traceback.print_exc()
607
  return None
@@ -625,7 +630,7 @@ class APIHelper(object):
625
  response = self.generator.create(
626
  messages=message,
627
  )
628
- result = response.choices[0].message.content
629
  except Exception:
630
  traceback.print_exc()
631
  return None
@@ -649,7 +654,7 @@ class APIHelper(object):
649
  response = self.generator.create(
650
  messages=message,
651
  )
652
- result = response.choices[0].message.content
653
  except Exception:
654
  traceback.print_exc()
655
  return None
@@ -673,7 +678,7 @@ class APIHelper(object):
673
  response = self.generator.create(
674
  messages=message,
675
  )
676
- result = response.choices[0].message.content
677
  except Exception:
678
  traceback.print_exc()
679
  return None
 
49
  def get_helper(self):
50
  MODEL_TYPE = os.environ["MODEL_TYPE"]
51
  MODEL_NAME = os.environ["MODEL_NAME"]
52
+ if MODEL_NAME != "local":
53
+ MODEL_API_KEY = os.environ["MODEL_API_KEY"]
54
+ else:
55
+ MODEL_API_KEY = ""
56
  BASE_URL = os.environ["BASE_URL"]
57
  return HelperCompany.get()[MODEL_TYPE](
58
  MODEL_API_KEY, MODEL_NAME, BASE_URL, timeout=None
 
67
  "glm4-air",
68
  "qwen-max",
69
  "qwen-plus",
70
+ "gpt-4o-mini",
71
+ "local",
72
  ]:
73
  raise ValueError(f"Check model name...")
74
 
 
83
  response1 = self.generator.create(
84
  messages=message,
85
  )
86
+ summary = clean_text(response1)
87
  message.append({"role": "assistant", "content": summary})
88
  message.append(self.prompt.queries[1][0]())
89
  response2 = self.generator.create(
90
  messages=message,
91
  )
92
+ detail = response2
93
  motivation = clean_text(detail.split(TAG_moti)[1].split(TAG_contr)[0])
94
  contribution = clean_text(detail.split(TAG_contr)[1])
95
  result = {
 
121
  response = self.generator.create(
122
  messages=message,
123
  )
124
+ entities = response
125
  entity_list = entities.strip().split(", ")
126
  clean_entity_list = []
127
  for entity in entity_list:
 
156
  response_brainstorming = self.generator.create(
157
  messages=message,
158
  )
159
+ brainstorming_ideas = response_brainstorming
160
 
161
  except Exception:
162
  traceback.print_exc()
 
183
  response = self.generator.create(
184
  messages=message,
185
  )
186
+ problem = response
187
  except Exception:
188
  traceback.print_exc()
189
  return None
 
212
  response = self.generator.create(
213
  messages=message,
214
  )
215
+ problem = response
216
  except Exception:
217
  traceback.print_exc()
218
  return None
 
233
  response = self.generator.create(
234
  messages=message,
235
  )
236
+ inspiration = response
237
  except Exception:
238
  traceback.print_exc()
239
  return None
 
259
  response = self.generator.create(
260
  messages=message,
261
  )
262
+ inspiration = response
263
  except Exception:
264
  traceback.print_exc()
265
  return None
 
287
  response = self.generator.create(
288
  messages=message,
289
  )
290
+ idea = response
291
  except Exception:
292
  traceback.print_exc()
293
  return None
 
319
  response = self.generator.create(
320
  messages=message,
321
  )
322
+ idea = response
323
  except Exception:
324
  traceback.print_exc()
325
  return None
 
345
  response = self.generator.create(
346
  messages=message,
347
  )
348
+ idea = response
349
  except Exception:
350
  traceback.print_exc()
351
  return None
 
377
  response = self.generator.create(
378
  messages=message,
379
  )
380
+ idea = response
381
  except Exception:
382
  traceback.print_exc()
383
  return None
 
396
  response = self.generator.create(
397
  messages=message,
398
  )
399
+ idea = response
400
  except Exception:
401
  traceback.print_exc()
402
  return None
 
418
  response = self.generator.create(
419
  messages=message,
420
  )
421
+ idea_filtered = response
422
  except Exception:
423
  traceback.print_exc()
424
  return None
 
440
  response = self.generator.create(
441
  messages=message,
442
  )
443
+ idea_modified = response
444
  except Exception:
445
  traceback.print_exc()
446
  return None
 
459
  response = self.generator.create(
460
  messages=message,
461
  )
462
+ ground_truth = response
463
  except Exception:
464
  traceback.print_exc()
465
  return ground_truth
 
474
  response = self.generator.create(
475
  messages=message,
476
  )
477
+ idea_norm = response
478
  except Exception:
479
  traceback.print_exc()
480
  return None
 
497
  messages=message,
498
  max_tokens=10,
499
  )
500
+ index = response
501
  except Exception:
502
  traceback.print_exc()
503
  return None
 
514
  messages=message,
515
  max_tokens=10,
516
  )
517
+ score = response
518
  except Exception:
519
  traceback.print_exc()
520
  return None
 
553
  stop=None,
554
  seed=0,
555
  )
556
+ content = response
557
  new_msg_history = new_msg_history + [
558
  {"role": "assistant", "content": content}
559
  ]
 
582
  response = self.generator.create(
583
  messages=message,
584
  )
585
+ result = response
586
  except Exception:
587
  traceback.print_exc()
588
  return None
 
606
  response = self.generator.create(
607
  messages=message,
608
  )
609
+ result = response
610
  except Exception:
611
  traceback.print_exc()
612
  return None
 
630
  response = self.generator.create(
631
  messages=message,
632
  )
633
+ result = response
634
  except Exception:
635
  traceback.print_exc()
636
  return None
 
654
  response = self.generator.create(
655
  messages=message,
656
  )
657
+ result = response
658
  except Exception:
659
  traceback.print_exc()
660
  return None
 
678
  response = self.generator.create(
679
  messages=message,
680
  )
681
+ result = response
682
  except Exception:
683
  traceback.print_exc()
684
  return None
src/utils/paper_client.py CHANGED
@@ -8,6 +8,7 @@ from collections import defaultdict, deque
8
  from py2neo import Graph, Node, Relationship
9
  from loguru import logger
10
 
 
11
  class PaperClient:
12
  _instance = None
13
  _initialized = False
@@ -43,10 +44,28 @@ class PaperClient:
43
  with self.driver.session() as session:
44
  result = session.execute_read(lambda tx: tx.run(query).data())
45
  if result:
46
- paper_from_client = result[0]['p']
47
  if paper_from_client is not None:
48
  paper.update(paper_from_client)
49
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  def get_paper_attribute(self, paper_id, attribute_name):
51
  query = f"""
52
  MATCH (p:Paper {{hash_id: {paper_id}}})
@@ -55,11 +74,11 @@ class PaperClient:
55
  with self.driver.session() as session:
56
  result = session.execute_read(lambda tx: tx.run(query).data())
57
  if result:
58
- return result[0]['attributeValue']
59
  else:
60
  logger.error(f"paper id {paper_id} get {attribute_name} failed.")
61
  return None
62
-
63
  def get_paper_by_attribute(self, attribute_name, anttribute_value):
64
  query = f"""
65
  MATCH (p:Paper {{{attribute_name}: '{anttribute_value}'}})
@@ -68,7 +87,7 @@ class PaperClient:
68
  with self.driver.session() as session:
69
  result = session.execute_read(lambda tx: tx.run(query).data())
70
  if result:
71
- return result[0]['p']
72
  else:
73
  return None
74
 
@@ -81,71 +100,50 @@ class PaperClient:
81
  RETURN p.hash_id as hash_id
82
  """
83
  with self.driver.session() as session:
84
- result = session.execute_read(lambda tx: tx.run(query, entity=entity).data())
 
 
85
  if result:
86
- return [record['hash_id'] for record in result]
87
  else:
88
  return []
89
-
90
- def find_related_entities_by_entity(self, entity_name, n=1, k=3, relation_name="related"):
91
- # relation_name = "related"
92
- def bfs_query(entity_name, n, k):
93
- queue = deque([(entity_name, 0)])
94
- visited = set([entity_name])
95
- related_entities = set()
96
-
97
- while queue:
98
- batch_queue = [queue.popleft() for _ in range(len(queue))]
99
- batch_entities = [item[0] for item in batch_queue]
100
- batch_depths = [item[1] for item in batch_queue]
101
-
102
- if all(depth >= n for depth in batch_depths):
103
- continue
104
- if relation_name == "related":
105
- query = """
106
- UNWIND $batch_entities AS entity_name
107
- MATCH (e1:Entity {name: entity_name})-[:RELATED_TO]->(p:Paper)<-[:RELATED_TO]-(e2:Entity)
108
- WHERE e1 <> e2
109
- WITH e1, e2, COUNT(p) AS common_papers, entity_name
110
- WHERE common_papers > $k
111
- RETURN e2.name AS entities, entity_name AS source_entity, common_papers
112
- """
113
- elif relation_name == "connect":
114
- query = """
115
- UNWIND $batch_entities AS entity_name
116
- MATCH (e1:Entity {name: entity_name})-[r:CONNECT]-(e2:Entity)
117
- WHERE e1 <> e2 and r.strength >= $k
118
- WITH e1, e2, entity_name
119
- RETURN e2.name AS entities, entity_name AS source_entity
120
- """
121
- with self.driver.session() as session:
122
- result = session.execute_read(lambda tx: tx.run(query, batch_entities=batch_entities, k=k).data())
123
-
124
- for record in result:
125
- entity = record['entities']
126
- source_entity = record['source_entity']
127
- if entity not in visited:
128
- visited.add(entity)
129
- queue.append((entity, batch_depths[batch_entities.index(source_entity)] + 1))
130
- related_entities.add(entity)
131
-
132
- return list(related_entities)
133
- related_entities = bfs_query(entity_name, n, k)
134
- if entity_name in related_entities:
135
- related_entities.remove(entity_name)
136
- return related_entities
137
-
138
- def find_entities_by_paper(self, hash_id: int):
139
  query = """
140
- MATCH (e:Entity)-[:RELATED_TO]->(p:Paper {hash_id: $hash_id})
141
- RETURN e.name AS entity_name
 
 
 
 
142
  """
143
  with self.driver.session() as session:
144
- result = session.execute_read(lambda tx: tx.run(query, hash_id=hash_id).data())
145
- if result:
146
- return [record['entity_name'] for record in result]
147
- else:
148
- return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
  def find_paper_by_entity(self, entity_name):
151
  query = """
@@ -153,18 +151,19 @@ class PaperClient:
153
  RETURN p.hash_id AS hash_id
154
  """
155
  with self.driver.session() as session:
156
- result = session.execute_read(lambda tx: tx.run(query, entity_name=entity_name).data())
 
 
157
  if result:
158
- return [record['hash_id'] for record in result]
159
  else:
160
  return []
161
-
162
  # TODO: @云翔
163
  # 增加通过entity返回包含entity语句的功能
164
  def find_sentence_by_entity(self, entity_name):
165
  # Return: list(str)
166
  return []
167
-
168
 
169
  def find_sentences_by_entity(self, entity_name):
170
  query = """
@@ -178,14 +177,25 @@ class PaperClient:
178
  p.hash_id AS hash_id
179
  """
180
  sentences = []
181
-
182
  with self.driver.session() as session:
183
- result = session.execute_read(lambda tx: tx.run(query, entity_name=entity_name).data())
 
 
184
  for record in result:
185
- for key in ['abstract', 'introduction', 'methodology']:
186
  if record[key]:
187
- filtered_sentences = [sentence.strip() + '.' for sentence in record[key].split('.') if entity_name in sentence]
188
- sentences.extend([f"{record['hash_id']}: {sentence}" for sentence in filtered_sentences])
 
 
 
 
 
 
 
 
 
189
 
190
  return sentences
191
 
@@ -194,9 +204,11 @@ class PaperClient:
194
  MATCH (n:Paper) where n.year=$year and n.venue_name=$venue_name return n
195
  """
196
  with self.driver.session() as session:
197
- result = session.execute_read(lambda tx: tx.run(query, year=year, venue_name=venue_name).data())
 
 
198
  if result:
199
- return [record['n'] for record in result]
200
  else:
201
  return []
202
 
@@ -230,7 +242,26 @@ class PaperClient:
230
  RETURN p
231
  """
232
  with self.driver.session() as session:
233
- result = session.execute_write(lambda tx: tx.run(query, hash_id=paper["hash_id"], venue_name=paper["venue_name"], year=paper["year"], title=paper["title"], pdf_url=paper["pdf_url"], abstract=paper["abstract"], introduction=paper["introduction"], reference=paper["reference"], summary=paper["summary"], motivation=paper["motivation"], contribution=paper["contribution"], methodology=paper["methodology"], ground_truth=paper["ground_truth"], reference_filter=paper["reference_filter"], conclusions=paper["conclusions"]).data())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
  def check_entity_node_count(self, hash_id: int):
236
  query_check_count = """
@@ -239,7 +270,9 @@ class PaperClient:
239
  """
240
  with self.driver.session() as session:
241
  # Check the number of related entities
242
- result = session.execute_read(lambda tx: tx.run(query_check_count, hash_id=hash_id).data())
 
 
243
  if result[0]["entity_count"] > 3:
244
  return False
245
  return True
@@ -254,16 +287,30 @@ class PaperClient:
254
  """
255
  with self.driver.session() as session:
256
  for entity_name in entities:
257
- result = session.execute_write(lambda tx: tx.run(query, entity_name=entity_name, hash_id=hash_id).data())
258
-
 
 
 
 
259
  def add_paper_citation(self, paper: dict):
260
  query = """
261
  MERGE (p:Paper {hash_id: $hash_id}) ON MATCH SET p.cite_id_list = $cite_id_list, p.entities = $entities, p.all_cite_id_list = $all_cite_id_list
262
  """
263
  with self.driver.session() as session:
264
- result = session.execute_write(lambda tx: tx.run(query, hash_id=paper["hash_id"], cite_id_list=paper["cite_id_list"], entities=paper["entities"], all_cite_id_list=paper["all_cite_id_list"]).data())
265
-
266
- def add_paper_abstract_embedding(self, embedding_model, hash_id=None):
 
 
 
 
 
 
 
 
 
 
267
  if hash_id is not None:
268
  query = """
269
  MATCH (p:Paper {hash_id: $hash_id})
@@ -271,119 +318,302 @@ class PaperClient:
271
  RETURN p.abstract AS context, p.hash_id AS hash_id, p.title AS title
272
  """
273
  with self.driver.session() as session:
274
- results = session.execute_write(lambda tx: tx.run(query, hash_id=hash_id).data())
275
- else:
 
 
 
 
 
 
276
  query = """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
  MATCH (p:Paper)
278
  WHERE p.abstract IS NOT NULL
279
  RETURN p.abstract AS context, p.hash_id AS hash_id, p.title AS title
 
280
  """
281
  with self.driver.session() as session:
282
- results = session.execute_write(lambda tx: tx.run(query).data())
283
- contexts = [result["title"] + result["context"] for result in results]
284
- paper_ids = [result["hash_id"] for result in results]
285
- context_embeddings = embedding_model.encode(contexts, batch_size=512, convert_to_tensor=True, device=self.device)
286
- query = """
287
- MERGE (p:Paper {hash_id: $hash_id})
288
- ON CREATE SET p.abstract_embedding = $embedding
289
- ON MATCH SET p.abstract_embedding = $embedding
290
- """
291
- for idx, hash_id in tqdm(enumerate(paper_ids)):
292
- embedding = context_embeddings[idx].detach().cpu().numpy().flatten().tolist()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
  with self.driver.session() as session:
294
- results = session.execute_write(lambda tx: tx.run(query, hash_id=hash_id, embedding=embedding).data())
 
 
 
 
295
 
296
- def add_paper_bg_embedding(self, embedding_model, hash_id=None):
297
  if hash_id is not None:
298
  query = """
299
  MATCH (p:Paper {hash_id: $hash_id})
300
  WHERE p.motivation IS NOT NULL
301
- RETURN p.motivation AS context, p.hash_id AS hash_id
302
  """
303
  with self.driver.session() as session:
304
- results = session.execute_write(lambda tx: tx.run(query, hash_id=hash_id).data())
305
- else:
 
 
 
 
 
 
306
  query = """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  MATCH (p:Paper)
308
  WHERE p.motivation IS NOT NULL
309
- RETURN p.motivation AS context, p.hash_id AS hash_id
 
310
  """
311
  with self.driver.session() as session:
312
- results = session.execute_write(lambda tx: tx.run(query).data())
313
- contexts = [result["context"] for result in results]
314
- paper_ids = [result["hash_id"] for result in results]
315
- context_embeddings = embedding_model.encode(contexts, batch_size=256, convert_to_tensor=True, device=self.device)
316
- query = """
317
- MERGE (p:Paper {hash_id: $hash_id})
318
- ON CREATE SET p.embedding = $embedding
319
- ON MATCH SET p.embedding = $embedding
320
- """
321
- for idx, hash_id in tqdm(enumerate(paper_ids)):
322
- embedding = context_embeddings[idx].detach().cpu().numpy().flatten().tolist()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
  with self.driver.session() as session:
324
- results = session.execute_write(lambda tx: tx.run(query, hash_id=hash_id, embedding=embedding).data())
325
-
326
- def add_paper_contribution_embedding(self, embedding_model, hash_id=None):
 
 
 
 
 
 
327
  if hash_id is not None:
328
  query = """
329
  MATCH (p:Paper {hash_id: $hash_id})
330
  WHERE p.contribution IS NOT NULL
331
- RETURN p.contribution AS context, p.hash_id AS hash_id
332
  """
333
  with self.driver.session() as session:
334
- results = session.execute_write(lambda tx: tx.run(query, hash_id=hash_id).data())
335
- else:
 
 
 
 
 
 
336
  query = """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
  MATCH (p:Paper)
338
  WHERE p.contribution IS NOT NULL
339
- RETURN p.contribution AS context, p.hash_id AS hash_id
 
340
  """
341
  with self.driver.session() as session:
342
- results = session.execute_write(lambda tx: tx.run(query).data())
343
- contexts = [result["context"] for result in results]
344
- paper_ids = [result["hash_id"] for result in results]
345
- context_embeddings = embedding_model.encode(contexts, batch_size=256, convert_to_tensor=True, device=self.device)
346
- query = """
347
- MERGE (p:Paper {hash_id: $hash_id})
348
- ON CREATE SET p.contribution_embedding = $embedding
349
- ON MATCH SET p.contribution_embedding = $embedding
350
- """
351
- for idx, hash_id in tqdm(enumerate(paper_ids)):
352
- embedding = context_embeddings[idx].detach().cpu().numpy().flatten().tolist()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  with self.driver.session() as session:
354
- results = session.execute_write(lambda tx: tx.run(query, hash_id=hash_id, embedding=embedding).data())
355
-
356
-
357
- def add_paper_summary_embedding(self, embedding_model, hash_id=None):
 
 
 
 
 
358
  if hash_id is not None:
359
  query = """
360
  MATCH (p:Paper {hash_id: $hash_id})
361
  WHERE p.summary IS NOT NULL
362
- RETURN p.summary AS context, p.hash_id AS hash_id
363
  """
364
  with self.driver.session() as session:
365
- results = session.execute_write(lambda tx: tx.run(query, hash_id=hash_id).data())
366
- else:
 
 
 
 
 
 
367
  query = """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368
  MATCH (p:Paper)
369
  WHERE p.summary IS NOT NULL
370
- RETURN p.summary AS context, p.hash_id AS hash_id
 
371
  """
372
  with self.driver.session() as session:
373
- results = session.execute_write(lambda tx: tx.run(query).data())
374
- contexts = [result["context"] for result in results]
375
- paper_ids = [result["hash_id"] for result in results]
376
- context_embeddings = embedding_model.encode(contexts, batch_size=256, convert_to_tensor=True, device=self.device)
377
- query = """
378
- MERGE (p:Paper {hash_id: $hash_id})
379
- ON CREATE SET p.summary_embedding = $embedding
380
- ON MATCH SET p.summary_embedding = $embedding
381
- """
382
- for idx, hash_id in tqdm(enumerate(paper_ids)):
383
- embedding = context_embeddings[idx].detach().cpu().numpy().flatten().tolist()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
  with self.driver.session() as session:
385
- results = session.execute_write(lambda tx: tx.run(query, hash_id=hash_id, embedding=embedding).data())
386
-
 
 
 
 
387
  def cosine_similarity_search(self, embedding, k=1, type_name="embedding"):
388
  query = f"""
389
  MATCH (paper:Paper)
@@ -394,8 +624,10 @@ class PaperClient:
394
  ORDER BY score DESC LIMIT {k}
395
  """
396
  with self.driver.session() as session:
397
- results = session.execute_read(lambda tx: tx.run(query, embedding=embedding).data())
398
- related_paper = []
 
 
399
  for result in results:
400
  related_paper.append(result["paper"]["hash_id"])
401
  return related_paper
@@ -417,7 +649,7 @@ class PaperClient:
417
  """
418
  with self.driver.session() as session:
419
  session.execute_write(lambda tx: tx.run(query).data())
420
-
421
  def filter_paper_id_list(self, paper_id_list, year="2024"):
422
  if not paper_id_list:
423
  return []
@@ -429,12 +661,14 @@ class PaperClient:
429
  RETURN p.hash_id AS hash_id
430
  """
431
  with self.driver.session() as session:
432
- result = session.execute_read(lambda tx: tx.run(query, paper_id_list=paper_id_list, year=year).data())
433
-
434
- existing_paper_ids = [record['hash_id'] for record in result]
 
 
435
  existing_paper_ids = list(set(existing_paper_ids))
436
  return existing_paper_ids
437
-
438
  def check_index_exists(self):
439
  query = "SHOW INDEXES"
440
  with self.driver.session() as session:
@@ -451,7 +685,7 @@ class PaperClient:
451
  """
452
  with self.driver.session() as session:
453
  session.execute_write(lambda tx: tx.run(query).data())
454
-
455
  def get_entity_related_paper_num(self, entity_name):
456
  query = """
457
  MATCH (e:Entity {name: $entity_name})-[:RELATED_TO]->(p:Paper)
@@ -459,10 +693,30 @@ class PaperClient:
459
  RETURN PaperCount
460
  """
461
  with self.driver.session() as session:
462
- result = session.execute_read(lambda tx: tx.run(query, entity_name=entity_name).data())
463
- paper_num = result[0]['PaperCount']
 
 
464
  return paper_num
465
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466
  def get_entity_text(self):
467
  query = """
468
  MATCH (e:Entity)-[:RELATED_TO]->(p:Paper)
@@ -472,11 +726,13 @@ class PaperClient:
472
  """
473
  with self.driver.session() as session:
474
  result = session.execute_read(lambda tx: tx.run(query).data())
475
- text_list = [record['entity_text'] for record in result]
476
  return text_list
477
-
478
  def get_entity_combinations(self, venue_name, year):
479
- def process_paper_relationships(session, entity_name_1, entity_name_2, abstract):
 
 
480
  if entity_name_2 < entity_name_1:
481
  entity_name_1, entity_name_2 = entity_name_2, entity_name_1
482
  query = """
@@ -486,13 +742,17 @@ class PaperClient:
486
  ON CREATE SET r.strength = 1
487
  ON MATCH SET r.strength = r.strength + 1
488
  """
489
- sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', abstract)
490
  for sentence in sentences:
491
  sentence = sentence.lower()
492
  if entity_name_1 in sentence and entity_name_2 in sentence:
493
  # 如果两个实体在同一句话中出现过,则创建或更新 CONNECT 关系
494
  session.execute_write(
495
- lambda tx: tx.run(query, entity_name_1=entity_name_1, entity_name_2=entity_name_2).data()
 
 
 
 
496
  )
497
  # logger.debug(f"CONNECT relation created or updated between {entity_name_1} and {entity_name_2} for Paper ID {paper_id}")
498
  break # 如果找到一次出现就可以退出循环
@@ -506,13 +766,17 @@ class PaperClient:
506
  RETURN p.hash_id AS hash_id, entities[i].name AS entity_name_1, entities[j].name AS entity_name_2
507
  """
508
  with self.driver.session() as session:
509
- result = session.execute_read(lambda tx: tx.run(query, venue_name=venue_name, year=year).data())
 
 
510
  for record in tqdm(result):
511
  paper_id = record["hash_id"]
512
- entity_name_1 = record['entity_name_1']
513
- entity_name_2 = record['entity_name_2']
514
  abstract = self.get_paper_attribute(paper_id, "abstract")
515
- process_paper_relationships(session, entity_name_1, entity_name_2, abstract)
 
 
516
 
517
  def build_citemap(self):
518
  citemap = defaultdict(set)
@@ -523,8 +787,8 @@ class PaperClient:
523
  with self.driver.session() as session:
524
  results = session.execute_read(lambda tx: tx.run(query).data())
525
  for result in results:
526
- hash_id = result['hash_id']
527
- cite_id_list = result['cite_id_list']
528
  if cite_id_list:
529
  for cited_id in cite_id_list:
530
  citemap[hash_id].add(cited_id)
@@ -537,12 +801,17 @@ class PaperClient:
537
  AUTH = (NEO4J_USERNAME, NEO4J_PASSWD)
538
  graph = Graph(URI, auth=AUTH)
539
  # 创建一个字典来保存数据
 
540
  data = {"nodes": [], "relationships": []}
541
- query = """
 
 
 
 
542
  MATCH (e:Entity)-[r:RELATED_TO]->(p:Paper)
543
- WHERE p.venue_name='iclr' and p.year='2024'
544
  RETURN p, e, r
545
  """
 
546
  results = graph.run(query)
547
  # 处理查询结果
548
  for record in tqdm(results):
@@ -550,39 +819,46 @@ class PaperClient:
550
  entity_node = record["e"]
551
  relationship = record["r"]
552
  # 将节点数据加入字典
553
- data["nodes"].append({
554
- "id": paper_node.identity,
555
- "label": "Paper",
556
- "properties": dict(paper_node)
557
- })
558
- data["nodes"].append({
559
- "id": entity_node.identity,
560
- "label": "Entity",
561
- "properties": dict(entity_node)
562
- })
 
 
 
 
563
  # 将关系数据加入字典
564
- data["relationships"].append({
565
- "start_node": entity_node.identity,
566
- "end_node": paper_node.identity,
567
- "type": "RELATED_TO",
568
- "properties": dict(relationship)
569
- })
 
 
 
570
  query = """
571
  MATCH (p:Paper)
572
  WHERE p.venue_name='acl' and p.year='2024'
573
  RETURN p
574
  """
575
- """
576
  results = graph.run(query)
577
  for record in tqdm(results):
578
  paper_node = record["p"]
579
  # 将节点数据加入字典
580
- data["nodes"].append({
581
- "id": paper_node.identity,
582
- "label": "Paper",
583
- "properties": dict(paper_node)
584
- })
585
- """
 
586
  # 去除重复节点
587
  # data["nodes"] = [dict(t) for t in {tuple(d.items()) for d in data["nodes"]}]
588
  unique_nodes = []
@@ -595,9 +871,11 @@ class PaperClient:
595
  unique_nodes.append(node)
596
  data["nodes"] = unique_nodes
597
  # 将数据保存为 JSON 文件
598
- with open("./assets/data/scipip_neo4j_clean_backup.json", "w", encoding="utf-8") as f:
 
 
599
  json.dump(data, f, ensure_ascii=False, indent=4)
600
-
601
  def neo4j_import_data(self):
602
  # clear_database() # 清空数据库,谨慎执行
603
  URI = os.environ["NEO4J_URL"]
@@ -606,7 +884,9 @@ class PaperClient:
606
  AUTH = (NEO4J_USERNAME, NEO4J_PASSWD)
607
  graph = Graph(URI, auth=AUTH)
608
  # 从 JSON 文件中读取数据
609
- with open("./assets/data/scipip_neo4j_clean_backup.json", "r", encoding="utf-8") as f:
 
 
610
  data = json.load(f)
611
  # 创建节点
612
  nodes = {}
 
8
  from py2neo import Graph, Node, Relationship
9
  from loguru import logger
10
 
11
+
12
  class PaperClient:
13
  _instance = None
14
  _initialized = False
 
44
  with self.driver.session() as session:
45
  result = session.execute_read(lambda tx: tx.run(query).data())
46
  if result:
47
+ paper_from_client = result[0]["p"]
48
  if paper_from_client is not None:
49
  paper.update(paper_from_client)
50
+
51
+ def update_papers_from_client(self, paper_id_list):
52
+ query = """
53
+ UNWIND $papers AS paper
54
+ MATCH (p:Paper {hash_id: paper.hash_id})
55
+ RETURN p as result
56
+ """
57
+ paper_data = [
58
+ {
59
+ "hash_id": hash_id,
60
+ }
61
+ for hash_id in paper_id_list
62
+ ]
63
+ with self.driver.session() as session:
64
+ result = session.execute_read(
65
+ lambda tx: tx.run(query, papers=paper_data).data()
66
+ )
67
+ return [r["result"] for r in result]
68
+
69
  def get_paper_attribute(self, paper_id, attribute_name):
70
  query = f"""
71
  MATCH (p:Paper {{hash_id: {paper_id}}})
 
74
  with self.driver.session() as session:
75
  result = session.execute_read(lambda tx: tx.run(query).data())
76
  if result:
77
+ return result[0]["attributeValue"]
78
  else:
79
  logger.error(f"paper id {paper_id} get {attribute_name} failed.")
80
  return None
81
+
82
  def get_paper_by_attribute(self, attribute_name, anttribute_value):
83
  query = f"""
84
  MATCH (p:Paper {{{attribute_name}: '{anttribute_value}'}})
 
87
  with self.driver.session() as session:
88
  result = session.execute_read(lambda tx: tx.run(query).data())
89
  if result:
90
+ return result[0]["p"]
91
  else:
92
  return None
93
 
 
100
  RETURN p.hash_id as hash_id
101
  """
102
  with self.driver.session() as session:
103
+ result = session.execute_read(
104
+ lambda tx: tx.run(query, entity=entity).data()
105
+ )
106
  if result:
107
+ return [record["hash_id"] for record in result]
108
  else:
109
  return []
110
+
111
+ def find_related_entities_by_entity_list(
112
+ self, entity_names, n=1, k=3, relation_name="related"
113
+ ):
114
+ related_entities = set()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  query = """
116
+ UNWIND $batch_entities AS entity_name
117
+ MATCH (e1:Entity {name: entity_name})-[:RELATED_TO]->(p:Paper)<-[:RELATED_TO]-(e2:Entity)
118
+ WHERE e1 <> e2
119
+ WITH e1, e2, COUNT(p) AS common_papers, entity_name
120
+ WHERE common_papers > $k
121
+ RETURN e2.name AS entities, entity_name AS source_entity, common_papers
122
  """
123
  with self.driver.session() as session:
124
+ result = session.execute_read(
125
+ lambda tx: tx.run(query, batch_entities=entity_names, k=k).data()
126
+ )
127
+ for record in result:
128
+ entity = record["entities"]
129
+ related_entities.add(entity)
130
+ return list(related_entities)
131
+
132
+ def find_entities_by_paper_list(self, hash_ids: list):
133
+ query = """
134
+ UNWIND $hash_ids AS hash_id
135
+ MATCH (e:Entity)-[:RELATED_TO]->(p:Paper {hash_id: hash_id})
136
+ RETURN hash_id, e.name AS entity_name
137
+ """
138
+ with self.driver.session() as session:
139
+ result = session.execute_read(
140
+ lambda tx: tx.run(query, hash_ids=hash_ids).data()
141
+ )
142
+ # 按照每个 hash_id 分组实体
143
+ entity_list = []
144
+ for record in result:
145
+ entity_list.append(record["entity_name"])
146
+ return entity_list
147
 
148
  def find_paper_by_entity(self, entity_name):
149
  query = """
 
151
  RETURN p.hash_id AS hash_id
152
  """
153
  with self.driver.session() as session:
154
+ result = session.execute_read(
155
+ lambda tx: tx.run(query, entity_name=entity_name).data()
156
+ )
157
  if result:
158
+ return [record["hash_id"] for record in result]
159
  else:
160
  return []
161
+
162
  # TODO: @云翔
163
  # 增加通过entity返回包含entity语句的功能
164
  def find_sentence_by_entity(self, entity_name):
165
  # Return: list(str)
166
  return []
 
167
 
168
  def find_sentences_by_entity(self, entity_name):
169
  query = """
 
177
  p.hash_id AS hash_id
178
  """
179
  sentences = []
180
+
181
  with self.driver.session() as session:
182
+ result = session.execute_read(
183
+ lambda tx: tx.run(query, entity_name=entity_name).data()
184
+ )
185
  for record in result:
186
+ for key in ["abstract", "introduction", "methodology"]:
187
  if record[key]:
188
+ filtered_sentences = [
189
+ sentence.strip() + "."
190
+ for sentence in record[key].split(".")
191
+ if entity_name in sentence
192
+ ]
193
+ sentences.extend(
194
+ [
195
+ f"{record['hash_id']}: {sentence}"
196
+ for sentence in filtered_sentences
197
+ ]
198
+ )
199
 
200
  return sentences
201
 
 
204
  MATCH (n:Paper) where n.year=$year and n.venue_name=$venue_name return n
205
  """
206
  with self.driver.session() as session:
207
+ result = session.execute_read(
208
+ lambda tx: tx.run(query, year=year, venue_name=venue_name).data()
209
+ )
210
  if result:
211
+ return [record["n"] for record in result]
212
  else:
213
  return []
214
 
 
242
  RETURN p
243
  """
244
  with self.driver.session() as session:
245
+ result = session.execute_write(
246
+ lambda tx: tx.run(
247
+ query,
248
+ hash_id=paper["hash_id"],
249
+ venue_name=paper["venue_name"],
250
+ year=paper["year"],
251
+ title=paper["title"],
252
+ pdf_url=paper["pdf_url"],
253
+ abstract=paper["abstract"],
254
+ introduction=paper["introduction"],
255
+ reference=paper["reference"],
256
+ summary=paper["summary"],
257
+ motivation=paper["motivation"],
258
+ contribution=paper["contribution"],
259
+ methodology=paper["methodology"],
260
+ ground_truth=paper["ground_truth"],
261
+ reference_filter=paper["reference_filter"],
262
+ conclusions=paper["conclusions"],
263
+ ).data()
264
+ )
265
 
266
  def check_entity_node_count(self, hash_id: int):
267
  query_check_count = """
 
270
  """
271
  with self.driver.session() as session:
272
  # Check the number of related entities
273
+ result = session.execute_read(
274
+ lambda tx: tx.run(query_check_count, hash_id=hash_id).data()
275
+ )
276
  if result[0]["entity_count"] > 3:
277
  return False
278
  return True
 
287
  """
288
  with self.driver.session() as session:
289
  for entity_name in entities:
290
+ result = session.execute_write(
291
+ lambda tx: tx.run(
292
+ query, entity_name=entity_name, hash_id=hash_id
293
+ ).data()
294
+ )
295
+
296
  def add_paper_citation(self, paper: dict):
297
  query = """
298
  MERGE (p:Paper {hash_id: $hash_id}) ON MATCH SET p.cite_id_list = $cite_id_list, p.entities = $entities, p.all_cite_id_list = $all_cite_id_list
299
  """
300
  with self.driver.session() as session:
301
+ result = session.execute_write(
302
+ lambda tx: tx.run(
303
+ query,
304
+ hash_id=paper["hash_id"],
305
+ cite_id_list=paper["cite_id_list"],
306
+ entities=paper["entities"],
307
+ all_cite_id_list=paper["all_cite_id_list"],
308
+ ).data()
309
+ )
310
+
311
+ def add_paper_abstract_embedding(
312
+ self, embedding_model, hash_id=None, batch_size=512
313
+ ):
314
  if hash_id is not None:
315
  query = """
316
  MATCH (p:Paper {hash_id: $hash_id})
 
318
  RETURN p.abstract AS context, p.hash_id AS hash_id, p.title AS title
319
  """
320
  with self.driver.session() as session:
321
+ results = session.execute_write(
322
+ lambda tx: tx.run(query, hash_id=hash_id).data()
323
+ )
324
+ contexts = [result["title"] + result["context"] for result in results]
325
+ paper_ids = [result["hash_id"] for result in results]
326
+ context_embeddings = embedding_model.encode(
327
+ contexts, convert_to_tensor=True, device=self.device
328
+ )
329
  query = """
330
+ MERGE (p:Paper {hash_id: $hash_id})
331
+ ON CREATE SET p.abstract_embedding = $embedding
332
+ ON MATCH SET p.abstract_embedding = $embedding
333
+ """
334
+ for idx, hash_id in tqdm(enumerate(paper_ids)):
335
+ embedding = (
336
+ context_embeddings[idx].detach().cpu().numpy().flatten().tolist()
337
+ )
338
+ with self.driver.session() as session:
339
+ results = session.execute_write(
340
+ lambda tx: tx.run(
341
+ query, hash_id=hash_id, embedding=embedding
342
+ ).data()
343
+ )
344
+ return
345
+ offset = 0
346
+ while True:
347
+ query = f"""
348
  MATCH (p:Paper)
349
  WHERE p.abstract IS NOT NULL
350
  RETURN p.abstract AS context, p.hash_id AS hash_id, p.title AS title
351
+ SKIP $offset LIMIT $batch_size
352
  """
353
  with self.driver.session() as session:
354
+ results = session.execute_write(
355
+ lambda tx: tx.run(
356
+ query, offset=offset, batch_size=batch_size
357
+ ).data()
358
+ )
359
+ if not results:
360
+ break
361
+ contexts = [result["title"] + result["context"] for result in results]
362
+ paper_ids = [result["hash_id"] for result in results]
363
+ context_embeddings = embedding_model.encode(
364
+ contexts,
365
+ batch_size=batch_size,
366
+ convert_to_tensor=True,
367
+ device=self.device,
368
+ )
369
+ write_query = """
370
+ UNWIND $data AS row
371
+ MERGE (p:Paper {hash_id: row.hash_id})
372
+ ON CREATE SET p.abstract_embedding = row.embedding
373
+ ON MATCH SET p.abstract_embedding = row.embedding
374
+ """
375
+ data_to_write = []
376
+ for idx, hash_id in enumerate(paper_ids):
377
+ embedding = (
378
+ context_embeddings[idx].detach().cpu().numpy().flatten().tolist()
379
+ )
380
+ data_to_write.append({"hash_id": hash_id, "embedding": embedding})
381
  with self.driver.session() as session:
382
+ session.execute_write(
383
+ lambda tx: tx.run(write_query, data=data_to_write)
384
+ )
385
+ offset += batch_size
386
+ logger.info(f"== Processed batch starting at offset {offset} ==")
387
 
388
+ def add_paper_bg_embedding(self, embedding_model, hash_id=None, batch_size=512):
389
  if hash_id is not None:
390
  query = """
391
  MATCH (p:Paper {hash_id: $hash_id})
392
  WHERE p.motivation IS NOT NULL
393
+ RETURN p.motivation AS context, p.hash_id AS hash_id, p.title AS title
394
  """
395
  with self.driver.session() as session:
396
+ results = session.execute_write(
397
+ lambda tx: tx.run(query, hash_id=hash_id).data()
398
+ )
399
+ contexts = [result["context"] for result in results]
400
+ paper_ids = [result["hash_id"] for result in results]
401
+ context_embeddings = embedding_model.encode(
402
+ contexts, convert_to_tensor=True, device=self.device
403
+ )
404
  query = """
405
+ MERGE (p:Paper {hash_id: $hash_id})
406
+ ON CREATE SET p.motivation_embedding = $embedding
407
+ ON MATCH SET p.motivation_embedding = $embedding
408
+ """
409
+ for idx, hash_id in tqdm(enumerate(paper_ids)):
410
+ embedding = (
411
+ context_embeddings[idx].detach().cpu().numpy().flatten().tolist()
412
+ )
413
+ with self.driver.session() as session:
414
+ results = session.execute_write(
415
+ lambda tx: tx.run(
416
+ query, hash_id=hash_id, embedding=embedding
417
+ ).data()
418
+ )
419
+ return
420
+ offset = 0
421
+ while True:
422
+ query = f"""
423
  MATCH (p:Paper)
424
  WHERE p.motivation IS NOT NULL
425
+ RETURN p.motivation AS context, p.hash_id AS hash_id, p.title AS title
426
+ SKIP $offset LIMIT $batch_size
427
  """
428
  with self.driver.session() as session:
429
+ results = session.execute_write(
430
+ lambda tx: tx.run(
431
+ query, offset=offset, batch_size=batch_size
432
+ ).data()
433
+ )
434
+ if not results:
435
+ break
436
+ contexts = [result["title"] + result["context"] for result in results]
437
+ paper_ids = [result["hash_id"] for result in results]
438
+ context_embeddings = embedding_model.encode(
439
+ contexts,
440
+ batch_size=batch_size,
441
+ convert_to_tensor=True,
442
+ device=self.device,
443
+ )
444
+ write_query = """
445
+ UNWIND $data AS row
446
+ MERGE (p:Paper {hash_id: row.hash_id})
447
+ ON CREATE SET p.motivation_embedding = row.embedding
448
+ ON MATCH SET p.motivation_embedding = row.embedding
449
+ """
450
+ data_to_write = []
451
+ for idx, hash_id in enumerate(paper_ids):
452
+ embedding = (
453
+ context_embeddings[idx].detach().cpu().numpy().flatten().tolist()
454
+ )
455
+ data_to_write.append({"hash_id": hash_id, "embedding": embedding})
456
  with self.driver.session() as session:
457
+ session.execute_write(
458
+ lambda tx: tx.run(write_query, data=data_to_write)
459
+ )
460
+ offset += batch_size
461
+ logger.info(f"== Processed batch starting at offset {offset} ==")
462
+
463
+ def add_paper_contribution_embedding(
464
+ self, embedding_model, hash_id=None, batch_size=512
465
+ ):
466
  if hash_id is not None:
467
  query = """
468
  MATCH (p:Paper {hash_id: $hash_id})
469
  WHERE p.contribution IS NOT NULL
470
+ RETURN p.contribution AS context, p.hash_id AS hash_id, p.title AS title
471
  """
472
  with self.driver.session() as session:
473
+ results = session.execute_write(
474
+ lambda tx: tx.run(query, hash_id=hash_id).data()
475
+ )
476
+ contexts = [result["context"] for result in results]
477
+ paper_ids = [result["hash_id"] for result in results]
478
+ context_embeddings = embedding_model.encode(
479
+ contexts, convert_to_tensor=True, device=self.device
480
+ )
481
  query = """
482
+ MERGE (p:Paper {hash_id: $hash_id})
483
+ ON CREATE SET p.contribution_embedding = $embedding
484
+ ON MATCH SET p.contribution_embedding = $embedding
485
+ """
486
+ for idx, hash_id in tqdm(enumerate(paper_ids)):
487
+ embedding = (
488
+ context_embeddings[idx].detach().cpu().numpy().flatten().tolist()
489
+ )
490
+ with self.driver.session() as session:
491
+ results = session.execute_write(
492
+ lambda tx: tx.run(
493
+ query, hash_id=hash_id, embedding=embedding
494
+ ).data()
495
+ )
496
+ return
497
+ offset = 0
498
+ while True:
499
+ query = f"""
500
  MATCH (p:Paper)
501
  WHERE p.contribution IS NOT NULL
502
+ RETURN p.contribution AS context, p.hash_id AS hash_id, p.title AS title
503
+ SKIP $offset LIMIT $batch_size
504
  """
505
  with self.driver.session() as session:
506
+ results = session.execute_write(
507
+ lambda tx: tx.run(
508
+ query, offset=offset, batch_size=batch_size
509
+ ).data()
510
+ )
511
+ if not results:
512
+ break
513
+ contexts = [result["context"] for result in results]
514
+ paper_ids = [result["hash_id"] for result in results]
515
+ context_embeddings = embedding_model.encode(
516
+ contexts,
517
+ batch_size=batch_size,
518
+ convert_to_tensor=True,
519
+ device=self.device,
520
+ )
521
+ write_query = """
522
+ UNWIND $data AS row
523
+ MERGE (p:Paper {hash_id: row.hash_id})
524
+ ON CREATE SET p.contribution_embedding = row.embedding
525
+ ON MATCH SET p.contribution_embedding = row.embedding
526
+ """
527
+ data_to_write = []
528
+ for idx, hash_id in enumerate(paper_ids):
529
+ embedding = (
530
+ context_embeddings[idx].detach().cpu().numpy().flatten().tolist()
531
+ )
532
+ data_to_write.append({"hash_id": hash_id, "embedding": embedding})
533
  with self.driver.session() as session:
534
+ session.execute_write(
535
+ lambda tx: tx.run(write_query, data=data_to_write)
536
+ )
537
+ offset += batch_size
538
+ logger.info(f"== Processed batch starting at offset {offset} ==")
539
+
540
+ def add_paper_summary_embedding(
541
+ self, embedding_model, hash_id=None, batch_size=512
542
+ ):
543
  if hash_id is not None:
544
  query = """
545
  MATCH (p:Paper {hash_id: $hash_id})
546
  WHERE p.summary IS NOT NULL
547
+ RETURN p.summary AS context, p.hash_id AS hash_id, p.title AS title
548
  """
549
  with self.driver.session() as session:
550
+ results = session.execute_write(
551
+ lambda tx: tx.run(query, hash_id=hash_id).data()
552
+ )
553
+ contexts = [result["context"] for result in results]
554
+ paper_ids = [result["hash_id"] for result in results]
555
+ context_embeddings = embedding_model.encode(
556
+ contexts, convert_to_tensor=True, device=self.device
557
+ )
558
  query = """
559
+ MERGE (p:Paper {hash_id: $hash_id})
560
+ ON CREATE SET p.summary_embedding = $embedding
561
+ ON MATCH SET p.summary_embedding = $embedding
562
+ """
563
+ for idx, hash_id in tqdm(enumerate(paper_ids)):
564
+ embedding = (
565
+ context_embeddings[idx].detach().cpu().numpy().flatten().tolist()
566
+ )
567
+ with self.driver.session() as session:
568
+ results = session.execute_write(
569
+ lambda tx: tx.run(
570
+ query, hash_id=hash_id, embedding=embedding
571
+ ).data()
572
+ )
573
+ return
574
+ offset = 0
575
+ while True:
576
+ query = f"""
577
  MATCH (p:Paper)
578
  WHERE p.summary IS NOT NULL
579
+ RETURN p.summary AS context, p.hash_id AS hash_id, p.title AS title
580
+ SKIP $offset LIMIT $batch_size
581
  """
582
  with self.driver.session() as session:
583
+ results = session.execute_write(
584
+ lambda tx: tx.run(
585
+ query, offset=offset, batch_size=batch_size
586
+ ).data()
587
+ )
588
+ if not results:
589
+ break
590
+ contexts = [result["context"] for result in results]
591
+ paper_ids = [result["hash_id"] for result in results]
592
+ context_embeddings = embedding_model.encode(
593
+ contexts,
594
+ batch_size=batch_size,
595
+ convert_to_tensor=True,
596
+ device=self.device,
597
+ )
598
+ write_query = """
599
+ UNWIND $data AS row
600
+ MERGE (p:Paper {hash_id: row.hash_id})
601
+ ON CREATE SET p.summary_embedding = row.embedding
602
+ ON MATCH SET p.summary_embedding = row.embedding
603
+ """
604
+ data_to_write = []
605
+ for idx, hash_id in enumerate(paper_ids):
606
+ embedding = (
607
+ context_embeddings[idx].detach().cpu().numpy().flatten().tolist()
608
+ )
609
+ data_to_write.append({"hash_id": hash_id, "embedding": embedding})
610
  with self.driver.session() as session:
611
+ session.execute_write(
612
+ lambda tx: tx.run(write_query, data=data_to_write)
613
+ )
614
+ offset += batch_size
615
+ logger.info(f"== Processed batch starting at offset {offset} ==")
616
+
617
  def cosine_similarity_search(self, embedding, k=1, type_name="embedding"):
618
  query = f"""
619
  MATCH (paper:Paper)
 
624
  ORDER BY score DESC LIMIT {k}
625
  """
626
  with self.driver.session() as session:
627
+ results = session.execute_read(
628
+ lambda tx: tx.run(query, embedding=embedding).data()
629
+ )
630
+ related_paper = []
631
  for result in results:
632
  related_paper.append(result["paper"]["hash_id"])
633
  return related_paper
 
649
  """
650
  with self.driver.session() as session:
651
  session.execute_write(lambda tx: tx.run(query).data())
652
+
653
  def filter_paper_id_list(self, paper_id_list, year="2024"):
654
  if not paper_id_list:
655
  return []
 
661
  RETURN p.hash_id AS hash_id
662
  """
663
  with self.driver.session() as session:
664
+ result = session.execute_read(
665
+ lambda tx: tx.run(query, paper_id_list=paper_id_list, year=year).data()
666
+ )
667
+
668
+ existing_paper_ids = [record["hash_id"] for record in result]
669
  existing_paper_ids = list(set(existing_paper_ids))
670
  return existing_paper_ids
671
+
672
  def check_index_exists(self):
673
  query = "SHOW INDEXES"
674
  with self.driver.session() as session:
 
685
  """
686
  with self.driver.session() as session:
687
  session.execute_write(lambda tx: tx.run(query).data())
688
+
689
  def get_entity_related_paper_num(self, entity_name):
690
  query = """
691
  MATCH (e:Entity {name: $entity_name})-[:RELATED_TO]->(p:Paper)
 
693
  RETURN PaperCount
694
  """
695
  with self.driver.session() as session:
696
+ result = session.execute_read(
697
+ lambda tx: tx.run(query, entity_name=entity_name).data()
698
+ )
699
+ paper_num = result[0]["PaperCount"]
700
  return paper_num
701
 
702
+ def get_entities_related_paper_num(self, entity_names):
703
+ query = """
704
+ UNWIND $entity_names AS entity_name
705
+ MATCH (e:Entity {name: entity_name})-[:RELATED_TO]->(p:Paper)
706
+ WITH entity_name, COUNT(p) AS PaperCount
707
+ RETURN entity_name, PaperCount
708
+ """
709
+
710
+ with self.driver.session() as session:
711
+ result = session.execute_read(
712
+ lambda tx: tx.run(query, entity_names=entity_names).data()
713
+ )
714
+ # 将查询结果转化为字典形式:实体名称 -> 论文数量
715
+ entity_paper_count = {
716
+ record["entity_name"]: record["PaperCount"] for record in result
717
+ }
718
+ return entity_paper_count
719
+
720
  def get_entity_text(self):
721
  query = """
722
  MATCH (e:Entity)-[:RELATED_TO]->(p:Paper)
 
726
  """
727
  with self.driver.session() as session:
728
  result = session.execute_read(lambda tx: tx.run(query).data())
729
+ text_list = [record["entity_text"] for record in result]
730
  return text_list
731
+
732
  def get_entity_combinations(self, venue_name, year):
733
+ def process_paper_relationships(
734
+ session, entity_name_1, entity_name_2, abstract
735
+ ):
736
  if entity_name_2 < entity_name_1:
737
  entity_name_1, entity_name_2 = entity_name_2, entity_name_1
738
  query = """
 
742
  ON CREATE SET r.strength = 1
743
  ON MATCH SET r.strength = r.strength + 1
744
  """
745
+ sentences = re.split(r"(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s", abstract)
746
  for sentence in sentences:
747
  sentence = sentence.lower()
748
  if entity_name_1 in sentence and entity_name_2 in sentence:
749
  # 如果两个实体在同一句话中出现过,则创建或更新 CONNECT 关系
750
  session.execute_write(
751
+ lambda tx: tx.run(
752
+ query,
753
+ entity_name_1=entity_name_1,
754
+ entity_name_2=entity_name_2,
755
+ ).data()
756
  )
757
  # logger.debug(f"CONNECT relation created or updated between {entity_name_1} and {entity_name_2} for Paper ID {paper_id}")
758
  break # 如果找到一次出现就可以退出循环
 
766
  RETURN p.hash_id AS hash_id, entities[i].name AS entity_name_1, entities[j].name AS entity_name_2
767
  """
768
  with self.driver.session() as session:
769
+ result = session.execute_read(
770
+ lambda tx: tx.run(query, venue_name=venue_name, year=year).data()
771
+ )
772
  for record in tqdm(result):
773
  paper_id = record["hash_id"]
774
+ entity_name_1 = record["entity_name_1"]
775
+ entity_name_2 = record["entity_name_2"]
776
  abstract = self.get_paper_attribute(paper_id, "abstract")
777
+ process_paper_relationships(
778
+ session, entity_name_1, entity_name_2, abstract
779
+ )
780
 
781
  def build_citemap(self):
782
  citemap = defaultdict(set)
 
787
  with self.driver.session() as session:
788
  results = session.execute_read(lambda tx: tx.run(query).data())
789
  for result in results:
790
+ hash_id = result["hash_id"]
791
+ cite_id_list = result["cite_id_list"]
792
  if cite_id_list:
793
  for cited_id in cite_id_list:
794
  citemap[hash_id].add(cited_id)
 
801
  AUTH = (NEO4J_USERNAME, NEO4J_PASSWD)
802
  graph = Graph(URI, auth=AUTH)
803
  # 创建一个字典来保存数据
804
+ # 定义批次大小
805
  data = {"nodes": [], "relationships": []}
806
+ # 计算数据的总数(例如查询节点总数)
807
+ total_papers_query = "MATCH (e:Entity)-[:RELATED_TO]->(p:Paper) RETURN COUNT(DISTINCT p) AS count"
808
+ total_papers = graph.run(total_papers_query).evaluate()
809
+ print(f"total paper: {total_papers}")
810
+ query = f"""
811
  MATCH (e:Entity)-[r:RELATED_TO]->(p:Paper)
 
812
  RETURN p, e, r
813
  """
814
+ """
815
  results = graph.run(query)
816
  # 处理查询结果
817
  for record in tqdm(results):
 
819
  entity_node = record["e"]
820
  relationship = record["r"]
821
  # 将节点数据加入字典
822
+ data["nodes"].append(
823
+ {
824
+ "id": paper_node.identity,
825
+ "label": "Paper",
826
+ "properties": dict(paper_node),
827
+ }
828
+ )
829
+ data["nodes"].append(
830
+ {
831
+ "id": entity_node.identity,
832
+ "label": "Entity",
833
+ "properties": dict(entity_node),
834
+ }
835
+ )
836
  # 将关系数据加入字典
837
+ data["relationships"].append(
838
+ {
839
+ "start_node": entity_node.identity,
840
+ "end_node": paper_node.identity,
841
+ "type": "RELATED_TO",
842
+ "properties": dict(relationship),
843
+ }
844
+ )
845
+ """
846
  query = """
847
  MATCH (p:Paper)
848
  WHERE p.venue_name='acl' and p.year='2024'
849
  RETURN p
850
  """
 
851
  results = graph.run(query)
852
  for record in tqdm(results):
853
  paper_node = record["p"]
854
  # 将节点数据加入字典
855
+ data["nodes"].append(
856
+ {
857
+ "id": paper_node.identity,
858
+ "label": "Paper",
859
+ "properties": dict(paper_node),
860
+ }
861
+ )
862
  # 去除重复节点
863
  # data["nodes"] = [dict(t) for t in {tuple(d.items()) for d in data["nodes"]}]
864
  unique_nodes = []
 
871
  unique_nodes.append(node)
872
  data["nodes"] = unique_nodes
873
  # 将数据保存为 JSON 文件
874
+ with open(
875
+ "./assets/data/scipip_neo4j_clean_backup.json", "w", encoding="utf-8"
876
+ ) as f:
877
  json.dump(data, f, ensure_ascii=False, indent=4)
878
+
879
  def neo4j_import_data(self):
880
  # clear_database() # 清空数据库,谨慎执行
881
  URI = os.environ["NEO4J_URL"]
 
884
  AUTH = (NEO4J_USERNAME, NEO4J_PASSWD)
885
  graph = Graph(URI, auth=AUTH)
886
  # 从 JSON 文件中读取数据
887
+ with open(
888
+ "./assets/data/scipip_neo4j_clean_backup.json", "r", encoding="utf-8"
889
+ ) as f:
890
  data = json.load(f)
891
  # 创建节点
892
  nodes = {}
src/utils/paper_retriever.py CHANGED
@@ -59,6 +59,7 @@ class CoCite:
59
 
60
  def __init__(self) -> None:
61
  if not self._initialized:
 
62
  self.paper_client = PaperClient()
63
  citemap = self.paper_client.build_citemap()
64
  self.comap = defaultdict(lambda: defaultdict(int))
@@ -101,20 +102,16 @@ class Retriever(object):
101
 
102
  def retrieve_entities_by_enties(self, entities):
103
  # TODO: KG
104
- expand_entities = []
105
- for entity in entities:
106
- expand_entities += self.paper_client.find_related_entities_by_entity(
107
- entity,
108
- n=self.config.RETRIEVE.kg_jump_num,
109
- k=self.config.RETRIEVE.kg_cover_num,
110
- relation_name=self.config.RETRIEVE.relation_name,
111
- )
112
  expand_entities = list(set(entities + expand_entities))
113
- entity_paper_num_dict = {}
114
- for entity in expand_entities:
115
- entity_paper_num_dict[entity] = (
116
- self.paper_client.get_entity_related_paper_num(entity)
117
- )
118
  new_entities = []
119
  entity_paper_num_dict = {
120
  k: v for k, v in entity_paper_num_dict.items() if v != 0
@@ -142,11 +139,7 @@ class Retriever(object):
142
  Return:
143
  related_paper: list(dict)
144
  """
145
- related_paper = []
146
- for paper_id in paper_id_list:
147
- paper = {"hash_id": paper_id}
148
- self.paper_client.update_paper_from_client(paper)
149
- related_paper.append(paper)
150
  return related_paper
151
 
152
  def calculate_similarity(self, entities, related_entities_list, use_weight=False):
@@ -333,7 +326,6 @@ class Retriever(object):
333
  similarity_threshold = self.config.RETRIEVE.similarity_threshold
334
  similarity_matrix = np.dot(target_paper_embedding, target_paper_embedding.T)
335
  target_labels = self.cluster_algorithm(target_paper_id_list, similarity_matrix)
336
- # target_labels = list(range(0, len(target_paper_id_list)))
337
  target_paper_label_dict = dict(zip(target_paper_id_list, target_labels))
338
  logger.debug("Target paper cluster result: {}".format(target_paper_label_dict))
339
  logger.debug(
@@ -672,8 +664,7 @@ class SNKGRetriever(Retriever):
672
  )
673
  related_paper = set()
674
  related_paper.update(sn_paper_id_list)
675
- for paper_id in sn_paper_id_list:
676
- sn_entities += self.paper_client.find_entities_by_paper(paper_id)
677
  logger.debug("SN entities for retriever: {}".format(sn_entities))
678
  entities = list(set(entities + sn_entities))
679
  new_entities = self.retrieve_entities_by_enties(entities)
 
59
 
60
  def __init__(self) -> None:
61
  if not self._initialized:
62
+ logger.debug("init co-cite map begin...")
63
  self.paper_client = PaperClient()
64
  citemap = self.paper_client.build_citemap()
65
  self.comap = defaultdict(lambda: defaultdict(int))
 
102
 
103
  def retrieve_entities_by_enties(self, entities):
104
  # TODO: KG
105
+ expand_entities = self.paper_client.find_related_entities_by_entity_list(
106
+ entities,
107
+ n=self.config.RETRIEVE.kg_jump_num,
108
+ k=self.config.RETRIEVE.kg_cover_num,
109
+ relation_name=self.config.RETRIEVE.relation_name,
110
+ )
 
 
111
  expand_entities = list(set(entities + expand_entities))
112
+ entity_paper_num_dict = self.paper_client.get_entities_related_paper_num(
113
+ expand_entities
114
+ )
 
 
115
  new_entities = []
116
  entity_paper_num_dict = {
117
  k: v for k, v in entity_paper_num_dict.items() if v != 0
 
139
  Return:
140
  related_paper: list(dict)
141
  """
142
+ related_paper = self.paper_client.update_papers_from_client(paper_id_list)
 
 
 
 
143
  return related_paper
144
 
145
  def calculate_similarity(self, entities, related_entities_list, use_weight=False):
 
326
  similarity_threshold = self.config.RETRIEVE.similarity_threshold
327
  similarity_matrix = np.dot(target_paper_embedding, target_paper_embedding.T)
328
  target_labels = self.cluster_algorithm(target_paper_id_list, similarity_matrix)
 
329
  target_paper_label_dict = dict(zip(target_paper_id_list, target_labels))
330
  logger.debug("Target paper cluster result: {}".format(target_paper_label_dict))
331
  logger.debug(
 
664
  )
665
  related_paper = set()
666
  related_paper.update(sn_paper_id_list)
667
+ sn_entities += self.paper_client.find_entities_by_paper_list(sn_paper_id_list)
 
668
  logger.debug("SN entities for retriever: {}".format(sn_entities))
669
  entities = list(set(entities + sn_entities))
670
  new_entities = self.retrieve_entities_by_enties(entities)