lihuigu commited on
Commit
a6a5155
·
1 Parent(s): b6336ac

[feat]add example & singleton

Browse files
configs/datasets.yaml CHANGED
@@ -4,7 +4,6 @@ DEFAULT:
4
  log_level: "DEBUG"
5
  log_dir: ./log
6
  embedding: ./assets/model/sentence-transformers/all-MiniLM-L6-v2
7
- device: "cpu" # "cpu"
8
 
9
  ARTICLE:
10
  summarizing_prompt: ./assets/prompt/summarizing.xml
 
4
  log_level: "DEBUG"
5
  log_dir: ./log
6
  embedding: ./assets/model/sentence-transformers/all-MiniLM-L6-v2
 
7
 
8
  ARTICLE:
9
  summarizing_prompt: ./assets/prompt/summarizing.xml
src/pages/button_interface.py CHANGED
@@ -7,6 +7,7 @@ from generator import IdeaGenerator
7
  class Backend(object):
8
  def __init__(self) -> None:
9
  CONFIG_PATH = "./configs/datasets.yaml"
 
10
  RETRIEVER_NAME = "SNKG"
11
  USE_INSPIRATION = True
12
  BRAINSTORM_MODE = "mode_c"
@@ -22,6 +23,16 @@ class Backend(object):
22
  self.idea_generator = IdeaGenerator(self.config, None)
23
  self.use_inspiration = USE_INSPIRATION
24
  self.brainstorm_mode = BRAINSTORM_MODE
 
 
 
 
 
 
 
 
 
 
25
 
26
  def background2brainstorm_callback(self, background, json_strs=None):
27
  if json_strs is not None: # only for DEBUG_MODE
@@ -99,12 +110,16 @@ class Backend(object):
99
  return final_ideas
100
 
101
  def get_demo_i(self, i):
102
- return ("The application scope of large-scale language models such as GPT-4 and LLaMA "
103
- "has rapidly expanded, demonstrating powerful capabilities in natural language processing "
104
- "and multimodal tasks. However, as the size and complexity of the models increase, understanding "
105
- "how they make decisions becomes increasingly difficult. Challenge: 1 The complexity of model "
106
- "interpretation: The billions of parameters and nonlinear decision paths within large-scale language "
107
- "models make it very difficult to track and interpret specific outputs. The existing interpretation "
108
- "methods usually only provide a local perspective and are difficult to systematize. 2. Transparency "
109
- "and Fairness: In specific scenarios, models may exhibit biased or discriminatory behavior. Ensuring "
110
- "the transparency of these models, reducing bias, and providing credible explanations is one of the current challenges.")
 
 
 
 
 
7
  class Backend(object):
8
  def __init__(self) -> None:
9
  CONFIG_PATH = "./configs/datasets.yaml"
10
+ EXAMPLE_PATH = "./assets/data/example.json"
11
  RETRIEVER_NAME = "SNKG"
12
  USE_INSPIRATION = True
13
  BRAINSTORM_MODE = "mode_c"
 
23
  self.idea_generator = IdeaGenerator(self.config, None)
24
  self.use_inspiration = USE_INSPIRATION
25
  self.brainstorm_mode = BRAINSTORM_MODE
26
+ self.examples = self.load_examples(EXAMPLE_PATH)
27
+
28
+ def load_examples(self, path):
29
+ try:
30
+ with open(path, "r") as f:
31
+ data = json.load(f)
32
+ return data
33
+ except (FileNotFoundError, json.JSONDecodeError) as e:
34
+ print(f"Error loading examples from {path}: {e}")
35
+ return []
36
 
37
  def background2brainstorm_callback(self, background, json_strs=None):
38
  if json_strs is not None: # only for DEBUG_MODE
 
110
  return final_ideas
111
 
112
  def get_demo_i(self, i):
