lihuigu commited on
Commit
c8709b2
·
1 Parent(s): 8a27036

update new version

Browse files
README.md CHANGED
@@ -36,15 +36,16 @@ SciPIP is a scientific paper idea generation tool powered by a large language mo
36
  ![SciPIP](./assets/pic/demo.png)
37
 
38
 
39
- 🤗 Try it on the Hugging Face (Coming Soon... You can deploy it at your own computer now.)
40
 
41
  ## Updates
42
 
43
  - [x] Idea generation in a GUI enviroment (web app).
44
- - [x] Idea generation for the NLP and multimodal (partial) field.
45
- - [ ] Idea generation for the CV field.
46
  - [ ] Idea generation for other fields.
47
- - [ ] Release the Huggingface demo.
 
48
 
49
  ## Prerequisites
50
 
@@ -64,11 +65,11 @@ The following enviroments are tested under Ubuntu 22.04 with python>=3.10.3.
64
  ## Install Neo4j database
65
  sudo apt install -y openjdk-17-jre # Install Neo4j required JDK
66
  # cd ~/Downloads # or /your/path/to/download/Neo4j
67
- wget http://dist.neo4j.org/neo4j-community-5.20.0-unix.tar.gz
68
- tar -xvf neo4j-community-5.20.0-unix.tar.gz
69
 
70
  ## Start Neo4j
71
- cd ./neo4j-community-5.20.0
72
  # Uncomment server.default_listen_address=0.0.0.0 in conf/neo4j.conf to visit Neo4j through a browser
73
  sed -i 's/# server.default_listen_address=0.0.0.0/server.default_listen_address=0.0.0.0/g' ./conf/neo4j.conf
74
  ./bin/neo4j start
@@ -93,17 +94,14 @@ The following enviroments are tested under Ubuntu 22.04 with python>=3.10.3.
93
  ```
94
  3. **Prepare the literature database**
95
 
96
- 1. Download the literature data from [this link](https://drive.google.com/file/d/1NZTDpxKo7bmxwXPI03dgikEemKGLkwne/view?usp=sharing) and save it to `assets/data/scipip_neo4j_clean_backup.json`.
97
- 2. Then, run the following command to load the literature into Neo4j database (It may 40-60 minutes):
98
- ```
99
- python src/utils/paper_client.py
100
- ```
101
 
102
- 4. **[Optional] Prepare the embedding model**. Our algorithm uses SentenceBERT and **will automatically download** it from Huggingface the first time the program is run. However, if you're concerned about potential download failures due to network issues, you can download it in advance and place it in the specified directory.
103
  ```bash
104
- cd /root/path/of/SciPIP && mkdir -p assets/model/sentence-transformers
105
- git clone https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2 assets/model/sentence-transformers/all-MiniLM-L6-v2 assets/model/sentence-transformers
106
  ```
 
107
 
108
  ## Run In a Browser (Recommended)
109
 
@@ -116,36 +114,20 @@ Then, visit `http://localhost:8501` in your browser with an interactive envirome
116
 
117
  ## Run In a Terminal
118
 
119
- **1. BackTracking of ACL 2024**
120
-
121
- ```
122
- python src/generator.py backtracking --brainstorm-mode mode_c --use-cue-words True --use-inspiration True --num 1
123
- ```
124
-
125
- Results dump in `assets/output_idea/output_backtracking_mode_c_cue_True_ins_True.json`.
126
 
127
- **2. Generate new idea**
128
-
129
- Input your backgound and cue words in `assets/data/test_background.json`
130
 
131
  ```
132
  python src/generator.py new-idea --brainstorm-mode mode_c --use-inspiration True --num 2
133
  ```
134
 
135
- Results dump in `assets/output_idea/output_new_idea_mode_c_ins_True.json`.
136
 
137
  ## Others
138
 
139
- ### Retrieve Eval
140
-
141
- Generate retrieve eval log result in `./log`.
142
-
143
- ```
144
- bash scripts/retriever_eval.sh
145
- ```
146
-
147
  ### Database Construction
148
- SciPIP uses Neo4j as its database. You can directly import the provided data or add your own research papers.
149
  ```
150
  wget https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl
151
  pip install en_core_web_sm-3.7.1-py3-none-any.whl
 
36
  ![SciPIP](./assets/pic/demo.png)
37
 
38
 
39
+ 🤗 Try it on the Hugging Face: https://huggingface.co/spaces/lihuigu/SciPIP (The demo uses the old version code temporally and will be updated soon.)
40
 
41
  ## Updates
42
 
43
  - [x] Idea generation in a GUI enviroment (web app).
44
+ - [x] Idea generation for the NLP and multimodal field.
45
+ - [x] Idea generation for the CV field.
46
  - [ ] Idea generation for other fields.
47
+ - [x] Release the Huggingface demo.
48
+ - [x] Support DeepSeek-v3 as an backend, now. 🎉 🎉 🎉
49
 
50
  ## Prerequisites
51
 
 
65
  ## Install Neo4j database
66
  sudo apt install -y openjdk-17-jre # Install Neo4j required JDK
67
  # cd ~/Downloads # or /your/path/to/download/Neo4j
68
+ wget http://dist.neo4j.org/neo4j-community-5.25.1-unix.tar.gz
69
+ tar -xvf neo4j-community-5.25.1-unix.tar.gz
70
 
71
  ## Start Neo4j
72
+ cd ./neo4j-community-5.25.1
73
  # Uncomment server.default_listen_address=0.0.0.0 in conf/neo4j.conf to visit Neo4j through a browser
74
  sed -i 's/# server.default_listen_address=0.0.0.0/server.default_listen_address=0.0.0.0/g' ./conf/neo4j.conf
75
  ./bin/neo4j start
 
94
  ```
95
  3. **Prepare the literature database**
96
 
