[feat]add example & singleton
Browse files- configs/datasets.yaml +0 -1
- src/pages/button_interface.py +24 -9
- src/pages/one_click_generation.py +5 -3
- src/pages/step_by_step_generation.py +1 -1
- src/paper_manager.py +4 -5
- src/utils/hash.py +31 -2
- src/utils/llms_api.py +1 -1
- src/utils/paper_client.py +22 -16
- src/utils/paper_retriever.py +45 -35
configs/datasets.yaml
CHANGED
@@ -4,7 +4,6 @@ DEFAULT:
|
|
4 |
log_level: "DEBUG"
|
5 |
log_dir: ./log
|
6 |
embedding: ./assets/model/sentence-transformers/all-MiniLM-L6-v2
|
7 |
-
device: "cpu" # "cpu"
|
8 |
|
9 |
ARTICLE:
|
10 |
summarizing_prompt: ./assets/prompt/summarizing.xml
|
|
|
4 |
log_level: "DEBUG"
|
5 |
log_dir: ./log
|
6 |
embedding: ./assets/model/sentence-transformers/all-MiniLM-L6-v2
|
|
|
7 |
|
8 |
ARTICLE:
|
9 |
summarizing_prompt: ./assets/prompt/summarizing.xml
|
src/pages/button_interface.py
CHANGED
@@ -7,6 +7,7 @@ from generator import IdeaGenerator
|
|
7 |
class Backend(object):
|
8 |
def __init__(self) -> None:
|
9 |
CONFIG_PATH = "./configs/datasets.yaml"
|
|
|
10 |
RETRIEVER_NAME = "SNKG"
|
11 |
USE_INSPIRATION = True
|
12 |
BRAINSTORM_MODE = "mode_c"
|
@@ -22,6 +23,16 @@ class Backend(object):
|
|
22 |
self.idea_generator = IdeaGenerator(self.config, None)
|
23 |
self.use_inspiration = USE_INSPIRATION
|
24 |
self.brainstorm_mode = BRAINSTORM_MODE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
def background2brainstorm_callback(self, background, json_strs=None):
|
27 |
if json_strs is not None: # only for DEBUG_MODE
|
@@ -99,12 +110,16 @@ class Backend(object):
|
|
99 |
return final_ideas
|
100 |
|
101 |
def get_demo_i(self, i):
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
"
|
107 |
-
"
|
108 |
-
"
|
109 |
-
"
|
110 |
-
"
|
|
|
|
|
|
|
|
|
|
7 |
class Backend(object):
|
8 |
def __init__(self) -> None:
|
9 |
CONFIG_PATH = "./configs/datasets.yaml"
|
10 |
+
EXAMPLE_PATH = "./assets/data/example.json"
|
11 |
RETRIEVER_NAME = "SNKG"
|
12 |
USE_INSPIRATION = True
|
13 |
BRAINSTORM_MODE = "mode_c"
|
|
|
23 |
self.idea_generator = IdeaGenerator(self.config, None)
|
24 |
self.use_inspiration = USE_INSPIRATION
|
25 |
self.brainstorm_mode = BRAINSTORM_MODE
|
26 |
+
self.examples = self.load_examples(EXAMPLE_PATH)
|
27 |
+
|
28 |
+
def load_examples(self, path):
|
29 |
+
try:
|
30 |
+
with open(path, "r") as f:
|
31 |
+
data = json.load(f)
|
32 |
+
return data
|
33 |
+
except (FileNotFoundError, json.JSONDecodeError) as e:
|
34 |
+
print(f"Error loading examples from {path}: {e}")
|
35 |
+
return []
|
36 |
|
37 |
def background2brainstorm_callback(self, background, json_strs=None):
|
38 |
if json_strs is not None: # only for DEBUG_MODE
|
|
|
110 |
return final_ideas
|
111 |
|
112 |
def get_demo_i(self, i):
|
113 |
+
if 0 <= i < len(self.examples):
|
114 |
+
return self.examples[i].get("background", "Background not found.")
|
115 |
+
else:
|
116 |
+
return "Example not found. Please select a valid index."
|
117 |
+
# return ("The application scope of large-scale language models such as GPT-4 and LLaMA "
|
118 |
+
# "has rapidly expanded, demonstrating powerful capabilities in natural language processing "
|
119 |
+
# "and multimodal tasks. However, as the size and complexity of the models increase, understanding "
|
120 |
+
# "how they make decisions becomes increasingly difficult. Challenge: 1 The complexity of model "
|
121 |
+
# "interpretation: The billions of parameters and nonlinear decision paths within large-scale language "
|
122 |
+
# "models make it very difficult to track and interpret specific outputs. The existing interpretation "
|
123 |
+
# "methods usually only provide a local perspective and are difficult to systematize. 2. Transparency "
|
124 |
+
# "and Fairness: In specific scenarios, models may exhibit biased or discriminatory behavior. Ensuring "
|
125 |
+
# "the transparency of these models, reducing bias, and providing credible explanations is one of the current challenges.")
|
src/pages/one_click_generation.py
CHANGED
@@ -74,9 +74,11 @@ def genrate_mainpage(backend):
|
|
74 |
st.session_state["use_demo_input"] = True
|
75 |
st.session_state["demo_input"] = demo_input
|
76 |
|
77 |
-
cols = st.columns([
|
78 |
-
cols[0].button("Example 1", on_click=get_demo_n, args=(
|
79 |
-
cols[1].button("Example 2", on_click=get_demo_n, args=(
|
|
|
|
|
80 |
|
81 |
def check_intermediate_outputs(id="brainstorms"):
|
82 |
msg = st.session_state["intermediate_output"].get(id, None)
|
|
|
74 |
st.session_state["use_demo_input"] = True
|
75 |
st.session_state["demo_input"] = demo_input
|
76 |
|
77 |
+
cols = st.columns([1, 1, 1, 1])
|
78 |
+
cols[0].button("Example 1", on_click=get_demo_n, args=(0,), use_container_width=True, disabled=not st.session_state.get("enable_submmit", True))
|
79 |
+
cols[1].button("Example 2", on_click=get_demo_n, args=(1,), use_container_width=True, disabled=not st.session_state.get("enable_submmit", True))
|
80 |
+
cols[2].button("Example 3", on_click=get_demo_n, args=(2,), use_container_width=True, disabled=not st.session_state.get("enable_submmit", True))
|
81 |
+
cols[3].button("Example 4", on_click=get_demo_n, args=(3,), use_container_width=True, disabled=not st.session_state.get("enable_submmit", True))
|
82 |
|
83 |
def check_intermediate_outputs(id="brainstorms"):
|
84 |
msg = st.session_state["intermediate_output"].get(id, None)
|
src/pages/step_by_step_generation.py
CHANGED
@@ -56,7 +56,7 @@ def genrate_mainpage(backend):
|
|
56 |
background = st.session_state.get("background", "")
|
57 |
background = st.text_area("Input your field background", background, placeholder="Input your field background", height=200, label_visibility="collapsed")
|
58 |
|
59 |
-
cols = st.columns(
|
60 |
def click_demo_i(i):
|
61 |
st.session_state["background"] = backend.get_demo_i(i)
|
62 |
for i, col in enumerate(cols):
|
|
|
56 |
background = st.session_state.get("background", "")
|
57 |
background = st.text_area("Input your field background", background, placeholder="Input your field background", height=200, label_visibility="collapsed")
|
58 |
|
59 |
+
cols = st.columns(4)
|
60 |
def click_demo_i(i):
|
61 |
st.session_state["background"] = backend.get_demo_i(i)
|
62 |
for i, col in enumerate(cols):
|
src/paper_manager.py
CHANGED
@@ -1,11 +1,11 @@
|
|
1 |
import os
|
2 |
import json
|
3 |
import re
|
4 |
-
from sentence_transformers import SentenceTransformer
|
5 |
from tqdm import tqdm
|
|
|
6 |
from utils.paper_crawling import PaperCrawling
|
7 |
from utils.paper_client import PaperClient
|
8 |
-
from utils.hash import generate_hash_id
|
9 |
from collections import defaultdict
|
10 |
from utils.header import get_dir, ConfigReader
|
11 |
from utils.llms_api import APIHelper
|
@@ -165,9 +165,8 @@ class PaperManager:
|
|
165 |
self.data_type = "train"
|
166 |
self.paper_client = PaperClient(config)
|
167 |
self.paper_crawling = PaperCrawling(config, data_type=self.data_type)
|
168 |
-
self.
|
169 |
-
|
170 |
-
)
|
171 |
self.api_helper = APIHelper(config)
|
172 |
self.retriever = Retriever(config)
|
173 |
self.paper_id_map = defaultdict()
|
|
|
1 |
import os
|
2 |
import json
|
3 |
import re
|
|
|
4 |
from tqdm import tqdm
|
5 |
+
import torch
|
6 |
from utils.paper_crawling import PaperCrawling
|
7 |
from utils.paper_client import PaperClient
|
8 |
+
from utils.hash import generate_hash_id, get_embedding_model
|
9 |
from collections import defaultdict
|
10 |
from utils.header import get_dir, ConfigReader
|
11 |
from utils.llms_api import APIHelper
|
|
|
165 |
self.data_type = "train"
|
166 |
self.paper_client = PaperClient(config)
|
167 |
self.paper_crawling = PaperCrawling(config, data_type=self.data_type)
|
168 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
169 |
+
self.embedding_model = get_embedding_model(config)
|
|
|
170 |
self.api_helper = APIHelper(config)
|
171 |
self.retriever = Retriever(config)
|
172 |
self.paper_id_map = defaultdict()
|
src/utils/hash.py
CHANGED
@@ -1,18 +1,23 @@
|
|
1 |
import re
|
2 |
import os
|
3 |
import hashlib
|
|
|
4 |
import struct
|
5 |
from collections import Counter
|
6 |
from huggingface_hub import hf_hub_download
|
|
|
|
|
7 |
|
8 |
ENV_CHECKED = False
|
9 |
EMBEDDING_CHECKED = False
|
10 |
|
|
|
11 |
def check_embedding():
|
12 |
global EMBEDDING_CHECKED
|
13 |
if not EMBEDDING_CHECKED:
|
14 |
# Define the repository and files to download
|
15 |
repo_id = "sentence-transformers/all-MiniLM-L6-v2" # "BAAI/bge-small-en-v1.5"
|
|
|
16 |
files_to_download = [
|
17 |
"config.json",
|
18 |
"pytorch_model.bin",
|
@@ -21,10 +26,18 @@ def check_embedding():
|
|
21 |
]
|
22 |
# Download each file and save it to the /model/bge directory
|
23 |
for file_name in files_to_download:
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
EMBEDDING_CHECKED = True
|
27 |
|
|
|
28 |
def check_env():
|
29 |
global ENV_CHECKED
|
30 |
if not ENV_CHECKED:
|
@@ -43,6 +56,22 @@ def check_env():
|
|
43 |
ENV_CHECKED = True
|
44 |
|
45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
def generate_hash_id(input_string):
|
47 |
if input_string is None:
|
48 |
return None
|
|
|
1 |
import re
|
2 |
import os
|
3 |
import hashlib
|
4 |
+
import torch
|
5 |
import struct
|
6 |
from collections import Counter
|
7 |
from huggingface_hub import hf_hub_download
|
8 |
+
from sentence_transformers import SentenceTransformer
|
9 |
+
from .header import get_dir
|
10 |
|
11 |
ENV_CHECKED = False
|
12 |
EMBEDDING_CHECKED = False
|
13 |
|
14 |
+
|
15 |
def check_embedding():
|
16 |
global EMBEDDING_CHECKED
|
17 |
if not EMBEDDING_CHECKED:
|
18 |
# Define the repository and files to download
|
19 |
repo_id = "sentence-transformers/all-MiniLM-L6-v2" # "BAAI/bge-small-en-v1.5"
|
20 |
+
local_dir = f"./assets/model/{repo_id}"
|
21 |
files_to_download = [
|
22 |
"config.json",
|
23 |
"pytorch_model.bin",
|
|
|
26 |
]
|
27 |
# Download each file and save it to the /model/bge directory
|
28 |
for file_name in files_to_download:
|
29 |
+
if not os.path.exists(os.path.join(local_dir, file_name)):
|
30 |
+
print(
|
31 |
+
f"file: {file_name} not exist in {local_dir}, try to download from huggingface ..."
|
32 |
+
)
|
33 |
+
hf_hub_download(
|
34 |
+
repo_id=repo_id,
|
35 |
+
filename=file_name,
|
36 |
+
local_dir=local_dir,
|
37 |
+
)
|
38 |
EMBEDDING_CHECKED = True
|
39 |
|
40 |
+
|
41 |
def check_env():
|
42 |
global ENV_CHECKED
|
43 |
if not ENV_CHECKED:
|
|
|
56 |
ENV_CHECKED = True
|
57 |
|
58 |
|
59 |
+
class EmbeddingModel:
|
60 |
+
_instance = None
|
61 |
+
|
62 |
+
def __new__(cls, config):
|
63 |
+
if cls._instance is None:
|
64 |
+
cls._instance = super(EmbeddingModel, cls).__new__(cls)
|
65 |
+
cls._instance.embedding_model = SentenceTransformer(
|
66 |
+
model_name_or_path=get_dir(config.DEFAULT.embedding),
|
67 |
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
68 |
+
)
|
69 |
+
return cls._instance
|
70 |
+
|
71 |
+
def get_embedding_model(config):
|
72 |
+
return EmbeddingModel(config).embedding_model
|
73 |
+
|
74 |
+
|
75 |
def generate_hash_id(input_string):
|
76 |
if input_string is None:
|
77 |
return None
|
src/utils/llms_api.py
CHANGED
@@ -864,7 +864,7 @@ class APIHelper(object):
|
|
864 |
def transfer_form(self, idea: str):
|
865 |
prompt_template_transfer = """
|
866 |
### Task Description:
|
867 |
-
I will give you some ideas, please standardize the output format of the ideas without
|
868 |
|
869 |
### Specific Information:
|
870 |
I will provide you with specific information now, please use them according to the instructions above:
|
|
|
864 |
def transfer_form(self, idea: str):
|
865 |
prompt_template_transfer = """
|
866 |
### Task Description:
|
867 |
+
I will give you some ideas, please standardize the output format of the ideas without changing any characters in their content. Note that the content of each idea includes everything about the idea。
|
868 |
|
869 |
### Specific Information:
|
870 |
I will provide you with specific information now, please use them according to the instructions above:
|
src/utils/paper_client.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import os
|
2 |
import re
|
3 |
import json
|
|
|
4 |
from tqdm import tqdm
|
5 |
from neo4j import GraphDatabase
|
6 |
from collections import defaultdict, deque
|
@@ -8,18 +9,26 @@ from py2neo import Graph, Node, Relationship
|
|
8 |
from loguru import logger
|
9 |
|
10 |
class PaperClient:
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
def get_neo4j_driver(self):
|
17 |
-
# 配置信息
|
18 |
URI = os.environ["NEO4J_URL"]
|
19 |
NEO4J_USERNAME = os.environ["NEO4J_USERNAME"]
|
20 |
NEO4J_PASSWD = os.environ["NEO4J_PASSWD"]
|
21 |
AUTH = (NEO4J_USERNAME, NEO4J_PASSWD)
|
22 |
-
# 连接到 Neo4j 数据库
|
23 |
driver = GraphDatabase.driver(URI, auth=AUTH)
|
24 |
return driver
|
25 |
|
@@ -274,7 +283,7 @@ class PaperClient:
|
|
274 |
results = session.execute_write(lambda tx: tx.run(query).data())
|
275 |
contexts = [result["title"] + result["context"] for result in results]
|
276 |
paper_ids = [result["hash_id"] for result in results]
|
277 |
-
context_embeddings = embedding_model.encode(contexts, batch_size=512, convert_to_tensor=True, device=self.
|
278 |
query = """
|
279 |
MERGE (p:Paper {hash_id: $hash_id})
|
280 |
ON CREATE SET p.abstract_embedding = $embedding
|
@@ -304,7 +313,7 @@ class PaperClient:
|
|
304 |
results = session.execute_write(lambda tx: tx.run(query).data())
|
305 |
contexts = [result["context"] for result in results]
|
306 |
paper_ids = [result["hash_id"] for result in results]
|
307 |
-
context_embeddings = embedding_model.encode(contexts, batch_size=256, convert_to_tensor=True, device=self.
|
308 |
query = """
|
309 |
MERGE (p:Paper {hash_id: $hash_id})
|
310 |
ON CREATE SET p.embedding = $embedding
|
@@ -334,7 +343,7 @@ class PaperClient:
|
|
334 |
results = session.execute_write(lambda tx: tx.run(query).data())
|
335 |
contexts = [result["context"] for result in results]
|
336 |
paper_ids = [result["hash_id"] for result in results]
|
337 |
-
context_embeddings = embedding_model.encode(contexts, batch_size=256, convert_to_tensor=True, device=self.
|
338 |
query = """
|
339 |
MERGE (p:Paper {hash_id: $hash_id})
|
340 |
ON CREATE SET p.contribution_embedding = $embedding
|
@@ -365,7 +374,7 @@ class PaperClient:
|
|
365 |
results = session.execute_write(lambda tx: tx.run(query).data())
|
366 |
contexts = [result["context"] for result in results]
|
367 |
paper_ids = [result["hash_id"] for result in results]
|
368 |
-
context_embeddings = embedding_model.encode(contexts, batch_size=256, convert_to_tensor=True, device=self.
|
369 |
query = """
|
370 |
MERGE (p:Paper {hash_id: $hash_id})
|
371 |
ON CREATE SET p.summary_embedding = $embedding
|
@@ -528,13 +537,13 @@ class PaperClient:
|
|
528 |
NEO4J_PASSWD = os.environ["NEO4J_PASSWD"]
|
529 |
AUTH = (NEO4J_USERNAME, NEO4J_PASSWD)
|
530 |
graph = Graph(URI, auth=AUTH)
|
|
|
|
|
531 |
query = """
|
532 |
MATCH (e:Entity)-[r:RELATED_TO]->(p:Paper)
|
533 |
RETURN p, e, r
|
534 |
"""
|
535 |
results = graph.run(query)
|
536 |
-
# 创建一个字典来保存数据
|
537 |
-
data = {"nodes": [], "relationships": []}
|
538 |
# 处理查询结果
|
539 |
for record in tqdm(results):
|
540 |
paper_node = record["p"]
|
@@ -622,9 +631,6 @@ class PaperClient:
|
|
622 |
|
623 |
|
624 |
if __name__ == "__main__":
|
625 |
-
|
626 |
-
config_path = get_dir("./configs/datasets.yaml")
|
627 |
-
config = ConfigReader.load(config_path)
|
628 |
-
paper_client = PaperClient(config)
|
629 |
# paper_client.neo4j_backup()
|
630 |
paper_client.neo4j_import_data()
|
|
|
1 |
import os
|
2 |
import re
|
3 |
import json
|
4 |
+
import torch
|
5 |
from tqdm import tqdm
|
6 |
from neo4j import GraphDatabase
|
7 |
from collections import defaultdict, deque
|
|
|
9 |
from loguru import logger
|
10 |
|
11 |
class PaperClient:
|
12 |
+
_instance = None
|
13 |
+
_initialized = False
|
14 |
+
|
15 |
+
def __new__(cls, *args, **kwargs):
|
16 |
+
if cls._instance is None:
|
17 |
+
cls._instance = super(PaperClient, cls).__new__(cls)
|
18 |
+
return cls._instance
|
19 |
+
|
20 |
+
def __init__(self) -> None:
|
21 |
+
if not self._initialized:
|
22 |
+
self.driver = self.get_neo4j_driver()
|
23 |
+
self.teb_model = None
|
24 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
25 |
+
PaperClient._initialized = True
|
26 |
|
27 |
def get_neo4j_driver(self):
|
|
|
28 |
URI = os.environ["NEO4J_URL"]
|
29 |
NEO4J_USERNAME = os.environ["NEO4J_USERNAME"]
|
30 |
NEO4J_PASSWD = os.environ["NEO4J_PASSWD"]
|
31 |
AUTH = (NEO4J_USERNAME, NEO4J_PASSWD)
|
|
|
32 |
driver = GraphDatabase.driver(URI, auth=AUTH)
|
33 |
return driver
|
34 |
|
|
|
283 |
results = session.execute_write(lambda tx: tx.run(query).data())
|
284 |
contexts = [result["title"] + result["context"] for result in results]
|
285 |
paper_ids = [result["hash_id"] for result in results]
|
286 |
+
context_embeddings = embedding_model.encode(contexts, batch_size=512, convert_to_tensor=True, device=self.device)
|
287 |
query = """
|
288 |
MERGE (p:Paper {hash_id: $hash_id})
|
289 |
ON CREATE SET p.abstract_embedding = $embedding
|
|
|
313 |
results = session.execute_write(lambda tx: tx.run(query).data())
|
314 |
contexts = [result["context"] for result in results]
|
315 |
paper_ids = [result["hash_id"] for result in results]
|
316 |
+
context_embeddings = embedding_model.encode(contexts, batch_size=256, convert_to_tensor=True, device=self.device)
|
317 |
query = """
|
318 |
MERGE (p:Paper {hash_id: $hash_id})
|
319 |
ON CREATE SET p.embedding = $embedding
|
|
|
343 |
results = session.execute_write(lambda tx: tx.run(query).data())
|
344 |
contexts = [result["context"] for result in results]
|
345 |
paper_ids = [result["hash_id"] for result in results]
|
346 |
+
context_embeddings = embedding_model.encode(contexts, batch_size=256, convert_to_tensor=True, device=self.device)
|
347 |
query = """
|
348 |
MERGE (p:Paper {hash_id: $hash_id})
|
349 |
ON CREATE SET p.contribution_embedding = $embedding
|
|
|
374 |
results = session.execute_write(lambda tx: tx.run(query).data())
|
375 |
contexts = [result["context"] for result in results]
|
376 |
paper_ids = [result["hash_id"] for result in results]
|
377 |
+
context_embeddings = embedding_model.encode(contexts, batch_size=256, convert_to_tensor=True, device=self.device)
|
378 |
query = """
|
379 |
MERGE (p:Paper {hash_id: $hash_id})
|
380 |
ON CREATE SET p.summary_embedding = $embedding
|
|
|
537 |
NEO4J_PASSWD = os.environ["NEO4J_PASSWD"]
|
538 |
AUTH = (NEO4J_USERNAME, NEO4J_PASSWD)
|
539 |
graph = Graph(URI, auth=AUTH)
|
540 |
+
# 创建一个字典来保存数据
|
541 |
+
data = {"nodes": [], "relationships": []}
|
542 |
query = """
|
543 |
MATCH (e:Entity)-[r:RELATED_TO]->(p:Paper)
|
544 |
RETURN p, e, r
|
545 |
"""
|
546 |
results = graph.run(query)
|
|
|
|
|
547 |
# 处理查询结果
|
548 |
for record in tqdm(results):
|
549 |
paper_node = record["p"]
|
|
|
631 |
|
632 |
|
633 |
if __name__ == "__main__":
|
634 |
+
paper_client = PaperClient()
|
|
|
|
|
|
|
635 |
# paper_client.neo4j_backup()
|
636 |
paper_client.neo4j_import_data()
|
src/utils/paper_retriever.py
CHANGED
@@ -2,8 +2,6 @@ import torch
|
|
2 |
import itertools
|
3 |
import threading
|
4 |
import numpy as np
|
5 |
-
from sentence_transformers import SentenceTransformer
|
6 |
-
from sklearn.feature_extraction.text import CountVectorizer
|
7 |
from sklearn.metrics.pairwise import cosine_similarity
|
8 |
from collections import Counter, defaultdict
|
9 |
from loguru import logger
|
@@ -11,7 +9,7 @@ from abc import ABCMeta, abstractmethod
|
|
11 |
from .paper_client import PaperClient
|
12 |
from .paper_crawling import PaperCrawling
|
13 |
from .llms_api import APIHelper
|
14 |
-
from .
|
15 |
|
16 |
|
17 |
class UnionFind:
|
@@ -51,18 +49,26 @@ def can_merge(uf, similarity_matrix, i, j, threshold):
|
|
51 |
|
52 |
|
53 |
class CoCite:
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
def get_cocite_ids(self, id_, k=1):
|
68 |
sorted_items = sorted(self.comap[id_].items(), key=lambda x: x[1], reverse=True)
|
@@ -82,14 +88,12 @@ class Retriever(object):
|
|
82 |
self.config = config
|
83 |
self.use_cocite = use_cocite
|
84 |
self.use_cluster_to_filter = use_cluster_to_filter
|
85 |
-
self.paper_client = PaperClient(
|
86 |
-
self.cocite = CoCite(
|
87 |
self.api_helper = APIHelper(config=config)
|
88 |
-
self.
|
89 |
-
|
90 |
-
)
|
91 |
self.paper_crawling = PaperCrawling(config=config)
|
92 |
-
self.vectorizer = CountVectorizer()
|
93 |
|
94 |
@abstractmethod
|
95 |
def retrieve(self, bg, entities, use_evaluate):
|
@@ -192,7 +196,7 @@ class Retriever(object):
|
|
192 |
entities = self.api_helper.generate_entity_list(context)
|
193 |
logger.debug("get entity from context: {}".format(entities))
|
194 |
origin_vector = self.embedding_model.encode(
|
195 |
-
context, convert_to_tensor=True, device=self.
|
196 |
).unsqueeze(0)
|
197 |
related_contexts = [
|
198 |
self.paper_client.get_paper_attribute(paper_id, type_name)
|
@@ -200,7 +204,10 @@ class Retriever(object):
|
|
200 |
]
|
201 |
if len(related_contexts) > 0:
|
202 |
context_embeddings = self.embedding_model.encode(
|
203 |
-
related_contexts,
|
|
|
|
|
|
|
204 |
)
|
205 |
score_1 = torch.nn.functional.cosine_similarity(
|
206 |
origin_vector, context_embeddings
|
@@ -208,7 +215,7 @@ class Retriever(object):
|
|
208 |
score_1 = score_1.cpu().numpy()
|
209 |
if self.config.RETRIEVE.need_normalize:
|
210 |
score_1 = score_1 / np.max(score_1)
|
211 |
-
# score_2 not enable
|
212 |
# if self.config.RETRIEVE.beta != 0:
|
213 |
score_sn_dict = dict(zip(related_paper_id_list, score_1))
|
214 |
score_en_dict = dict(zip(related_paper_id_list, score_2))
|
@@ -231,28 +238,33 @@ class Retriever(object):
|
|
231 |
else list(score_dict.keys())
|
232 |
)
|
233 |
return paper_id_list
|
234 |
-
else:
|
235 |
# clustering filter, ensure that each category the highest score save first
|
236 |
paper_id_list = list(score_dict.keys())
|
237 |
paper_embedding_list = [
|
238 |
-
self.paper_client.get_paper_attribute(paper_id, "embedding")
|
|
|
239 |
]
|
240 |
paper_embedding = np.array(paper_embedding_list)
|
241 |
paper_embedding_list = [
|
242 |
-
self.paper_client.get_paper_attribute(
|
|
|
|
|
|
|
243 |
]
|
244 |
paper_contribution_embedding = np.array(paper_embedding_list)
|
245 |
paper_embedding_list = [
|
246 |
-
self.paper_client.get_paper_attribute(paper_id, "summary_embedding")
|
|
|
247 |
]
|
248 |
paper_summary_embedding = np.array(paper_embedding_list)
|
249 |
weight_embedding = self.config.RETRIEVE.s_bg
|
250 |
weight_contribution = self.config.RETRIEVE.s_contribution
|
251 |
weight_summary = self.config.RETRIEVE.s_summary
|
252 |
paper_embedding = (
|
253 |
-
weight_embedding * paper_embedding
|
254 |
-
weight_contribution * paper_contribution_embedding
|
255 |
-
weight_summary * paper_summary_embedding
|
256 |
)
|
257 |
similarity_matrix = np.dot(paper_embedding, paper_embedding.T)
|
258 |
related_labels = self.cluster_algorithm(paper_id_list, similarity_matrix)
|
@@ -542,9 +554,7 @@ class SNRetriever(Retriever):
|
|
542 |
related_paper_id_list = retrieve_result["paper"]
|
543 |
retrieve_paper_num = len(related_paper_id_list)
|
544 |
_, _, score_all_dict = self.cal_related_score(
|
545 |
-
bg,
|
546 |
-
related_paper_id_list=related_paper_id_list,
|
547 |
-
entities=entities
|
548 |
)
|
549 |
top_k_matrix = {}
|
550 |
recall = 0
|
@@ -746,4 +756,4 @@ class SNKGRetriever(Retriever):
|
|
746 |
"retrieve_paper_num": retrieve_paper_num,
|
747 |
"label_num": label_num,
|
748 |
}
|
749 |
-
return result
|
|
|
2 |
import itertools
|
3 |
import threading
|
4 |
import numpy as np
|
|
|
|
|
5 |
from sklearn.metrics.pairwise import cosine_similarity
|
6 |
from collections import Counter, defaultdict
|
7 |
from loguru import logger
|
|
|
9 |
from .paper_client import PaperClient
|
10 |
from .paper_crawling import PaperCrawling
|
11 |
from .llms_api import APIHelper
|
12 |
+
from .hash import get_embedding_model
|
13 |
|
14 |
|
15 |
class UnionFind:
|
|
|
49 |
|
50 |
|
51 |
class CoCite:
|
52 |
+
_instance = None
|
53 |
+
_initialized = False
|
54 |
+
|
55 |
+
def __new__(cls, *args, **kwargs):
|
56 |
+
if cls._instance is None:
|
57 |
+
cls._instance = super(CoCite, cls).__new__(cls)
|
58 |
+
return cls._instance
|
59 |
+
|
60 |
+
def __init__(self) -> None:
|
61 |
+
if not self._initialized:
|
62 |
+
self.paper_client = PaperClient()
|
63 |
+
citemap = self.paper_client.build_citemap()
|
64 |
+
self.comap = defaultdict(lambda: defaultdict(int))
|
65 |
+
for paper_id, cited_id in citemap.items():
|
66 |
+
for id0, id1 in itertools.combinations(cited_id, 2):
|
67 |
+
# ensure comap[id0][id1] == comap[id1][id0]
|
68 |
+
self.comap[id0][id1] += 1
|
69 |
+
self.comap[id1][id0] += 1
|
70 |
+
logger.debug("init co-cite map success")
|
71 |
+
CoCite._initialized = True
|
72 |
|
73 |
def get_cocite_ids(self, id_, k=1):
|
74 |
sorted_items = sorted(self.comap[id_].items(), key=lambda x: x[1], reverse=True)
|
|
|
88 |
self.config = config
|
89 |
self.use_cocite = use_cocite
|
90 |
self.use_cluster_to_filter = use_cluster_to_filter
|
91 |
+
self.paper_client = PaperClient()
|
92 |
+
self.cocite = CoCite()
|
93 |
self.api_helper = APIHelper(config=config)
|
94 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
95 |
+
self.embedding_model = get_embedding_model(config)
|
|
|
96 |
self.paper_crawling = PaperCrawling(config=config)
|
|
|
97 |
|
98 |
@abstractmethod
|
99 |
def retrieve(self, bg, entities, use_evaluate):
|
|
|
196 |
entities = self.api_helper.generate_entity_list(context)
|
197 |
logger.debug("get entity from context: {}".format(entities))
|
198 |
origin_vector = self.embedding_model.encode(
|
199 |
+
context, convert_to_tensor=True, device=self.device
|
200 |
).unsqueeze(0)
|
201 |
related_contexts = [
|
202 |
self.paper_client.get_paper_attribute(paper_id, type_name)
|
|
|
204 |
]
|
205 |
if len(related_contexts) > 0:
|
206 |
context_embeddings = self.embedding_model.encode(
|
207 |
+
related_contexts,
|
208 |
+
batch_size=512,
|
209 |
+
convert_to_tensor=True,
|
210 |
+
device=self.device,
|
211 |
)
|
212 |
score_1 = torch.nn.functional.cosine_similarity(
|
213 |
origin_vector, context_embeddings
|
|
|
215 |
score_1 = score_1.cpu().numpy()
|
216 |
if self.config.RETRIEVE.need_normalize:
|
217 |
score_1 = score_1 / np.max(score_1)
|
218 |
+
# score_2 not enable
|
219 |
# if self.config.RETRIEVE.beta != 0:
|
220 |
score_sn_dict = dict(zip(related_paper_id_list, score_1))
|
221 |
score_en_dict = dict(zip(related_paper_id_list, score_2))
|
|
|
238 |
else list(score_dict.keys())
|
239 |
)
|
240 |
return paper_id_list
|
241 |
+
else:
|
242 |
# clustering filter, ensure that each category the highest score save first
|
243 |
paper_id_list = list(score_dict.keys())
|
244 |
paper_embedding_list = [
|
245 |
+
self.paper_client.get_paper_attribute(paper_id, "embedding")
|
246 |
+
for paper_id in paper_id_list
|
247 |
]
|
248 |
paper_embedding = np.array(paper_embedding_list)
|
249 |
paper_embedding_list = [
|
250 |
+
self.paper_client.get_paper_attribute(
|
251 |
+
paper_id, "contribution_embedding"
|
252 |
+
)
|
253 |
+
for paper_id in paper_id_list
|
254 |
]
|
255 |
paper_contribution_embedding = np.array(paper_embedding_list)
|
256 |
paper_embedding_list = [
|
257 |
+
self.paper_client.get_paper_attribute(paper_id, "summary_embedding")
|
258 |
+
for paper_id in paper_id_list
|
259 |
]
|
260 |
paper_summary_embedding = np.array(paper_embedding_list)
|
261 |
weight_embedding = self.config.RETRIEVE.s_bg
|
262 |
weight_contribution = self.config.RETRIEVE.s_contribution
|
263 |
weight_summary = self.config.RETRIEVE.s_summary
|
264 |
paper_embedding = (
|
265 |
+
weight_embedding * paper_embedding
|
266 |
+
+ weight_contribution * paper_contribution_embedding
|
267 |
+
+ weight_summary * paper_summary_embedding
|
268 |
)
|
269 |
similarity_matrix = np.dot(paper_embedding, paper_embedding.T)
|
270 |
related_labels = self.cluster_algorithm(paper_id_list, similarity_matrix)
|
|
|
554 |
related_paper_id_list = retrieve_result["paper"]
|
555 |
retrieve_paper_num = len(related_paper_id_list)
|
556 |
_, _, score_all_dict = self.cal_related_score(
|
557 |
+
bg, related_paper_id_list=related_paper_id_list, entities=entities
|
|
|
|
|
558 |
)
|
559 |
top_k_matrix = {}
|
560 |
recall = 0
|
|
|
756 |
"retrieve_paper_num": retrieve_paper_num,
|
757 |
"label_num": label_num,
|
758 |
}
|
759 |
+
return result
|