113
+ if 0 <= i < len(self.examples):
114
+ return self.examples[i].get("background", "Background not found.")
115
+ else:
116
+ return "Example not found. Please select a valid index."
117
+ # return ("The application scope of large-scale language models such as GPT-4 and LLaMA "
118
+ # "has rapidly expanded, demonstrating powerful capabilities in natural language processing "
119
+ # "and multimodal tasks. However, as the size and complexity of the models increase, understanding "
120
+ # "how they make decisions becomes increasingly difficult. Challenge: 1 The complexity of model "
121
+ # "interpretation: The billions of parameters and nonlinear decision paths within large-scale language "
122
+ # "models make it very difficult to track and interpret specific outputs. The existing interpretation "
123
+ # "methods usually only provide a local perspective and are difficult to systematize. 2. Transparency "
124
+ # "and Fairness: In specific scenarios, models may exhibit biased or discriminatory behavior. Ensuring "
125
+ # "the transparency of these models, reducing bias, and providing credible explanations is one of the current challenges.")
src/pages/one_click_generation.py CHANGED
@@ -74,9 +74,11 @@ def genrate_mainpage(backend):
74
  st.session_state["use_demo_input"] = True
75
  st.session_state["demo_input"] = demo_input
76
 
77
- cols = st.columns([2, 2])
78
- cols[0].button("Example 1", on_click=get_demo_n, args=(1,), use_container_width=True, disabled=not st.session_state.get("enable_submmit", True))
79
- cols[1].button("Example 2", on_click=get_demo_n, args=(2,), use_container_width=True, disabled=not st.session_state.get("enable_submmit", True))
 
 
80
 
81
  def check_intermediate_outputs(id="brainstorms"):
82
  msg = st.session_state["intermediate_output"].get(id, None)
 
74
  st.session_state["use_demo_input"] = True
75
  st.session_state["demo_input"] = demo_input
76
 
77
+ cols = st.columns([1, 1, 1, 1])
78
+ cols[0].button("Example 1", on_click=get_demo_n, args=(0,), use_container_width=True, disabled=not st.session_state.get("enable_submmit", True))
79
+ cols[1].button("Example 2", on_click=get_demo_n, args=(1,), use_container_width=True, disabled=not st.session_state.get("enable_submmit", True))
80
+ cols[2].button("Example 3", on_click=get_demo_n, args=(2,), use_container_width=True, disabled=not st.session_state.get("enable_submmit", True))
81
+ cols[3].button("Example 4", on_click=get_demo_n, args=(3,), use_container_width=True, disabled=not st.session_state.get("enable_submmit", True))
82
 
83
  def check_intermediate_outputs(id="brainstorms"):
84
  msg = st.session_state["intermediate_output"].get(id, None)
src/pages/step_by_step_generation.py CHANGED
@@ -56,7 +56,7 @@ def genrate_mainpage(backend):
56
  background = st.session_state.get("background", "")
57
  background = st.text_area("Input your field background", background, placeholder="Input your field background", height=200, label_visibility="collapsed")
58
 
59
- cols = st.columns(2)
60
  def click_demo_i(i):
61
  st.session_state["background"] = backend.get_demo_i(i)
62
  for i, col in enumerate(cols):
 
56
  background = st.session_state.get("background", "")
57
  background = st.text_area("Input your field background", background, placeholder="Input your field background", height=200, label_visibility="collapsed")
58
 
59
+ cols = st.columns(4)
60
  def click_demo_i(i):
61
  st.session_state["background"] = backend.get_demo_i(i)
62
  for i, col in enumerate(cols):
src/paper_manager.py CHANGED
@@ -1,11 +1,11 @@
1
  import os
2
  import json
3
  import re
4
- from sentence_transformers import SentenceTransformer
5
  from tqdm import tqdm
 
6
  from utils.paper_crawling import PaperCrawling
7
  from utils.paper_client import PaperClient
8
- from utils.hash import generate_hash_id
9
  from collections import defaultdict
10
  from utils.header import get_dir, ConfigReader
11
  from utils.llms_api import APIHelper
@@ -165,9 +165,8 @@ class PaperManager:
165
  self.data_type = "train"
166
  self.paper_client = PaperClient(config)
167
  self.paper_crawling = PaperCrawling(config, data_type=self.data_type)
168
- self.embedding_model = SentenceTransformer(
169
- model_name_or_path=get_dir(config.DEFAULT.embedding), device=self.config.DEFAULT.device
170
- )
171
  self.api_helper = APIHelper(config)