97
+ 1. Download the literature data from [google_drive](https://drive.google.com/file/d/1kZmJff8am-JGegZZQx0qxlC7o7YgBURg/view?usp=sharing) or [baidu disk](https://pan.baidu.com/s/1S22Evi5ReL0MvahFoQ-ipA?pwd=scip). Replace the `/your/path/neo4j-community-5.25.1/data` folder with our provided `data` folder, which contains literature of CV, NLP, ML, *etc.*
98
+ 2. [Optional] Prepare the embedding model. Our algorithm uses **jina-embedding v3** and will automatically download it from Huggingface the first time the program is run. However, if you're concerned about potential download failures due to network issues, you can download it in advance and place it in the specified directory.
 
 
 
99
 
 
100
  ```bash
101
+ cd /root/path/of/SciPIP && mkdir -p assets/model/
102
+ git clone https://huggingface.co/jinaai/jina-embeddings-v3 assets/model
103
  ```
104
+
105
 
106
  ## Run In a Browser (Recommended)
107
 
 
114
 
115
  ## Run In a Terminal
116
 
117
+ **1. Generate new idea**
 
 
 
 
 
 
118
 
119
+ Input your backgound in `assets/data/test_background.json`
 
 
120
 
121
  ```
122
  python src/generator.py new-idea --brainstorm-mode mode_c --use-inspiration True --num 2
123
  ```
124
 
125
+ Results dump in `assets/output_idea/output-file.json`.
126
 
127
  ## Others
128
 
 
 
 
 
 
 
 
 
129
  ### Database Construction
130
+ SciPIP uses Neo4j as its database. You can directly import the provided data or add your own research papers.ddddddsfasdfldsafkldsjfdkls
131
  ```
132
  wget https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl
133
  pip install en_core_web_sm-3.7.1-py3-none-any.whl
app.py CHANGED
@@ -22,11 +22,9 @@ if __name__ == "__main__":
22
  def fn2():
23
  step_by_step_generation.step_by_step_generation(backend)
24
 
25
- pg = st.navigation(
26
- [
27
- st.Page(homepage.home_page, title=_("🏠️ Homepage")),
28
- st.Page(fn1, title=_("💧 One-click Generation")),
29
- st.Page(fn2, title=_("💦 Step-by-step Generation")),
30
- ]
31
- )
32
- pg.run()
 
22
  def fn2():
23
  step_by_step_generation.step_by_step_generation(backend)
24
 
25
+ pg = st.navigation([
26
+ st.Page(homepage.home_page, title=_("🏠️ Homepage")),
27
+ st.Page(fn1, title=_("💧 One-click Generation")),
28
+ st.Page(fn2, title=_("💦 Step-by-step Generation")),
29
+ ])
30
+ pg.run()
 
 
assets/data/test_background.json CHANGED
@@ -1,2 +1,6 @@
1
- {"background": "The application scope of large-scale language models such as GPT-4 and LLaMA has rapidly expanded, demonstrating powerful capabilities in natural language processing and multimodal tasks. However, as the size and complexity of the models increase, understanding how they make decisions becomes increasingly difficult. Challenge: 1 The complexity of model interpretation: The billions of parameters and nonlinear decision paths within large-scale language models make it very difficult to track and interpret specific outputs. The existing interpretation methods usually only provide a local perspective and are difficult to systematize. 2. Transparency and Fairness: In specific scenarios, models may exhibit biased or discriminatory behavior. Ensuring the transparency of these models, reducing bias, and providing credible explanations is one of the current challenges."}
2
- {"background": "Multimodal learning is committed to integrating multiple information sources such as text, images, audio, and video to create more powerful and universal AI models. The research on unified representation aims to find representation methods that can generalize across modalities. Challenge: 1 Modal alignment: There is heterogeneity between different modalities, and how to achieve semantic alignment of these modalities to ensure that the model can comprehensively understand different types of data is a core challenge. 2. Data sparsity and imbalance: There is usually an imbalance in the amount of data in different modalities, such as video and audio data being relatively scarce, while text data is relatively abundant. How to effectively utilize and fuse these modalities to avoid overfitting or underfitting remains a research difficulty.", "cue_words": ["cross-modal embedding", "dat augmentation", "modality-aware fusion", "heterogeneous data integration"]}
 
 
 
 
 
1
+ [{"background": "1. The limitations of existing methods in leveraging nonverbal information for discerning complex semantics in unsupervised scenarios. \n2. The recognition that non-verbal modalities (video and audio) play a critical role in performing unsupervised clustering and can provide useful cues for semantics discovery."},
2
+ {"background": "1. The need to reduce the memory footprint and computational requirements of Large Language Models (LLMs) for deployment on resource-constrained devices. \n2. The challenge of preserving model performance at sub-4-bit quantization levels, where existing methods significantly degrade the fidelity of model weights."},
3
+ {"background": "1. The need to trace back and understand which training data influences specific generations in large language models for tasks like risk assessment, model retraining, and improving explainability. \n2. The challenge of scaling existing influence estimation methods to large language models due to their massive size and the computational expense of processing and storing gradients."},
4
+ {"background": "1. The need for a unifying framework to integrate the multiple dimensions of lexical semantic change detected by historical linguists. \n2. The desire to align theoretical insights from linguistics with the methodological sophistication of natural language processing to better understand social and cultural change."},
5
+ {"background": "The application scope of large-scale language models such as GPT-4 and LLaMA has rapidly expanded, demonstrating powerful capabilities in natural language processing and multimodal tasks. However, as the size and complexity of the models increase, understanding how they make decisions becomes increasingly difficult. Challenge: 1 The complexity of model interpretation: The billions of parameters and nonlinear decision paths within large-scale language models make it very difficult to track and interpret specific outputs. The existing interpretation methods usually only provide a local perspective and are difficult to systematize. 2. Transparency and Fairness: In specific scenarios, models may exhibit biased or discriminatory behavior. Ensuring the transparency of these models, reducing bias, and providing credible explanations is one of the current challenges."},
6
+ {"background": "Multimodal learning is committed to integrating multiple information sources such as text, images, audio, and video to create more powerful and universal AI models. The research on unified representation aims to find representation methods that can generalize across modalities. Challenge: 1 Modal alignment: There is heterogeneity between different modalities, and how to achieve semantic alignment of these modalities to ensure that the model can comprehensively understand different types of data is a core challenge. 2. Data sparsity and imbalance: There is usually an imbalance in the amount of data in different modalities, such as video and audio data being relatively scarce, while text data is relatively abundant. How to effectively utilize and fuse these modalities to avoid overfitting or underfitting remains a research difficulty.", "cue_words": ["cross-modal embedding", "dat augmentation", "modality-aware fusion", "heterogeneous data integration"]}]
assets/prompt/expand_background.xml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="utf-8"?>
2
+ <!DOCTYPE body [
3
+ <!ENTITY warning "Warning: Something bad happened... please refresh and try again.">
4
+ ]>
5
+ <body>
6
+ <query rank="0">
7
+ <title>System Message</title>
8
+ <text>
9
+ You are a teacher in the field of AI, skilled at clearly explaining AI concepts to students. Your student is an undergraduate in AI with a basic understanding of deep learning.
10
+ </text>
11
+ </query>
12
+ <query rank="1">
13
+ <title>User Message</title>
14
+ <text>
15
+ # Task Description:
16
+ You are teaching your undergraduate about a specific subfield of AI research. You have a brief description of the research background, and now you need to explain its meaning and purpose in detail to your undergraduate. Keep in mind that your undergraduate may be completely unfamiliar with the technical terms in the research background. I will give you an example. The example begins with "# Example 1" and includes a Brief Research Background, several Technical Terms, and the corresponding Detailed Research Background. Then, your task starts with "# Your Task", containing "Your Brief Research Background" and "Your Technical Terms". Your job is to expand Your Brief Research Background into a Detailed Research Background by referring to Example 1. Note that the research background in Example 1 are unrelated to yours, so the key focus should be on the relationship between the Brief Research Background and the Detailed Research Background. You should directly start with your response and do not start with a section title like "## Detailed Background".
17
+
18
+ # Example 1
19
+
20
+ ## Brief Research Background
21
+
22
+ During the inference process of large language models, the KV cache grows with the context and the length of the generated content, occupying an increasing amount of GPU memory. How can we minimize the memory usage of the KV cache as much as possible to extend the text length that large language models can handle without increasing the cache size?
23
+
24
+ ## Technical Terms
25
+
26
+ large language models, kv cache, gpu memory
27
+
28
+ ## Detailed Research Background
29
+
30
+ Large language models use a Transformer-based architecture and generate text autoregressively, outputting one token at a time during inference. Within the Transformer, there is a self-attention module where each token is associated with three vectors: Q (query), K (key), and V (value). For each newly generated token, its Q needs to be computed with the K and V of all previously generated tokens. To avoid recalculating K and V repeatedly, the K and V of all tokens are stored in GPU memory, which is referred to as KV-cache. During the inference process of large language models, the KV cache grows with the context and the length of the generated content, occupying an increasing amount of GPU memory. How can we minimize the memory usage of the KV cache as much as possible to extend the text length that large language models can handle without increasing the cache size?
31
+
32
+ # Your Task
33
+
34
+ ## Your Research Background
35
+
36
+ {brief_background}
37
+
38
+ ## Your Technical Terms
39
+
40
+ {keywords}
41
+
42
+ </text>
43
+ </query>
44
+ </body>
assets/prompt/expand_idea.xml ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="utf-8"?>
2
+ <!DOCTYPE body [
3
+ <!ENTITY warning "Warning: Something bad happened... please refresh and try again.">
4
+ ]>
5
+ <body>
6
+ <query rank="0">
7
+ <title>System Message</title>
8
+ <text>
9
+ Now you are a researcher in the field of AI with innovative and pioneering abilities. You are good at transforming a brief scientific idea into a concrete algorithm.
10
+ </text>
11
+ </query>
12
+ <query rank="1">
13
+ <title>User Message</title>
14
+ <text>
15
+ # Task Description:
16
+ You are an AI researcher conducting studies in a specific domain. Someone has provided you with a brief scientific idea, and your task is to transform it into a detailed, feasible, and concrete algorithm. If necessary, you may incorporate formulas to elaborate on the algorithm in the Latex format. I will give you an example. The example begins with "# Example 1" and includes a Brief Scientific Idea and its corresponding Detailed Scientific Idea. Then, your task starts with "# Your Task", containing "Your Brief Scientific Idea". Your job is to expand Your Brief Scientific Idea into a Detailed Scientific Idea by referring to Example 1. Note that the ideas in Example 1 are unrelated to your idea, so the key focus should be on the relationship between the Brief Scientific Idea and the Detailed Scientific Idea. You should directly start with your response and do not start with a section title like "## Detailed Scientific Idea".
17
+
18
+ # Example 1
19
+
20
+ ## Example Brief Scientific Idea
21
+
22
+ 1. The use of a dual-stage process that first quantizes the LLM's model weights into 4-bit and then introduces a side network that leverages downsampled outputs and hidden states from the quantized LLM to make task-specific predictions.
23
+ 2. The innovative application of several low-rank adapter methods, such as MaxPooling and AvgPooling, within the side network to perform downsampling and significantly reduce the number of trainable parameters and the memory footprint of optimizer states.
24
+ 3. The aggregation of hidden states from the quantized LLM and the side network using a learnable parameter, which allows for efficient parallel computation of the LLM and side network without increasing inference latency.
25
+
26
+ ## Example Detailed Scientific Idea
27
+
28
+ Building on the concept of Parameter-Efficient Fine-Tuning (PEFT), which aims to adapt Large Language Models (LLMs) to specific tasks without the full computational cost of training all parameters, the following algorithm integrates a memory-efficient strategy enhanced through quantization and an auxiliary side network. This allows for efficient fine-tuning and inference on large models with reduced memory requirements and computational demands.
29
+
30
+ 1. **Quantization of LLM to 4-bit Precision:**
31
+ 1. Utilize 4-bit quantization to reduce the memory footprint of the LLM's weights. Each floating-point parameter in the LLM is converted to a 4-bit representation for efficient memory utilization.
32
+ 2. Begin with the conversion of 16-bit parameters to 4-bit using the relation:
33
+ $$
34
+ X_{{4bit}} = \text{{round}}\left(\frac{{M_{{4bit}}}}{{\text{{Absmax}}(X_{{16bit}})}} \cdot X_{{16bit}}\right)
35
+ $$
36
+ where $M_{{4bit}}$ is the maximum value representable in 4 bits, ensuring quantization minimizes precision loss by managing outliers through block-wise separate quantization.
37
+
38
+ 2. **Side Network for Memory-Efficient Tuning:**
39
+ 1. Implement a side network $g$ with dimensions reduced by a factor $r$ relative to the original model $f$. The side network processes information more economically, storing less data and reducing computational load during training.
40
+ 2. Define the hidden state transformation at the $i$-th layer as:
41
+ $$
42
+ h_{{gi}}^{{16bit}} = (1 - \beta_i) * \text{{downsample}}_i(h_{{fi}}^{{16bit}}) + \beta_i * h_{{gi-1}}^{{16bit}}
43
+ $$
44
+ where $\beta_i$ is a learnable gating parameter and $\text{{downsample}}_i$ reduces dimensionality.
45
+
46
+ 3. **Low-Memory Gradient Calculation:**
47
+ 1. Perform backpropagation limited to the side network $g$, excluding the calculation of gradients for the quantized weights in $f$, leveraging the pre-trained knowledge while focusing computational resources on the task-specific adaptation.
48
+ 2. The gradient computation is detached from the main LLM, avoiding the costly backpropagation through large transformer layers and focusing updates through efficient gradient paths within $g$.
49
+
50
+ 4. **Combining Outputs for Inference:**
51
+ 1. At inference, blend the outputs from the LLM and side network as a weighted sum:
52
+ $$
53
+ h_N^{{16bit}} = \alpha h_{{fN}}^{{16bit}} + (1-\alpha) h_{{gN}}^{{16bit}}
54
+ $$
55
+ where $\alpha$ is a learnable parameter initialized to prioritize the pre-trained model influence gradually allowing task customization through tuning.
56
+
57
+ 5. **Optimized Training Procedure:**
58
+ 1. Integrate efficient downsampling techniques, such as LoRA and Adapter models, reducing parameter size significantly without losing efficacy.
59
+ 2. Maintain a 16-bit floating-point data type for computations in forward and backward passes to balance precision and performance, ensuring that the quantized network remains robust and generalizable.
60
+
61
+ By synthesizing quantization with an auxiliary network, the algorithm achieves a robust parameter-efficient fine-tuning technique, significantly reducing memory overhead, improving inference speed, and maintaining high performance despite the minimal parameters being updated. This approach effectively supports large-scale models, facilitating application in environments with constrained computational resources.
62
+
63
+ # Your Task
64
+
65
+ ## Your Research Background
66
+
67
+ {background}
68
+
69
+ ## Your Brief Scientific Idea
70
+
71
+ {brief_idea}
72
+
73
+ </text>
74
+ </query>
75
+ </body>
assets/prompt/generate_brainstorm.xml CHANGED
@@ -13,21 +13,21 @@ Now you are a researcher in the field of AI with innovative and pioneering abili
13
  <title>User Message</title>
14
  <text>
15
  ### Task Description:
16
- trunkYou are an AI researcher tasked with brainstorming initial, innovative ideas to address a given research problem in AI. Focus on generating diverse and creative approaches rather than finalized methods. The ideas can be rough and in their infancy but should cover a range of possible directions that could be explored further.
17
 
18
- trunk### Information Provided:
19
- trunk- **Research Background**: {background}
20
 
21
- trunk### Approach:
22
- trunkYour brainstorming should be systematic:
23
- trunk- **Step 1**: Thoroughly understand the research background.
24
- trunk- **Step 2**: Generate a list of 4 to 6 high-level ideas or directions that could potentially solve problems in the given background. Be creative, think outside the box, and avoid merely rephrasing existing methods.
25
 
26
- trunk### Format for Your Response:
27
- trunkPlease present 4 to 6 ideas in the following format:
28
- trunk**Idea 1**: [Brief description of the first idea]
29
- trunk**Idea 2**: [Brief description of the second idea]
30
- trunk...
31
  </text>
32
  </query>
33
  </body>
 
13
  <title>User Message</title>
14
  <text>
15
  ### Task Description:
16
+ You are an AI researcher tasked with brainstorming initial, innovative ideas to address a given research problem in AI. Focus on generating diverse and creative approaches. The ideas should cover a range of possible directions that could be explored further.
17
 
18
+ ### Information Provided:
19
+ - **Research Background**: {background}
20
 
21
+ ### Approach:
22
+ Your brainstorming should be systematic:
23
+ - **Step 1**: Thoroughly understand the research background.
24
+ - **Step 2**: Generate a list of 3 to 4 high-level ideas or directions that could potentially solve problems in the given background. Be creative, think outside the box, and avoid merely rephrasing existing methods.
25
 
26
+ ### Format for Your Response:
27
+ Please present 3 to 4 ideas in the following format:
28
+ **Idea 1**: [Brief description of the first idea]
29
+ **Idea 2**: [Brief description of the second idea]
30
+ ...
31
  </text>
32
  </query>
33
  </body>
assets/prompt/generate_concise_method.xml ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="utf-8"?>
2
+ <!DOCTYPE body [
3
+ <!ENTITY warning "Warning: Something bad happened... please refresh and try again.">
4
+ ]>
5
+ <body>
6
+ <query rank="0">
7
+ <title>System Message</title>
8
+ <text>
9
+ Now you are a researcher in the field of AI with innovative and pioneering abilities.
10
+ </text>
11
+ </query>
12
+ <query rank="1">
13
+ <title>User Message</title>
14
+ <text>
15
+ # Task Description:
16
+ You are an AI researcher conducting studies in a specific domain. Someone has provided you with a methodology section, and your task is to transform it into another style. I will give you an example. The example begins with "# Example 1" and includes a Example Summarized Methods. Then, your task starts with "# Your Task", containing "Your Methodology Section". Your job is to transform Your Methodology Section into a Summarized Methods by referring to Example 1. Note that the ideas in Example 1 are unrelated to your idea, so the key focus should be on the style of Example Summarized Methods. You should directly start with your response and do not start with a section title like "## Your Summarized Methods".
17
+
18
+ # Example 1
19
+
20
+ ## Example Summarized Methods
21
+
22
+ Building on the concept of Parameter-Efficient Fine-Tuning (PEFT), which aims to adapt Large Language Models (LLMs) to specific tasks without the full computational cost of training all parameters, the following algorithm integrates a memory-efficient strategy enhanced through quantization and an auxiliary side network. This allows for efficient fine-tuning and inference on large models with reduced memory requirements and computational demands.
23
+
24
+ 1. **Quantization of LLM to 4-bit Precision:**
25
+ 1. Utilize 4-bit quantization to reduce the memory footprint of the LLM's weights. Each floating-point parameter in the LLM is converted to a 4-bit representation for efficient memory utilization.
26
+ 2. Begin with the conversion of 16-bit parameters to 4-bit using the relation:
27
+ \[
28
+ X_{{4bit}} = \text{{round}}\left(\frac{{M_{{4bit}}}}{{\text{{Absmax}}(X_{{16bit}})}} \cdot X_{{16bit}}\right)
29
+ \]
30
+ where \(M_{{4bit}}\) is the maximum value representable in 4 bits, ensuring quantization minimizes precision loss by managing outliers through block-wise separate quantization.
31
+
32
+ 2. **Side Network for Memory-Efficient Tuning:**
33
+ 1. Implement a side network \(g\) with dimensions reduced by a factor \(r\) relative to the original model \(f\). The side network processes information more economically, storing less data and reducing computational load during training.
34
+ 2. Define the hidden state transformation at the \(i\)-th layer as:
35
+ \[
36
+ h_{{gi}}^{{16bit}} = (1 - \beta_i) * \text{{downsample}}_i(h_{{fi}}^{{16bit}}) + \beta_i * h_{{gi-1}}^{{16bit}}
37
+ \]
38
+ where \(\beta_i\) is a learnable gating parameter and \(\text{{downsample}}_i\) reduces dimensionality.
39
+
40
+ 3. **Low-Memory Gradient Calculation:**
41
+ 1. Perform backpropagation limited to the side network \(g\), excluding the calculation of gradients for the quantized weights in \)f\), leveraging the pre-trained knowledge while focusing computational resources on the task-specific adaptation.
42
+ 2. The gradient computation is detached from the main LLM, avoiding the costly backpropagation through large transformer layers and focusing updates through efficient gradient paths within \(g\).
43
+
44
+ 4. **Combining Outputs for Inference:**
45
+ 1. At inference, blend the outputs from the LLM and side network as a weighted sum:
46
+ \[
47
+ h_N^{{16bit}} = \alpha h_{{fN}}^{{16bit}} + (1-\alpha) h_{{gN}}^{{16bit}}
48
+ \]
49
+ where \(\alpha\) is a learnable parameter initialized to prioritize the pre-trained model influence gradually allowing task customization through tuning.
50
+
51
+ 5. **Optimized Training Procedure:**
52
+ 1. Integrate efficient downsampling techniques, such as LoRA and Adapter models, reducing parameter size significantly without losing efficacy.
53
+ 2. Maintain a 16-bit floating-point data type for computations in forward and backward passes to balance precision and performance, ensuring that the quantized network remains robust and generalizable.
54
+
55
+ By synthesizing quantization with an auxiliary network, the algorithm achieves a robust parameter-efficient fine-tuning technique, significantly reducing memory overhead, improving inference speed, and maintaining high performance despite the minimal parameters being updated. This approach effectively supports large-scale models, facilitating application in environments with constrained computational resources.
56
+
57
+ # Your Task
58
+
59
+ ## Your Methodology Section
60
+
61
+ {methodology}
62
+
63
+ ## Your Summarized Methods
64
+
65
+ </text>
66
+ </query>
67
+ </body>
assets/prompt/generate_idea_by_inspiration.xml CHANGED
@@ -12,30 +12,19 @@ Now you are a researcher in the field of AI with innovative and pioneering abili
12
  <query rank="1">
13
  <title>User Message</title>
14
  <text>
15
- ### Task Description:
16
- You will be provided with a research problem and its rationales, along with inspirations and their rationales extracted from related papers. Your task is to brainstorm some ideas that are clear, innovative, valid, and comprehensive to address the problem.
17
 
18
- ### Information Provided:
19
- 1. **Research problem &amp; Rationales**: The key issues or aspects of the problem that need to be addressed. These will form the foundation for generating your ideas.
20
- 2. **Inspirations**: Insights and ideas extracted from related papers that may provide valuable perspectives or techniques applicable to the research problem.
21
 
22
- ### Approach:
23
- Your approach should be systematic:
24
- - **Step 1**: Thoroughly read and understand the research problem to identify your primary focus.
25
- - **Step 2**: Review the inspirations extracted from the related papers to gain a broader perspective and insights relevant to the research topic.
26
- - **Step 3**: Based on the provided information, propose some ideas that are clear, innovative, valid, and comprehensive.
27
 
28
- ### Specific Information:
29
- I will provide you with specific information now, please use them according to the instructions above:
30
- 1. **Research problem &amp; Rationales**: {problem}
31
- 2. **Inspirations**: {inspirations_text}
 
32
 
33
- ### Format for Your Response:
34
- Please ensure that your final ideas include about 10 entries, presented in the following format:
35
- **Idea 1**: [The first method idea]
36
- **Idea 2**: [The second method idea]
37
- **Idea 3**: [The third method idea]
38
- ...
39
  </text>
40
  </query>
41
  </body>
 
12
  <query rank="1">
13
  <title>User Message</title>
14
  <text>
15
+ # Task Description:
16
+ You will be provided with a research problem, along with inspirations extracted from related papers. Your task is to identify and combine these inspirations to propose 3 to 4 different ideas to solve the research problem. The ideas should be innovative and try to avoid using the same inspirations repeatedly. Each idea should starts with "**Idea **:".
17
 
18
+ # Information Provided:
 
 
19
 
20
+ ## Research problem
 
 
 
 
21
 
22
+ {background}
23
+
24
+ ## Inspirations
25
+
26
+ {inspirations}
27
 
 
 
 
 
 
 
28
  </text>
29
  </query>
30
  </body>
assets/prompt/generate_inspiration_with_detail_method.xml ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="utf-8"?>
2
+ <!DOCTYPE body [
3
+ <!ENTITY warning "Warning: Something bad happened... please refresh and try again.">
4
+ ]>
5
+ <body>
6
+ <query rank="0">
7
+ <title>System Message</title>
8
+ <text>
9
+ Now you are a researcher in the field of AI with innovative and pioneering abilities. You are good at extracting novel and valuable inspirations from papers.
10
+ </text>
11
+ </query>
12
+ <query rank="1">
13
+ <title>User Message</title>
14
+ <text>
15
+ # Task Description:
16
+ You will be provided with a research problem, as well as some reference materials. Your task is to extract a novel, effective, and specific inspiration from the materials that can help addressing the research problem. I will give you an example. The example begins with "# Example 1" and includes a Example Problem, Example Materials, and Example Inspiration. Your task is to read Your Problem and Your Materials, and extract Inspirations by referring to Example 1. Note that the contents in Example 1 are unrelated to yours, so the key focus should be on the relationship among the Background, Materials, and Inspiration. Only output the three most significant inspirations and each inspiration should be concluded with three sentences. You should directly start with your response and do not start with a section title like "## Your Inspirations". Further, if you believe the materials do not contribute to solving the problem described in the Background, you may simply reply with "None" and provide no further response.
17
+
18
+ # Example 1
19
+
20
+ ## Example Problem
21
+
22
+ 1. The need to ensure that chain-of-thought (CoT) rationales generated by large language models are consistent with their predictions and faithfully justify those decisions.
23
+ 2. The desire to distill the reasoning capabilities of large LMs into smaller models without losing the quality and faithfulness of the rationales.
24
+
25
+ ## Example Materials
26
+
27
+ In addressing the challenges of generating faithful Chain-Of-Thought (CoT) rationales and consistent student outputs in knowledge distillation, we introduce the Self-Consistent Chain-Of-Thought Distillation (SCOTT) method. SCOTT is designed to enhance consistency in rationale generation and counter the pitfalls of hallucination and reasoning shortcuts in language models. This approach involves training a smaller student model to produce rationales that align with its predictions, learning from a larger teacher model to achieve this. Our approach leverages contrastive decoding and counterfactual reasoning to improve the quality and faithfulness of rationales.
28
+
29
+ 1. **Contrastive Decoding for Teacher Model:**
30
+ 1. Employ contrastive decoding to generate more relevant and answer-grounded rationales by the teacher model, which mitigates issues related to hallucination common in language models.
31
+ 2. This is achieved by introducing perturbed answers and evaluating the plausibility shift of each token to ensure rationales support the intended answers more distinctly:
32
+ $$
33
+ G(t_i | a^*) = \log \frac{{P(t_i | p, q, A, a^*, t_{{&lt;i}})}}{{P(t_i | p, q, A, a', t_{{&lt;i}})}}
34
+ $$
35
+ Here, $a'$ represents a perturbed answer, used to fine-tune the rationale's specificity towards the correct answer.
36
+
37
+ 2. **Counterfactual Reasoning for Student Model:**
38
+ 1. Train the student model to validate its rationale against its predictions through counterfactual reasoning, requiring the student to adjust predictions when confronted with altered context or rationale.
39
+ 2. Implement this by incorporating variabilities in rationales that lead to different answers and ensuring the model understands and adjusts accordingly:
40
+ $$
41
+ L_{{\text{{counterfactual}}}} = - \sum \log P(t_i | q, r', t_{{&lt;i}})
42
+ $$
43
+ where $r'$ is a rationale leading to a perturbed answer, encouraging the student model to reflect such dependencies in its decision-making process.
44
+
45
+ 3. **Holistic Training Approach:**
46
+ 1. Integrate the contrastive decoding outputs and counterfactual reasoning objective into the student’s training to simultaneously focus on consistency in rationale generation and alignment with the predictions.
47
+ 2. By incorporating more on-topic rationale-answer pairs and utilizing both factual and counterfactual losses, the student model's faithfulness and performance are improved:
48
+ $$
49
+ L_{{\text{{total}}}} = L_{{\text{{factual}}}} + L_{{\text{{counterfactual}}}}
50
+ $$
51
+
52
+ 4. **Experimentation and Validation:**
53
+ 1. Conduct experiments on open-domain QA tasks where knowledge-intensive reasoning is essential, assessing both rationale consistency and the student's alignment between rationale and prediction.
54
+ 2. Results indicate the proposed SCOTT method leads to improved student faithfulness, maintaining competitive performance with additional advantages in rationale justification consistency compared to baseline models.
55
+ 3. Additional ablation studies reveal robustness across various student model sizes, ensuring consistent rationale fidelity irrespective of model capacity.
56
+
57
+ Through the integration of contrastive decoding and counterfactual reasoning, SCOTT offers a novel and robust approach to improve rationale consistency and model faithfulness in natural language processing tasks, enhancing interpretability and performance.
58
+
59
+ ## Example Inspiration
60
+
61
+ 1. When prompting large language models to generate rationales, the faithfulness of the rationale to the prediction can be enhanced using contrastive decoding. Specifically, for a given prediction, the model-generated rationale should differ as much as possible from the rationale generated for other predictions.
62
+
63
+ 2. The chain-of-thought (CoT) reasoning ability of smaller language models can be improved through chain-of-thought distillation. During distillation, for the same question, when the chain-of-thought content differs, the model's predictions should also differ.
64
+
65
+ # You Task
66
+
67
+ ## Your Problem
68
+
69
+ {background}
70
+
71
+ ## Your Materials
72
+
73
+ {detail_method}
74
+
75
+ </text>
76
+ </query>
77
+ </body>
configs/datasets.yaml CHANGED
@@ -3,30 +3,40 @@ DEFAULT:
3
  ignore_paper_id_list: ./assets/data/ignore_paper_id_list.json
4
  log_level: "DEBUG"
5
  log_dir: ./log
6
- embedding: sentence-transformers/all-MiniLM-L6-v2
 
 
 
 
 
7
 
8
  ARTICLE:
9
  summarizing_prompt: ./assets/prompt/summarizing.xml
10
 
11
  RETRIEVE:
12
  retriever_name: "SNKG"
13
- use_cocite: False
14
- use_cluster_to_filter: False # 过滤器中使用聚类算法
15
- cite_type: "all_cite_id_list"
 
 
16
  limit_num: 100 # 限制entity对应的paper数量
17
- sn_num_for_entity: 3 # SN搜索的文章数量,扩充entity
18
- kg_jump_num: 1 # 跳数
19
- kg_cover_num: 3 # entity重合数量
20
- sum_paper_num: 50 # 最多检索到的paper数量
21
- sn_retrieve_paper_num: 55 # 通过SN检索到的文章
 
22
  cocite_top_k: 1
23
  need_normalize: True
24
  alpha: 1
25
  beta: 0
26
  relation_name: "related" # "connect"
27
  top_p_list: [0.1, 0.2, 0.3, 0.4, 0.5]
28
- top_k_list: [10, 20, 30, 40, 50]
29
- s_bg: 0
30
- s_contribution: 0.5
31
- s_summary: 0.5
32
- similarity_threshold: 0.55
 
 
 
3
  ignore_paper_id_list: ./assets/data/ignore_paper_id_list.json
4
  log_level: "DEBUG"
5
  log_dir: ./log
6
+ # embedding: sentence-transformers/all-MiniLM-L6-v2
7
+ # embedding: BAAI/llm-embedder
8
+ embedding: jina-embeddings-v3
9
+ embedding_task: text-matching # ONLY FOR JINA_v3, retrieval.passage, text-matching, retrieval.query
10
+ embedding_database: text-matching # ONLY FOR JINA_v3, retrieval.passage, text-matching, retrieval.query
11
+
12
 
13
  ARTICLE:
14
  summarizing_prompt: ./assets/prompt/summarizing.xml
15
 
16
  RETRIEVE:
17
  retriever_name: "SNKG"
18
+ # retriever_name: "SN"
19
+ SN_field_name: "background"
20
+ use_cocite: True
21
+ use_cluster_to_filter: True # 过滤器中使用聚类算法
22
+ cite_type: "cite_id_list"
23
  limit_num: 100 # 限制entity对应的paper数量
24
+ sn_num_for_entity: 5 # SN搜索的文章数量,扩充entity
25
+ kg_jump_num: 1 # 跳数,这个参数是不用的,就默认一次
26
+ kg_cover_num: 7 # entity重合数量,就是两个entity共同同时出现在了kg_cover_num篇文章中
27
+ sum_paper_num: 100 # 最多检索到的paper数量,指的是通过entity检索到的paper数量
28
+ sn_retrieve_paper_num: 100 # 通过SN检索到的文章
29
+ all_retrieve_paper_num: 10
30
  cocite_top_k: 1
31
  need_normalize: True
32
  alpha: 1
33
  beta: 0
34
  relation_name: "related" # "connect"
35
  top_p_list: [0.1, 0.2, 0.3, 0.4, 0.5]
36
+ top_k_list: [10, 20, 30, 40, 50, 60, 80, 100, 120, 150]
37
+ s_bg: 1.0
38
+ s_contribution: 0.0
39
+ s_summary: 0.0
40
+ s_abstract: 0.0
41
+ similarity_threshold: 0.95
42
+ # similarity_threshold: 0.55
scripts/env.sh CHANGED
@@ -1,6 +1,13 @@
1
  export NEO4J_URL="bolt://127.0.0.1:7687"
2
  export NEO4J_USERNAME="neo4j" # default neo4j
3
  export NEO4J_PASSWD="****" # your passwd
 
 
 
 
 
 
 
4
  ## Use Qwen
5
  export MODEL_NAME="qwen-turbo"
6
  export MODEL_TYPE="OpenAI"
 
1
  export NEO4J_URL="bolt://127.0.0.1:7687"
2
  export NEO4J_USERNAME="neo4j" # default neo4j
3
  export NEO4J_PASSWD="****" # your passwd
4
+
5
+ ## Use GPT-4o
6
+ # export MODEL_NAME="gpt-4o"
7
+ # export MODEL_TYPE="OpenAI"
8
+ # export MODEL_API_KEY="sk-************************************************"
9
+ # export BASE_URL="https://********************************"
10
+
11
  ## Use Qwen
12
  export MODEL_NAME="qwen-turbo"
13
  export MODEL_TYPE="OpenAI"
scripts/retriever_eval.sh CHANGED
@@ -1,5 +1,5 @@
1
  #!/bin/bash
2
- python src/retriever.py retrieve \
3
  -c configs/datasets.yaml \
4
  --ids-path assets/data/test_acl_2024.json
5
 
 
1
  #!/bin/bash
2
+ CUDA_VISIBLE_DEVICES=0 python src/retriever.py retrieve \
3
  -c configs/datasets.yaml \
4
  --ids-path assets/data/test_acl_2024.json
5
 
src/app_pages/button_interface.py CHANGED
@@ -4,6 +4,7 @@ from utils.llms_api import APIHelper
4
  from utils.header import ConfigReader
5
  from utils.hash import check_env, check_embedding
6
  from generator import IdeaGenerator
 
7
 
8
 
9
  class Backend(object):
@@ -36,94 +37,53 @@ class Backend(object):
36
  except (FileNotFoundError, json.JSONDecodeError) as e:
37
  print(f"Error loading examples from {path}: {e}")
38
  return []
 
 
 
39
 
40
- def background2brainstorm_callback(self, background, json_strs=None):
41
- if json_strs is not None: # only for DEBUG_MODE
42
- json_contents = json.loads(json_strs)
43
- return json_contents["brainstorm"]
44
- else:
45
- return self.api_helper.generate_brainstorm(background)
46
 
47
- def brainstorm2entities_callback(self, background, brainstorm, json_strs=None):
48
- if json_strs is not None: # only for DEBUG_MODE
49
- json_contents = json.loads(json_strs)
50
- entities_bg = json_contents["entities_bg"]
51
- entities_bs = json_contents["entities_bs"]
52
- entities_all = entities_bg + entities_bs
53
- # return gr.CheckboxGroup(choices=entities, value=entities, label="Expanded key words", visible=True)
54
- return entities_all
55
- else:
56
- entities_bg = self.api_helper.generate_entity_list(background)
57
- entities_bs = self.api_helper.generate_entity_list(brainstorm, 10)
58
- entities_all = list(set(entities_bg) | set(entities_bs))
59
- # return extracted_entities
60
- # return gr.CheckboxGroup(choices=entities_all, value=entities_all, label="Expanded key words", visible=True)
61
- return entities_all
62
 
63
  def upload_json_callback(self, input):
64
- # print(type(input))
65
- # print(len(input))
66
- # print(input) # temp file path
67
  with open(input, "r") as json_file:
68
  contents = json_file.read()
69
  json_contents = json.loads(contents)
70
  return [json_contents["background"], contents]
71
 
72
- def entities2literature_callback(self, background, entities, json_strs=None):
73
- if json_strs is not None:
74
- result = json.loads(json_strs)
75
- res = []
76
- for i, p in enumerate(result["related_paper"]):
77
- res.append(str(p))
78
- else:
79
- result = self.retriever_factory.retrieve(
80
- background, entities, need_evaluate=False, target_paper_id_list=[]
81
- )
82
- res = []
83
- for i, p in enumerate(result["related_paper"]):
84
- res.append(f'{p["title"]}. {p["venue_name"].upper()} {p["year"]}.')
85
  return res, result["related_paper"]
86
 
87
  def literature2initial_ideas_callback(
88
- self, background, brainstorms, retrieved_literature, json_strs=None
89
  ):
90
- if json_strs is not None:
91
- json_contents = json.loads(json_strs)
92
- return json_contents["median"]["initial_idea"]
93
- else:
94
- self.idea_generator.paper_list = retrieved_literature
95
- self.idea_generator.brainstorm = brainstorms
96
- if self.use_inspiration:
97
- message_input, idea_modified, median = (
98
- self.idea_generator.generate_by_inspiration(
99
- background, "new_idea", self.brainstorm_mode, False
100
- )
101
- )
102
- else:
103
- message_input, idea_modified, median = self.idea_generator.generate(
104
- background, "new_idea", self.brainstorm_mode, False
105
- )
106
- return median["initial_idea"], idea_modified
107
 
108
- def initial2final_callback(self, initial_ideas, final_ideas, json_strs=None):
109
- if json_strs is not None:
110
- json_contents = json.loads(json_strs)
111
- return json_contents["median"]["modified_idea"]
112
- else:
113
- return final_ideas
114
 
115
  def get_demo_i(self, i):
116
  if 0 <= i < len(self.examples):
117
  return self.examples[i].get("background", "Background not found.")
118
  else:
119
  return "Example not found. Please select a valid index."
120
-
121
- # return ("The application scope of large-scale language models such as GPT-4 and LLaMA "
122
- # "has rapidly expanded, demonstrating powerful capabilities in natural language processing "
123
- # "and multimodal tasks. However, as the size and complexity of the models increase, understanding "
124
- # "how they make decisions becomes increasingly difficult. Challenge: 1 The complexity of model "
125
- # "interpretation: The billions of parameters and nonlinear decision paths within large-scale language "
126
- # "models make it very difficult to track and interpret specific outputs. The existing interpretation "
127
- # "methods usually only provide a local perspective and are difficult to systematize. 2. Transparency "
128
- # "and Fairness: In specific scenarios, models may exhibit biased or discriminatory behavior. Ensuring "
129
- # "the transparency of these models, reducing bias, and providing credible explanations is one of the current challenges.")
 
4
  from utils.header import ConfigReader
5
  from utils.hash import check_env, check_embedding
6
  from generator import IdeaGenerator
7
+ import functools
8
 
9
 
10
  class Backend(object):
 
37
  except (FileNotFoundError, json.JSONDecodeError) as e:
38
  print(f"Error loading examples from {path}: {e}")
39
  return []
40
+
41
+ def background2entities_callback(self, background):
42
+ return self.api_helper.generate_entity_list(background)
43
 
44
+ def background2expandedbackground_callback(self, background, entities):
45
+ keywords_str = functools.reduce(lambda x, y: f"{x}, {y}", entities)
46
+ expanded_background = self.api_helper.expand_background(background, keywords_str)
47
+ return expanded_background
 
 
48
 
49
+ def background2brainstorm_callback(self, expanded_background):
50
+ return self.api_helper.generate_brainstorm(expanded_background)
51
+
52
+ def brainstorm2entities_callback(self, brainstorm, entities):
53
+ entities_bs = self.api_helper.generate_entity_list(brainstorm, 10)
54
+ entities_all = list(set(entities) | set(entities_bs))
55
+ return entities_all
 
 
 
 
 
 
 
 
56
 
57
  def upload_json_callback(self, input):
 
 
 
58
  with open(input, "r") as json_file:
59
  contents = json_file.read()
60
  json_contents = json.loads(contents)
61
  return [json_contents["background"], contents]
62
 
63
+ def entities2literature_callback(self, expanded_background, entities):
64
+ result = self.retriever_factory.retrieve(
65
+ expanded_background, entities, need_evaluate=False, target_paper_id_list=[]
66
+ )
67
+ res = []
68
+ for i, p in enumerate(result["related_paper"]):
69
+ res.append(f'{p["title"]}. {p["venue_name"].upper()} {p["year"]}.')
 
 
 
 
 
 
70
  return res, result["related_paper"]
71
 
72
  def literature2initial_ideas_callback(
73
+ self, expanded_background, brainstorms, retrieved_literature
74
  ):
75
+ self.idea_generator.paper_list = retrieved_literature
76
+ self.idea_generator.brainstorm = brainstorms
77
+ _, _, inspirations, initial_ideas, idea_filtered, final_ideas = (
78
+ self.idea_generator.generate_ins_bs(expanded_background)
79
+ )
80
+ return idea_filtered, final_ideas
 
 
 
 
 
 
 
 
 
 
 
81
 
82
+ def initial2final_callback(self, initial_ideas, final_ideas):
83
+ return final_ideas
 
 
 
 
84
 
85
  def get_demo_i(self, i):
86
  if 0 <= i < len(self.examples):
87
  return self.examples[i].get("background", "Background not found.")
88
  else:
89
  return "Example not found. Please select a valid index."
 
 
 
 
 
 
 
 
 
 
src/app_pages/homepage.py CHANGED
@@ -40,7 +40,7 @@ def generate_mainpage():
40
  st.header("Resources")
41
  st.markdown("Our paper: [https://arxiv.org/abs/2410.23166](https://arxiv.org/abs/2410.23166)")
42
  st.markdown("Our github repository: [https://github.com/cheerss/SciPIP](https://github.com/cheerss/SciPIP)")
43
- st.markdown("Our Huggingface demo: Coming soon...")
44
  # st.page_link("https://arxiv.org/abs/2410.23166", label="Our paper: https://arxiv.org/abs/2410.23166", icon=None)
45
  # st.page_link("https://github.com/cheerss/SciPIP", label="Our github repository: https://github.com/cheerss/SciPIP", icon=None)
46
 
@@ -71,7 +71,7 @@ def generate_mainpage():
71
  st.header("相关资源")
72
  st.markdown("论文: [https://arxiv.org/abs/2410.23166](https://arxiv.org/abs/2410.23166)")
73
  st.markdown("Github仓库: [https://github.com/cheerss/SciPIP](https://github.com/cheerss/SciPIP)")
74
- st.markdown("Huggingface演示: 敬请期待...")
75
  # st.page_link("https://arxiv.org/abs/2410.23166", label="Our paper: https://arxiv.org/abs/2410.23166", icon=None)
76
  # st.page_link("https://github.com/cheerss/SciPIP", label="Our github repository: https://github.com/cheerss/SciPIP", icon=None)
77
 
 
40
  st.header("Resources")
41
  st.markdown("Our paper: [https://arxiv.org/abs/2410.23166](https://arxiv.org/abs/2410.23166)")
42
  st.markdown("Our github repository: [https://github.com/cheerss/SciPIP](https://github.com/cheerss/SciPIP)")
43
+ st.markdown("Our Huggingface demo: [https://huggingface.co/spaces/lihuigu/SciPIP](https://huggingface.co/spaces/lihuigu/SciPIP)")
44
  # st.page_link("https://arxiv.org/abs/2410.23166", label="Our paper: https://arxiv.org/abs/2410.23166", icon=None)
45
  # st.page_link("https://github.com/cheerss/SciPIP", label="Our github repository: https://github.com/cheerss/SciPIP", icon=None)
46
 
 
71
  st.header("相关资源")
72
  st.markdown("论文: [https://arxiv.org/abs/2410.23166](https://arxiv.org/abs/2410.23166)")
73
  st.markdown("Github仓库: [https://github.com/cheerss/SciPIP](https://github.com/cheerss/SciPIP)")
74
+ st.markdown("Huggingface演示: [https://huggingface.co/spaces/lihuigu/SciPIP](https://huggingface.co/spaces/lihuigu/SciPIP)")
75
  # st.page_link("https://arxiv.org/abs/2410.23166", label="Our paper: https://arxiv.org/abs/2410.23166", icon=None)
76
  # st.page_link("https://github.com/cheerss/SciPIP", label="Our github repository: https://github.com/cheerss/SciPIP", icon=None)
77
 
src/app_pages/one_click_generation.py CHANGED
@@ -1,9 +1,8 @@
1
  import streamlit as st
 
2
  from .locale import _
3
  from .sidebar_components import get_sidebar_header, get_sidebar_supported_fields, get_help_us_improve, get_language_select
4
 
5
- # st.set_page_config(layout="wide", page_title="🦜🔗 Generate Idea Step-by-step")
6
-
7
  ## Pipeline global state
8
  # 1.0: Input background is in progress
9
  # 2.0: Brainstorming is in progress
@@ -37,7 +36,7 @@ def generate_sidebar():
37
  get_help_us_improve()
38
 
39
  def generate_mainpage(backend):
40
- st.title(_("💧 One-click Generation"))
41
 
42
  if "messages" not in st.session_state:
43
  st.session_state["messages"] = [{"role": "assistant", "content": "Please give me some key words or a background"}]
@@ -95,42 +94,47 @@ def generate_mainpage(backend):
95
  cols[3].button(_("Reset Chat"), on_click=reset, use_container_width=True, type="primary")
96
 
97
  def generate_ideas(backend, background):
 
 
 
 
 
 
 
98
  with st.spinner(text=("Brainstorming...")):
99
- brainstorms = backend.background2brainstorm_callback(background)
100
  st.session_state["intermediate_output"]["brainstorms"] = {"role": "assistant", "content": brainstorms}
101
- # st.chat_message("assistant").write(brainstorms)
102
- st.session_state["global_state_one_click"] = 2.5
103
 
104
- with st.spinner(text=("Extracting entities...")):
105
- entities = backend.brainstorm2entities_callback(background, brainstorms)
106
- st.session_state["intermediate_output"]["entities"] = {"role": "assistant", "content": entities}
107
  # st.chat_message("assistant").write(entities)
108
- st.session_state["global_state_one_click"] = 3.5
109
 
110
  with st.spinner(text=("Retrieving related works...")):
111
- msg = "My initial ideas are:"
112
- related_works, related_works_intact = backend.entities2literature_callback(background, entities)
113
  st.session_state["intermediate_output"]["related_works"] = {"role": "assistant", "content": related_works}
114
  # st.chat_message("assistant").write(related_works)
115
- st.session_state["global_state_one_click"] = 4.5
116
 
117
- with st.spinner(text="Generating initial ideas..."):
118
- msg = "My initial ideas are:"
119
  initial_ideas, final_ideas = backend.literature2initial_ideas_callback(background, brainstorms, related_works_intact)
120
- st.session_state.messages.append({"role": "assistant", "content": msg})
121
- st.chat_message("assistant").write(msg)
122
- st.session_state.messages.append({"role": "assistant", "content": initial_ideas})
123
- st.chat_message("assistant").write(initial_ideas)
124
- st.session_state["global_state_one_click"] = 5.5
125
-
126
- with st.spinner(text=("Generating final ideas...")):
127
- msg = "My final ideas after refinement are:"
128
- final_ideas = backend.initial2final_callback(initial_ideas, final_ideas)
129
- st.session_state.messages.append({"role": "assistant", "content": msg})
130
  st.chat_message("assistant").write(msg)
131
- st.session_state.messages.append({"role": "assistant", "content": final_ideas})
132
- st.chat_message("assistant").write(final_ideas)
133
- st.session_state["global_state_one_click"] = 6.5
 
 
 
 
 
 
 
 
134
 
135
  def one_click_generation(backend):
136
  generate_sidebar()
 
1
  import streamlit as st
2
+ from loguru import logger
3
  from .locale import _
4
  from .sidebar_components import get_sidebar_header, get_sidebar_supported_fields, get_help_us_improve, get_language_select
5
 
 
 
6
  ## Pipeline global state
7
  # 1.0: Input background is in progress
8
  # 2.0: Brainstorming is in progress
 
36
  get_help_us_improve()
37
 
38
  def generate_mainpage(backend):
39
+ st.title(_("One-click Generation"))
40
 
41
  if "messages" not in st.session_state:
42
  st.session_state["messages"] = [{"role": "assistant", "content": "Please give me some key words or a background"}]
 
94
  cols[3].button(_("Reset Chat"), on_click=reset, use_container_width=True, type="primary")
95
 
96
  def generate_ideas(backend, background):
97
+ with st.spinner(text=("Extracting entities from the user's input...")):
98
+ entities_bg = backend.background2entities_callback(background)
99
+
100
+ with st.spinner(text=("Understanding the user's input...")):
101
+ expanded_background = backend.background2expandedbackground_callback(background, entities_bg)
102
+ st.session_state["intermediate_output"]["expanded_background"] = {"role": "assistant", "content": expanded_background}
103
+
104
  with st.spinner(text=("Brainstorming...")):
105
+ brainstorms = backend.background2brainstorm_callback(expanded_background)
106
  st.session_state["intermediate_output"]["brainstorms"] = {"role": "assistant", "content": brainstorms}
107
+ st.chat_message("assistant").write("I have the following thoughts, but I'll search the literature to further consolidate and improve the ideas.")
108
+ st.chat_message("assistant").write(brainstorms)
109
 
110
+ with st.spinner(text=("Extracting entities for literature retrieval...")):
111
+ entities_all = backend.brainstorm2entities_callback(brainstorms, entities_bg)
112
+ st.session_state["intermediate_output"]["entities"] = {"role": "assistant", "content": entities_all}
113
  # st.chat_message("assistant").write(entities)
 
114
 
115
  with st.spinner(text=("Retrieving related works...")):
116
+ msg = "The retrieved works include:"
117
+ related_works, related_works_intact = backend.entities2literature_callback(expanded_background, entities_all)
118
  st.session_state["intermediate_output"]["related_works"] = {"role": "assistant", "content": related_works}
119
  # st.chat_message("assistant").write(related_works)
 
120
 
121
+ with st.spinner(text="Generating ideas... (This may take up to 5 minutes)"):
 
122
  initial_ideas, final_ideas = backend.literature2initial_ideas_callback(background, brainstorms, related_works_intact)
123
+ logger.info(f"Num of initial ideas: {len(initial_ideas)}, num of final ideas: {len(final_ideas)}")
124
+ # assert len(initial_ideas) == len(final_ideas)
125
+ msg = f"I have {len(initial_ideas)} ideas:"
 
 
 
 
 
 
 
126
  st.chat_message("assistant").write(msg)
127
+ for i in range(len(initial_ideas)):
128
+ output = f"""### Concise Idea
129
+ {initial_ideas[i]}
130
+
131
+ ### Idea in Detail:
132
+
133
+ {final_ideas[i]}
134
+
135
+ """
136
+ st.session_state.messages.append({"role": "assistant", "content": output})
137
+ st.chat_message("assistant").write(output)
138
 
139
  def one_click_generation(backend):
140
  generate_sidebar()
src/app_pages/sidebar_components.py CHANGED
@@ -11,9 +11,9 @@ def get_sidebar_supported_fields():
11
  st.sidebar.caption(_("The supported fields are temporarily limited because we only collect literature "
12
  "from ICML, ICLR, NeurIPS, ACL, and EMNLP. Support for other fields are in progress."))
13
  st.sidebar.checkbox(_("Natural Language Processing (NLP)"), value=True, disabled=True)
14
- st.sidebar.checkbox(_("Computer Vision (CV)"), value=False, disabled=True)
15
 
16
- st.sidebar.checkbox(_("[Partial] Multimodal"), value=True, disabled=True)
17
  st.sidebar.checkbox(_("Incoming Other Fields"), value=False, disabled=True)
18
 
19
  def get_help_us_improve():
 
11
  st.sidebar.caption(_("The supported fields are temporarily limited because we only collect literature "
12
  "from ICML, ICLR, NeurIPS, ACL, and EMNLP. Support for other fields are in progress."))
13
  st.sidebar.checkbox(_("Natural Language Processing (NLP)"), value=True, disabled=True)
14
+ st.sidebar.checkbox(_("Computer Vision (CV)"), value=True, disabled=True)
15
 
16
+ st.sidebar.checkbox(_("Multimodal"), value=True, disabled=True)
17
  st.sidebar.checkbox(_("Incoming Other Fields"), value=False, disabled=True)
18
 
19
  def get_help_us_improve():
src/app_pages/step_by_step_generation.py CHANGED
@@ -44,8 +44,8 @@ def get_textarea_height(text_content):
44
  return max(count * 23 + 20, 100) # 23 is a magic number
45
 
46
  def generate_mainpage(backend):
47
- st.title(_("💦 Step-by-step Generation"))
48
- st.header(_("🐳 Background"))
49
  with st.form('background_form') as bg_form:
50
  background = st.session_state.get("background", "")
51
  background = st.text_area("Input your field background", background, placeholder="Input your field background", height=200, label_visibility="collapsed")
@@ -60,112 +60,101 @@ def generate_mainpage(backend):
60
  submitted = col1.form_submit_button(_("Submit"), type="primary")
61
  if submitted:
62
  st.session_state["global_state_step"] = 2.0
63
- with st.spinner(text="Brainstorming..."):
64
- st.session_state["brainstorms"] = backend.background2brainstorm_callback(background)
 
 
 
 
 
 
65
  # st.session_state["brainstorms"] = "Test text"
66
  st.session_state["brainstorms_expand"] = True
67
  st.session_state["global_state_step"] = 2.5
68
 
69
  ## Brainstorms
70
- st.header(_("👻 Brainstorms"))
71
- with st.expander("", expanded=st.session_state.get("brainstorms_expand", False)):
72
- # st.write("<div class='myclass'>")
73
- col1, col2 = st.columns(2)
74
- widget_height = get_textarea_height(st.session_state.get("brainstorms", ""))
75
- brainstorms = col1.text_area(label="brainstorms", value=st.session_state.get("brainstorms", ""),
76
- label_visibility="collapsed", height=widget_height)
77
- st.session_state["brainstorms"] = brainstorms
78
- if brainstorms:
79
- col2.markdown(f"{brainstorms}")
80
- else:
81
- col2.markdown(_("Please input the brainstorms on the left."))
82
- # st.write("</div>")
83
- col1, col2 = st.columns([2, 20])
84
- submitted = col1.button(_("Submit"), type="primary")
85
- if submitted:
86
- st.session_state["global_state_step"] = 3.0
87
- with st.spinner(text="Extracting entities..."):
88
- st.session_state["entities"] = backend.brainstorm2entities_callback(background, brainstorms)
89
- # st.session_state["entities"] = "entities"
90
- st.session_state["global_state_step"] = 3.5
91
- st.session_state["entities_expand"] = True
 
92
 
93
  ## Entities
94
- st.header(_("🐱 Extracted Entities"))
95
- with st.expander("", expanded=st.session_state.get("entities_expand", False)):
96
- ## pills
97
- def update_entities():
98
- return
99
- ori_entities = st.session_state.get("entities", [])
100
- entities_updated = st.session_state.get("entities_updated", ori_entities)
101
- entities_updated = st.pills(label="entities", options=ori_entities, selection_mode="multi",
102
- default=ori_entities, label_visibility="collapsed", on_change=update_entities)
103
- st.session_state["entities_updated"] = entities_updated
 
104
 
105
- submitted = st.button(_("Submit"), key="entities_button", type="primary")
106
- if submitted:
107
- st.session_state["global_state_step"] = 4.0
108
- with st.spinner(text="Retrieving related works..."):
109
- st.session_state["related_works"], st.session_state["related_works_intact"] = backend.entities2literature_callback(background, entities_updated)
110
- st.session_state["related_works_use_state"] = [True] * len(st.session_state["related_works"])
111
- st.session_state["global_state_step"] = 4.5
112
- st.session_state["related_works_expand"] = True
 
113
 
114
  ## Retrieved related works
115
- st.header(_("📖 Retrieved Related Works"))
116
- with st.expander("", expanded=st.session_state.get("related_works_expand", False)):
117
- related_works = st.session_state.get("related_works", [])
118
- for i, rw in enumerate(related_works):
119
- checked = st.checkbox(rw, value=st.session_state.get("related_works_use_state")[i])
120
- st.session_state.get("related_works_use_state")[i] = checked
 
121
 
122
- submitted = st.button(_("Submit"), key="related_works_button", type="primary")
123
- if submitted:
124
- st.session_state["global_state_step"] = 5.0
125
- with st.spinner(text="Generating initial ideas..."):
126
- st.session_state["selected_related_works_intact"] = []
127
- for s, p in zip(st.session_state.get("related_works_use_state"), st.session_state["related_works_intact"]):
128
- if s:
129
- st.session_state["selected_related_works_intact"].append(p)
130
- res = backend.literature2initial_ideas_callback(background, brainstorms, st.session_state["selected_related_works_intact"])
131
- st.session_state["initial_ideas"] = res[0]
132
- st.session_state["final_ideas"] = res[1]
133
- # st.session_state["initial_ideas"] = "initial ideas"
134
- st.session_state["global_state_step"] = 5.5
135
- st.session_state["initial_ideas_expand"] = True
136
 
137
  ## Initial ideas
138
- st.header(_("😼 Generated Initial Ideas"))
139
- with st.expander("", expanded=st.session_state.get("initial_ideas_expand", False)):
140
- col1, col2 = st.columns(2, )
141
- widget_height = get_textarea_height(st.session_state.get("initial_ideas", ""))
142
- initial_ideas = col1.text_area(label="initial_ideas", value=st.session_state.get("initial_ideas", ""),
143
- label_visibility="collapsed", height=widget_height)
144
- if initial_ideas:
145
- col2.markdown(f"{initial_ideas}")
146
- else:
147
- col2.markdown(_("Please input the initial ideas on the left."))
148
- submitted = col1.button(_("Submit"), key="initial_ideas_button", type="primary")
149
- if submitted:
150
- st.session_state["global_state_step"] = 6.0
151
- with st.spinner(text="Generating final ideas..."):
152
- st.session_state["final_ideas"] = backend.initial2final_callback(initial_ideas, st.session_state["final_ideas"])
153
- # st.session_state["final_ideas"] = "final ideas"
154
- st.session_state["global_state_step"] = 6.5
155
- st.session_state["final_ideas_expand"] = True
156
-
157
- ## Final ideas
158
- st.header(_("😸 Generated Final Ideas"))
159
- with st.expander("", expanded=st.session_state.get("final_ideas_expand", False)):
160
- col1, col2 = st.columns(2, )
161
- widget_height = get_textarea_height(st.session_state.get("final_ideas", ""))
162
- user_input = col1.text_area(label="final_ideas", value=st.session_state.get("final_ideas", ""),
163
- label_visibility="collapsed", height=widget_height)
164
- if user_input:
165
- col2.markdown(f"{user_input}")
166
- else:
167
- col2.markdown(_("Please input the final ideas on the left."))
168
- submitted = col1.button(_("Submit"), key="final_ideas_button", type="primary")
169
 
170
  def step_by_step_generation(backend):
171
  ## Pipeline global state
 
44
  return max(count * 23 + 20, 100) # 23 is a magic number
45
 
46
  def generate_mainpage(backend):
47
+ st.title(_("Step-by-step Generation"))
48
+ st.header(_("Background"))
49
  with st.form('background_form') as bg_form:
50
  background = st.session_state.get("background", "")
51
  background = st.text_area("Input your field background", background, placeholder="Input your field background", height=200, label_visibility="collapsed")
 
60
  submitted = col1.form_submit_button(_("Submit"), type="primary")
61
  if submitted:
62
  st.session_state["global_state_step"] = 2.0
63
+ with st.spinner(text="Let me first brainstorm some ideas..."):
64
+ st.session_state["entities_bg"] = backend.background2entities_callback(background)
65
+ st.session_state["expanded_background"] = backend.background2expandedbackground_callback(
66
+ background, st.session_state["entities_bg"]
67
+ )
68
+ st.session_state["brainstorms"] = backend.background2brainstorm_callback(
69
+ st.session_state["expanded_background"]
70
+ )
71
  # st.session_state["brainstorms"] = "Test text"
72
  st.session_state["brainstorms_expand"] = True
73
  st.session_state["global_state_step"] = 2.5
74
 
75
  ## Brainstorms
76
+ if st.session_state["global_state_step"] >= 2.5:
77
+ st.header(_("Brainstorms"))
78
+ with st.expander("", expanded=st.session_state.get("brainstorms_expand", False)):
79
+ # st.write("<div class='myclass'>")
80
+ col1, col2 = st.columns(2)
81
+ widget_height = get_textarea_height(st.session_state.get("brainstorms", ""))
82
+ brainstorms = col1.text_area(label="brainstorms", value=st.session_state.get("brainstorms", ""),
83
+ label_visibility="collapsed", height=widget_height)
84
+ st.session_state["brainstorms"] = brainstorms
85
+ if brainstorms:
86
+ col2.markdown(f"{brainstorms}")
87
+ else:
88
+ col2.markdown(_("Please input the brainstorms on the left."))
89
+ # st.write("</div>")
90
+ col1, col2 = st.columns([2, 20])
91
+ submitted = col1.button(_("Submit"), type="primary")
92
+ if submitted:
93
+ st.session_state["global_state_step"] = 3.0
94
+ with st.spinner(text="I'am extracting keywords in the background and brainstorming ideas"):
95
+ st.session_state["entities"] = backend.brainstorm2entities_callback(brainstorms, st.session_state["entities_bg"])
96
+ # st.session_state["entities"] = "entities"
97
+ st.session_state["global_state_step"] = 3.5
98
+ st.session_state["entities_expand"] = True
99
 
100
  ## Entities
101
+ if st.session_state["global_state_step"] >= 3.5:
102
+ st.header(_("Extracted Entities"))
103
+ with st.expander("", expanded=st.session_state.get("entities_expand", False)):
104
+ ## pills
105
+ def update_entities():
106
+ return
107
+ ori_entities = st.session_state.get("entities", [])
108
+ entities_updated = st.session_state.get("entities_updated", ori_entities)
109
+ entities_updated = st.pills(label="entities", options=ori_entities, selection_mode="multi",
110
+ default=ori_entities, label_visibility="collapsed", on_change=update_entities)
111
+ st.session_state["entities_updated"] = entities_updated
112
 
113
+ submitted = st.button(_("Submit"), key="entities_button", type="primary")
114
+ if submitted:
115
+ st.session_state["global_state_step"] = 4.0
116
+ with st.spinner(text="I am retrieving related works for more ideas..."):
117
+ st.session_state["related_works"], st.session_state["related_works_intact"] = \
118
+ backend.entities2literature_callback(st.session_state["expanded_background"], entities_updated)
119
+ st.session_state["related_works_use_state"] = [True] * len(st.session_state["related_works"])
120
+ st.session_state["global_state_step"] = 4.5
121
+ st.session_state["related_works_expand"] = True
122
 
123
  ## Retrieved related works
124
+ if st.session_state["global_state_step"] >= 4.5:
125
+ st.header(_("Retrieved Related Works"))
126
+ with st.expander("", expanded=st.session_state.get("related_works_expand", False)):
127
+ related_works = st.session_state.get("related_works", [])
128
+ for i, rw in enumerate(related_works):
129
+ checked = st.checkbox(rw, value=st.session_state.get("related_works_use_state")[i])
130
+ st.session_state.get("related_works_use_state")[i] = checked
131
 
132
+ submitted = st.button(_("Submit"), key="related_works_button", type="primary")
133
+ if submitted:
134
+ st.session_state["global_state_step"] = 5.0
135
+ with st.spinner(text="I am generating final ideas..."):
136
+ st.session_state["selected_related_works_intact"] = []
137
+ for s, p in zip(st.session_state.get("related_works_use_state"), st.session_state["related_works_intact"]):
138
+ if s:
139
+ st.session_state["selected_related_works_intact"].append(p)
140
+ res = backend.literature2initial_ideas_callback(background, brainstorms, st.session_state["selected_related_works_intact"])
141
+ st.session_state["initial_ideas"] = res[0]
142
+ st.session_state["final_ideas"] = res[1]
143
+ # st.session_state["initial_ideas"] = "initial ideas"
144
+ st.session_state["global_state_step"] = 5.5
145
+ st.session_state["initial_ideas_expand"] = True
146
 
147
  ## Initial ideas
148
+ if st.session_state["global_state_step"] >= 5.5:
149
+ st.header(_("Generated Ideas"))
150
+ with st.expander("", expanded=st.session_state.get("initial_ideas_expand", False)):
151
+ for initial_idea, final_idea in zip(st.session_state.get("initial_ideas", ""), st.session_state.get("final_ideas", "")):
152
+ st.divider()
153
+ st.markdown("### Concise Idea")
154
+ st.markdown(initial_idea)
155
+ st.markdown("### Idea in Detail")
156
+ st.markdown(final_idea)
157
+ st.divider()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
  def step_by_step_generation(backend):
160
  ## Pipeline global state
src/config/reader.py CHANGED
@@ -159,5 +159,5 @@ class ConfigReader:
159
  """
160
  config = ConfigReader(file_, included).config
161
  for k, v in kwargs.items():
162
- config[k] = v
163
  return config
 
159
  """
160
  config = ConfigReader(file_, included).config
161
  for k, v in kwargs.items():
162
+ config.get(k, {}).update(v)
163
  return config
src/generator.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from utils.paper_retriever import RetrieverFactory
2
  from utils.paper_client import PaperClient
3
  from utils.llms_api import APIHelper
@@ -10,6 +11,7 @@ import warnings
10
  import time
11
  import os
12
  from utils.hash import check_env, check_embedding
 
13
 
14
  warnings.filterwarnings("ignore")
15
 
@@ -26,30 +28,40 @@ def extract_problem(problem, background):
26
  return research_problem
27
 
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  class IdeaGenerator:
30
  def __init__(
31
  self,
32
  config,
33
  paper_list: list[dict] = [],
34
- cue_words: list = None,
35
  brainstorm: str = None,
36
  ) -> None:
37
  self.api_helper = APIHelper(config)
38
  self.paper_list = paper_list
39
- self.cue_words = cue_words
40
  self.brainstorm = brainstorm
41
 
42
- def generate_with_cue_words(self, background: str):
43
- problem, message_input = self.api_helper.generate_problem_with_cue_words(
44
- background, self.paper_list, self.cue_words
45
- )
46
- idea = self.api_helper.generate_idea_with_cue_words(
47
- problem, self.paper_list, self.cue_words
48
- )
49
- idea_filtered = self.api_helper.filter_idea(idea, background)
50
- return message_input, problem, idea, idea_filtered
51
-
52
  def generate_without_cue_words(self, background: str):
 
 
53
  problem, message_input = self.api_helper.generate_problem(
54
  background, self.paper_list
55
  )
@@ -57,19 +69,9 @@ class IdeaGenerator:
57
  idea_filtered = self.api_helper.filter_idea(idea, background)
58
  return message_input, problem, idea, idea_filtered
59
 
60
- def generate_with_cue_words_bs(self, background: str):
61
- problem, message_input = self.api_helper.generate_problem_with_cue_words(
62
- background, self.paper_list, self.cue_words
63
- )
64
- idea = self.api_helper.generate_idea_with_cue_words(
65
- problem, self.paper_list, self.cue_words
66
- )
67
- idea_filtered = self.api_helper.integrate_idea(
68
- background, self.brainstorm, idea
69
- )
70
- return message_input, problem, idea, idea_filtered
71
-
72
  def generate_without_cue_words_bs(self, background: str):
 
 
73
  problem, message_input = self.api_helper.generate_problem(
74
  background, self.paper_list
75
  )
@@ -79,24 +81,9 @@ class IdeaGenerator:
79
  )
80
  return message_input, problem, idea, idea_filtered
81
 
82
- def generate_with_cue_words_ins(self, background: str):
83
- problem, message_input = self.api_helper.generate_problem_with_cue_words(
84
- background, self.paper_list, self.cue_words
85
- )
86
- research_problem = extract_problem(problem, background)
87
- inspirations = []
88
- for paper in self.paper_list:
89
- inspiration = self.api_helper.generate_inspiration_with_cue_words(
90
- research_problem, paper, self.cue_words
91
- )
92
- inspirations.append(inspiration)
93
- idea = self.api_helper.generate_idea_by_inspiration_with_cue_words(
94
- problem, inspirations, self.cue_words
95
- )
96
- idea_filtered = self.api_helper.filter_idea(idea, background)
97
- return message_input, problem, inspirations, idea, idea_filtered
98
-
99
  def generate_without_cue_words_ins(self, background: str):
 
 
100
  problem, message_input = self.api_helper.generate_problem(
101
  background, self.paper_list
102
  )
@@ -109,26 +96,9 @@ class IdeaGenerator:
109
  idea_filtered = self.api_helper.filter_idea(idea, background)
110
  return message_input, problem, inspirations, idea, idea_filtered
111
 
112
- def generate_with_cue_words_ins_bs(self, background: str):
113
- problem, message_input = self.api_helper.generate_problem_with_cue_words(
114
- background, self.paper_list, self.cue_words
115
- )
116
- research_problem = extract_problem(problem, background)
117
- inspirations = []
118
- for paper in self.paper_list:
119
- inspiration = self.api_helper.generate_inspiration_with_cue_words(
120
- research_problem, paper, self.cue_words
121
- )
122
- inspirations.append(inspiration)
123
- idea = self.api_helper.generate_idea_by_inspiration_with_cue_words(
124
- problem, inspirations, self.cue_words
125
- )
126
- idea_filtered = self.api_helper.integrate_idea(
127
- background, self.brainstorm, idea
128
- )
129
- return message_input, problem, inspirations, idea, idea_filtered
130
-
131
  def generate_without_cue_words_ins_bs(self, background: str):
 
 
132
  problem, message_input = self.api_helper.generate_problem(
133
  background, self.paper_list
134
  )
@@ -143,6 +113,52 @@ class IdeaGenerator:
143
  )
144
  return message_input, problem, inspirations, idea, idea_filtered
145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  def generate(
147
  self,
148
  background: str,
@@ -156,37 +172,21 @@ class IdeaGenerator:
156
  elif mode == "new_idea":
157
  mode_name = "Generate new idea"
158
  if bs_mode == "mode_a":
159
- if use_cue_words:
160
- logger.info(
161
- "{} using brainstorm_mode_a with cue words.".format(mode_name)
162
- )
163
- (message_input, problem, idea, idea_filtered) = (
164
- self.generate_with_cue_words(background)
165
- )
166
- else:
167
- logger.info(
168
- "{} using brainstorm_mode_a without cue words.".format(mode_name)
169
- )
170
- (message_input, problem, idea, idea_filtered) = (
171
- self.generate_without_cue_words(background)
172
- )
173
  elif bs_mode == "mode_b" or bs_mode == "mode_c":
174
- if use_cue_words:
175
- logger.info(
176
- "{} using brainstorm_{} with cue words.".format(mode_name, bs_mode)
177
- )
178
- (message_input, problem, idea, idea_filtered) = (
179
- self.generate_with_cue_words_bs(background)
180
- )
181
- else:
182
- logger.info(
183
- "{} using brainstorm_{} without cue words.".format(
184
- mode_name, bs_mode
185
- )
186
- )
187
- (message_input, problem, idea, idea_filtered) = (
188
- self.generate_without_cue_words_bs(background)
189
  )
 
 
 
 
190
 
191
  idea_modified = self.api_helper.modify_idea(background, idea_filtered)
192
  median = {
@@ -209,37 +209,21 @@ class IdeaGenerator:
209
  elif mode == "new_idea":
210
  mode_name = "Generate new idea"
211
  if bs_mode == "mode_a":
212
- if use_cue_words:
213
- logger.info(
214
- "{} using brainstorm_mode_a with cue words.".format(mode_name)
215
- )
216
- (message_input, problem, inspirations, idea, idea_filtered) = (
217
- self.generate_with_cue_words_ins(background)
218
- )
219
- else:
220
- logger.info(
221
- "{} using brainstorm_mode_a without cue words.".format(mode_name)
222
- )
223
- (message_input, problem, inspirations, idea, idea_filtered) = (
224
- self.generate_without_cue_words_ins(background)
225
- )
226
  elif bs_mode == "mode_b" or bs_mode == "mode_c":
227
- if use_cue_words:
228
- logger.info(
229
- "{} using brainstorm_{} with cue words.".format(mode_name, bs_mode)
230
- )
231
- (message_input, problem, inspirations, idea, idea_filtered) = (
232
- self.generate_with_cue_words_ins_bs(background)
233
- )
234
- else:
235
- logger.info(
236
- "{} using brainstorm_{} without cue words.".format(
237
- mode_name, bs_mode
238
- )
239
- )
240
- (message_input, problem, inspirations, idea, idea_filtered) = (
241
- self.generate_without_cue_words_ins_bs(background)
242
  )
 
 
 
 
243
 
244
  idea_modified = self.api_helper.modify_idea(background, idea_filtered)
245
  median = {
@@ -271,209 +255,22 @@ def main(ctx):
271
  )
272
  @click.option(
273
  "--ids-path",
274
- default="./assets/data/test_acl_2024.json",
275
  type=click.File(),
276
  required=True,
277
  help="Dataset configuration file in YAML",
278
  )
279
  @click.option(
280
- "-r",
281
- "--retriever-name",
282
- default="SNKG",
283
- type=str,
284
- required=True,
285
- help="Retrieve method",
286
- )
287
- @click.option(
288
- "--brainstorm-mode",
289
- default="mode_c",
290
  type=str,
291
  required=True,
292
- help="Choose your brainstorm mode (mode_a: no brainstorm, mode_b: brainstorm for idea generation, mode_c: brainstorm for idea generation and retrival)",
293
- )
294
- @click.option(
295
- "--use-cue-words",
296
- default=False,
297
- type=bool,
298
- required=True,
299
- help="Use cue words in generation",
300
- )
301
- @click.option(
302
- "--use-inspiration",
303
- default=False,
304
- type=bool,
305
- required=True,
306
- help="Use inspiration in generation",
307
- )
308
- @click.option(
309
- "--num",
310
- default=100,
311
- type=int,
312
- required=False,
313
- help="The number of papers you want to process",
314
- )
315
- def backtracking(
316
- config_path,
317
- ids_path,
318
- retriever_name,
319
- brainstorm_mode,
320
- use_cue_words,
321
- use_inspiration,
322
- num,
323
- **kwargs,
324
- ):
325
- check_env()
326
- check_embedding()
327
- # Configuration
328
- config = ConfigReader.load(config_path, **kwargs)
329
- logger.add(
330
- "log/generate_{}_{}.log".format(time.time(), retriever_name),
331
- level=config.DEFAULT.log_level,
332
- )
333
- logger.info("\nretrieve name : {}".format(retriever_name))
334
- logger.info("Loaded configuration:\n{}".format(OmegaConf.to_yaml(config)))
335
- api_helper = APIHelper(config)
336
- paper_client = PaperClient()
337
- eval_data = []
338
- processed_ids = set()
339
- cur_num = 0
340
- batch_size = 2
341
- output_dir = "./assets/output_idea/"
342
- os.makedirs(output_dir, exist_ok=True)
343
- output_file = os.path.join(
344
- output_dir,
345
- f"output_backtracking_{brainstorm_mode}_cue_{use_cue_words}_ins_{use_inspiration}.json",
346
- )
347
- if os.path.exists(output_file):
348
- with open(output_file, "r", encoding="utf-8") as f:
349
- try:
350
- eval_data = json.load(f)
351
- processed_ids = {paper["hash_id"] for paper in eval_data}
352
- cur_num = len(eval_data)
353
- except json.JSONDecodeError:
354
- print("Failed to decode JSON, initializing eval_data as an empty list.")
355
- print(f"{cur_num} papers have been processed.")
356
- for line in ids_path:
357
- # 解析每行的JSON数据
358
- paper = json.loads(line)
359
- if paper["hash_id"] in processed_ids:
360
- print(f"Skipping already processed paper: {paper_id}")
361
- continue
362
- logger.info("\nbegin generate paper hash id {}".format(paper["hash_id"]))
363
- # if "entities" in paper.keys():
364
- # entities = paper["entities"]
365
- # else:
366
- # 1. 获取背景信息
367
- paper = paper_client.get_paper_by_id(paper["hash_id"])
368
- if "motivation" in paper.keys():
369
- bg = paper["motivation"]
370
- else:
371
- print(f"Paper hash_id {paper['hash_id']} doesn't have background...")
372
- continue
373
- if brainstorm_mode == "mode_b" or brainstorm_mode == "mode_c":
374
- brainstorm = api_helper.generate_brainstorm(bg)
375
- else:
376
- brainstorm = None
377
- if "entities" in paper.keys():
378
- entities = paper["entities"]
379
- else:
380
- entities = api_helper.generate_entity_list(bg)
381
- logger.debug("Original entities from background: {}".format(entities))
382
- if brainstorm_mode == "mode_c":
383
- entities_bs = api_helper.generate_entity_list(brainstorm, 10)
384
- logger.debug("Original entities from brainstorm: {}".format(entities_bs))
385
- entities_all = list(set(entities) | set(entities_bs))
386
- else:
387
- entities_bs = None
388
- entities_all = entities
389
- # 2. 获取真实引用文章 (用于评估)
390
- cite_type = "cite_id_list"
391
- # cite_type = config.RETRIEVE.cite_type
392
- if cite_type in paper and len(paper[cite_type]) >= 5:
393
- target_paper_id_list = paper[cite_type]
394
- else:
395
- logger.warning(
396
- "Hash ID {} cited paper num less than 5...".format(paper["hash_id"])
397
- )
398
- continue
399
- # 3. 检索相关论文
400
- rt = RetrieverFactory.get_retriever_factory().create_retriever(
401
- retriever_name, config
402
- )
403
- result = rt.retrieve(
404
- bg, entities_all, need_evaluate=False, target_paper_id_list=[]
405
- )
406
- related_paper = result["related_paper"]
407
- logger.info("Find {} related papers...".format(len(related_paper)))
408
- entities_rt = result["entities"]
409
- # 4. 生成IDEA
410
- if use_cue_words:
411
- if "contribution" in paper.keys():
412
- cue_words = api_helper.generate_entity_list(paper["contribution"])
413
- else:
414
- print(f"Paper hash_id {paper['hash_id']} doesn't have contribution...")
415
- cue_words = None
416
- else:
417
- cue_words = None
418
- idea_generator = IdeaGenerator(config, related_paper, cue_words, brainstorm)
419
- if not use_inspiration:
420
- message_input, idea_modified, median = idea_generator.generate(
421
- bg, "backtracking", brainstorm_mode, use_cue_words
422
- )
423
- else:
424
- message_input, idea_modified, median = (
425
- idea_generator.generate_by_inspiration(
426
- bg, "backtracking", brainstorm_mode, use_cue_words
427
- )
428
- )
429
- eval_data.append(
430
- {
431
- "hash_id": paper["hash_id"],
432
- "background": bg,
433
- "entities_bg": entities,
434
- "brainstorm": brainstorm,
435
- "entities_bs": entities_bs,
436
- "entities_rt": entities_rt,
437
- "related_paper": [p["hash_id"] for p in related_paper],
438
- "input": message_input,
439
- "cue_words": cue_words,
440
- "median": median,
441
- "pred": idea_modified,
442
- "ground_truth": paper["ground_truth"],
443
- }
444
- )
445
- cur_num += 1
446
- if cur_num % batch_size == 0:
447
- with open(
448
- output_file,
449
- "w",
450
- encoding="utf-8",
451
- ) as f:
452
- json.dump(eval_data, f, ensure_ascii=False, indent=4)
453
- if cur_num >= num:
454
- break
455
- logger.info("=== Finish ===")
456
- with open(
457
- output_file,
458
- "w",
459
- encoding="utf-8",
460
- ) as f:
461
- json.dump(eval_data, f, ensure_ascii=False, indent=4)
462
-
463
-
464
- @main.command()
465
- @click.option(
466
- "-c",
467
- "--config-path",
468
- default="./configs/datasets.yaml",
469
- type=click.File(),
470
- required=True,
471
  help="Dataset configuration file in YAML",
472
  )
473
  @click.option(
474
- "--ids-path",
475
- default="./assets/data/test_background.json",
476
- type=click.File(),
477
  required=True,
478
  help="Dataset configuration file in YAML",
479
  )
@@ -499,6 +296,12 @@ def backtracking(
499
  required=True,
500
  help="Use inspiration in generation",
501
  )
 
 
 
 
 
 
502
  @click.option(
503
  "--num",
504
  default=100,
@@ -509,9 +312,12 @@ def backtracking(
509
  def new_idea(
510
  config_path,
511
  ids_path,
 
 
512
  retriever_name,
513
  brainstorm_mode,
514
  use_inspiration,
 
515
  num,
516
  **kwargs,
517
  ):
@@ -523,16 +329,16 @@ def new_idea(
523
  # Configuration
524
  config = ConfigReader.load(config_path, **kwargs)
525
  api_helper = APIHelper(config)
 
526
  check_embedding(config.DEFAULT.embedding)
527
  eval_data = []
528
  cur_num = 0
529
  data_num = 0
530
- batch_size = 2
531
  bg_ids = set()
532
- output_dir = "./assets/output_idea/"
533
- os.makedirs(output_dir, exist_ok=True)
534
  output_file = os.path.join(
535
- output_dir, f"output_new_idea_{brainstorm_mode}_ins_{use_inspiration}.json"
536
  )
537
  if os.path.exists(output_file):
538
  with open(output_file, "r", encoding="utf-8") as f:
@@ -543,10 +349,12 @@ def new_idea(
543
  except json.JSONDecodeError:
544
  eval_data = []
545
  logger.debug(f"{cur_num} datas have been processed.")
546
- for line in ids_path:
 
547
  # 解析每行的JSON数据
548
- data = json.loads(line)
549
- # 1. 获取背景信息
 
550
  if "background" in data.keys():
551
  bg = data["background"]
552
  else:
@@ -557,17 +365,28 @@ def new_idea(
557
  data_num += 1
558
  print(f"Skipping already processed data_{data_num}.")
559
  continue
 
 
 
 
 
 
 
 
 
560
  if brainstorm_mode == "mode_b" or brainstorm_mode == "mode_c":
561
- brainstorm = api_helper.generate_brainstorm(bg)
 
 
 
 
 
 
 
562
  else:
563
  brainstorm = None
564
- if "cue_words" in data.keys():
565
- use_cue_words = True
566
- cue_words = data["cue_words"]
567
- else:
568
- use_cue_words = False
569
- cue_words = None
570
- entities = api_helper.generate_entity_list(bg)
571
  logger.debug("Original entities from background: {}".format(entities))
572
  if brainstorm_mode == "mode_c":
573
  entities_bs = api_helper.generate_entity_list(brainstorm, 10)
@@ -576,49 +395,54 @@ def new_idea(
576
  else:
577
  entities_bs = None
578
  entities_all = entities
579
- # 2. 检索相关论文
 
580
  rt = RetrieverFactory.get_retriever_factory().create_retriever(
581
  retriever_name, config
582
  )
583
  result = rt.retrieve(
584
- bg, entities_all, need_evaluate=False, target_paper_id_list=[]
585
  )
586
  related_paper = result["related_paper"]
587
  logger.info("Find {} related papers...".format(len(related_paper)))
588
  entities_rt = result["entities"]
589
- # 3. 生成IDEA
590
- idea_generator = IdeaGenerator(config, related_paper, cue_words, brainstorm)
591
- if not use_inspiration:
592
- message_input, idea_modified, median = idea_generator.generate(
593
- bg, "new_idea", brainstorm_mode, use_cue_words
594
- )
595
- else:
596
- message_input, idea_modified, median = (
597
- idea_generator.generate_by_inspiration(
598
- bg, "new_idea", brainstorm_mode, use_cue_words
599
- )
600
- )
 
 
 
 
601
  eval_data.append(
602
  {
603
  "background": bg,
 
604
  "entities_bg": entities,
605
  "brainstorm": brainstorm,
 
606
  "entities_bs": entities_bs,
607
  "entities_rt": entities_rt,
608
- "related_paper": [p["hash_id"] for p in related_paper],
609
- "input": message_input,
610
- "cue_words": cue_words,
611
- "median": median,
612
- "pred": idea_modified,
 
 
613
  }
614
  )
615
  cur_num += 1
616
  if cur_num % batch_size == 0:
617
- with open(
618
- output_file,
619
- "w",
620
- encoding="utf-8",
621
- ) as f:
622
  json.dump(eval_data, f, ensure_ascii=False, indent=4)
623
  if cur_num >= num:
624
  break
 
1
+ import functools
2
  from utils.paper_retriever import RetrieverFactory
3
  from utils.paper_client import PaperClient
4
  from utils.llms_api import APIHelper
 
11
  import time
12
  import os
13
  from utils.hash import check_env, check_embedding
14
+ import threading
15
 
16
  warnings.filterwarnings("ignore")
17
 
 
28
  return research_problem
29
 
30
 
31
+ def extract_ideas(idea_str):
32
+ if idea_str is None:
33
+ return ""
34
+ ideas = []
35
+ for i in range(1, 100): # 100 is a magic number
36
+ start_word = f"**Idea {i}"
37
+ end_word = f"**Idea {i+1}"
38
+ start_index = idea_str.find(start_word)
39
+ end_index = idea_str.find(end_word)
40
+ if start_index != -1 and end_index != -1:
41
+ ideas.append(idea_str[start_index:end_index].strip())
42
+ # idea_str = idea_str[start_index+end_index+1:]
43
+ elif start_index != -1:
44
+ ideas.append(idea_str[start_index:].strip())
45
+ break
46
+ else:
47
+ break
48
+ return ideas if ideas else [idea_str]
49
+
50
+
51
  class IdeaGenerator:
52
  def __init__(
53
  self,
54
  config,
55
  paper_list: list[dict] = [],
 
56
  brainstorm: str = None,
57
  ) -> None:
58
  self.api_helper = APIHelper(config)
59
  self.paper_list = paper_list
 
60
  self.brainstorm = brainstorm
61
 
 
 
 
 
 
 
 
 
 
 
62
  def generate_without_cue_words(self, background: str):
63
+ """Generate ideas without cue words and brainstorm
64
+ """
65
  problem, message_input = self.api_helper.generate_problem(
66
  background, self.paper_list
67
  )
 
69
  idea_filtered = self.api_helper.filter_idea(idea, background)
70
  return message_input, problem, idea, idea_filtered
71
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  def generate_without_cue_words_bs(self, background: str):
73
+ """Generate ideas without cue words, but brainstorm
74
+ """
75
  problem, message_input = self.api_helper.generate_problem(
76
  background, self.paper_list
77
  )
 
81
  )
82
  return message_input, problem, idea, idea_filtered
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  def generate_without_cue_words_ins(self, background: str):
85
+ """Generate ideas without cue words and brainstorm, but inspiration
86
+ """
87
  problem, message_input = self.api_helper.generate_problem(
88
  background, self.paper_list
89
  )
 
96
  idea_filtered = self.api_helper.filter_idea(idea, background)
97
  return message_input, problem, inspirations, idea, idea_filtered
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  def generate_without_cue_words_ins_bs(self, background: str):
100
+ """Generate ideas without cue words, but inspiration and brainstorm
101
+ """
102
  problem, message_input = self.api_helper.generate_problem(
103
  background, self.paper_list
104
  )
 
113
  )
114
  return message_input, problem, inspirations, idea, idea_filtered
115
 
116
+ def generate_ins_bs(self, detail_background: str):
117
+ """Generate ideas with inspiration and brainstorm
118
+ """
119
+ inspirations = []
120
+
121
+ ## generate inspirations
122
+ processes = []
123
+ def generate_inspiration(paper, i):
124
+ detail_method = self.api_helper.generate_concise_method(paper["methodology"])
125
+ inspiration = self.api_helper.generate_inspiration_with_detail_method(detail_background, detail_method)
126
+ logger.info(f"Generate inspiration for related paper {i} succeed")
127
+ if not(inspiration.startswith("None") or (len(inspiration) < 100 and "None" in inspiration)):
128
+ inspirations.append(inspiration)
129
+
130
+ for i, paper in enumerate(self.paper_list):
131
+ p = threading.Thread(target=generate_inspiration, args=(paper, i))
132
+ processes.append(p)
133
+ p.start()
134
+ for p in processes:
135
+ p.join(120)
136
+
137
+ ## generate ideas through all inspirations
138
+ logger.info("Generate inspirations for all related papers succeed")
139
+ idea = self.api_helper.generate_idea_by_inspiration(detail_background, inspirations)
140
+ initial_ideas = extract_ideas(idea)
141
+ logger.info("Generate ideas from inspirations succeed")
142
+ idea_filtered = self.api_helper.integrate_idea(detail_background, self.brainstorm, idea)
143
+ logger.info("Idea integration succeed")
144
+
145
+ ## expand ideas
146
+ ideas_filtered = extract_ideas(idea_filtered)
147
+ final_ideas = ["None"] * len(ideas_filtered)
148
+ def expand_idea(detail_background: str, idea: str, i):
149
+ final_ideas[i] = self.api_helper.expand_idea(detail_background, idea)
150
+ logger.info(f"Expand the {i}th idea succeed")
151
+ processes = []
152
+ for i, idea in enumerate(ideas_filtered):
153
+ p = threading.Thread(target=expand_idea, args=(detail_background, idea, i))
154
+ processes.append(p)
155
+ p.start()
156
+ for p in processes:
157
+ p.join(120)
158
+
159
+ ## reture
160
+ return None, None, inspirations, initial_ideas, ideas_filtered, final_ideas
161
+
162
  def generate(
163
  self,
164
  background: str,
 
172
  elif mode == "new_idea":
173
  mode_name = "Generate new idea"
174
  if bs_mode == "mode_a":
175
+ logger.info(
176
+ "{} using brainstorm_mode_a without cue words.".format(mode_name)
177
+ )
178
+ (message_input, problem, idea, idea_filtered) = (
179
+ self.generate_without_cue_words(background)
180
+ )
 
 
 
 
 
 
 
 
181
  elif bs_mode == "mode_b" or bs_mode == "mode_c":
182
+ logger.info(
183
+ "{} using brainstorm_{} without cue words.".format(
184
+ mode_name, bs_mode
 
 
 
 
 
 
 
 
 
 
 
 
185
  )
186
+ )
187
+ (message_input, problem, idea, idea_filtered) = (
188
+ self.generate_without_cue_words_bs(background)
189
+ )
190
 
191
  idea_modified = self.api_helper.modify_idea(background, idea_filtered)
192
  median = {
 
209
  elif mode == "new_idea":
210
  mode_name = "Generate new idea"
211
  if bs_mode == "mode_a":
212
+ logger.info(
213
+ "{} using brainstorm_mode_a without cue words.".format(mode_name)
214
+ )
215
+ (message_input, problem, inspirations, idea, idea_filtered) = (
216
+ self.generate_without_cue_words_ins(background)
217
+ )
 
 
 
 
 
 
 
 
218
  elif bs_mode == "mode_b" or bs_mode == "mode_c":
219
+ logger.info(
220
+ "{} using brainstorm_{} without cue words.".format(
221
+ mode_name, bs_mode
 
 
 
 
 
 
 
 
 
 
 
 
222
  )
223
+ )
224
+ (message_input, problem, inspirations, idea, idea_filtered) = (
225
+ self.generate_without_cue_words_ins_bs(background)
226
+ )
227
 
228
  idea_modified = self.api_helper.modify_idea(background, idea_filtered)
229
  median = {
 
255
  )
256
  @click.option(
257
  "--ids-path",
258
+ default="./assets/data/test_background.json",
259
  type=click.File(),
260
  required=True,
261
  help="Dataset configuration file in YAML",
262
  )
263
  @click.option(
264
+ "--out-path",
265
+ default="./assets/output_idea/",
 
 
 
 
 
 
 
 
266
  type=str,
267
  required=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
  help="Dataset configuration file in YAML",
269
  )
270
  @click.option(
271
+ "--out-file",
272
+ default="out-file.json",
273
+ type=str,
274
  required=True,
275
  help="Dataset configuration file in YAML",
276
  )
 
296
  required=True,
297
  help="Use inspiration in generation",
298
  )
299
+ @click.option(
300
+ "--expand-intermediate",
301
+ default=False,
302
+ type=bool,
303
+ help="The number of data you want to process",
304
+ )
305
  @click.option(
306
  "--num",
307
  default=100,
 
312
  def new_idea(
313
  config_path,
314
  ids_path,
315
+ out_path,
316
+ out_file,
317
  retriever_name,
318
  brainstorm_mode,
319
  use_inspiration,
320
+ expand_intermediate,
321
  num,
322
  **kwargs,
323
  ):
 
329
  # Configuration
330
  config = ConfigReader.load(config_path, **kwargs)
331
  api_helper = APIHelper(config)
332
+ paper_client = PaperClient()
333
  check_embedding(config.DEFAULT.embedding)
334
  eval_data = []
335
  cur_num = 0
336
  data_num = 0
337
+ batch_size = 1
338
  bg_ids = set()
339
+ os.makedirs(out_path, exist_ok=True)
 
340
  output_file = os.path.join(
341
+ out_path, out_file
342
  )
343
  if os.path.exists(output_file):
344
  with open(output_file, "r", encoding="utf-8") as f:
 
349
  except json.JSONDecodeError:
350
  eval_data = []
351
  logger.debug(f"{cur_num} datas have been processed.")
352
+ all_input = json.load(ids_path)
353
+ for line in all_input:
354
  # 解析每行的JSON数据
355
+ # data = json.loads(line)
356
+ data = line
357
+ ### 1. 获取背景信息
358
  if "background" in data.keys():
359
  bg = data["background"]
360
  else:
 
365
  data_num += 1
366
  print(f"Skipping already processed data_{data_num}.")
367
  continue
368
+
369
+ ## extract entities from background
370
+ entities = api_helper.generate_entity_list(bg)
371
+
372
+ ## expand background to a detailed version
373
+ keywords_str = functools.reduce(lambda x, y: f"{x}, {y}", entities)
374
+ expanded_background = api_helper.expand_background(bg, keywords_str)
375
+
376
+ ## brainstorm according to the background
377
  if brainstorm_mode == "mode_b" or brainstorm_mode == "mode_c":
378
+ brainstorm = api_helper.generate_brainstorm(expanded_background)
379
+ seperate_brainstorm = extract_ideas(brainstorm)
380
+ ## expand the brainstorms to a detailed version
381
+ expanded_brainstorms = []
382
+ if expand_intermediate:
383
+ for i, sb in enumerate(seperate_brainstorm):
384
+ expanded_brainstorms.append(api_helper.expand_idea(expanded_background, sb))
385
+ logger.info(f"Expand the {i}th brainstorm succeed")
386
  else:
387
  brainstorm = None
388
+
389
+ ## Extract entities from the brainstorm result
 
 
 
 
 
390
  logger.debug("Original entities from background: {}".format(entities))
391
  if brainstorm_mode == "mode_c":
392
  entities_bs = api_helper.generate_entity_list(brainstorm, 10)
 
395
  else:
396
  entities_bs = None
397
  entities_all = entities
398
+
399
+ ### 2. 检索相关论文
400
  rt = RetrieverFactory.get_retriever_factory().create_retriever(
401
  retriever_name, config
402
  )
403
  result = rt.retrieve(
404
+ expanded_background, entities_all, need_evaluate=False, target_paper_id_list=[]
405
  )
406
  related_paper = result["related_paper"]
407
  logger.info("Find {} related papers...".format(len(related_paper)))
408
  entities_rt = result["entities"]
409
+ for paper in related_paper:
410
+ if not ("detail_method" in paper):
411
+ paper["detail_method"] = api_helper.generate_concise_method(paper["methodology"])
412
+ if isinstance(paper["detail_method"], str):
413
+ paper_client.insert_new_field(paper["hash_id"], "detail_method", paper["detail_method"])
414
+ logger.info(f"Add new field detail method to paper: {paper['hash_id']} succeed")
415
+ logger.info("Generate detail methods for all related papers succeed")
416
+
417
+ ### 3. 生成IDEA
418
+ idea_generator = IdeaGenerator(config, related_paper, brainstorm)
419
+ _, _, inspirations, initial_ideas, idea_filtered, final_ideas = idea_generator.generate_ins_bs(expanded_background)
420
+ expanded_initial_ideas = []
421
+ if expand_intermediate:
422
+ for i, initial_idea in enumerate(initial_ideas):
423
+ expanded_initial_ideas.append(api_helper.expand_idea(expanded_background, initial_idea))
424
+ logger.info(f"Expand the {i}th initial idea succeed")
425
  eval_data.append(
426
  {
427
  "background": bg,
428
+ "expanded_background": expanded_background,
429
  "entities_bg": entities,
430
  "brainstorm": brainstorm,
431
+ "seperate_brainstorm": seperate_brainstorm,
432
  "entities_bs": entities_bs,
433
  "entities_rt": entities_rt,
434
+ "related_paper": [p["title"] for p in related_paper],
435
+ "inspirations": inspirations,
436
+ "initial_ideas": initial_ideas,
437
+ "filtered_ideas": idea_filtered,
438
+ "expanded_final_ideas": final_ideas,
439
+ "expanded_brainstorms": expanded_brainstorms,
440
+ "expanded_initial_ideas": expanded_initial_ideas,
441
  }
442
  )
443
  cur_num += 1
444
  if cur_num % batch_size == 0:
445
+ with open(output_file, "w", encoding="utf-8") as f:
 
 
 
 
446
  json.dump(eval_data, f, ensure_ascii=False, indent=4)
447
  if cur_num >= num:
448
  break
src/paper_manager.py CHANGED
@@ -22,6 +22,12 @@ unicode_pattern = r"\u00c0-\u00ff\u0100-\u017f\u0180-\u024f\u4e00-\u9fff\u3040-\
22
 
23
 
24
  def find_methodology(article_dict):
 
 
 
 
 
 
25
  def find_section_index(keywords):
26
  for i, section in enumerate(article_dict["sections"], 1):
27
  heading = section["heading"].lower()
@@ -70,10 +76,14 @@ def find_methodology(article_dict):
70
 
71
 
72
  def count_sb_pairs(text):
 
 
73
  return len(re.findall(r"\[.*?\]", text))
74
 
75
 
76
  def count_rb_pairs(text):
 
 
77
  return len(re.findall(r"\(.*?\)", text))
78
 
79
 
@@ -85,17 +95,18 @@ def find_cite_paper(introduction, methodology, references):
85
  text = introduction + methodology
86
  rb_count = count_rb_pairs(introduction)
87
  sb_count = count_sb_pairs(introduction)
88
- pattern = (
89
- r"\b[A-Z"
90
- + unicode_pattern
91
- + r"][a-zA-Z"
92
- + unicode_pattern
93
- + r"]+(?: and [A-Z"
94
- + unicode_pattern
95
- + r"][a-zA-Z"
96
- + unicode_pattern
97
- + r"]+)?(?: et al\.)?, \d{4}[a-z]?\b"
98
- )
 
99
  pattern = (
100
  r"\b[A-Z"
101
  + unicode_pattern
@@ -199,11 +210,22 @@ class PaperManager:
199
  self.paper_client.create_vector_index()
200
 
201
  def clean_entity(self, entity):
 
 
 
 
 
 
202
  if entity is None:
203
  return None
 
204
  cleaned_entity = re.sub(r"\([^)]*\)", "", entity)
 
205
  cleaned_entity = re.sub(r"[^\w\s]", "", cleaned_entity)
 
206
  cleaned_entity = re.sub(r"_", " ", cleaned_entity)
 
 
207
  cleaned_entity = re.sub(r"\s+", " ", cleaned_entity).strip()
208
  return cleaned_entity
209
 
@@ -347,6 +369,8 @@ class PaperManager:
347
  need_get_entities=False,
348
  need_ground_truth=False,
349
  ):
 
 
350
  if paper["pdf_url"] in self.ignore_paper_pdf_url:
351
  logger.warning(
352
  "hash_id: {}, pdf_url: {} ignore".format(
@@ -511,6 +535,8 @@ class PaperManager:
511
  need_get_entities=False,
512
  need_ground_truth=False,
513
  ):
 
 
514
  result = []
515
  if self.year != "all":
516
  logger.info(
@@ -632,6 +658,45 @@ class PaperManager:
632
  # )
633
  # self.paper_client.add_paper_summary_embedding(self.embedding_model, hash_id)
634
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
635
  def cosine_similarity_search(self, data_type, context, k=1):
636
  """
637
  return related paper: list
@@ -641,6 +706,12 @@ class PaperManager:
641
  return result
642
 
643
  def generate_paper_list(self):
 
 
 
 
 
 
644
  folder_path = f"./assets/paper/{self.venue_name}"
645
  if not os.path.exists(folder_path):
646
  os.makedirs(folder_path)
@@ -721,6 +792,14 @@ def main(ctx):
721
  help="The LLMS API aliases used for generation. When used, it will invalidate --llms-api",
722
  )
723
  def crawling(config_path, year, venue_name, **kwargs):
 
 
 
 
 
 
 
 
724
  # Configuration
725
  config = ConfigReader.load(config_path, **kwargs)
726
  pm = PaperManager(config, venue_name, year)
@@ -772,6 +851,9 @@ def crawling(config_path, year, venue_name, **kwargs):
772
  help="The LLMS API aliases used for generation. When used, it will invalidate --llms-api",
773
  )
774
  def update(config_path, year, venue_name, **kwargs):
 
 
 
775
  # Configuration
776
  config = ConfigReader.load(config_path, **kwargs)
777
  pm = PaperManager(config, venue_name, year)
@@ -831,6 +913,8 @@ def update(config_path, year, venue_name, **kwargs):
831
  help="Dataset configuration file in YAML",
832
  )
833
  def local(config_path, year, venue_name, output, **kwargs):
 
 
834
  # Configuration
835
  output_path = output.name
836
  if not os.path.exists(os.path.dirname(output_path)):
@@ -842,7 +926,6 @@ def local(config_path, year, venue_name, output, **kwargs):
842
  need_download=True, need_parse=True, need_summary=True
843
  )
844
 
845
-
846
  @main.command()
847
  @click.option(
848
  "-c",
@@ -853,10 +936,93 @@ def local(config_path, year, venue_name, output, **kwargs):
853
  help="Dataset configuration file in YAML",
854
  )
855
  def embedding(config_path):
 
 
856
  # Configuration
857
  config = ConfigReader.load(config_path)
858
  PaperManager(config).insert_embedding()
859
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
860
 
861
  if __name__ == "__main__":
862
  main()
 
22
 
23
 
24
  def find_methodology(article_dict):
25
+ """For an article dict (representing an article), return the methodology part
26
+ Args:
27
+ article_dict
28
+ Returns:
29
+ methodology: str
30
+ """
31
  def find_section_index(keywords):
32
  for i, section in enumerate(article_dict["sections"], 1):
33
  heading = section["heading"].lower()
 
76
 
77
 
78
  def count_sb_pairs(text):
79
+ """Find the number of square brackets (possible citations)
80
+ """
81
  return len(re.findall(r"\[.*?\]", text))
82
 
83
 
84
  def count_rb_pairs(text):
85
+ """Find the number of round brackets (possible citations)
86
+ """
87
  return len(re.findall(r"\(.*?\)", text))
88
 
89
 
 
95
  text = introduction + methodology
96
  rb_count = count_rb_pairs(introduction)
97
  sb_count = count_sb_pairs(introduction)
98
+ ## Seems redudant, remove repeated definition of pattern
99
+ # pattern = (
100
+ # r"\b[A-Z"
101
+ # + unicode_pattern
102
+ # + r"][a-zA-Z"
103
+ # + unicode_pattern
104
+ # + r"]+(?: and [A-Z"
105
+ # + unicode_pattern
106
+ # + r"][a-zA-Z"
107
+ # + unicode_pattern
108
+ # + r"]+)?(?: et al\.)?, \d{4}[a-z]?\b"
109
+ # )
110
  pattern = (
111
  r"\b[A-Z"
112
  + unicode_pattern
 
210
  self.paper_client.create_vector_index()
211
 
212
  def clean_entity(self, entity):
213
+ """The extracted entities may be noisy, remove all noisy characters
214
+ Args:
215
+ entity (str): an entity
216
+ Returns:
217
+ cleaned_entity (str): entity after cleaning
218
+ """
219
  if entity is None:
220
  return None
221
+ # remove all () and their contents
222
  cleaned_entity = re.sub(r"\([^)]*\)", "", entity)
223
+ # remove non-word characters (e.g., punctuations)
224
  cleaned_entity = re.sub(r"[^\w\s]", "", cleaned_entity)
225
+ # replace _ as a whitespace
226
  cleaned_entity = re.sub(r"_", " ", cleaned_entity)
227
+ # remove multiple continuous blanks, remove leading and trailing spaces
228
+ # (\s means blank characters)
229
  cleaned_entity = re.sub(r"\s+", " ", cleaned_entity).strip()
230
  return cleaned_entity
231
 
 
369
  need_get_entities=False,
370
  need_ground_truth=False,
371
  ):
372
+ """Parse a paper, dump the result into a json file
373
+ """
374
  if paper["pdf_url"] in self.ignore_paper_pdf_url:
375
  logger.warning(
376
  "hash_id: {}, pdf_url: {} ignore".format(
 
535
  need_get_entities=False,
536
  need_ground_truth=False,
537
  ):
538
+ """Parse a paper and dump into the json file
539
+ """
540
  result = []
541
  if self.year != "all":
542
  logger.info(
 
658
  # )
659
  # self.paper_client.add_paper_summary_embedding(self.embedding_model, hash_id)
660
 
661
+ def add_new_embedding(self, hash_id=None, to="all"):
662
+ """add new embeddings for abstract, background, contribution, and summary
663
+ """
664
+ postfix_set = {
665
+ "sentence-transformers/all-MiniLM-L6-v2": "",
666
+ "BAAI/llm-embedder": "_llm_embedder",
667
+ "jina-embeddings-v3": "_jina_v3"
668
+ }
669
+ postfix = postfix_set[self.config.DEFAULT.embedding]
670
+ if "jina" in postfix:
671
+ if self.config.DEFAULT.embedding_task == "text-matching":
672
+ postfix += "_text_matching"
673
+ elif self.config.DEFAULT.embedding_task == "retrieval.query":
674
+ postfix += "_query"
675
+ elif self.config.DEFAULT.embedding_task == "retrieval.passage":
676
+ postfix += "_passage"
677
+ else:
678
+ assert False
679
+ if to == "all" or to == "abstract":
680
+ self.paper_client.update_paper_embedding(
681
+ self.embedding_model, hash_id,
682
+ name="abstract", postfix=postfix
683
+ )
684
+ if to == "all" or to == "background":
685
+ self.paper_client.update_paper_embedding(
686
+ self.embedding_model, hash_id,
687
+ name="background", postfix=postfix
688
+ )
689
+ if to == "all" or to == "contribution":
690
+ self.paper_client.update_paper_embedding(
691
+ self.embedding_model, hash_id,
692
+ name="contribution", postfix=postfix
693
+ )
694
+ if to == "all" or to == "summary":
695
+ self.paper_client.update_paper_embedding(
696
+ self.embedding_model, hash_id,
697
+ name="summary", postfix=postfix
698
+ )
699
+
700
  def cosine_similarity_search(self, data_type, context, k=1):
701
  """
702
  return related paper: list
 
706
  return result
707
 
708
  def generate_paper_list(self):
709
+ """Dump paper list into a json file, the json is saved at "folder_path"
710
+ Args:
711
+ None
712
+ Return:
713
+ None
714
+ """
715
  folder_path = f"./assets/paper/{self.venue_name}"
716
  if not os.path.exists(folder_path):
717
  os.makedirs(folder_path)
 
792
  help="The LLMS API aliases used for generation. When used, it will invalidate --llms-api",
793
  )
794
  def crawling(config_path, year, venue_name, **kwargs):
795
+ """Download paper list in to a json file
796
+ Args:
797
+ config_path (str):
798
+ year (int): the paper's publication data
799
+ venue_name (str): CVPR, etc.
800
+ Resturns:
801
+ None
802
+ """
803
  # Configuration
804
  config = ConfigReader.load(config_path, **kwargs)
805
  pm = PaperManager(config, venue_name, year)
 
851
  help="The LLMS API aliases used for generation. When used, it will invalidate --llms-api",
852
  )
853
  def update(config_path, year, venue_name, **kwargs):
854
+ """Read paper lists from assets/paper directory, insert them into database,
855
+ including downloading, parsing, etc., but not embedding
856
+ """
857
  # Configuration
858
  config = ConfigReader.load(config_path, **kwargs)
859
  pm = PaperManager(config, venue_name, year)
 
913
  help="Dataset configuration file in YAML",
914
  )
915
  def local(config_path, year, venue_name, output, **kwargs):
916
+ """Parse papers and dump them into json files
917
+ """
918
  # Configuration
919
  output_path = output.name
920
  if not os.path.exists(os.path.dirname(output_path)):
 
926
  need_download=True, need_parse=True, need_summary=True
927
  )
928
 
 
929
  @main.command()
930
  @click.option(
931
  "-c",
 
936
  help="Dataset configuration file in YAML",
937
  )
938
  def embedding(config_path):
939
+ """Insert embedding for papers in the database
940
+ """
941
  # Configuration
942
  config = ConfigReader.load(config_path)
943
  PaperManager(config).insert_embedding()
944
 
945
+ @main.command()
946
+ @click.option(
947
+ "-c",
948
+ "--config-path",
949
+ default=get_dir("./configs/datasets.yaml"),
950
+ type=click.File(),
951
+ required=True,
952
+ help="Dataset configuration file in YAML",
953
+ )
954
+ def add_new_embedding(config_path):
955
+ """Insert another new embedding for papers in the database
956
+ """
957
+ # Configuration
958
+ config = ConfigReader.load(config_path)
959
+ PaperManager(config).add_new_embedding(to="all")
960
+
961
+ @main.command()
962
+ @click.option(
963
+ "-c",
964
+ "--config-path",
965
+ default=get_dir("./configs/datasets.yaml"),
966
+ type=click.File(),
967
+ required=True,
968
+ help="Dataset configuration file in YAML",
969
+ )
970
+ @click.option(
971
+ "--llms-api",
972
+ default=None,
973
+ type=str,
974
+ required=False,
975
+ help="The LLMS API alias used. If you do not have separate APIs for summarization and generation, you can use this unified setting. This option is ignored when setting the API to be used by summarization and generation separately",
976
+ )
977
+ @click.option(
978
+ "--sum-api",
979
+ default=None,
980
+ type=str,
981
+ required=False,
982
+ help="The LLMS API aliases used for summarization. When used, it will invalidate --llms-api",
983
+ )
984
+ @click.option(
985
+ "--gen-api",
986
+ default=None,
987
+ type=str,
988
+ required=False,
989
+ help="The LLMS API aliases used for generation. When used, it will invalidate --llms-api",
990
+ )
991
+ @click.option(
992
+ "--year",
993
+ default="2013",
994
+ type=str,
995
+ required=True,
996
+ help="Venue year",
997
+ )
998
+ @click.option(
999
+ "--venue-name",
1000
+ default="acl",
1001
+ type=str,
1002
+ required=True,
1003
+ help="Venue name",
1004
+ )
1005
+ @click.option(
1006
+ "-o",
1007
+ "--output",
1008
+ default=get_dir("./output/out.json"),
1009
+ type=click.File("wb"),
1010
+ required=True,
1011
+ help="Dataset configuration file in YAML",
1012
+ )
1013
+ def parse_papers_to_json(config_path, venue_name, year, output, **kwargs):
1014
+ """Read json files and download papers, then parse them and dump into jsons
1015
+ """
1016
+ # Configuration
1017
+ output_path = output.name
1018
+ if not os.path.exists(os.path.dirname(output_path)):
1019
+ os.makedirs(os.path.dirname(output_path))
1020
+ config = ConfigReader.load(config_path, output_path=output_path, **kwargs)
1021
+ pm = PaperManager(config, venue_name=venue_name, year=year)
1022
+ pm.update_paper_from_json_to_json(
1023
+ need_download=True, need_parse=True, need_summary=True
1024
+ )
1025
+
1026
 
1027
  if __name__ == "__main__":
1028
  main()
src/retriever.py CHANGED
@@ -22,7 +22,10 @@ def main(ctx):
22
  print("Mode:", ctx.invoked_subcommand)
23
 
24
 
25
- @main.command()
 
 
 
26
  @click.option(
27
  "-c",
28
  "--config-path",
@@ -38,9 +41,21 @@ def main(ctx):
38
  required=True,
39
  help="Dataset configuration file in YAML",
40
  )
41
- def retrieve(
42
- config_path, ids_path, **kwargs
43
- ):
 
 
 
 
 
 
 
 
 
 
 
 
44
  config = ConfigReader.load(config_path, **kwargs)
45
  check_embedding(config.DEFAULT.embedding)
46
  check_env()
@@ -82,8 +97,8 @@ def retrieve(
82
  logger.info("\nbegin generate paper hash id {}".format(paper["hash_id"]))
83
  # 1. Get Background
84
  paper = paper_client.get_paper_by_id(paper["hash_id"])
85
- if "motivation" in paper.keys():
86
- bg = paper["motivation"]
87
  else:
88
  logger.error(f"paper hash_id {paper['hash_id']} doesn't have background...")
89
  continue
 
22
  print("Mode:", ctx.invoked_subcommand)
23
 
24
 
25
+ @main.command(context_settings=dict(
26
+ ignore_unknown_options=True,
27
+ allow_extra_args=True,
28
+ ))
29
  @click.option(
30
  "-c",
31
  "--config-path",
 
41
  required=True,
42
  help="Dataset configuration file in YAML",
43
  )
44
+ @click.pass_context
45
+ def retrieve(ctx,
46
+ config_path, ids_path
47
+ ):
48
+ initial_kwargs={ctx.args[i][2:]: ctx.args[i+1] for i in range(0, len(ctx.args), 2)}
49
+ kwargs = {"RETRIEVE": {}, "DEFAULT": {}}
50
+ for k, v in initial_kwargs.items():
51
+ if "num" in k:
52
+ kwargs["RETRIEVE"][k] = int(v)
53
+ elif "s_" in k:
54
+ kwargs["RETRIEVE"][k] = float(v)
55
+ elif "use_cocite" in k:
56
+ kwargs["RETRIEVE"][k] = bool(int(v))
57
+ else:
58
+ kwargs["RETRIEVE"][k] = v
59
  config = ConfigReader.load(config_path, **kwargs)
60
  check_embedding(config.DEFAULT.embedding)
61
  check_env()
 
97
  logger.info("\nbegin generate paper hash id {}".format(paper["hash_id"]))
98
  # 1. Get Background
99
  paper = paper_client.get_paper_by_id(paper["hash_id"])
100
+ if "background" in paper.keys():
101
+ bg = paper["background"]
102
  else:
103
  logger.error(f"paper hash_id {paper['hash_id']} doesn't have background...")
104
  continue
src/utils/api/base_helper.py CHANGED
@@ -128,7 +128,7 @@ class BaseHelper:
128
  response = r.json()["data"]["output"]
129
  return response # 或者根据需要返回其他内容
130
  else:
131
- print("服务请求失败,响应状态码:", response.status_code)
132
  except RequestException as e:
133
  print("请求发生错误:", e)
134
 
 
128
  response = r.json()["data"]["output"]
129
  return response # 或者根据需要返回其他内容
130
  else:
131
+ print("服务请求失败,响应状态码:", r.status_code)
132
  except RequestException as e:
133
  print("请求发生错误:", e)
134
 
src/utils/hash.py CHANGED
@@ -21,7 +21,7 @@ def check_embedding(repo_id):
21
  if repo_id in [
22
  "sentence-transformers/all-MiniLM-L6-v2",
23
  "BAAI/bge-small-en-v1.5",
24
- "BAAAI/llm_embedder",
25
  ]:
26
  # repo_id = "sentence-transformers/all-MiniLM-L6-v2"
27
  # repo_id = "BAAI/bge-small-en-v1.5"
@@ -31,6 +31,18 @@ def check_embedding(repo_id):
31
  "tokenizer_config.json",
32
  "vocab.txt",
33
  ]
 
 
 
 
 
 
 
 
 
 
 
 
34
  elif repo_id in ["Alibaba-NLP/gte-base-en-v1.5"]:
35
  files_to_download = [
36
  "config.json",
@@ -89,6 +101,8 @@ class EmbeddingModel:
89
  device=device,
90
  trust_remote_code=True,
91
  )
 
 
92
  print(f"==== using device {device} ====")
93
  return cls._instance
94
 
 
21
  if repo_id in [
22
  "sentence-transformers/all-MiniLM-L6-v2",
23
  "BAAI/bge-small-en-v1.5",
24
+ "BAAI/llm-embedder",
25
  ]:
26
  # repo_id = "sentence-transformers/all-MiniLM-L6-v2"
27
  # repo_id = "BAAI/bge-small-en-v1.5"
 
31
  "tokenizer_config.json",
32
  "vocab.txt",
33
  ]
34
+ elif repo_id in [
35
+ "jina-embeddings-v3",
36
+ ]:
37
+ files_to_download = [
38
+ "model.safetensors",
39
+ "modules.json",
40
+ "tokenizer.json",
41
+ "config_sentence_transformers.json",
42
+ "tokenizer_config.json",
43
+ "1_Pooling/config.json",
44
+ "config.json",
45
+ ]
46
  elif repo_id in ["Alibaba-NLP/gte-base-en-v1.5"]:
47
  files_to_download = [
48
  "config.json",
 
101
  device=device,
102
  trust_remote_code=True,
103
  )
104
+ if "jina-embeddings-v3" in config.DEFAULT.embedding:
105
+ cls._instance.embedding_model[0].default_task = config.DEFAULT.embedding_task
106
  print(f"==== using device {device} ====")
107
  return cls._instance
108
 
src/utils/llms_api.py CHANGED
@@ -62,15 +62,15 @@ class APIHelper(object):
62
  pass
63
 
64
  def __call__(self, title: str, abstract: str, introduction: str) -> dict:
65
- if os.environ["MODEL_NAME"] not in [
66
- "glm4",
67
- "glm4-air",
68
- "qwen-max",
69
- "qwen-plus",
70
- "gpt-4o-mini",
71
- "local",
72
- ]:
73
- raise ValueError(f"Check model name...")
74
 
75
  if title is None or abstract is None or introduction is None:
76
  return None
@@ -102,6 +102,25 @@ class APIHelper(object):
102
  return None
103
  return result
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  def generate_entity_list(self, abstract: str, max_num: int = 5) -> list:
106
  prompt = get_prompt()
107
 
@@ -163,6 +182,46 @@ class APIHelper(object):
163
  return None
164
 
165
  return brainstorming_ideas
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
  def generate_problem(self, background: str, related_papers: list[dict]):
168
  prompt = get_prompt()
@@ -238,6 +297,27 @@ class APIHelper(object):
238
  traceback.print_exc()
239
  return None
240
  return inspiration
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
 
242
  def generate_inspiration_with_cue_words(
243
  self, problem: str, related_paper: dict, cue_words: list
@@ -325,10 +405,10 @@ class APIHelper(object):
325
  return None
326
  return idea
327
 
328
- def generate_idea_by_inspiration(self, problem: str, inspirations: list[str]):
329
  prompt = get_prompt()
330
 
331
- if problem is None or inspirations is None:
332
  return None
333
  try:
334
  inspirations_text = "".join(
@@ -340,7 +420,7 @@ class APIHelper(object):
340
 
341
  message = [
342
  prompt[0][0](),
343
- prompt[1][0](problem=problem, inspirations_text=inspirations_text),
344
  ]
345
  response = self.generator.create(
346
  messages=message,
 
62
  pass
63
 
64
  def __call__(self, title: str, abstract: str, introduction: str) -> dict:
65
+ # if os.environ["MODEL_NAME"] not in [
66
+ # "glm4",
67
+ # "glm4-air",
68
+ # "qwen-max",
69
+ # "qwen-plus",
70
+ # "gpt-4o-mini",
71
+ # "local",
72
+ # ]:
73
+ # raise ValueError(f"Check model name...")
74
 
75
  if title is None or abstract is None or introduction is None:
76
  return None
 
102
  return None
103
  return result
104
 
105
+ def generate_concise_method(self, methodology: str):
106
+ prompt = get_prompt()
107
+ if methodology is None:
108
+ return None
109
+ try:
110
+ message = [
111
+ prompt[0][0](),
112
+ prompt[1][0](
113
+ methodology=methodology
114
+ ),
115
+ ]
116
+ detail_method = self.generator.create(
117
+ messages=message,
118
+ )
119
+ except Exception:
120
+ traceback.print_exc()
121
+ return None
122
+ return detail_method
123
+
124
  def generate_entity_list(self, abstract: str, max_num: int = 5) -> list:
125
  prompt = get_prompt()
126
 
 
182
  return None
183
 
184
  return brainstorming_ideas
185
+
186
+ def expand_idea(self, background: str, idea: str) -> str:
187
+ prompt = get_prompt()
188
+
189
+ if background is None:
190
+ print("Input background is empty ...")
191
+ return None
192
+ try:
193
+ # Initial brainstorming to generate raw ideas
194
+ message = [prompt[0][0](), prompt[1][0](background=background, brief_idea=idea)]
195
+ # Call the API to generate brainstorming ideas
196
+ detail_ideas = self.generator.create(
197
+ messages=message,
198
+ )
199
+
200
+ except Exception:
201
+ traceback.print_exc()
202
+ return None
203
+
204
+ return detail_ideas
205
+
206
+ def expand_background(self, brief_background: str, keywords: str) -> str:
207
+ prompt = get_prompt()
208
+
209
+ if brief_background is None:
210
+ print("Input brief background is empty ...")
211
+ return None
212
+ try:
213
+ # Initial brainstorming to generate raw ideas
214
+ message = [prompt[0][0](), prompt[1][0](brief_background=brief_background, keywords=keywords)]
215
+ # Call the API to generate brainstorming ideas
216
+ detail_background= self.generator.create(
217
+ messages=message,
218
+ )
219
+
220
+ except Exception:
221
+ traceback.print_exc()
222
+ return None
223
+
224
+ return detail_background
225
 
226
  def generate_problem(self, background: str, related_papers: list[dict]):
227
  prompt = get_prompt()
 
297
  traceback.print_exc()
298
  return None
299
  return inspiration
300
+
301
+
302
+ def generate_inspiration_with_detail_method(self, background: str, detail_method: str):
303
+ prompt = get_prompt()
304
+ if background is None or detail_method is None:
305
+ return None
306
+ try:
307
+ message = [
308
+ prompt[0][0](),
309
+ prompt[1][0](
310
+ background=background, detail_method=detail_method
311
+ ),
312
+ ]
313
+ response = self.generator.create(
314
+ messages=message,
315
+ )
316
+ inspiration = response
317
+ except Exception:
318
+ traceback.print_exc()
319
+ return None
320
+ return inspiration
321
 
322
  def generate_inspiration_with_cue_words(
323
  self, problem: str, related_paper: dict, cue_words: list
 
405
  return None
406
  return idea
407
 
408
+ def generate_idea_by_inspiration(self, background: str, inspirations: list[str]):
409
  prompt = get_prompt()
410
 
411
+ if background is None or inspirations is None:
412
  return None
413
  try:
414
  inspirations_text = "".join(
 
420
 
421
  message = [
422
  prompt[0][0](),
423
+ prompt[1][0](background=background, inspirations=inspirations_text),
424
  ]
425
  response = self.generator.create(
426
  messages=message,
src/utils/paper_client.py CHANGED
@@ -1,3 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import re
3
  import json
@@ -34,7 +49,13 @@ class PaperClient:
34
  return driver
35
 
36
  def update_paper_from_client(self, paper):
37
- paper_id = paper["hash_id"]
 
 
 
 
 
 
38
  if paper_id is None:
39
  return None
40
  query = f"""
@@ -49,6 +70,12 @@ class PaperClient:
49
  paper.update(paper_from_client)
50
 
51
  def update_papers_from_client(self, paper_id_list):
 
 
 
 
 
 
52
  query = """
53
  UNWIND $papers AS paper
54
  MATCH (p:Paper {hash_id: paper.hash_id})
@@ -67,6 +94,13 @@ class PaperClient:
67
  return [r["result"] for r in result]
68
 
69
  def get_paper_attribute(self, paper_id, attribute_name):
 
 
 
 
 
 
 
70
  query = f"""
71
  MATCH (p:Paper {{hash_id: {paper_id}}})
72
  RETURN p.{attribute_name} AS attributeValue
@@ -80,6 +114,13 @@ class PaperClient:
80
  return None
81
 
82
  def get_papers_attribute(self, paper_id_list, attribute_name):
 
 
 
 
 
 
 
83
  query = """
84
  UNWIND $paper_ids AS paper_id
85
  MATCH (p:Paper {hash_id: paper_id})
@@ -95,6 +136,13 @@ class PaperClient:
95
  return paper_attributes
96
 
97
  def get_paper_by_attribute(self, attribute_name, anttribute_value):
 
 
 
 
 
 
 
98
  query = f"""
99
  MATCH (p:Paper {{{attribute_name}: '{anttribute_value}'}})
100
  RETURN p
@@ -107,6 +155,13 @@ class PaperClient:
107
  return None
108
 
109
  def get_paper_from_term(self, entity):
 
 
 
 
 
 
 
110
  if entity is None:
111
  return None
112
  query = """
@@ -126,6 +181,14 @@ class PaperClient:
126
  def find_related_entities_by_entity_list(
127
  self, entity_names, n=1, k=3, relation_name="related"
128
  ):
 
 
 
 
 
 
 
 
129
  related_entities = set()
130
  query = """
131
  UNWIND $batch_entities AS entity_name
@@ -145,6 +208,12 @@ class PaperClient:
145
  return list(related_entities)
146
 
147
  def find_entities_by_paper_list(self, hash_ids: list):
 
 
 
 
 
 
148
  query = """
149
  UNWIND $hash_ids AS hash_id
150
  MATCH (e:Entity)-[:RELATED_TO]->(p:Paper {hash_id: hash_id})
@@ -161,6 +230,12 @@ class PaperClient:
161
  return entity_list
162
 
163
  def find_paper_by_entity(self, entity_name):
 
 
 
 
 
 
164
  query = """
165
  MATCH (e1:Entity {name: $entity_name})-[:RELATED_TO]->(p:Paper)
166
  RETURN p.hash_id AS hash_id
@@ -181,6 +256,19 @@ class PaperClient:
181
  return []
182
 
183
  def find_sentences_by_entity(self, entity_name):
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  query = """
185
  MATCH (e:Entity {name: $entity_name})-[:RELATED_TO]->(p:Paper)
186
  WHERE p.abstract CONTAINS $entity_name OR
@@ -215,6 +303,13 @@ class PaperClient:
215
  return sentences
216
 
217
  def select_paper(self, venue_name, year):
 
 
 
 
 
 
 
218
  query = """
219
  MATCH (n:Paper) where n.year=$year and n.venue_name=$venue_name return n
220
  """
@@ -228,6 +323,12 @@ class PaperClient:
228
  return []
229
 
230
  def add_paper_node(self, paper: dict):
 
 
 
 
 
 
231
  if "summary" not in paper.keys():
232
  paper["summary"] = None
233
  if "abstract" not in paper.keys():
@@ -279,6 +380,12 @@ class PaperClient:
279
  )
280
 
281
  def check_entity_node_count(self, hash_id: int):
 
 
 
 
 
 
282
  query_check_count = """
283
  MATCH (e:Entity)-[:RELATED_TO]->(p:Paper {hash_id: $hash_id})
284
  RETURN count(e) AS entity_count
@@ -293,6 +400,13 @@ class PaperClient:
293
  return True
294
 
295
  def add_entity_node(self, hash_id: int, entities: list):
 
 
 
 
 
 
 
296
  query = """
297
  MERGE (e:Entity {name: $entity_name})
298
  WITH e
@@ -309,6 +423,14 @@ class PaperClient:
309
  )
310
 
311
  def add_paper_citation(self, paper: dict):
 
 
 
 
 
 
 
 
312
  query = """
313
  MERGE (p:Paper {hash_id: $hash_id}) ON MATCH SET p.cite_id_list = $cite_id_list, p.entities = $entities, p.all_cite_id_list = $all_cite_id_list
314
  """
@@ -323,9 +445,122 @@ class PaperClient:
323
  ).data()
324
  )
325
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
  def add_paper_abstract_embedding(
327
  self, embedding_model, hash_id=None, batch_size=512
328
  ):
 
 
 
 
 
 
 
329
  if hash_id is not None:
330
  query = """
331
  MATCH (p:Paper {hash_id: $hash_id})
@@ -401,6 +636,13 @@ class PaperClient:
401
  logger.info(f"== Processed batch starting at offset {offset} ==")
402
 
403
  def add_paper_bg_embedding(self, embedding_model, hash_id=None, batch_size=512):
 
 
 
 
 
 
 
404
  if hash_id is not None:
405
  query = """
406
  MATCH (p:Paper {hash_id: $hash_id})
@@ -478,6 +720,13 @@ class PaperClient:
478
  def add_paper_contribution_embedding(
479
  self, embedding_model, hash_id=None, batch_size=512
480
  ):
 
 
 
 
 
 
 
481
  if hash_id is not None:
482
  query = """
483
  MATCH (p:Paper {hash_id: $hash_id})
@@ -555,6 +804,13 @@ class PaperClient:
555
  def add_paper_summary_embedding(
556
  self, embedding_model, hash_id=None, batch_size=512
557
  ):
 
 
 
 
 
 
 
558
  if hash_id is not None:
559
  query = """
560
  MATCH (p:Paper {hash_id: $hash_id})
@@ -567,6 +823,7 @@ class PaperClient:
567
  )
568
  contexts = [result["context"] for result in results]
569
  paper_ids = [result["hash_id"] for result in results]
 
570
  context_embeddings = embedding_model.encode(
571
  contexts, convert_to_tensor=True, device=self.device
572
  )
@@ -630,6 +887,15 @@ class PaperClient:
630
  logger.info(f"== Processed batch starting at offset {offset} ==")
631
 
632
  def cosine_similarity_search(self, embedding, k=1, type_name="embedding"):
 
 
 
 
 
 
 
 
 
633
  query = f"""
634
  MATCH (paper:Paper)
635
  WITH paper,
@@ -649,7 +915,7 @@ class PaperClient:
649
 
650
  def create_vector_index(self):
651
  """
652
- 适用于Paper节点
653
  针对Paper节点上的是属性 embedding 进行索引
654
  索引向量的维度为384
655
  适用余弦相似度作为计算相似度的方法
@@ -666,6 +932,13 @@ class PaperClient:
666
  session.execute_write(lambda tx: tx.run(query).data())
667
 
668
  def filter_paper_id_list(self, paper_id_list, year="2024"):
 
 
 
 
 
 
 
669
  if not paper_id_list:
670
  return []
671
  # WHERE p.year < "2024" AND p.venue_name <> "acl"
 
1
+ r"""_summary_
2
+ -*- coding: utf-8 -*-
3
+
4
+ Module : prompt.utils
5
+
6
+ File Name : utils.paper_client
7
+
8
+ Description : paper client, all operations about neo4j database are in PaperClient
9
+
10
+ Creation Date : 2024-11-09
11
+
12
+ Modification Date : 2024-12-17
13
+
14
+ Author : Lihui Gu (code), Wenxiao Wang (comments)
15
+ """
16
  import os
17
  import re
18
  import json
 
49
  return driver
50
 
51
  def update_paper_from_client(self, paper):
52
+ """Read paper from the database (client), update it info into `paper`
53
+ Args:
54
+ paper (str): a paper's hash_id
55
+ Returns:
56
+ None
57
+ """
58
+ paper_id = paper.get("hash_id", None)
59
  if paper_id is None:
60
  return None
61
  query = f"""
 
70
  paper.update(paper_from_client)
71
 
72
  def update_papers_from_client(self, paper_id_list):
73
+ """Read paper from the database (client)
74
+ Args:
75
+ paper_id_list (List of str)
76
+ Returns:
77
+ List of papers read from the database
78
+ """
79
  query = """
80
  UNWIND $papers AS paper
81
  MATCH (p:Paper {hash_id: paper.hash_id})
 
94
  return [r["result"] for r in result]
95
 
96
  def get_paper_attribute(self, paper_id, attribute_name):
97
+ """Get some attribute of a certain paper
98
+ Args:
99
+ paper_id (str):
100
+ attribute_name (str):
101
+ Returns:
102
+ The certain attribute
103
+ """
104
  query = f"""
105
  MATCH (p:Paper {{hash_id: {paper_id}}})
106
  RETURN p.{attribute_name} AS attributeValue
 
114
  return None
115
 
116
  def get_papers_attribute(self, paper_id_list, attribute_name):
117
+ """Get some attribute of a list of papers
118
+ Args:
119
+ paper_id (List of str):
120
+ attribute_name (str):
121
+ Returns:
122
+ List of certain attribute
123
+ """
124
  query = """
125
  UNWIND $paper_ids AS paper_id
126
  MATCH (p:Paper {hash_id: paper_id})
 
136
  return paper_attributes
137
 
138
  def get_paper_by_attribute(self, attribute_name, anttribute_value):
139
+ """Get some paper whose `attribute_name` is exactly equal to `anttribute_value`
140
+ Args:
141
+ anttribute_name
142
+ anttribute_value
143
+ Returns:
144
+ The first exact match paper object or None
145
+ """
146
  query = f"""
147
  MATCH (p:Paper {{{attribute_name}: '{anttribute_value}'}})
148
  RETURN p
 
155
  return None
156
 
157
  def get_paper_from_term(self, entity):
158
+ """Get paper from entity. The method is so strict that paper.entities must be
159
+ exactly equal to entity. The method is not used now.
160
+ Args:
161
+ entity:
162
+ Returns:
163
+
164
+ """
165
  if entity is None:
166
  return None
167
  query = """
 
181
  def find_related_entities_by_entity_list(
182
  self, entity_names, n=1, k=3, relation_name="related"
183
  ):
184
+ """Find all entities related to an entity name
185
+ Args:
186
+ entity_names (List): list of entities
187
+ n: not used
188
+ k: entity a and b are related if they co-occure in at least `k` papers
189
+ Returns:
190
+ related_entities (List): list of entities who are related with any entity in `entity_names`
191
+ """
192
  related_entities = set()
193
  query = """
194
  UNWIND $batch_entities AS entity_name
 
208
  return list(related_entities)
209
 
210
  def find_entities_by_paper_list(self, hash_ids: list):
211
+ """Retrieve entities for a list of papers:
212
+ Args:
213
+ hash_ids (List of papers):
214
+ Returns:
215
+ entity_list (List of List of entities): each item is also a list, meaning all entities from a paper
216
+ """
217
  query = """
218
  UNWIND $hash_ids AS hash_id
219
  MATCH (e:Entity)-[:RELATED_TO]->(p:Paper {hash_id: hash_id})
 
230
  return entity_list
231
 
232
  def find_paper_by_entity(self, entity_name):
233
+ """Find all papers with `entity_name`
234
+ Args:
235
+ entity_name (str)
236
+ Returns:
237
+ res (List of hash_ids): papers with `entity_name`
238
+ """
239
  query = """
240
  MATCH (e1:Entity {name: $entity_name})-[:RELATED_TO]->(p:Paper)
241
  RETURN p.hash_id AS hash_id
 
256
  return []
257
 
258
  def find_sentences_by_entity(self, entity_name):
259
+ """Find all sentences with a certain `entity_name`
260
+ Args:
261
+ entity_name (str)
262
+ Return:
263
+ sentences (List of strs): One str concatenates all sentences with `entity_name` in a section
264
+ E.g. [
265
+ "abstract sentence 1 from paper 1.abstract sentence 2 from paper 1",
266
+ "introduction sentence 1 from paper 1.introduction sentence 2 from paper 1",
267
+ "methodology sentence 1 from paper 1.",
268
+ "abstract sentence 1 from paper 2.abstract sentence 2 from paper 2",
269
+ ...
270
+ ]
271
+ """
272
  query = """
273
  MATCH (e:Entity {name: $entity_name})-[:RELATED_TO]->(p:Paper)
274
  WHERE p.abstract CONTAINS $entity_name OR
 
303
  return sentences
304
 
305
  def select_paper(self, venue_name, year):
306
+ """Retrieve a list of papers which published at the `venue_name` in `year`
307
+ Args:
308
+ venue_name (str)
309
+ year (int?)
310
+ Returns:
311
+ result (List of paper node)
312
+ """
313
  query = """
314
  MATCH (n:Paper) where n.year=$year and n.venue_name=$venue_name return n
315
  """
 
323
  return []
324
 
325
  def add_paper_node(self, paper: dict):
326
+ """Add a paper node
327
+ Args:
328
+ paper (Dict)
329
+ Returns:
330
+ None
331
+ """
332
  if "summary" not in paper.keys():
333
  paper["summary"] = None
334
  if "abstract" not in paper.keys():
 
380
  )
381
 
382
  def check_entity_node_count(self, hash_id: int):
383
+ """Whether a paper has more than `3` entities
384
+ Args:
385
+ hash_id: a paper's hash_id
386
+ Returns:
387
+ True if has <= 2 entitis, False otherwise
388
+ """
389
  query_check_count = """
390
  MATCH (e:Entity)-[:RELATED_TO]->(p:Paper {hash_id: $hash_id})
391
  RETURN count(e) AS entity_count
 
400
  return True
401
 
402
  def add_entity_node(self, hash_id: int, entities: list):
403
+ """Add a entity node, and link it to its paper
404
+ Args:
405
+ hash_id: a paper's id
406
+ entities: a paper's all entities
407
+ Returns:
408
+ None
409
+ """
410
  query = """
411
  MERGE (e:Entity {name: $entity_name})
412
  WITH e
 
423
  )
424
 
425
  def add_paper_citation(self, paper: dict):
426
+ """Add citations for the paper node, set its cite_id_list, entities, and all_cite_id_list
427
+ `cite_id_list` means citations in the Introduction section
428
+ `all_cite_id_list` means all citations
429
+ Args:
430
+ paper (Dict of a paper)
431
+ Returns:
432
+ None
433
+ """
434
  query = """
435
  MERGE (p:Paper {hash_id: $hash_id}) ON MATCH SET p.cite_id_list = $cite_id_list, p.entities = $entities, p.all_cite_id_list = $all_cite_id_list
436
  """
 
445
  ).data()
446
  )
447
 
448
+ def insert_new_field(self, hash_id: str, field_name: str, content):
449
+ if hash_id is not None:
450
+ query = f"""
451
+ MATCH (n {{hash_id: $hash_id}})
452
+ SET n.{field_name} = $content
453
+ RETURN n
454
+ """
455
+ with self.driver.session() as session:
456
+ result = session.execute_write(
457
+ lambda tx: tx.run(
458
+ query, hash_id=hash_id, content=content
459
+ ).data()
460
+ )
461
+ return result
462
+ else:
463
+ return None
464
+
465
+ def update_paper_embedding(
466
+ self, embedding_model, hash_id=None, batch_size=512, name="abstract", postfix=""
467
+ ):
468
+ """Extract paper embedding and store in the database
469
+ Args:
470
+ embedding_model (TODO: what model?): an pytorch embedding model
471
+ hash_id (str): add embedding for a paper if hash_id is not None.
472
+ Otherwise, all papers will be handled with a batch size of 512
473
+ batch_size: if hash_id is None, all papers will be processed with `batch_size`
474
+ """
475
+ if hash_id is not None:
476
+ query = f"""
477
+ MATCH (p:Paper {{hash_id: $hash_id}})
478
+ WHERE p.{name} IS NOT NULL
479
+ RETURN p.{name} AS context, p.hash_id AS hash_id, p.title AS title
480
+ """
481
+ with self.driver.session() as session:
482
+ results = session.execute_write(
483
+ lambda tx: tx.run(query, hash_id=hash_id).data()
484
+ )
485
+ # contexts = [result["title"] + result["context"] for result in results]
486
+ if name == "abstract":
487
+ contexts = [result["title"] + result["context"] for result in results]
488
+ else:
489
+ contexts = [result["context"] for result in results]
490
+ paper_ids = [result["hash_id"] for result in results]
491
+ context_embeddings = embedding_model.encode(
492
+ contexts, convert_to_tensor=True, device=self.device
493
+ )
494
+ query = f"""
495
+ MERGE (p:Paper {{hash_id: $hash_id}})
496
+ ON CREATE SET p.{name}_embedding{postfix} = $embedding
497
+ ON MATCH SET p.{name}_embedding{postfix} = $embedding
498
+ """
499
+ for idx, hash_id in tqdm(enumerate(paper_ids)):
500
+ embedding = (
501
+ context_embeddings[idx].detach().cpu().numpy().flatten().tolist()
502
+ )
503
+ with self.driver.session() as session:
504
+ results = session.execute_write(
505
+ lambda tx: tx.run(
506
+ query, hash_id=hash_id, embedding=embedding
507
+ ).data()
508
+ )
509
+ return
510
+ offset = 0
511
+ while True:
512
+ query = f"""
513
+ MATCH (p:Paper)
514
+ WHERE p.{name} IS NOT NULL
515
+ RETURN p.{name} AS context, p.hash_id AS hash_id, p.title AS title
516
+ SKIP $offset LIMIT $batch_size
517
+ """
518
+ with self.driver.session() as session:
519
+ results = session.execute_write(
520
+ lambda tx: tx.run(
521
+ query, offset=offset, batch_size=batch_size
522
+ ).data()
523
+ )
524
+ if not results:
525
+ break
526
+ if name == "abstract":
527
+ contexts = [result["title"] + result["context"] for result in results]
528
+ else:
529
+ contexts = [result["context"] for result in results]
530
+ paper_ids = [result["hash_id"] for result in results]
531
+ context_embeddings = embedding_model.encode(
532
+ contexts,
533
+ batch_size=batch_size,
534
+ convert_to_tensor=True,
535
+ device=self.device,
536
+ )
537
+ write_query = f"""
538
+ UNWIND $data AS row
539
+ MERGE (p:Paper {{hash_id: row.hash_id}})
540
+ ON CREATE SET p.{name}_embedding{postfix} = row.embedding
541
+ ON MATCH SET p.{name}_embedding{postfix} = row.embedding
542
+ """
543
+ data_to_write = []
544
+ context_embeddings = context_embeddings.detach().cpu().tolist()
545
+ for idx, hash_id in enumerate(paper_ids):
546
+ data_to_write.append({"hash_id": hash_id, "embedding": context_embeddings[idx]})
547
+ with self.driver.session() as session:
548
+ session.execute_write(
549
+ lambda tx: tx.run(write_query, data=data_to_write)
550
+ )
551
+ offset += batch_size
552
+ logger.info(f"== Processed batch starting at offset {offset} ==")
553
+
554
  def add_paper_abstract_embedding(
555
  self, embedding_model, hash_id=None, batch_size=512
556
  ):
557
+ """Extract paper abstract embedding and store in the database
558
+ Args:
559
+ embedding_model (TODO: what model?): an pytorch embedding model
560
+ hash_id (str): add abstract embedding for a paper if hash_id is not None.
561
+ Otherwise, all papers will be handled with a batch size of 512
562
+ batch_size: if hash_id is None, all papers will be processed with `batch_size`
563
+ """
564
  if hash_id is not None:
565
  query = """
566
  MATCH (p:Paper {hash_id: $hash_id})
 
636
  logger.info(f"== Processed batch starting at offset {offset} ==")
637
 
638
  def add_paper_bg_embedding(self, embedding_model, hash_id=None, batch_size=512):
639
+ """Extract paper background embedding and store in the database
640
+ Args:
641
+ embedding_model (TODO: what model?): an pytorch embedding model
642
+ hash_id (str): add background embedding for a paper if hash_id is not None.
643
+ Otherwise, all papers will be handled with a batch size of 512
644
+ batch_size: if hash_id is None, all papers will be processed with `batch_size`
645
+ """
646
  if hash_id is not None:
647
  query = """
648
  MATCH (p:Paper {hash_id: $hash_id})
 
720
  def add_paper_contribution_embedding(
721
  self, embedding_model, hash_id=None, batch_size=512
722
  ):
723
+ """Extract paper contribution embedding and store in the database
724
+ Args:
725
+ embedding_model (TODO: what model?): an pytorch embedding model
726
+ hash_id (str): add contribution embedding for a paper if hash_id is not None.
727
+ Otherwise, all papers will be handled with a batch size of 512
728
+ batch_size: if hash_id is None, all papers will be processed with `batch_size`
729
+ """
730
  if hash_id is not None:
731
  query = """
732
  MATCH (p:Paper {hash_id: $hash_id})
 
804
  def add_paper_summary_embedding(
805
  self, embedding_model, hash_id=None, batch_size=512
806
  ):
807
+ """Extract paper summary embedding and store in the database
808
+ Args:
809
+ embedding_model (TODO: what model?): an pytorch embedding model
810
+ hash_id (str): add summary embedding for a paper if hash_id is not None.
811
+ Otherwise, all papers will be handled with a batch size of 512
812
+ batch_size: if hash_id is None, all papers will be processed with `batch_size`
813
+ """
814
  if hash_id is not None:
815
  query = """
816
  MATCH (p:Paper {hash_id: $hash_id})
 
823
  )
824
  contexts = [result["context"] for result in results]
825
  paper_ids = [result["hash_id"] for result in results]
826
+ # context_embeddings are pytorch.Tensor
827
  context_embeddings = embedding_model.encode(
828
  contexts, convert_to_tensor=True, device=self.device
829
  )
 
887
  logger.info(f"== Processed batch starting at offset {offset} ==")
888
 
889
  def cosine_similarity_search(self, embedding, k=1, type_name="embedding"):
890
+ """Retrieve all papers whose `type_name` embedding is similar to `embedding`
891
+ (cosine_sim > 0 and return in a descending order)
892
+ Args:
893
+ embedding (TODO: type): the embedding to be checked
894
+ k: only return topk papers with highest similarities
895
+ type_name: "abstract_embedding", "summary_embedding", etc.
896
+ Returns:
897
+ related_paper (List of str): hash_id of retrieved papers
898
+ """
899
  query = f"""
900
  MATCH (paper:Paper)
901
  WITH paper,
 
915
 
916
  def create_vector_index(self):
917
  """
918
+ 适用于Paper节点,这里的语句应该是针对所有数据库里的paper都做索引
919
  针对Paper节点上的是属性 embedding 进行索引
920
  索引向量的维度为384
921
  适用余弦相似度作为计算相似度的方法
 
932
  session.execute_write(lambda tx: tx.run(query).data())
933
 
934
  def filter_paper_id_list(self, paper_id_list, year="2024"):
935
+ """Retrieve all papers' ids which released before "year" (not contained) and existed in the database
936
+ Args:
937
+ paper_id_list (List of str): a list of paper ids
938
+ year: the paper before
939
+ Returns:
940
+ existing_paper_ids (List of str): paper_ids that satisfy the conditions
941
+ """
942
  if not paper_id_list:
943
  return []
944
  # WHERE p.year < "2024" AND p.venue_name <> "acl"
src/utils/paper_crawling.py CHANGED
@@ -142,6 +142,19 @@ class PaperCrawling:
142
  print(e)
143
 
144
  def crawling(self, year, venue_name):
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  paper_list = []
146
  paper_html_list = []
147
 
 
142
  print(e)
143
 
144
  def crawling(self, year, venue_name):
145
+ """
146
+ Args:
147
+ Returns:
148
+ paper_list (List of Dict):[
149
+ {
150
+ "hash_id": hash_id, hash id of the paper
151
+ "year": year, published year
152
+ "venue_name": venue_name, venue name
153
+ "title": title, paper title
154
+ "pdf_url": pdf_url, paper url
155
+ }
156
+ ]
157
+ """
158
  paper_list = []
159
  paper_html_list = []
160
 
src/utils/paper_retriever.py CHANGED
@@ -36,6 +36,9 @@ class UnionFind:
36
 
37
 
38
  def can_merge(uf, similarity_matrix, i, j, threshold):
 
 
 
39
  root_i = uf.find(i)
40
  root_j = uf.find(j)
41
  for k in range(len(similarity_matrix)):
@@ -72,6 +75,8 @@ class CoCite:
72
  CoCite._initialized = True
73
 
74
  def get_cocite_ids(self, id_, k=1):
 
 
75
  sorted_items = sorted(self.comap[id_].items(), key=lambda x: x[1], reverse=True)
76
  top_k = sorted_items[:k]
77
  paper_ids = []
@@ -82,6 +87,12 @@ class CoCite:
82
 
83
 
84
  class Retriever(object):
 
 
 
 
 
 
85
  __metaclass__ = ABCMeta
86
  retriever_name = "BASE"
87
 
@@ -95,12 +106,38 @@ class Retriever(object):
95
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
96
  self.embedding_model = get_embedding_model(config)
97
  self.paper_crawling = PaperCrawling(config=config)
98
-
 
 
 
 
 
 
 
 
 
 
 
99
  @abstractmethod
100
  def retrieve(self, bg, entities, use_evaluate):
 
 
 
 
 
 
101
  pass
102
 
103
  def retrieve_entities_by_enties(self, entities):
 
 
 
 
 
 
 
 
 
104
  # TODO: KG
105
  expand_entities = self.paper_client.find_related_entities_by_entity_list(
106
  entities,
@@ -135,14 +172,16 @@ class Retriever(object):
135
  def update_related_paper(self, paper_id_list):
136
  """
137
  Args:
138
- paper_id_list: list
139
  Return:
140
- related_paper: list(dict)
141
  """
142
  related_paper = self.paper_client.update_papers_from_client(paper_id_list)
143
  return related_paper
144
 
145
  def calculate_similarity(self, entities, related_entities_list, use_weight=False):
 
 
146
  if use_weight:
147
  vec1 = self.vectorizer.transform([" ".join(entities)]).toarray()[0]
148
  weighted_vec1 = np.array(
@@ -181,8 +220,21 @@ class Retriever(object):
181
  return similarity
182
 
183
  def cal_related_score(
184
- self, embedding, related_paper_id_list, type_name="embedding"
185
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  score_1 = np.zeros((len(related_paper_id_list)))
187
  # score_2 = np.zeros((len(related_paper_id_list)))
188
  origin_vector = torch.tensor(embedding).to(self.device).unsqueeze(0)
@@ -211,6 +263,15 @@ class Retriever(object):
211
  return {}, {}, score_all_dict
212
 
213
  def filter_related_paper(self, score_dict, top_k):
 
 
 
 
 
 
 
 
 
214
  if len(score_dict) <= top_k:
215
  return list(score_dict.keys())
216
  if not self.use_cluster_to_filter:
@@ -221,33 +282,49 @@ class Retriever(object):
221
  )
222
  return paper_id_list
223
  else:
 
 
224
  # clustering filter, ensure that each category the highest score save first
 
225
  paper_id_list = list(score_dict.keys())
226
  paper_embedding_list = [
227
- self.paper_client.get_paper_attribute(paper_id, "embedding")
228
  for paper_id in paper_id_list
229
  ]
230
  paper_embedding = np.array(paper_embedding_list)
 
231
  paper_embedding_list = [
232
  self.paper_client.get_paper_attribute(
233
- paper_id, "contribution_embedding"
234
  )
235
  for paper_id in paper_id_list
236
  ]
237
  paper_contribution_embedding = np.array(paper_embedding_list)
 
238
  paper_embedding_list = [
239
- self.paper_client.get_paper_attribute(paper_id, "summary_embedding")
240
  for paper_id in paper_id_list
241
  ]
242
  paper_summary_embedding = np.array(paper_embedding_list)
243
- weight_embedding = self.config.RETRIEVE.s_bg
 
 
 
 
 
 
 
244
  weight_contribution = self.config.RETRIEVE.s_contribution
245
  weight_summary = self.config.RETRIEVE.s_summary
 
246
  paper_embedding = (
247
- weight_embedding * paper_embedding
248
  + weight_contribution * paper_contribution_embedding
249
  + weight_summary * paper_summary_embedding
 
250
  )
 
 
251
  similarity_matrix = np.dot(paper_embedding, paper_embedding.T)
252
  related_labels = self.cluster_algorithm(paper_id_list, similarity_matrix)
253
  related_paper_label_dict = dict(zip(paper_id_list, related_labels))
@@ -257,6 +334,7 @@ class Retriever(object):
257
  label_group[label] = []
258
  label_group[label].append(paper_id)
259
  paper_id_list = []
 
260
  while len(paper_id_list) < top_k:
261
  for label, papers in label_group.items():
262
  if papers:
@@ -265,9 +343,12 @@ class Retriever(object):
265
  break
266
  return paper_id_list
267
 
268
- def cosine_similarity_search(self, embedding, k=1, type_name="embedding"):
269
- """
270
- return related paper: list
 
 
 
271
  """
272
  result = self.paper_client.cosine_similarity_search(
273
  embedding, k, type_name=type_name
@@ -277,6 +358,8 @@ class Retriever(object):
277
  return result
278
 
279
  def cluster_algorithm(self, paper_id_list, similarity_matrix):
 
 
280
  threshold = self.config.RETRIEVE.similarity_threshold
281
  uf = UnionFind(len(paper_id_list))
282
  # merge
@@ -298,34 +381,51 @@ class Retriever(object):
298
  for k in self.config.RETRIEVE.top_k_list:
299
  result[k] = {"recall": 0, "precision": 0}
300
  return result, 0, 0, 0
 
 
 
301
  all_paper_id_set = set(related_paper_id_list)
302
  all_paper_id_set.update(target_paper_id_list)
303
  all_paper_id_list = list(all_paper_id_set)
 
304
  paper_embedding_list = [
305
- self.paper_client.get_paper_attribute(paper_id, "embedding")
306
  for paper_id in target_paper_id_list
307
  ]
308
  paper_embedding = np.array(paper_embedding_list)
 
309
  paper_embedding_list = [
310
- self.paper_client.get_paper_attribute(paper_id, "contribution_embedding")
311
  for paper_id in target_paper_id_list
312
  ]
313
  paper_contribution_embedding = np.array(paper_embedding_list)
 
 
 
 
 
 
314
  paper_embedding_list = [
315
- self.paper_client.get_paper_attribute(paper_id, "summary_embedding")
316
  for paper_id in target_paper_id_list
317
  ]
 
 
318
  paper_summary_embedding = np.array(paper_embedding_list)
319
- weight_embedding = self.config.RETRIEVE.s_bg
320
  weight_contribution = self.config.RETRIEVE.s_contribution
321
  weight_summary = self.config.RETRIEVE.s_summary
 
 
322
  target_paper_embedding = (
323
- weight_embedding * paper_embedding
324
  + weight_contribution * paper_contribution_embedding
325
  + weight_summary * paper_summary_embedding
 
326
  )
327
  similarity_threshold = self.config.RETRIEVE.similarity_threshold
328
  similarity_matrix = np.dot(target_paper_embedding, target_paper_embedding.T)
 
329
  target_labels = self.cluster_algorithm(target_paper_id_list, similarity_matrix)
330
  target_paper_label_dict = dict(zip(target_paper_id_list, target_labels))
331
  logger.debug("Target paper cluster result: {}".format(target_paper_label_dict))
@@ -335,28 +435,41 @@ class Retriever(object):
335
  for paper_id in target_paper_label_dict.keys()
336
  }
337
  )
338
-
 
339
  all_labels = []
340
  for paper_id in all_paper_id_list:
 
341
  paper_bg_embedding = [
342
- self.paper_client.get_paper_attribute(paper_id, "embedding")
343
  ]
344
  paper_bg_embedding = np.array(paper_bg_embedding)
 
345
  paper_contribution_embedding = [
346
  self.paper_client.get_paper_attribute(
347
- paper_id, "contribution_embedding"
348
  )
349
  ]
350
  paper_contribution_embedding = np.array(paper_contribution_embedding)
 
351
  paper_summary_embedding = [
352
- self.paper_client.get_paper_attribute(paper_id, "summary_embedding")
353
  ]
354
  paper_summary_embedding = np.array(paper_summary_embedding)
 
 
 
 
 
 
355
  paper_embedding = (
356
- weight_embedding * paper_bg_embedding
357
  + weight_contribution * paper_contribution_embedding
358
  + weight_summary * paper_summary_embedding
 
359
  )
 
 
360
  similarities = cosine_similarity(paper_embedding, target_paper_embedding)[0]
361
  if np.any(similarities >= similarity_threshold):
362
  all_labels.append(target_labels[np.argmax(similarities)])
@@ -364,14 +477,15 @@ class Retriever(object):
364
  all_labels.append(-1) # other class: -1
365
  all_paper_label_dict = dict(zip(all_paper_id_list, all_labels))
366
  all_label_counts = Counter(all_paper_label_dict.values())
367
- logger.debug(f"all label counts : {all_label_counts}")
368
  target_label_counts = Counter(target_paper_label_dict.values())
369
- logger.debug(f"target label counts : {target_label_counts}")
370
  target_label_list = list(target_label_counts.keys())
371
  max_k = max(self.config.RETRIEVE.top_k_list)
372
  logger.info("=== Begin filter related paper ===")
373
  max_k_paper_id_list = self.filter_related_paper(score_all_dict, top_k=max_k)
374
  logger.info("=== End filter related paper ===")
 
375
  for k in self.config.RETRIEVE.top_k_list:
376
  # 前top k 的文章
377
  top_k = min(k, len(max_k_paper_id_list))
@@ -380,7 +494,7 @@ class Retriever(object):
380
  for paper_id in top_k_paper_id_list:
381
  top_k_paper_label_dict[paper_id] = all_paper_label_dict[paper_id]
382
  logger.debug(
383
- "=== top k {} paper id list : {}".format(k, top_k_paper_label_dict)
384
  )
385
  logger.debug(
386
  {
@@ -389,7 +503,7 @@ class Retriever(object):
389
  }
390
  )
391
  top_k_label_counts = Counter(top_k_paper_label_dict.values())
392
- logger.debug(f"top K label counts : {top_k_label_counts}")
393
  top_k_label_list = list(top_k_label_counts.keys())
394
  match_label_list = list(set(target_label_list) & set(top_k_label_list))
395
  logger.debug(f"match label list : {match_label_list}")
@@ -403,6 +517,7 @@ class Retriever(object):
403
  precision /= len(top_k_paper_id_list)
404
  result[k] = {"recall": recall, "precision": precision}
405
 
 
406
  related_paper_id_list = list(score_all_dict.keys())
407
  related_paper_label_dict = {}
408
  for paper_id in related_paper_id_list:
@@ -419,11 +534,19 @@ class Retriever(object):
419
  precision += related_label_counts[label]
420
  recall /= len(target_paper_id_list)
421
  precision /= len(related_paper_id_list)
 
422
  logger.debug(result)
423
  return result, len(target_label_counts), recall, precision
424
 
425
 
426
  class RetrieverFactory(object):
 
 
 
 
 
 
 
427
  _instance = None
428
  _lock = threading.Lock()
429
 
@@ -441,11 +564,24 @@ class RetrieverFactory(object):
441
 
442
  @staticmethod
443
  def get_retriever_factory():
 
 
 
 
 
 
444
  if RetrieverFactory._instance is None:
445
  RetrieverFactory._instance = RetrieverFactory()
446
  return RetrieverFactory._instance
447
 
448
  def register_retriever(self, retriever_name, retriever_class) -> bool:
 
 
 
 
 
 
 
449
  if retriever_name not in self.retriever_classes:
450
  self.retriever_classes[retriever_name] = retriever_class
451
  return True
@@ -467,8 +603,14 @@ class RetrieverFactory(object):
467
  return len(self.retriever_classes)
468
 
469
  def create_retriever(self, retriever_name, *args, **kwargs) -> Retriever:
 
 
 
 
 
 
470
  if retriever_name not in self.retriever_classes:
471
- raise ValueError(f"Unknown retriever type: {retriever_name}")
472
  else:
473
  return self.retriever_classes[retriever_name](*args, **kwargs)
474
 
@@ -493,11 +635,24 @@ class SNRetriever(Retriever):
493
  super().__init__(config)
494
 
495
  def retrieve_paper(self, bg):
 
 
 
 
 
 
 
 
 
 
 
 
496
  entities = []
497
  embedding = self.embedding_model.encode(bg, device=self.device)
498
  sn_paper_id_list = self.cosine_similarity_search(
499
  embedding=embedding,
500
  k=self.config.RETRIEVE.sn_retrieve_paper_num,
 
501
  )
502
  related_paper = set()
503
  related_paper.update(sn_paper_id_list)
@@ -513,7 +668,7 @@ class SNRetriever(Retriever):
513
  related_paper = list(related_paper)
514
  logger.debug(f"paper num before filter: {len(related_paper)}")
515
  result = {
516
- "embedding": embedding,
517
  "paper": related_paper,
518
  "entities": entities,
519
  "cocite_paper": list(cocite_id_set),
@@ -523,9 +678,21 @@ class SNRetriever(Retriever):
523
  def retrieve(self, bg, entities, need_evaluate=True, target_paper_id_list=[]):
524
  """
525
  Args:
526
- context: string
527
  Return:
528
- list(dict)
 
 
 
 
 
 
 
 
 
 
 
 
529
  """
530
  if need_evaluate:
531
  if target_paper_id_list is None or len(target_paper_id_list) == 0:
@@ -537,23 +704,27 @@ class SNRetriever(Retriever):
537
  retrieve_result = self.retrieve_paper(bg)
538
  related_paper_id_list = retrieve_result["paper"]
539
  retrieve_paper_num = len(related_paper_id_list)
 
540
  _, _, score_all_dict = self.cal_related_score(
541
- retrieve_result["embedding"], related_paper_id_list=related_paper_id_list
 
542
  )
543
  top_k_matrix = {}
544
  recall = 0
545
  precision = 0
546
  filtered_recall = 0
 
547
  filtered_precision = 0
548
  if need_evaluate:
549
  top_k_matrix, label_num, recall, precision = self.eval_related_paper_in_all(
550
  score_all_dict, target_paper_id_list
551
  )
552
- logger.debug("Top P matrix:{}".format(top_k_matrix))
553
  logger.debug("before filter:")
554
  logger.debug(f"Recall: {recall:.3f}")
555
  logger.debug(f"Precision: {precision:.3f}")
556
- related_paper = self.filter_related_paper(score_all_dict, top_k=10)
 
557
  related_paper = self.update_related_paper(related_paper)
558
  result = {
559
  "recall": recall,
@@ -578,6 +749,8 @@ class KGRetriever(Retriever):
578
  super().__init__(config)
579
 
580
  def retrieve_paper(self, entities):
 
 
581
  new_entities = self.retrieve_entities_by_enties(entities)
582
  logger.debug("KG entities for retriever: {}".format(new_entities))
583
  related_paper = set()
@@ -618,12 +791,14 @@ class KGRetriever(Retriever):
618
  retrieve_paper_num = len(related_paper_id_list)
619
  embedding = self.embedding_model.encode(bg, device=self.device)
620
  _, _, score_all_dict = self.cal_related_score(
621
- embedding, related_paper_id_list=related_paper_id_list
 
622
  )
623
  top_k_matrix = {}
624
  recall = 0
625
  precision = 0
626
  filtered_recall = 0
 
627
  filtered_precision = 0
628
  if need_evaluate:
629
  top_k_matrix, label_num, recall, precision = self.eval_related_paper_in_all(
@@ -633,7 +808,7 @@ class KGRetriever(Retriever):
633
  logger.debug("before filter:")
634
  logger.debug(f"Recall: {recall:.3f}")
635
  logger.debug(f"Precision: {precision:.3f}")
636
- related_paper = self.filter_related_paper(score_all_dict, top_k=10)
637
  related_paper = self.update_related_paper(related_paper)
638
  result = {
639
  "recall": recall,
@@ -659,28 +834,44 @@ class SNKGRetriever(Retriever):
659
 
660
  def retrieve_paper(self, bg, entities):
661
  sn_entities = []
 
662
  embedding = self.embedding_model.encode(bg, device=self.device)
663
  sn_paper_id_list = self.cosine_similarity_search(
664
- embedding, k=self.config.RETRIEVE.sn_num_for_entity
 
665
  )
666
  related_paper = set()
667
  related_paper.update(sn_paper_id_list)
 
 
 
 
668
  sn_entities += self.paper_client.find_entities_by_paper_list(sn_paper_id_list)
669
  logger.debug("SN entities for retriever: {}".format(sn_entities))
670
  entities = list(set(entities + sn_entities))
 
671
  new_entities = self.retrieve_entities_by_enties(entities)
672
  logger.debug("SNKG entities for retriever: {}".format(new_entities))
 
673
  for entity in new_entities:
674
- paper_id_set = set(self.paper_client.find_paper_by_entity(entity))
675
- related_paper = related_paper.union(paper_id_set)
 
 
 
 
676
  cocite_id_set = set()
677
  if self.use_cocite:
678
  for paper_id in related_paper:
679
  cocite_id_set.update(self.cocite.get_cocite_ids(paper_id))
680
  related_paper = related_paper.union(cocite_id_set)
 
 
 
 
681
  related_paper = list(related_paper)
682
  result = {
683
- "embedding": embedding,
684
  "paper": related_paper,
685
  "entities": entities,
686
  "cocite_paper": list(cocite_id_set),
@@ -688,7 +879,7 @@ class SNKGRetriever(Retriever):
688
  return result
689
 
690
  def retrieve(
691
- self, bg, entities, need_evaluate=True, target_paper_id_list=[], top_k=10
692
  ):
693
  """
694
  Args:
@@ -709,7 +900,8 @@ class SNKGRetriever(Retriever):
709
  retrieve_paper_num = len(related_paper_id_list)
710
  logger.info("=== Begin cal related paper score ===")
711
  _, _, score_all_dict = self.cal_related_score(
712
- retrieve_result["embedding"], related_paper_id_list=related_paper_id_list
 
713
  )
714
  logger.info("=== End cal related paper score ===")
715
  top_k_matrix = {}
@@ -727,7 +919,7 @@ class SNKGRetriever(Retriever):
727
  logger.debug(f"Recall: {recall:.3f}")
728
  logger.debug(f"Precision: {precision:.3f}")
729
  logger.info("=== Begin filter related paper score ===")
730
- related_paper = self.filter_related_paper(score_all_dict, top_k)
731
  logger.info("=== End filter related paper score ===")
732
  related_paper = self.update_related_paper(related_paper)
733
  result = {
 
36
 
37
 
38
  def can_merge(uf, similarity_matrix, i, j, threshold):
39
+ """Condition of i and j can be merged: After merging, the similarity of any two nodes
40
+ from root_i and root_j are larger than threshold
41
+ """
42
  root_i = uf.find(i)
43
  root_j = uf.find(j)
44
  for k in range(len(similarity_matrix)):
 
75
  CoCite._initialized = True
76
 
77
  def get_cocite_ids(self, id_, k=1):
78
+ """
79
+ """
80
  sorted_items = sorted(self.comap[id_].items(), key=lambda x: x[1], reverse=True)
81
  top_k = sorted_items[:k]
82
  paper_ids = []
 
87
 
88
 
89
  class Retriever(object):
90
+ """The superclass of all retrievers
91
+ Args:
92
+ config:
93
+ Returns:
94
+ A Retriever instance
95
+ """
96
  __metaclass__ = ABCMeta
97
  retriever_name = "BASE"
98
 
 
106
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
107
  self.embedding_model = get_embedding_model(config)
108
  self.paper_crawling = PaperCrawling(config=config)
109
+ if self.config.DEFAULT.embedding == "sentence-transformers/all-MiniLM-L6-v2":
110
+ self.embedding_postfix = ""
111
+ elif self.config.DEFAULT.embedding == "BAAI/llm-embedder":
112
+ self.embedding_postfix = "_llm_embedder"
113
+ elif self.config.DEFAULT.embedding == "jina-embeddings-v3":
114
+ self.embedding_postfix = "_jina_v3"
115
+ if self.config.DEFAULT.embedding_database == "text-matching":
116
+ self.embedding_postfix += "_text_matching"
117
+ elif self.config.DEFAULT.embedding_database == "retrieval.query":
118
+ self.embedding_postfix += "_query"
119
+ elif self.config.DEFAULT.embedding_database == "retrieval.passage":
120
+ self.embedding_postfix += "_passage"
121
  @abstractmethod
122
  def retrieve(self, bg, entities, use_evaluate):
123
+ """Retrieve papers, should be implemented by the sub-class
124
+ Args:
125
+ None
126
+ Returns:
127
+ None
128
+ """
129
  pass
130
 
131
  def retrieve_entities_by_enties(self, entities):
132
+ """The method do three things:
133
+ 1. Expand entities according to entities co-occurence
134
+ 2. Count the number of papers related to each expanded entity. Sort entities in terms of their occurence times in ascending order
135
+ 3. Initial new entities. Retrieve entities one by one until the number of related papers reach a threshold
136
+ Args:
137
+ entities: A List of entities, e.g., [str, str, ...]
138
+ Returns:
139
+ new_entities: A List of entities after expansion, e.g., [str, str, ...]
140
+ """
141
  # TODO: KG
142
  expand_entities = self.paper_client.find_related_entities_by_entity_list(
143
  entities,
 
172
  def update_related_paper(self, paper_id_list):
173
  """
174
  Args:
175
+ paper_id_list (List of hash_id): e.g., [1231214, 46345]
176
  Return:
177
+ related_paper (List of dict):
178
  """
179
  related_paper = self.paper_client.update_papers_from_client(paper_id_list)
180
  return related_paper
181
 
182
  def calculate_similarity(self, entities, related_entities_list, use_weight=False):
183
+ """[Deprecated] Calculate the similarities between two lists of entities
184
+ """
185
  if use_weight:
186
  vec1 = self.vectorizer.transform([" ".join(entities)]).toarray()[0]
187
  weighted_vec1 = np.array(
 
220
  return similarity
221
 
222
  def cal_related_score(
223
+ self, embedding, related_paper_id_list, type_name="background_embedding"
224
  ):
225
+ """Calculate the cosine similarity between the input background's embedding and
226
+ given list of papers
227
+ Args:
228
+ embedding: the embedding of the input background
229
+ related_paper_id_list (List of int): the paper ids in the database
230
+ Returns:
231
+ Empty dict: {}
232
+ Empty dict: {}
233
+ score_all_dict:
234
+ paper_id1: score1,
235
+ paper_id2: score2,
236
+ ...
237
+ """
238
  score_1 = np.zeros((len(related_paper_id_list)))
239
  # score_2 = np.zeros((len(related_paper_id_list)))
240
  origin_vector = torch.tensor(embedding).to(self.device).unsqueeze(0)
 
263
  return {}, {}, score_all_dict
264
 
265
  def filter_related_paper(self, score_dict, top_k):
266
+ """Pick top_k papers from all retrieved papers in terms of score_dict. If clustering
267
+ is not used, top_k papers with highest scores will be picked. If clustering is used,
268
+ we will pick papers from each cluster in turn util top_k papers are chosen.
269
+ Args:
270
+ score_dict (dict): dict of (paper_id, similarity with user input background)
271
+ top_k (int): pick top_k papers
272
+ Returns:
273
+
274
+ """
275
  if len(score_dict) <= top_k:
276
  return list(score_dict.keys())
277
  if not self.use_cluster_to_filter:
 
282
  )
283
  return paper_id_list
284
  else:
285
+ ## Calculate the final embedding for each paper, which is the weighted average
286
+ ## background_embedding (embedding), contribution_embedding, and summary_embedding.
287
  # clustering filter, ensure that each category the highest score save first
288
+ # background embedding
289
  paper_id_list = list(score_dict.keys())
290
  paper_embedding_list = [
291
+ self.paper_client.get_paper_attribute(paper_id, f"background_embedding{self.embedding_postfix}")
292
  for paper_id in paper_id_list
293
  ]
294
  paper_embedding = np.array(paper_embedding_list)
295
+ # contribution embedding
296
  paper_embedding_list = [
297
  self.paper_client.get_paper_attribute(
298
+ paper_id, f"contribution_embedding{self.embedding_postfix}"
299
  )
300
  for paper_id in paper_id_list
301
  ]
302
  paper_contribution_embedding = np.array(paper_embedding_list)
303
+ # summary embedding
304
  paper_embedding_list = [
305
+ self.paper_client.get_paper_attribute(paper_id, f"summary_embedding{self.embedding_postfix}")
306
  for paper_id in paper_id_list
307
  ]
308
  paper_summary_embedding = np.array(paper_embedding_list)
309
+ # abstract embedding
310
+ paper_embedding_list = [
311
+ self.paper_client.get_paper_attribute(paper_id, f"abstract_embedding{self.embedding_postfix}")
312
+ for paper_id in paper_id_list
313
+ ]
314
+ paper_abstract_embedding = np.array(paper_embedding_list)
315
+
316
+ weight_background = self.config.RETRIEVE.s_bg
317
  weight_contribution = self.config.RETRIEVE.s_contribution
318
  weight_summary = self.config.RETRIEVE.s_summary
319
+ weight_abstract = self.config.RETRIEVE.s_abstract
320
  paper_embedding = (
321
+ weight_background * paper_embedding
322
  + weight_contribution * paper_contribution_embedding
323
  + weight_summary * paper_summary_embedding
324
+ + weight_abstract * paper_abstract_embedding
325
  )
326
+
327
+ ## similarity_matrix of all retrieved papers
328
  similarity_matrix = np.dot(paper_embedding, paper_embedding.T)
329
  related_labels = self.cluster_algorithm(paper_id_list, similarity_matrix)
330
  related_paper_label_dict = dict(zip(paper_id_list, related_labels))
 
334
  label_group[label] = []
335
  label_group[label].append(paper_id)
336
  paper_id_list = []
337
+ # randomly pick a paper from each cluster in turn until top_k papers are chosen
338
  while len(paper_id_list) < top_k:
339
  for label, papers in label_group.items():
340
  if papers:
 
343
  break
344
  return paper_id_list
345
 
346
+ def cosine_similarity_search(self, embedding, k=1, type_name="background_embedding"):
347
+ """Retrieve papers through embedding
348
+ Args:
349
+ embedding: the input embedding
350
+ Returns:
351
+ result (List of Papers): return related papers with the least embedding distance
352
  """
353
  result = self.paper_client.cosine_similarity_search(
354
  embedding, k, type_name=type_name
 
358
  return result
359
 
360
  def cluster_algorithm(self, paper_id_list, similarity_matrix):
361
+ """
362
+ """
363
  threshold = self.config.RETRIEVE.similarity_threshold
364
  uf = UnionFind(len(paper_id_list))
365
  # merge
 
381
  for k in self.config.RETRIEVE.top_k_list:
382
  result[k] = {"recall": 0, "precision": 0}
383
  return result, 0, 0, 0
384
+
385
+ ## merge retrieved papers and target papers and clustering
386
+ ## clustering according to the combination of background, contribution, and summary_embedding
387
  all_paper_id_set = set(related_paper_id_list)
388
  all_paper_id_set.update(target_paper_id_list)
389
  all_paper_id_list = list(all_paper_id_set)
390
+ # get all target papers' background_embedding
391
  paper_embedding_list = [
392
+ self.paper_client.get_paper_attribute(paper_id, f"background_embedding{self.embedding_postfix}")
393
  for paper_id in target_paper_id_list
394
  ]
395
  paper_embedding = np.array(paper_embedding_list)
396
+ # get all target papers' contribution_embedding
397
  paper_embedding_list = [
398
+ self.paper_client.get_paper_attribute(paper_id, f"contribution_embedding{self.embedding_postfix}")
399
  for paper_id in target_paper_id_list
400
  ]
401
  paper_contribution_embedding = np.array(paper_embedding_list)
402
+ # get all target papers' summary_embedding
403
+ paper_embedding_list = [
404
+ self.paper_client.get_paper_attribute(paper_id, f"summary_embedding{self.embedding_postfix}")
405
+ for paper_id in target_paper_id_list
406
+ ]
407
+ # abstract embedding
408
  paper_embedding_list = [
409
+ self.paper_client.get_paper_attribute(paper_id, f"abstract_embedding{self.embedding_postfix}")
410
  for paper_id in target_paper_id_list
411
  ]
412
+ paper_abstract_embedding = np.array(paper_embedding_list)
413
+
414
  paper_summary_embedding = np.array(paper_embedding_list)
415
+ weight_background = self.config.RETRIEVE.s_bg
416
  weight_contribution = self.config.RETRIEVE.s_contribution
417
  weight_summary = self.config.RETRIEVE.s_summary
418
+ weight_abstract = self.config.RETRIEVE.s_abstract
419
+ # 2D matrix of size [# of target papers, embedding dimension]
420
  target_paper_embedding = (
421
+ weight_background * paper_embedding
422
  + weight_contribution * paper_contribution_embedding
423
  + weight_summary * paper_summary_embedding
424
+ + weight_abstract * paper_abstract_embedding
425
  )
426
  similarity_threshold = self.config.RETRIEVE.similarity_threshold
427
  similarity_matrix = np.dot(target_paper_embedding, target_paper_embedding.T)
428
+ # return each target_paper's cluster label
429
  target_labels = self.cluster_algorithm(target_paper_id_list, similarity_matrix)
430
  target_paper_label_dict = dict(zip(target_paper_id_list, target_labels))
431
  logger.debug("Target paper cluster result: {}".format(target_paper_label_dict))
 
435
  for paper_id in target_paper_label_dict.keys()
436
  }
437
  )
438
+
439
+ ## calculate the similarity between each two papers
440
  all_labels = []
441
  for paper_id in all_paper_id_list:
442
+ # for each paper, get its background_embedding
443
  paper_bg_embedding = [
444
+ self.paper_client.get_paper_attribute(paper_id, f"background_embedding{self.embedding_postfix}")
445
  ]
446
  paper_bg_embedding = np.array(paper_bg_embedding)
447
+ # for each paper, get its contribution_embedding
448
  paper_contribution_embedding = [
449
  self.paper_client.get_paper_attribute(
450
+ paper_id, f"contribution_embedding{self.embedding_postfix}"
451
  )
452
  ]
453
  paper_contribution_embedding = np.array(paper_contribution_embedding)
454
+ # for each paper, get its summary_embedding
455
  paper_summary_embedding = [
456
+ self.paper_client.get_paper_attribute(paper_id, f"summary_embedding{self.embedding_postfix}")
457
  ]
458
  paper_summary_embedding = np.array(paper_summary_embedding)
459
+ # for each paper, get its abstract_embedding
460
+ paper_abstract_embedding = [
461
+ self.paper_client.get_paper_attribute(paper_id, f"abstract_embedding{self.embedding_postfix}")
462
+ ]
463
+ paper_abstract_embedding = np.array(paper_abstract_embedding)
464
+
465
  paper_embedding = (
466
+ weight_background * paper_bg_embedding
467
  + weight_contribution * paper_contribution_embedding
468
  + weight_summary * paper_summary_embedding
469
+ + weight_abstract * paper_abstract_embedding
470
  )
471
+
472
+ # vector of size embedding dimension
473
  similarities = cosine_similarity(paper_embedding, target_paper_embedding)[0]
474
  if np.any(similarities >= similarity_threshold):
475
  all_labels.append(target_labels[np.argmax(similarities)])
 
477
  all_labels.append(-1) # other class: -1
478
  all_paper_label_dict = dict(zip(all_paper_id_list, all_labels))
479
  all_label_counts = Counter(all_paper_label_dict.values())
480
+ logger.debug(f"All labels and the number of papers of each label: {all_label_counts}")
481
  target_label_counts = Counter(target_paper_label_dict.values())
482
+ logger.debug(f"All labels and the number of target papers of each label : {target_label_counts}")
483
  target_label_list = list(target_label_counts.keys())
484
  max_k = max(self.config.RETRIEVE.top_k_list)
485
  logger.info("=== Begin filter related paper ===")
486
  max_k_paper_id_list = self.filter_related_paper(score_all_dict, top_k=max_k)
487
  logger.info("=== End filter related paper ===")
488
+ ## calculate recall and precision of first {10, 20, 30, ...} papers
489
  for k in self.config.RETRIEVE.top_k_list:
490
  # 前top k 的文章
491
  top_k = min(k, len(max_k_paper_id_list))
 
494
  for paper_id in top_k_paper_id_list:
495
  top_k_paper_label_dict[paper_id] = all_paper_label_dict[paper_id]
496
  logger.debug(
497
+ "=== ideal top {}, real top {} paper id list : {}".format(k, top_k, top_k_paper_label_dict)
498
  )
499
  logger.debug(
500
  {
 
503
  }
504
  )
505
  top_k_label_counts = Counter(top_k_paper_label_dict.values())
506
+ logger.debug(f"Retrieved {top_k} papers have K different label: {top_k_label_counts}")
507
  top_k_label_list = list(top_k_label_counts.keys())
508
  match_label_list = list(set(target_label_list) & set(top_k_label_list))
509
  logger.debug(f"match label list : {match_label_list}")
 
517
  precision /= len(top_k_paper_id_list)
518
  result[k] = {"recall": recall, "precision": precision}
519
 
520
+ ## calculate recall and precision of all retrieved papers
521
  related_paper_id_list = list(score_all_dict.keys())
522
  related_paper_label_dict = {}
523
  for paper_id in related_paper_id_list:
 
534
  precision += related_label_counts[label]
535
  recall /= len(target_paper_id_list)
536
  precision /= len(related_paper_id_list)
537
+
538
  logger.debug(result)
539
  return result, len(target_label_counts), recall, precision
540
 
541
 
542
  class RetrieverFactory(object):
543
+ """RetrieverFactory is a singleton class, which will return cls._instance if it has been
544
+ created, it saves all Retriever instances.
545
+ Args:
546
+ None
547
+ Returns:
548
+ The singleton instance of the RetrieverFactory
549
+ """
550
  _instance = None
551
  _lock = threading.Lock()
552
 
 
564
 
565
  @staticmethod
566
  def get_retriever_factory():
567
+ """The method can also return the singleton instance of the RetrieverFactory
568
+ Args:
569
+ None
570
+ Returns:
571
+ The singleton instance of the RetrieverFactory
572
+ """
573
  if RetrieverFactory._instance is None:
574
  RetrieverFactory._instance = RetrieverFactory()
575
  return RetrieverFactory._instance
576
 
577
  def register_retriever(self, retriever_name, retriever_class) -> bool:
578
+ """Register a new retriever class (not instance) to the RetrieverFactory
579
+ Args:
580
+ retriever_name: str
581
+ retriever_class: a class object (not instance)
582
+ Returns:
583
+ True if add successfully, False otherwise
584
+ """
585
  if retriever_name not in self.retriever_classes:
586
  self.retriever_classes[retriever_name] = retriever_class
587
  return True
 
603
  return len(self.retriever_classes)
604
 
605
  def create_retriever(self, retriever_name, *args, **kwargs) -> Retriever:
606
+ """Return a retriever instance
607
+ Args:
608
+ retriever_name: str
609
+ Returns:
610
+ The retriever
611
+ """
612
  if retriever_name not in self.retriever_classes:
613
+ raise ValueError(f"Unknown retriever type: {retriever_name}. retriever_name should be one of {self.retriever_classes.keys()}")
614
  else:
615
  return self.retriever_classes[retriever_name](*args, **kwargs)
616
 
 
635
  super().__init__(config)
636
 
637
  def retrieve_paper(self, bg):
638
+ """Retrieve papers P (a set) according to embeddings' similarity between the input
639
+ background and the backgrounds from the database. Optionally, you can also retrieve
640
+ papers co-cited with P.
641
+ Args:
642
+ bg (str): the input background
643
+ Returns:
644
+ result (dict):
645
+ "background_embedding": embedding of the input background,
646
+ "paper" (List of int): all retrieved related_papers' ids,
647
+ "entities" (List): An empty list (TODO: remove),
648
+ "cocite_paper" (List of int): all papers cocited with embedding-retrieved papers
649
+ """
650
  entities = []
651
  embedding = self.embedding_model.encode(bg, device=self.device)
652
  sn_paper_id_list = self.cosine_similarity_search(
653
  embedding=embedding,
654
  k=self.config.RETRIEVE.sn_retrieve_paper_num,
655
+ type_name=f"{self.config.RETRIEVE.SN_field_name}_embedding{self.embedding_postfix}"
656
  )
657
  related_paper = set()
658
  related_paper.update(sn_paper_id_list)
 
668
  related_paper = list(related_paper)
669
  logger.debug(f"paper num before filter: {len(related_paper)}")
670
  result = {
671
+ f"background_embedding{self.embedding_postfix}": embedding,
672
  "paper": related_paper,
673
  "entities": entities,
674
  "cocite_paper": list(cocite_id_set),
 
678
  def retrieve(self, bg, entities, need_evaluate=True, target_paper_id_list=[]):
679
  """
680
  Args:
681
+ bg (str): The user input background
682
  Return:
683
+ result (dict):
684
+ "recall": recall of paper retrieval, 0 if need_evaluate==False,
685
+ "precision": precision of paper retrieval, 0 if need_evaluate==False,,
686
+ "filtered_recall": recall of paper retrieval after filtering, 0 if need_evaluate==False,,
687
+ "filtered_precision": precision of paper retrieval after filtering, 0 if need_evaluate==False,,
688
+ "related_paper": all retrieved related_papers. !!! [ The most important item ]
689
+ "related_paper_id_list": all retrieved related_papers' ids. !!! [ The most important item ]
690
+ "cocite_paper_id_list": retrieve_result["cocite_paper"],
691
+ "entities": retrieve_result["entities"], always empty
692
+ "top_k_matrix": top_k_matrix, 0 if need_evaluate==False
693
+ "gt_reference_num": len(target_paper_id_list)
694
+ "retrieve_paper_num": len(related_paper_id_list),
695
+ "label_num": TODO,
696
  """
697
  if need_evaluate:
698
  if target_paper_id_list is None or len(target_paper_id_list) == 0:
 
704
  retrieve_result = self.retrieve_paper(bg)
705
  related_paper_id_list = retrieve_result["paper"]
706
  retrieve_paper_num = len(related_paper_id_list)
707
+ # scores between the input background and all retrieved papers
708
  _, _, score_all_dict = self.cal_related_score(
709
+ retrieve_result[f"background_embedding{self.embedding_postfix}"], related_paper_id_list=related_paper_id_list,
710
+ type_name=f"{self.config.RETRIEVE.SN_field_name}_embedding{self.embedding_postfix}"
711
  )
712
  top_k_matrix = {}
713
  recall = 0
714
  precision = 0
715
  filtered_recall = 0
716
+ label_num = 0
717
  filtered_precision = 0
718
  if need_evaluate:
719
  top_k_matrix, label_num, recall, precision = self.eval_related_paper_in_all(
720
  score_all_dict, target_paper_id_list
721
  )
722
+ logger.debug("Top K matrix:{}".format(top_k_matrix))
723
  logger.debug("before filter:")
724
  logger.debug(f"Recall: {recall:.3f}")
725
  logger.debug(f"Precision: {precision:.3f}")
726
+ ## For idea generation, only top 10 papers will be used, which has no relations with retriveal evaluation
727
+ related_paper = self.filter_related_paper(score_all_dict, top_k=self.config.RETRIEVE.all_retrieve_paper_num)
728
  related_paper = self.update_related_paper(related_paper)
729
  result = {
730
  "recall": recall,
 
749
  super().__init__(config)
750
 
751
  def retrieve_paper(self, entities):
752
+ """Retrieve according to entities
753
+ """
754
  new_entities = self.retrieve_entities_by_enties(entities)
755
  logger.debug("KG entities for retriever: {}".format(new_entities))
756
  related_paper = set()
 
791
  retrieve_paper_num = len(related_paper_id_list)
792
  embedding = self.embedding_model.encode(bg, device=self.device)
793
  _, _, score_all_dict = self.cal_related_score(
794
+ embedding, related_paper_id_list=related_paper_id_list,
795
+ type_name=f"background_embedding{self.embedding_postfix}"
796
  )
797
  top_k_matrix = {}
798
  recall = 0
799
  precision = 0
800
  filtered_recall = 0
801
+ label_num = 0
802
  filtered_precision = 0
803
  if need_evaluate:
804
  top_k_matrix, label_num, recall, precision = self.eval_related_paper_in_all(
 
808
  logger.debug("before filter:")
809
  logger.debug(f"Recall: {recall:.3f}")
810
  logger.debug(f"Precision: {precision:.3f}")
811
+ related_paper = self.filter_related_paper(score_all_dict, top_k=self.config.RETRIEVE.all_retrieve_paper_num)
812
  related_paper = self.update_related_paper(related_paper)
813
  result = {
814
  "recall": recall,
 
834
 
835
  def retrieve_paper(self, bg, entities):
836
  sn_entities = []
837
+ ## 1. Retrieve papers according to the embeddings of input background
838
  embedding = self.embedding_model.encode(bg, device=self.device)
839
  sn_paper_id_list = self.cosine_similarity_search(
840
+ embedding, k=self.config.RETRIEVE.sn_num_for_entity,
841
+ type_name=f"{self.config.RETRIEVE.SN_field_name}_embedding{self.embedding_postfix}"
842
  )
843
  related_paper = set()
844
  related_paper.update(sn_paper_id_list)
845
+ logger.debug(f"SN retrieve {len(related_paper)} papers")
846
+
847
+ ## 2. Retrieve papers according to entites
848
+ # Fetch all entities from embedding-retrieved papers
849
  sn_entities += self.paper_client.find_entities_by_paper_list(sn_paper_id_list)
850
  logger.debug("SN entities for retriever: {}".format(sn_entities))
851
  entities = list(set(entities + sn_entities))
852
+ # Expand entity list through synonyms
853
  new_entities = self.retrieve_entities_by_enties(entities)
854
  logger.debug("SNKG entities for retriever: {}".format(new_entities))
855
+ paper_id_set = set()
856
  for entity in new_entities:
857
+ paper_id_set.update(self.paper_client.find_paper_by_entity(entity))
858
+ related_paper = related_paper.union(paper_id_set)
859
+ logger.debug(f"Entity retrieve {len(paper_id_set)} papers")
860
+ logger.debug(f"SN+entity retrieve {len(related_paper)} papers")
861
+
862
+ ## 3. Retrieve papers according to citation co-occurrence
863
  cocite_id_set = set()
864
  if self.use_cocite:
865
  for paper_id in related_paper:
866
  cocite_id_set.update(self.cocite.get_cocite_ids(paper_id))
867
  related_paper = related_paper.union(cocite_id_set)
868
+ logger.debug(f"Cocite retrieve {len(cocite_id_set)} papers")
869
+ logger.debug(f"SN+entity+cocite retrieve {len(related_paper)} papers")
870
+
871
+ ## 4. Return retrieval results
872
  related_paper = list(related_paper)
873
  result = {
874
+ f"background_embedding{self.embedding_postfix}": embedding,
875
  "paper": related_paper,
876
  "entities": entities,
877
  "cocite_paper": list(cocite_id_set),
 
879
  return result
880
 
881
  def retrieve(
882
+ self, bg, entities, need_evaluate=True, target_paper_id_list=[]
883
  ):
884
  """
885
  Args:
 
900
  retrieve_paper_num = len(related_paper_id_list)
901
  logger.info("=== Begin cal related paper score ===")
902
  _, _, score_all_dict = self.cal_related_score(
903
+ retrieve_result[f"background_embedding{self.embedding_postfix}"], related_paper_id_list=related_paper_id_list,
904
+ type_name=f"background_embedding{self.embedding_postfix}"
905
  )
906
  logger.info("=== End cal related paper score ===")
907
  top_k_matrix = {}
 
919
  logger.debug(f"Recall: {recall:.3f}")
920
  logger.debug(f"Precision: {precision:.3f}")
921
  logger.info("=== Begin filter related paper score ===")
922
+ related_paper = self.filter_related_paper(score_all_dict, self.config.RETRIEVE.all_retrieve_paper_num)
923
  logger.info("=== End filter related paper score ===")
924
  related_paper = self.update_related_paper(related_paper)
925
  result = {