172
  self.retriever = Retriever(config)
173
  self.paper_id_map = defaultdict()
 
1
  import os
2
  import json
3
  import re
 
4
  from tqdm import tqdm
5
+ import torch
6
  from utils.paper_crawling import PaperCrawling
7
  from utils.paper_client import PaperClient
8
+ from utils.hash import generate_hash_id, get_embedding_model
9
  from collections import defaultdict
10
  from utils.header import get_dir, ConfigReader
11
  from utils.llms_api import APIHelper
 
165
  self.data_type = "train"
166
  self.paper_client = PaperClient(config)
167
  self.paper_crawling = PaperCrawling(config, data_type=self.data_type)
168
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
169
+ self.embedding_model = get_embedding_model(config)
 
170
  self.api_helper = APIHelper(config)
171
  self.retriever = Retriever(config)
172
  self.paper_id_map = defaultdict()
src/utils/hash.py CHANGED
@@ -1,18 +1,23 @@
1
  import re
2
  import os
3
  import hashlib
 
4
  import struct
5
  from collections import Counter
6
  from huggingface_hub import hf_hub_download
 
 
7
 
8
  ENV_CHECKED = False
9
  EMBEDDING_CHECKED = False
10
 
 
11
  def check_embedding():
12
  global EMBEDDING_CHECKED
13
  if not EMBEDDING_CHECKED:
14
  # Define the repository and files to download
15
  repo_id = "sentence-transformers/all-MiniLM-L6-v2" # "BAAI/bge-small-en-v1.5"
 
16
  files_to_download = [
17
  "config.json",
18
  "pytorch_model.bin",
@@ -21,10 +26,18 @@ def check_embedding():
21
  ]
22
  # Download each file and save it to the /model/bge directory
23
  for file_name in files_to_download:
24
- print("Checking for file: ", file_name)
25
- hf_hub_download(repo_id=repo_id, filename=file_name, local_dir=f"./assets/model/{repo_id}")
 
 
 
 
 
 
 
26
  EMBEDDING_CHECKED = True
27
 
 
28
  def check_env():
29
  global ENV_CHECKED
30
  if not ENV_CHECKED:
@@ -43,6 +56,22 @@ def check_env():
43
  ENV_CHECKED = True
44
 
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  def generate_hash_id(input_string):
47
  if input_string is None:
48
  return None
 
1
  import re
2
  import os
3
  import hashlib
4
+ import torch
5
  import struct
6
  from collections import Counter
7
  from huggingface_hub import hf_hub_download
8
+ from sentence_transformers import SentenceTransformer
9
+ from .header import get_dir
10
 
11
  ENV_CHECKED = False
12
  EMBEDDING_CHECKED = False
13
 
14
+
15
  def check_embedding():
16
  global EMBEDDING_CHECKED
17
  if not EMBEDDING_CHECKED:
18
  # Define the repository and files to download
19
  repo_id = "sentence-transformers/all-MiniLM-L6-v2" # "BAAI/bge-small-en-v1.5"
20
+ local_dir = f"./assets/model/{repo_id}"
21
  files_to_download = [
22
  "config.json",
23
  "pytorch_model.bin",
 
26
  ]
27
  # Download each file and save it to the /model/bge directory
28
  for file_name in files_to_download:
29
+ if not os.path.exists(os.path.join(local_dir, file_name)):
30
+ print(
31
+ f"file: {file_name} not exist in {local_dir}, try to download from huggingface ..."
32
+ )
33
+ hf_hub_download(
34
+ repo_id=repo_id,
35
+ filename=file_name,
36
+ local_dir=local_dir,
37
+ )
38
  EMBEDDING_CHECKED = True
39
 
40
+
41
  def check_env():
42
  global ENV_CHECKED
43
  if not ENV_CHECKED:
 
56
  ENV_CHECKED = True
57
 
58
 
59
+ class EmbeddingModel:
60
+ _instance = None
61
+
62
+ def __new__(cls, config):
63
+ if cls._instance is None:
64
+ cls._instance = super(EmbeddingModel, cls).__new__(cls)
65
+ cls._instance.embedding_model = SentenceTransformer(
66
+ model_name_or_path=get_dir(config.DEFAULT.embedding),
67
+ device="cuda" if torch.cuda.is_available() else "cpu",
68
+ )
69
+ return cls._instance
70
+
71
+ def get_embedding_model(config):
72
+ return EmbeddingModel(config).embedding_model
73
+
74
+
75
  def generate_hash_id(input_string):
76
  if input_string is None:
77
  return None
src/utils/llms_api.py CHANGED
@@ -864,7 +864,7 @@ class APIHelper(object):
864
  def transfer_form(self, idea: str):
865
  prompt_template_transfer = """
866
  ### Task Description:
867
- I will give you some ideas, please standardize the output format of the ideas without simplifying or modifying their specific content. Note that the content of each idea includes everything about the idea。
868
 
869
  ### Specific Information:
870
  I will provide you with specific information now, please use them according to the instructions above:
 
864
  def transfer_form(self, idea: str):
865
  prompt_template_transfer = """
866
  ### Task Description:
867
+ I will give you some ideas, please standardize the output format of the ideas without changing any characters in their content. Note that the content of each idea includes everything about the idea。
868
 
869
  ### Specific Information:
870
  I will provide you with specific information now, please use them according to the instructions above:
src/utils/paper_client.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import re
3
  import json
 
4
  from tqdm import tqdm
5
  from neo4j import GraphDatabase
6
  from collections import defaultdict, deque
@@ -8,18 +9,26 @@ from py2neo import Graph, Node, Relationship
8
  from loguru import logger
9
 
10
  class PaperClient:
11
- def __init__(self, config) -> None:
12
- self.config = config
13
- self.driver = self.get_neo4j_driver()
14
- self.teb_model = None
 
 
 
 
 
 
 
 
 
 
15
 
16
  def get_neo4j_driver(self):
17
- # 配置信息
18
  URI = os.environ["NEO4J_URL"]
19
  NEO4J_USERNAME = os.environ["NEO4J_USERNAME"]
20
  NEO4J_PASSWD = os.environ["NEO4J_PASSWD"]
21
  AUTH = (NEO4J_USERNAME, NEO4J_PASSWD)
22
- # 连接到 Neo4j 数据库
23
  driver = GraphDatabase.driver(URI, auth=AUTH)
24
  return driver
25
 
@@ -274,7 +283,7 @@ class PaperClient:
274
  results = session.execute_write(lambda tx: tx.run(query).data())
275
  contexts = [result["title"] + result["context"] for result in results]
276
  paper_ids = [result["hash_id"] for result in results]
277
- context_embeddings = embedding_model.encode(contexts, batch_size=512, convert_to_tensor=True, device=self.config.DEFAULT.device)
278
  query = """
279
  MERGE (p:Paper {hash_id: $hash_id})
280
  ON CREATE SET p.abstract_embedding = $embedding
@@ -304,7 +313,7 @@ class PaperClient:
304
  results = session.execute_write(lambda tx: tx.run(query).data())
305
  contexts = [result["context"] for result in results]
306
  paper_ids = [result["hash_id"] for result in results]
307
- context_embeddings = embedding_model.encode(contexts, batch_size=256, convert_to_tensor=True, device=self.config.DEFAULT.device)
308
  query = """
309
  MERGE (p:Paper {hash_id: $hash_id})
310
  ON CREATE SET p.embedding = $embedding
@@ -334,7 +343,7 @@ class PaperClient:
334
  results = session.execute_write(lambda tx: tx.run(query).data())
335
  contexts = [result["context"] for result in results]
336
  paper_ids = [result["hash_id"] for result in results]
337
- context_embeddings = embedding_model.encode(contexts, batch_size=256, convert_to_tensor=True, device=self.config.DEFAULT.device)
338
  query = """
339
  MERGE (p:Paper {hash_id: $hash_id})
340
  ON CREATE SET p.contribution_embedding = $embedding
@@ -365,7 +374,7 @@ class PaperClient:
365
  results = session.execute_write(lambda tx: tx.run(query).data())
366
  contexts = [result["context"] for result in results]
367
  paper_ids = [result["hash_id"] for result in results]
368
- context_embeddings = embedding_model.encode(contexts, batch_size=256, convert_to_tensor=True, device=self.config.DEFAULT.device)
369
  query = """
370
  MERGE (p:Paper {hash_id: $hash_id})
371
  ON CREATE SET p.summary_embedding = $embedding
@@ -528,13 +537,13 @@ class PaperClient:
528
  NEO4J_PASSWD = os.environ["NEO4J_PASSWD"]
529
  AUTH = (NEO4J_USERNAME, NEO4J_PASSWD)
530
  graph = Graph(URI, auth=AUTH)
 
 
531
  query = """
532
  MATCH (e:Entity)-[r:RELATED_TO]->(p:Paper)
533
  RETURN p, e, r
534
  """
535
  results = graph.run(query)
536
- # 创建一个字典来保存数据
537
- data = {"nodes": [], "relationships": []}
538
  # 处理查询结果
539
  for record in tqdm(results):
540
  paper_node = record["p"]
@@ -622,9 +631,6 @@ class PaperClient:
622
 
623
 
624
  if __name__ == "__main__":
625
- from header import get_dir, ConfigReader
626
- config_path = get_dir("./configs/datasets.yaml")
627
- config = ConfigReader.load(config_path)
628
- paper_client = PaperClient(config)
629
  # paper_client.neo4j_backup()
630
  paper_client.neo4j_import_data()
 
1
  import os
2
  import re
3
  import json
4
+ import torch
5
  from tqdm import tqdm
6
  from neo4j import GraphDatabase
7
  from collections import defaultdict, deque
 
9
  from loguru import logger
10
 
11
  class PaperClient:
12
+ _instance = None
13
+ _initialized = False
14
+
15
+ def __new__(cls, *args, **kwargs):
16
+ if cls._instance is None:
17
+ cls._instance = super(PaperClient, cls).__new__(cls)
18
+ return cls._instance
19
+
20
+ def __init__(self) -> None:
21
+ if not self._initialized:
22
+ self.driver = self.get_neo4j_driver()
23
+ self.teb_model = None
24
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+ PaperClient._initialized = True
26
 
27
  def get_neo4j_driver(self):
 
28
  URI = os.environ["NEO4J_URL"]
29
  NEO4J_USERNAME = os.environ["NEO4J_USERNAME"]
30
  NEO4J_PASSWD = os.environ["NEO4J_PASSWD"]
31
  AUTH = (NEO4J_USERNAME, NEO4J_PASSWD)
 
32
  driver = GraphDatabase.driver(URI, auth=AUTH)
33
  return driver
34
 
 
283
  results = session.execute_write(lambda tx: tx.run(query).data())
284
  contexts = [result["title"] + result["context"] for result in results]
285
  paper_ids = [result["hash_id"] for result in results]
286
+ context_embeddings = embedding_model.encode(contexts, batch_size=512, convert_to_tensor=True, device=self.device)
287
  query = """
288
  MERGE (p:Paper {hash_id: $hash_id})
289
  ON CREATE SET p.abstract_embedding = $embedding
 
313
  results = session.execute_write(lambda tx: tx.run(query).data())
314
  contexts = [result["context"] for result in results]
315
  paper_ids = [result["hash_id"] for result in results]
316
+ context_embeddings = embedding_model.encode(contexts, batch_size=256, convert_to_tensor=True, device=self.device)
317
  query = """
318
  MERGE (p:Paper {hash_id: $hash_id})
319
  ON CREATE SET p.embedding = $embedding
 
343
  results = session.execute_write(lambda tx: tx.run(query).data())
344
  contexts = [result["context"] for result in results]
345
  paper_ids = [result["hash_id"] for result in results]
346
+ context_embeddings = embedding_model.encode(contexts, batch_size=256, convert_to_tensor=True, device=self.device)
347
  query = """
348
  MERGE (p:Paper {hash_id: $hash_id})
349
  ON CREATE SET p.contribution_embedding = $embedding
 
374
  results = session.execute_write(lambda tx: tx.run(query).data())
375
  contexts = [result["context"] for result in results]
376
  paper_ids = [result["hash_id"] for result in results]
377
+ context_embeddings = embedding_model.encode(contexts, batch_size=256, convert_to_tensor=True, device=self.device)
378
  query = """
379
  MERGE (p:Paper {hash_id: $hash_id})
380
  ON CREATE SET p.summary_embedding = $embedding
 
537
  NEO4J_PASSWD = os.environ["NEO4J_PASSWD"]
538
  AUTH = (NEO4J_USERNAME, NEO4J_PASSWD)
539
  graph = Graph(URI, auth=AUTH)
540
+ # 创建一个字典来保存数据
541
+ data = {"nodes": [], "relationships": []}
542
  query = """
543
  MATCH (e:Entity)-[r:RELATED_TO]->(p:Paper)
544
  RETURN p, e, r
545
  """
546
  results = graph.run(query)
 
 
547
  # 处理查询结果
548
  for record in tqdm(results):
549
  paper_node = record["p"]
 
631
 
632
 
633
  if __name__ == "__main__":
634
+ paper_client = PaperClient()
 
 
 
635
  # paper_client.neo4j_backup()
636
  paper_client.neo4j_import_data()
src/utils/paper_retriever.py CHANGED
@@ -2,8 +2,6 @@ import torch
2
  import itertools
3
  import threading
4
  import numpy as np
5
- from sentence_transformers import SentenceTransformer
6
- from sklearn.feature_extraction.text import CountVectorizer
7
  from sklearn.metrics.pairwise import cosine_similarity
8
  from collections import Counter, defaultdict
9
  from loguru import logger
@@ -11,7 +9,7 @@ from abc import ABCMeta, abstractmethod
11
  from .paper_client import PaperClient
12
  from .paper_crawling import PaperCrawling
13
  from .llms_api import APIHelper
14
- from .header import get_dir
15
 
16
 
17
  class UnionFind:
@@ -51,18 +49,26 @@ def can_merge(uf, similarity_matrix, i, j, threshold):
51
 
52
 
53
  class CoCite:
54
- def __init__(self, config) -> None:
55
- self.paper_client = PaperClient(config)
56
- citemap = self.paper_client.build_citemap()
57
- self.comap = defaultdict(
58
- lambda: defaultdict(int)
59
- )
60
- for paper_id, cited_id in citemap.items():
61
- for id0, id1 in itertools.combinations(cited_id, 2):
62
- # ensure comap[id0][id1] == comap[id1][id0]
63
- self.comap[id0][id1] += 1
64
- self.comap[id1][id0] += 1
65
- logger.debug("init co-cite map success")
 
 
 
 
 
 
 
 
66
 
67
  def get_cocite_ids(self, id_, k=1):
68
  sorted_items = sorted(self.comap[id_].items(), key=lambda x: x[1], reverse=True)
@@ -82,14 +88,12 @@ class Retriever(object):
82
  self.config = config
83
  self.use_cocite = use_cocite
84
  self.use_cluster_to_filter = use_cluster_to_filter
85
- self.paper_client = PaperClient(config)
86
- self.cocite = CoCite(config)
87
  self.api_helper = APIHelper(config=config)
88
- self.embedding_model = SentenceTransformer(
89
- model_name_or_path=get_dir(config.DEFAULT.embedding), device=self.config.DEFAULT.device
90
- )
91
  self.paper_crawling = PaperCrawling(config=config)
92
- self.vectorizer = CountVectorizer()
93
 
94
  @abstractmethod
95
  def retrieve(self, bg, entities, use_evaluate):
@@ -192,7 +196,7 @@ class Retriever(object):
192
  entities = self.api_helper.generate_entity_list(context)
193
  logger.debug("get entity from context: {}".format(entities))
194
  origin_vector = self.embedding_model.encode(
195
- context, convert_to_tensor=True, device=self.config.DEFAULT.device
196
  ).unsqueeze(0)
197
  related_contexts = [
198
  self.paper_client.get_paper_attribute(paper_id, type_name)
@@ -200,7 +204,10 @@ class Retriever(object):
200
  ]
201
  if len(related_contexts) > 0:
202
  context_embeddings = self.embedding_model.encode(
203
- related_contexts, batch_size=512, convert_to_tensor=True, device=self.config.DEFAULT.device
 
 
 
204
  )
205
  score_1 = torch.nn.functional.cosine_similarity(
206
  origin_vector, context_embeddings
@@ -208,7 +215,7 @@ class Retriever(object):
208
  score_1 = score_1.cpu().numpy()
209
  if self.config.RETRIEVE.need_normalize:
210
  score_1 = score_1 / np.max(score_1)
211
- # score_2 not enable
212
  # if self.config.RETRIEVE.beta != 0:
213
  score_sn_dict = dict(zip(related_paper_id_list, score_1))
214
  score_en_dict = dict(zip(related_paper_id_list, score_2))
@@ -231,28 +238,33 @@ class Retriever(object):
231
  else list(score_dict.keys())
232
  )
233
  return paper_id_list
234
- else:
235
  # clustering filter, ensure that each category the highest score save first
236
  paper_id_list = list(score_dict.keys())
237
  paper_embedding_list = [
238
- self.paper_client.get_paper_attribute(paper_id, "embedding") for paper_id in paper_id_list
 
239
  ]
240
  paper_embedding = np.array(paper_embedding_list)
241
  paper_embedding_list = [
242
- self.paper_client.get_paper_attribute(paper_id, "contribution_embedding") for paper_id in paper_id_list
 
 
 
243
  ]
244
  paper_contribution_embedding = np.array(paper_embedding_list)
245
  paper_embedding_list = [
246
- self.paper_client.get_paper_attribute(paper_id, "summary_embedding") for paper_id in paper_id_list
 
247
  ]
248
  paper_summary_embedding = np.array(paper_embedding_list)
249
  weight_embedding = self.config.RETRIEVE.s_bg
250
  weight_contribution = self.config.RETRIEVE.s_contribution
251
  weight_summary = self.config.RETRIEVE.s_summary
252
  paper_embedding = (
253
- weight_embedding * paper_embedding +
254
- weight_contribution * paper_contribution_embedding +
255
- weight_summary * paper_summary_embedding
256
  )
257
  similarity_matrix = np.dot(paper_embedding, paper_embedding.T)
258
  related_labels = self.cluster_algorithm(paper_id_list, similarity_matrix)
@@ -542,9 +554,7 @@ class SNRetriever(Retriever):
542
  related_paper_id_list = retrieve_result["paper"]
543
  retrieve_paper_num = len(related_paper_id_list)
544
  _, _, score_all_dict = self.cal_related_score(
545
- bg,
546
- related_paper_id_list=related_paper_id_list,
547
- entities=entities
548
  )
549
  top_k_matrix = {}
550
  recall = 0
@@ -746,4 +756,4 @@ class SNKGRetriever(Retriever):
746
  "retrieve_paper_num": retrieve_paper_num,
747
  "label_num": label_num,
748
  }
749
- return result
 
2
  import itertools
3
  import threading
4
  import numpy as np
 
 
5
  from sklearn.metrics.pairwise import cosine_similarity
6
  from collections import Counter, defaultdict
7
  from loguru import logger
 
9
  from .paper_client import PaperClient
10
  from .paper_crawling import PaperCrawling
11
  from .llms_api import APIHelper
12
+ from .hash import get_embedding_model
13
 
14
 
15
  class UnionFind:
 
49
 
50
 
51
  class CoCite:
52
+ _instance = None
53
+ _initialized = False
54
+
55
+ def __new__(cls, *args, **kwargs):
56
+ if cls._instance is None:
57
+ cls._instance = super(CoCite, cls).__new__(cls)
58
+ return cls._instance
59
+
60
+ def __init__(self) -> None:
61
+ if not self._initialized:
62
+ self.paper_client = PaperClient()
63
+ citemap = self.paper_client.build_citemap()
64
+ self.comap = defaultdict(lambda: defaultdict(int))
65
+ for paper_id, cited_id in citemap.items():
66
+ for id0, id1 in itertools.combinations(cited_id, 2):
67
+ # ensure comap[id0][id1] == comap[id1][id0]
68
+ self.comap[id0][id1] += 1
69
+ self.comap[id1][id0] += 1
70
+ logger.debug("init co-cite map success")
71
+ CoCite._initialized = True
72
 
73
  def get_cocite_ids(self, id_, k=1):
74
  sorted_items = sorted(self.comap[id_].items(), key=lambda x: x[1], reverse=True)
 
88
  self.config = config
89
  self.use_cocite = use_cocite
90
  self.use_cluster_to_filter = use_cluster_to_filter
91
+ self.paper_client = PaperClient()
92
+ self.cocite = CoCite()
93
  self.api_helper = APIHelper(config=config)
94
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
95
+ self.embedding_model = get_embedding_model(config)
 
96
  self.paper_crawling = PaperCrawling(config=config)
 
97
 
98
  @abstractmethod
99
  def retrieve(self, bg, entities, use_evaluate):
 
196
  entities = self.api_helper.generate_entity_list(context)
197
  logger.debug("get entity from context: {}".format(entities))
198
  origin_vector = self.embedding_model.encode(
199
+ context, convert_to_tensor=True, device=self.device
200
  ).unsqueeze(0)
201
  related_contexts = [
202
  self.paper_client.get_paper_attribute(paper_id, type_name)
 
204
  ]
205
  if len(related_contexts) > 0:
206
  context_embeddings = self.embedding_model.encode(
207
+ related_contexts,
208
+ batch_size=512,
209
+ convert_to_tensor=True,
210
+ device=self.device,
211
  )
212
  score_1 = torch.nn.functional.cosine_similarity(
213
  origin_vector, context_embeddings
 
215
  score_1 = score_1.cpu().numpy()
216
  if self.config.RETRIEVE.need_normalize:
217
  score_1 = score_1 / np.max(score_1)
218
+ # score_2 not enable
219
  # if self.config.RETRIEVE.beta != 0:
220
  score_sn_dict = dict(zip(related_paper_id_list, score_1))
221
  score_en_dict = dict(zip(related_paper_id_list, score_2))
 
238
  else list(score_dict.keys())
239
  )
240
  return paper_id_list
241
+ else:
242
  # clustering filter, ensure that each category the highest score save first
243
  paper_id_list = list(score_dict.keys())
244
  paper_embedding_list = [
245
+ self.paper_client.get_paper_attribute(paper_id, "embedding")
246
+ for paper_id in paper_id_list
247
  ]
248
  paper_embedding = np.array(paper_embedding_list)
249
  paper_embedding_list = [
250
+ self.paper_client.get_paper_attribute(
251
+ paper_id, "contribution_embedding"
252
+ )
253
+ for paper_id in paper_id_list
254
  ]
255
  paper_contribution_embedding = np.array(paper_embedding_list)
256
  paper_embedding_list = [
257
+ self.paper_client.get_paper_attribute(paper_id, "summary_embedding")
258
+ for paper_id in paper_id_list
259
  ]
260
  paper_summary_embedding = np.array(paper_embedding_list)
261
  weight_embedding = self.config.RETRIEVE.s_bg
262
  weight_contribution = self.config.RETRIEVE.s_contribution
263
  weight_summary = self.config.RETRIEVE.s_summary
264
  paper_embedding = (
265
+ weight_embedding * paper_embedding
266
+ + weight_contribution * paper_contribution_embedding
267
+ + weight_summary * paper_summary_embedding
268
  )
269
  similarity_matrix = np.dot(paper_embedding, paper_embedding.T)
270
  related_labels = self.cluster_algorithm(paper_id_list, similarity_matrix)
 
554
  related_paper_id_list = retrieve_result["paper"]
555
  retrieve_paper_num = len(related_paper_id_list)
556
  _, _, score_all_dict = self.cal_related_score(
557
+ bg, related_paper_id_list=related_paper_id_list, entities=entities
 
 
558
  )
559
  top_k_matrix = {}
560
  recall = 0
 
756
  "retrieve_paper_num": retrieve_paper_num,
757
  "label_num": label_num,
758
  }
759
+ return result